Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions .github/workflows/build-release.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
name: Build & Release Wheels

on:
push:
tags:
- "v*"
workflow_dispatch:

concurrency:
group: build-release-${{ github.ref }}
cancel-in-progress: true

jobs:
build-wheel:
name: "wheel / ${{ matrix.cuda }} / cp312 / ${{ matrix.arch }}"
runs-on: ${{ matrix.arch == 'aarch64' && 'ubuntu-24.04-arm' || 'ubuntu-latest' }}
strategy:
fail-fast: false
matrix:
cuda:
- cu129
- cu130
arch:
- x86_64
- aarch64
container:
image: "nvidia/cuda:${{ matrix.cuda == 'cu129' && '12.9.0' || '13.0.0' }}-devel-ubuntu24.04"

steps:
- name: Free disk space
run: |
rm -rf /opt/hostedtoolcache /usr/local/lib/android /usr/share/dotnet \
/usr/local/share/boost /opt/ghc 2>/dev/null || true
apt-get clean 2>/dev/null || true
df -h / || true

- name: Install git
run: |
apt-get update && apt-get install -y --no-install-recommends git \
&& rm -rf /var/lib/apt/lists/*

- name: Checkout
uses: actions/checkout@v5
with:
fetch-depth: 0
submodules: recursive

- name: Configure git safe directory
run: git config --global --add safe.directory "$GITHUB_WORKSPACE"

- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: "3.12"

- name: Install Python dependencies
run: |
python -m pip install --no-cache-dir --upgrade pip
python -m pip install --no-cache-dir torch --index-url ${{ matrix.cuda == 'cu129' && 'https://download.pytorch.org/whl/cu129' || 'https://download.pytorch.org/whl/cu130' }}
python -m pip install --no-cache-dir setuptools wheel "setuptools_scm>=6.0" build ninja

- name: Compute version
id: version
run: |
if [[ "$GITHUB_REF" == refs/tags/v* ]]; then
BASE="${GITHUB_REF#refs/tags/v}"
else
# Strip any local segment (+gXXX) so we get a clean base
BASE=$(python -c "from setuptools_scm import get_version; print(get_version().split('+')[0])")
fi
echo "version=${BASE}+${{ matrix.cuda }}" >> "$GITHUB_OUTPUT"

- name: Build fat-binary wheel
env:
CULA_BUILD_ALL_ARCHS: "1"
SETUPTOOLS_SCM_PRETEND_VERSION: "${{ steps.version.outputs.version }}"
NVCC_THREADS: "4"
MAX_JOBS: "4"
run: python -m build --wheel --no-isolation

- name: Verify wheel
run: |
echo "Built wheel:"
ls -lh dist/*.whl
ls dist/*.whl | grep -q "+${{ matrix.cuda }}" \
|| { echo "ERROR: wheel name missing +${{ matrix.cuda }} suffix"; exit 1; }

- name: Upload wheel artifact
uses: actions/upload-artifact@v6
with:
name: wheel-${{ matrix.cuda }}-${{ matrix.arch }}
path: dist/*.whl

release:
name: Create GitHub Release
needs: [build-wheel]
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/v')
permissions:
contents: write
steps:
- name: Download all artifacts
uses: actions/download-artifact@v6
with:
path: artifacts/

- name: Create release
uses: softprops/action-gh-release@v3
with:
files: |
artifacts/wheel-*/*.whl
generate_release_notes: true
draft: true
prerelease: ${{ contains(github.ref, 'rc') || contains(github.ref, 'beta') || contains(github.ref, 'alpha') }}
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@ cuLA supports both **Hopper (SM90)** and **Blackwell (SM10X)** GPUs.

> **Note:** The PyTorch CUDA version must match your system CUDA Toolkit version. Check with `nvcc --version` and `python -c "import torch; print(torch.version.cuda)"`.

### Pre-built Wheels

Pre-built fat-binary wheels (SM90 + SM100 + SM103) are available on [GitHub Releases](https://github.com/inclusionAI/cuLA/releases):

pip install "cuda-linear-attention==<VERSION>+<CUDA_TAG>" -f https://github.com/inclusionAI/cuLA/releases/expanded_assets/<TAG>

Replace `<TAG>` with the release tag (e.g., `v0.2.0`), `<VERSION>` with the base version (e.g., `0.2.0`), and `<CUDA_TAG>` with your PyTorch CUDA build tag (e.g., `cu129` or `cu130`). Or download the `.whl` file directly from the [Releases page](https://github.com/inclusionAI/cuLA/releases) and install it with `pip install <filename>.whl`.

### Build from Source

**Clone cuLA & dependencies:**

```bash
Expand All @@ -47,6 +57,12 @@ pip install -e third_party/flash-linear-attention
pip install -e . --no-build-isolation
```

**Build fat wheel (SM90 + SM100 + SM103):**

```bash
CULA_BUILD_ALL_ARCHS=1 python -m build --wheel --no-isolation
```

## Quick Start

### KDA (Kimi Delta Attention) β€” Blackwell (SM10X)
Expand Down
6 changes: 6 additions & 0 deletions csrc/api/kda_sm100.cu
Original file line number Diff line number Diff line change
Expand Up @@ -188,4 +188,10 @@ ChunkKDAFwdRecompWU(
StaticPersistentTileScheduler::Params{tile_num, params.h_v, params.heads_per_group, params.num_sm, nullptr};

kda::sm100::run_kda_fwd_recomp_w_u_sm100(params, at::cuda::getCurrentCUDAStream());
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "cuLA SM100/SM103 kernels";
m.def("chunk_kda_fwd_intra_cuda", &ChunkKDAFwdIntra);
m.def("recompute_w_u_cuda", &ChunkKDAFwdRecompWU);
}
5 changes: 5 additions & 0 deletions csrc/api/kda_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,8 @@ kda_fwd_prefill(

return {output, output_state};
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "cuLA SM90 kernels";
m.def("kda_fwd_prefill", &kda_fwd_prefill);
}
80 changes: 0 additions & 80 deletions csrc/api/pybind.cu

This file was deleted.

5 changes: 4 additions & 1 deletion cula/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "0.1.0"
try:
from cula._version import version as __version__
except ImportError:
__version__ = "0.1.0"

from cula.ops.lightning_attn_sm100 import LinearAttentionChunkwiseDecay

Expand Down
82 changes: 82 additions & 0 deletions cula/cudac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright 2025-2026 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unified interface to per-architecture CUDA extensions.

Downstream code can continue to use ``import cula.cudac as cula_cuda``
and call ``cula_cuda.kda_fwd_prefill(...)`` or
``cula_cuda.chunk_kda_fwd_intra_cuda(...)`` without knowing which
extension provides the function.
"""

import importlib
import sys
import threading
from types import ModuleType


class _CudacProxy(ModuleType):
"""Lazy proxy that exposes functions from all built arch extensions."""

def __init__(self):
super().__init__(__name__)
self.__path__ = []
self._modules_loaded = False
self._funcs: dict[str, object] = {}
self._lock = threading.Lock()

def _load(self):
if self._modules_loaded:
return
with self._lock:
if self._modules_loaded:
return
loaded_any = False
errors: dict[str, Exception] = {}
for ext_name in ("cula._cudac_sm100", "cula._cudac_sm90"):
try:
mod = importlib.import_module(ext_name)
for attr in dir(mod):
if not attr.startswith("_"):
self._funcs[attr] = getattr(mod, attr)
loaded_any = True
except ImportError as exc:
errors[ext_name] = exc
if not loaded_any:
details = "; ".join(f"{name}: {exc}" for name, exc in errors.items())
raise ImportError(
"None of the cuLA CUDA extensions could be imported. "
f"Per-extension errors: [{details}]. "
"Please make sure cuLA is compiled correctly."
)
self.__dict__.update(self._funcs)
self._modules_loaded = True

def __getattr__(self, name: str):
if name.startswith("_"):
raise AttributeError(name)
self._load()
try:
return self._funcs[name]
except KeyError:
raise AttributeError(f"module 'cula.cudac' has no attribute '{name}'") from None

def __dir__(self):
self._load()
return list(self._funcs.keys())


_proxy = _CudacProxy()
_proxy.__dict__.update({k: globals().get(k) for k in ("__spec__", "__file__", "__package__", "__loader__")})
sys.modules[__name__] = _proxy
Comment on lines +29 to +82

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using a custom ModuleType subclass and replacing sys.modules is a legacy pattern that can be replaced with standard PEP 562 module-level __getattr__ and __dir__ functions (supported in Python 3.7+). This is much cleaner, avoids subclassing overhead, and ensures standard module attributes are correctly preserved in __dir__.

Additionally, catching only ImportError during dynamic loading of C++ extensions is risky. Dynamic loading of PyTorch/CUDA extensions can frequently raise RuntimeError (e.g., due to CUDA driver/runtime mismatch) or OSError (e.g., due to missing shared libraries). Catching these exceptions ensures that if one extension fails to load, the other can still be loaded successfully.

_modules_loaded = False
_funcs = {}
_lock = threading.Lock()

def _load():
    global _modules_loaded
    if _modules_loaded:
        return
    with _lock:
        if _modules_loaded:
            return
        loaded_any = False
        errors = {}
        for ext_name in ("cula._cudac_sm100", "cula._cudac_sm90"):
            try:
                mod = importlib.import_module(ext_name)
                for attr in dir(mod):
                    if not attr.startswith("_"):
                        _funcs[attr] = getattr(mod, attr)
                loaded_any = True
            except (ImportError, RuntimeError, OSError) as exc:
                errors[ext_name] = exc
        if not loaded_any:
            details = "; ".join(f"{name}: {exc}" for name, exc in errors.items())
            raise ImportError(
                "None of the cuLA CUDA extensions could be imported. "
                f"Per-extension errors: [{details}]. "
                "Please make sure cuLA is compiled correctly."
            )
        globals().update(_funcs)
        _modules_loaded = True

def __getattr__(name: str):
    if name.startswith("_"):
        raise AttributeError(name)
    _load()
    try:
        return _funcs[name]
    except KeyError:
        raise AttributeError(f"module 'cula.cudac' has no attribute '{name}'") from None

def __dir__():
    _load()
    return sorted(k for k in globals().keys() if not k.startswith("_") or k.startswith("__"))

5 changes: 1 addition & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,6 @@ force-sort-within-sections = false
"cula/kda/blackwell_fused_fwd.py" = ["F821"]

[tool.setuptools_scm]
# write generated version into package for runtime access
write_to = "cula/_version.py"
# add a date-based local suffix when needed
local_scheme = "node-and-date"
# fallback for non-git sources
local_scheme = "no-local-version"
fallback_version = "0.1.0"
18 changes: 14 additions & 4 deletions scripts/build_wheel.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,19 @@ cd "$REPO_ROOT"

# Parse args
ISOLATION_FLAG="--no-isolation"
if [[ "${1:-}" == "--isolated" ]]; then
ISOLATION_FLAG=""
echo "[build_wheel] Using isolated build environment"
else
for arg in "$@"; do
case "$arg" in
--isolated)
ISOLATION_FLAG=""
echo "[build_wheel] Using isolated build environment"
;;
--fat)
export CULA_BUILD_ALL_ARCHS=1
echo "[build_wheel] Fat binary: building for all SM architectures"
;;
esac
done
if [[ "$ISOLATION_FLAG" == "--no-isolation" ]]; then
echo "[build_wheel] Using current environment (--no-isolation)"
fi

Expand All @@ -33,6 +42,7 @@ rm -rf dist build *.egg-info
echo "[build_wheel] Python: $(python -V 2>&1)"
echo "[build_wheel] torch: $(python -c 'import torch; print(torch.__version__)' 2>/dev/null || echo 'not installed')"
echo "[build_wheel] CUDA: $(nvcc --version 2>/dev/null | grep 'release' | sed 's/.*release //' | sed 's/,.*//' || echo 'not found')"
echo "[build_wheel] Fat binary: ${CULA_BUILD_ALL_ARCHS:-0}"

# Build wheel
echo "[build_wheel] Building wheel..."
Expand Down
Loading