diff --git a/.github/workflows/build-release.yml b/.github/workflows/build-release.yml new file mode 100644 index 00000000..ae15a184 --- /dev/null +++ b/.github/workflows/build-release.yml @@ -0,0 +1,114 @@ +name: Build & Release Wheels + +on: + push: + tags: + - "v*" + workflow_dispatch: + +concurrency: + group: build-release-${{ github.ref }} + cancel-in-progress: true + +jobs: + build-wheel: + name: "wheel / ${{ matrix.cuda }} / cp312 / ${{ matrix.arch }}" + runs-on: ${{ matrix.arch == 'aarch64' && 'ubuntu-24.04-arm' || 'ubuntu-latest' }} + strategy: + fail-fast: false + matrix: + cuda: + - cu129 + - cu130 + arch: + - x86_64 + - aarch64 + container: + image: "nvidia/cuda:${{ matrix.cuda == 'cu129' && '12.9.0' || '13.0.0' }}-devel-ubuntu24.04" + + steps: + - name: Free disk space + run: | + rm -rf /opt/hostedtoolcache /usr/local/lib/android /usr/share/dotnet \ + /usr/local/share/boost /opt/ghc 2>/dev/null || true + apt-get clean 2>/dev/null || true + df -h / || true + + - name: Install git + run: | + apt-get update && apt-get install -y --no-install-recommends git \ + && rm -rf /var/lib/apt/lists/* + + - name: Checkout + uses: actions/checkout@v5 + with: + fetch-depth: 0 + submodules: recursive + + - name: Configure git safe directory + run: git config --global --add safe.directory "$GITHUB_WORKSPACE" + + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: "3.12" + + - name: Install Python dependencies + run: | + python -m pip install --no-cache-dir --upgrade pip + python -m pip install --no-cache-dir torch --index-url ${{ matrix.cuda == 'cu129' && 'https://download.pytorch.org/whl/cu129' || 'https://download.pytorch.org/whl/cu130' }} + python -m pip install --no-cache-dir setuptools wheel "setuptools_scm>=6.0" build ninja + + - name: Compute version + id: version + run: | + if [[ "$GITHUB_REF" == refs/tags/v* ]]; then + BASE="${GITHUB_REF#refs/tags/v}" + else + # Strip any local segment (+gXXX) so we get a clean base + BASE=$(python -c "from setuptools_scm import get_version; print(get_version().split('+')[0])") + fi + echo "version=${BASE}+${{ matrix.cuda }}" >> "$GITHUB_OUTPUT" + + - name: Build fat-binary wheel + env: + CULA_BUILD_ALL_ARCHS: "1" + SETUPTOOLS_SCM_PRETEND_VERSION: "${{ steps.version.outputs.version }}" + NVCC_THREADS: "4" + MAX_JOBS: "4" + run: python -m build --wheel --no-isolation + + - name: Verify wheel + run: | + echo "Built wheel:" + ls -lh dist/*.whl + ls dist/*.whl | grep -q "+${{ matrix.cuda }}" \ + || { echo "ERROR: wheel name missing +${{ matrix.cuda }} suffix"; exit 1; } + + - name: Upload wheel artifact + uses: actions/upload-artifact@v6 + with: + name: wheel-${{ matrix.cuda }}-${{ matrix.arch }} + path: dist/*.whl + + release: + name: Create GitHub Release + needs: [build-wheel] + runs-on: ubuntu-latest + if: startsWith(github.ref, 'refs/tags/v') + permissions: + contents: write + steps: + - name: Download all artifacts + uses: actions/download-artifact@v6 + with: + path: artifacts/ + + - name: Create release + uses: softprops/action-gh-release@v3 + with: + files: | + artifacts/wheel-*/*.whl + generate_release_notes: true + draft: true + prerelease: ${{ contains(github.ref, 'rc') || contains(github.ref, 'beta') || contains(github.ref, 'alpha') }} diff --git a/README.md b/README.md index 7bed61e2..7502e302 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,16 @@ cuLA supports both **Hopper (SM90)** and **Blackwell (SM10X)** GPUs. > **Note:** The PyTorch CUDA version must match your system CUDA Toolkit version. Check with `nvcc --version` and `python -c "import torch; print(torch.version.cuda)"`. +### Pre-built Wheels + +Pre-built fat-binary wheels (SM90 + SM100 + SM103) are available on [GitHub Releases](https://github.com/inclusionAI/cuLA/releases): + + pip install "cuda-linear-attention==+" -f https://github.com/inclusionAI/cuLA/releases/expanded_assets/ + +Replace `` with the release tag (e.g., `v0.2.0`), `` with the base version (e.g., `0.2.0`), and `` with your PyTorch CUDA build tag (e.g., `cu129` or `cu130`). Or download the `.whl` file directly from the [Releases page](https://github.com/inclusionAI/cuLA/releases) and install it with `pip install .whl`. + +### Build from Source + **Clone cuLA & dependencies:** ```bash @@ -47,6 +57,12 @@ pip install -e third_party/flash-linear-attention pip install -e . --no-build-isolation ``` +**Build fat wheel (SM90 + SM100 + SM103):** + +```bash +CULA_BUILD_ALL_ARCHS=1 python -m build --wheel --no-isolation +``` + ## Quick Start ### KDA (Kimi Delta Attention) — Blackwell (SM10X) diff --git a/csrc/api/kda_sm100.cu b/csrc/api/kda_sm100.cu index 7edca370..020d90ca 100644 --- a/csrc/api/kda_sm100.cu +++ b/csrc/api/kda_sm100.cu @@ -188,4 +188,10 @@ ChunkKDAFwdRecompWU( StaticPersistentTileScheduler::Params{tile_num, params.h_v, params.heads_per_group, params.num_sm, nullptr}; kda::sm100::run_kda_fwd_recomp_w_u_sm100(params, at::cuda::getCurrentCUDAStream()); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.doc() = "cuLA SM100/SM103 kernels"; + m.def("chunk_kda_fwd_intra_cuda", &ChunkKDAFwdIntra); + m.def("recompute_w_u_cuda", &ChunkKDAFwdRecompWU); } \ No newline at end of file diff --git a/csrc/api/kda_sm90.cu b/csrc/api/kda_sm90.cu index 9e016eb1..d80df7cc 100644 --- a/csrc/api/kda_sm90.cu +++ b/csrc/api/kda_sm90.cu @@ -191,3 +191,8 @@ kda_fwd_prefill( return {output, output_state}; } + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.doc() = "cuLA SM90 kernels"; + m.def("kda_fwd_prefill", &kda_fwd_prefill); +} diff --git a/csrc/api/pybind.cu b/csrc/api/pybind.cu deleted file mode 100644 index d14a41c5..00000000 --- a/csrc/api/pybind.cu +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright 2025-2026 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include - -#if defined(CULA_SM100_ENABLED) || defined(CULA_SM103_ENABLED) -void -ChunkKDAFwdIntra( - at::Tensor q, - at::Tensor k, - at::Tensor g, - at::Tensor beta, - at::Tensor cu_seqlens, - at::Tensor chunk_indices, - at::Tensor Aqk_out, - at::Tensor Akk_out, - at::Tensor tile_counter, - float scale, - int chunk_size, - bool use_tf32_inverse, - bool unified_gref); -void -ChunkKDAFwdRecompWU( - at::Tensor k, - at::Tensor v, - at::Tensor beta, - at::Tensor A, - at::Tensor g, - at::Tensor cu_seqlens, - at::Tensor chunk_indices, - at::Tensor w_out, - at::Tensor u_out, - at::Tensor kg_out, - int chunk_size, - std::optional q, - std::optional qg_out); -#endif - -#if defined(CULA_SM90A_ENABLED) -std::tuple> -kda_fwd_prefill( - std::optional output_, - std::optional output_state_, - torch::Tensor const& q, - torch::Tensor const& k, - torch::Tensor const& v, - std::optional input_state_, - std::optional alpha_, - std::optional beta_, - torch::Tensor const& cu_seqlens, - torch::Tensor workspace_buffer, - float scale, - bool output_final_state, - bool safe_gate); -#endif - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.doc() = "cuLA"; -#if defined(CULA_SM100_ENABLED) || defined(CULA_SM103_ENABLED) - m.def("chunk_kda_fwd_intra_cuda", &ChunkKDAFwdIntra); - m.def("recompute_w_u_cuda", &ChunkKDAFwdRecompWU); -#endif -#if defined(CULA_SM90A_ENABLED) - m.def("kda_fwd_prefill", &kda_fwd_prefill); -#endif -} diff --git a/cula/__init__.py b/cula/__init__.py index 7272e289..6e13aa13 100644 --- a/cula/__init__.py +++ b/cula/__init__.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.1.0" +try: + from cula._version import version as __version__ +except ImportError: + __version__ = "0.1.0" from cula.ops.lightning_attn_sm100 import LinearAttentionChunkwiseDecay diff --git a/cula/cudac.py b/cula/cudac.py new file mode 100644 index 00000000..1bbaf108 --- /dev/null +++ b/cula/cudac.py @@ -0,0 +1,82 @@ +# Copyright 2025-2026 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unified interface to per-architecture CUDA extensions. + +Downstream code can continue to use ``import cula.cudac as cula_cuda`` +and call ``cula_cuda.kda_fwd_prefill(...)`` or +``cula_cuda.chunk_kda_fwd_intra_cuda(...)`` without knowing which +extension provides the function. +""" + +import importlib +import sys +import threading +from types import ModuleType + + +class _CudacProxy(ModuleType): + """Lazy proxy that exposes functions from all built arch extensions.""" + + def __init__(self): + super().__init__(__name__) + self.__path__ = [] + self._modules_loaded = False + self._funcs: dict[str, object] = {} + self._lock = threading.Lock() + + def _load(self): + if self._modules_loaded: + return + with self._lock: + if self._modules_loaded: + return + loaded_any = False + errors: dict[str, Exception] = {} + for ext_name in ("cula._cudac_sm100", "cula._cudac_sm90"): + try: + mod = importlib.import_module(ext_name) + for attr in dir(mod): + if not attr.startswith("_"): + self._funcs[attr] = getattr(mod, attr) + loaded_any = True + except ImportError as exc: + errors[ext_name] = exc + if not loaded_any: + details = "; ".join(f"{name}: {exc}" for name, exc in errors.items()) + raise ImportError( + "None of the cuLA CUDA extensions could be imported. " + f"Per-extension errors: [{details}]. " + "Please make sure cuLA is compiled correctly." + ) + self.__dict__.update(self._funcs) + self._modules_loaded = True + + def __getattr__(self, name: str): + if name.startswith("_"): + raise AttributeError(name) + self._load() + try: + return self._funcs[name] + except KeyError: + raise AttributeError(f"module 'cula.cudac' has no attribute '{name}'") from None + + def __dir__(self): + self._load() + return list(self._funcs.keys()) + + +_proxy = _CudacProxy() +_proxy.__dict__.update({k: globals().get(k) for k in ("__spec__", "__file__", "__package__", "__loader__")}) +sys.modules[__name__] = _proxy diff --git a/pyproject.toml b/pyproject.toml index ef1a531b..fe93e562 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,9 +84,6 @@ force-sort-within-sections = false "cula/kda/blackwell_fused_fwd.py" = ["F821"] [tool.setuptools_scm] -# write generated version into package for runtime access write_to = "cula/_version.py" -# add a date-based local suffix when needed -local_scheme = "node-and-date" -# fallback for non-git sources +local_scheme = "no-local-version" fallback_version = "0.1.0" diff --git a/scripts/build_wheel.sh b/scripts/build_wheel.sh index 42b35665..79ac3305 100755 --- a/scripts/build_wheel.sh +++ b/scripts/build_wheel.sh @@ -18,10 +18,19 @@ cd "$REPO_ROOT" # Parse args ISOLATION_FLAG="--no-isolation" -if [[ "${1:-}" == "--isolated" ]]; then - ISOLATION_FLAG="" - echo "[build_wheel] Using isolated build environment" -else +for arg in "$@"; do + case "$arg" in + --isolated) + ISOLATION_FLAG="" + echo "[build_wheel] Using isolated build environment" + ;; + --fat) + export CULA_BUILD_ALL_ARCHS=1 + echo "[build_wheel] Fat binary: building for all SM architectures" + ;; + esac +done +if [[ "$ISOLATION_FLAG" == "--no-isolation" ]]; then echo "[build_wheel] Using current environment (--no-isolation)" fi @@ -33,6 +42,7 @@ rm -rf dist build *.egg-info echo "[build_wheel] Python: $(python -V 2>&1)" echo "[build_wheel] torch: $(python -c 'import torch; print(torch.__version__)' 2>/dev/null || echo 'not installed')" echo "[build_wheel] CUDA: $(nvcc --version 2>/dev/null | grep 'release' | sed 's/.*release //' | sed 's/,.*//' || echo 'not found')" +echo "[build_wheel] Fat binary: ${CULA_BUILD_ALL_ARCHS:-0}" # Build wheel echo "[build_wheel] Building wheel..." diff --git a/setup.py b/setup.py index f7b11b95..78c61e5c 100644 --- a/setup.py +++ b/setup.py @@ -46,13 +46,15 @@ def detect_gpu_archs() -> tuple[bool, bool, bool]: def resolve_disable_flag(env_name: str, detected: bool) -> bool: """ Resolve whether to disable a given SM target. + - If CULA_BUILD_ALL_ARCHS is set, all targets are enabled unconditionally. - If the environment variable is explicitly set, honour it. - Otherwise, disable the target when no matching GPU is detected. """ + if os.getenv("CULA_BUILD_ALL_ARCHS", "0") == "1": + return False env_val = os.getenv(env_name) if env_val is not None: return env_val.lower() in ["true", "1", "y", "yes"] - # Auto-detect: disable if no matching device found disable = not detected if disable: print(f" No matching GPU detected; auto-setting {env_name}=1 (disable). Set {env_name}=0 to override.") @@ -66,7 +68,11 @@ def get_features_args(): USE_FAST_MATH = os.getenv("CULA_USE_FAST_MATH", "1") == "1" -print("Detecting GPU architectures...") +if os.getenv("CULA_BUILD_ALL_ARCHS", "0") == "1": + print("CULA_BUILD_ALL_ARCHS=1: enabling all SM targets (sm90a, sm100a, sm103a)") +else: + print("Detecting GPU architectures...") + _has_sm100, _has_sm103, _has_sm90 = detect_gpu_archs() DISABLE_SM100 = resolve_disable_flag("CULA_DISABLE_SM100", _has_sm100) DISABLE_SM103 = resolve_disable_flag("CULA_DISABLE_SM103", _has_sm103) @@ -111,26 +117,6 @@ def assert_blackwell_build_env() -> None: ) -def get_arch_flags(): - major, minor = get_nvcc_version() - print(f"Compiling using NVCC {major}.{minor}") - - # Validate Blackwell build environment - assert_blackwell_build_env() - - arch_flags = [] - if not DISABLE_SM100: - arch_flags.extend(["-gencode", "arch=compute_100a,code=sm_100a"]) - arch_flags.extend(["-DCULA_SM100_ENABLED"]) - if not DISABLE_SM103: - arch_flags.extend(["-gencode", "arch=compute_103a,code=sm_103a"]) - arch_flags.extend(["-DCULA_SM103_ENABLED"]) - if not DISABLE_SM90: - arch_flags.extend(["-gencode", "arch=compute_90a,code=sm_90a"]) - arch_flags.extend(["-DCULA_SM90A_ENABLED"]) - return arch_flags - - def get_nvcc_thread_args(): nvcc_threads = os.getenv("NVCC_THREADS") or "32" return ["--threads", nvcc_threads] @@ -145,61 +131,84 @@ def get_nvcc_thread_args(): else: cxx_args = ["-O3", "-std=c++20", "-DNDEBUG", "-Wno-deprecated-declarations"] -cuda_sources = [ - "csrc/api/pybind.cu", +nvcc_common_args = [ + "-O3", + "-std=c++20", + "-DNDEBUG", + # "-D_USE_MATH_DEFINES", + "-Wno-deprecated-declarations", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "-lineinfo", + "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage", + "-diag-suppress=3189", ] + +include_dirs = [ + Path(this_dir) / "csrc", + Path(this_dir) / "csrc" / "kerutils" / "include", + Path(this_dir) / "csrc" / "cutlass" / "include", + Path(this_dir) / "csrc" / "cutlass" / "tools" / "util" / "include", +] + +major, minor = get_nvcc_version() +print(f"Compiling using NVCC {major}.{minor}") +assert_blackwell_build_env() + +ext_modules = [] + if not DISABLE_SM100 or not DISABLE_SM103: - cuda_sources.extend( - [ - "csrc/api/kda_sm100.cu", - "csrc/kda/sm100/kda_fwd_sm100.cu", - ] - ) -if not DISABLE_SM90: - cuda_sources.extend( - [ - "csrc/api/kda_sm90.cu", - "csrc/kda/sm90/kda_fwd_sm90.cu", - "csrc/kda/sm90/kda_fwd_sm90_safe_gate.cu", - ] + sm100_arch_flags = [] + if not DISABLE_SM100: + sm100_arch_flags.extend(["-gencode", "arch=compute_100a,code=sm_100a"]) + if not DISABLE_SM103: + sm100_arch_flags.extend(["-gencode", "arch=compute_103a,code=sm_103a"]) + + ext_modules.append( + CUDAExtension( + name="cula._cudac_sm100", + sources=[ + "csrc/api/kda_sm100.cu", + "csrc/kda/sm100/kda_fwd_sm100.cu", + ], + extra_compile_args={ + "cxx": cxx_args + get_features_args(), + "nvcc": nvcc_common_args + + get_features_args() + + sm100_arch_flags + + get_nvcc_thread_args() + + (["--use_fast_math"] if USE_FAST_MATH else []), + }, + include_dirs=include_dirs, + ) ) -ext_modules = [] -ext_modules.append( - CUDAExtension( - name="cula.cudac", - sources=cuda_sources, - extra_compile_args={ - "cxx": cxx_args + get_features_args(), - "nvcc": [ - "-O3", - "-std=c++20", - "-DNDEBUG", - # "-D_USE_MATH_DEFINES", - "-Wno-deprecated-declarations", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_HALF2_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "-lineinfo", - "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage", - "-diag-suppress=3189", # suppress the warning of torch in C++ 20 - ] - + get_features_args() - + get_arch_flags() - + get_nvcc_thread_args() - + (["--use_fast_math"] if USE_FAST_MATH else []), - }, - include_dirs=[ - Path(this_dir) / "csrc", - Path(this_dir) / "csrc" / "kerutils" / "include", - Path(this_dir) / "csrc" / "cutlass" / "include", - Path(this_dir) / "csrc" / "cutlass" / "tools" / "util" / "include", - ], +if not DISABLE_SM90: + sm90_arch_flags = ["-gencode", "arch=compute_90a,code=sm_90a", "-DCULA_SM90A_ENABLED"] + + ext_modules.append( + CUDAExtension( + name="cula._cudac_sm90", + sources=[ + "csrc/api/kda_sm90.cu", + "csrc/kda/sm90/kda_fwd_sm90.cu", + "csrc/kda/sm90/kda_fwd_sm90_safe_gate.cu", + ], + extra_compile_args={ + "cxx": cxx_args + get_features_args(), + "nvcc": nvcc_common_args + + get_features_args() + + sm90_arch_flags + + get_nvcc_thread_args() + + (["--use_fast_math"] if USE_FAST_MATH else []), + }, + include_dirs=include_dirs, + ) ) -) setup( name="cuda-linear-attention", diff --git a/tests/conftest.py b/tests/conftest.py index f144c10b..a9338aca 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ import re + import pytest import torch @@ -56,9 +57,5 @@ def pytest_collection_modifyitems(config, items): item.add_marker(skip_slow) continue callspec = getattr(item, "callspec", None) - if ( - callspec is not None - and callspec.params.get("disable_recompute") - and "kda_fast_norecomp" not in item.keywords - ): + if callspec is not None and callspec.params.get("disable_recompute") and "kda_fast_norecomp" not in item.keywords: item.add_marker(skip_fast_norecomp)