Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 27 additions & 26 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
ARG CUDA_VERSION=12.8.0
ARG CUDA_VERSION=13.0.0
FROM nvidia/cuda:${CUDA_VERSION}-cudnn-devel-ubuntu22.04

ARG PYTHON_VERSION=3.10
ARG MAMBA_VERSION=24.7.1-0
ARG VLLM_VERSION=0.16.0
ARG VLLM_VERSION=0.20.2
ARG NIXL_REF=v1.1.0
ARG FLASH_MLA_REF=47c35a7
ARG DEEPGEMM_REF=891d57b4db1071624b5c8fa0d1e51cb317fa709f
ARG TARGETPLATFORM
ARG ENABLE_DEEPEP=1
ARG ENABLE_NIXL=1
ARG ENABLE_CACHE=1
ARG ENABLE_SM100=0

ENV PATH=/opt/conda/bin:$PATH \
CONDA_PREFIX=/opt/conda
Expand Down Expand Up @@ -44,13 +47,18 @@ WORKDIR /root

COPY ./requirements.txt /lightllm/requirements.txt
RUN pip install -U pip
RUN pip install -r /lightllm/requirements.txt --no-cache-dir
RUN pip install --no-cache-dir vllm==${VLLM_VERSION}
RUN git clone https://github.com/deepseek-ai/FlashMLA.git /root/FlashMLA && \
RUN pip install --no-cache-dir \
--extra-index-url https://download.pytorch.org/whl/cu130 \
vllm==${VLLM_VERSION}
RUN pip install -r /lightllm/requirements.txt --no-cache-dir \
--extra-index-url https://download.pytorch.org/whl/cu130
RUN export CPATH=/usr/local/cuda/targets/x86_64-linux/include/cccl:/usr/local/cuda/targets/x86_64-linux/include${CPATH:+:${CPATH}} && \
git clone https://github.com/deepseek-ai/FlashMLA.git /root/FlashMLA && \
cd /root/FlashMLA && \
git checkout ${FLASH_MLA_REF} && \
git submodule update --init --recursive && \
FLASH_MLA_DISABLE_SM100=1 pip install --no-cache-dir .
FLASH_MLA_DISABLE_SM100="$(if [ "${ENABLE_SM100}" = "1" ]; then echo 0; else echo 1; fi)" \
pip install --no-cache-dir .

RUN apt-get update && apt-get install -y libnuma-dev && rm -rf /var/lib/apt/lists/*

Expand Down Expand Up @@ -78,27 +86,20 @@ RUN if [ "${ENABLE_NIXL}" = "1" ] || [ "${ENABLE_DEEPEP}" = "1" ]; then \
RUN if [ "${ENABLE_DEEPEP}" = "1" ]; then \
set -e; \
ln -sf /usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so; \
NVSHMEM_VERSION=3.3.9; \
CUDA_ARCHS=90; \
wget https://developer.download.nvidia.com/compute/redist/nvshmem/${NVSHMEM_VERSION}/source/nvshmem_src_cuda12-all-all-${NVSHMEM_VERSION}.tar.gz \
&& tar -xf nvshmem_src_cuda12-all-all-${NVSHMEM_VERSION}.tar.gz && mv nvshmem_src nvshmem \
&& cd nvshmem \
&& rm -f /root/nvshmem_src_cuda12-all-all-${NVSHMEM_VERSION}.tar.gz \
&& NVSHMEM_SHMEM_SUPPORT=0 \
NVSHMEM_UCX_SUPPORT=0 \
NVSHMEM_USE_NCCL=0 \
NVSHMEM_MPI_SUPPORT=0 \
NVSHMEM_IBGDA_SUPPORT=1 \
NVSHMEM_PMIX_SUPPORT=0 \
NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \
NVSHMEM_USE_GDRCOPY=1 \
cmake -S . -B build/ -DCMAKE_INSTALL_PREFIX=/root/nvshmem/install -DCMAKE_CUDA_ARCHITECTURES=${CUDA_ARCHS} \
&& cmake --build build --target install -j64; \
DEEPEP_COMMIT=b6ce310bb0b75079682d09bc2ebc063a074fbd58; \
cd /root && git clone https://github.com/deepseek-ai/DeepEP.git && cd DeepEP && git checkout ${DEEPEP_COMMIT} && cd ..; \
cd /root/DeepEP && NVSHMEM_DIR=/root/nvshmem/install python setup.py install; \
python -m pip install --upgrade --no-deps \
"nvidia-nccl-cu13==2.30.4" \
"nvidia-nvshmem-cu13==3.6.5"; \
cd /root && git clone https://github.com/deepseek-ai/DeepEP.git && cd DeepEP && git checkout b306af06afd412c88e51e71802951606e40b7358; \
ln -sf /opt/conda/lib/python${PYTHON_VERSION}/site-packages/nvidia/nvshmem/lib/libnvshmem_host.so.3 /opt/conda/lib/python${PYTHON_VERSION}/site-packages/nvidia/nvshmem/lib/libnvshmem_host.so; \
ln -sf /opt/conda/lib/python${PYTHON_VERSION}/site-packages/nvidia/nccl/lib/libnccl.so.2 /opt/conda/lib/python${PYTHON_VERSION}/site-packages/nvidia/nccl/lib/libnccl.so; \
pip install --no-build-isolation .; \
fi

RUN cd /root && git clone https://github.com/deepseek-ai/DeepGEMM.git && \
cd DeepGEMM && git checkout ${DEEPGEMM_REF} && \
git submodule update --init --recursive && \
pip install --no-build-isolation .

RUN if [ "${ENABLE_NIXL}" = "1" ]; then \
apt-get update && apt-get install -y cmake automake autotools-dev libtool libz-dev && \
DEBIAN_FRONTEND=noninteractive apt-get -y install --reinstall libibverbs-dev rdma-core ibverbs-utils libibumad-dev; \
Expand Down Expand Up @@ -126,7 +127,7 @@ RUN if [ "${ENABLE_NIXL}" = "1" ]; then \
apt-get update && apt-get install -y pkg-config tmux net-tools && \
cd /usr/local/src; \
pip install --upgrade meson pybind11 patchelf; \
git clone https://github.com/ai-dynamo/nixl.git -b main && \
git clone https://github.com/ai-dynamo/nixl.git -b ${NIXL_REF} && \
cd nixl && \
rm -rf build && \
mkdir build && \
Expand Down
14 changes: 10 additions & 4 deletions docker/scripts/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,23 @@ set -euo pipefail
# --no-nixl Disable NIXL (default: enabled)
# --no-cache Disable cache (default: enabled)
# --lite Disable DEEPEP, NIXL and cache in one shot
# --cuda-version <ver> CUDA version (default: 12.8.0)
# --cuda-version <ver> CUDA version (default: 13.0.0)
# --image-prefix <name> Image prefix (default: lightllm)
# --image-tag <tag> Image tag (default: generated from enabled features)
# --enable-sm100 Enable SM100 support (default: disabled)
# -h / --help Show help

ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)"
cd "${ROOT_DIR}"

IMAGE_PREFIX="${IMAGE_PREFIX:-lightllm}"
CUDA_VERSION="${CUDA_VERSION:-12.8.0}"
CUDA_VERSION="${CUDA_VERSION:-13.0.0}"
IMAGE_TAG="${IMAGE_TAG:-}"

ENABLE_DEEPEP="${ENABLE_DEEPEP:-1}"
ENABLE_NIXL="${ENABLE_NIXL:-1}"
ENABLE_CACHE="${ENABLE_CACHE:-1}"
ENABLE_SM100="${ENABLE_SM100:-0}"

print_help() {
sed -n '1,80p' "$0" | sed 's/^# \{0,1\}//'
Expand All @@ -43,6 +45,7 @@ while [[ $# -gt 0 ]]; do
--no-deepep) ENABLE_DEEPEP=0 ;;
--no-nixl) ENABLE_NIXL=0 ;;
--no-cache) ENABLE_CACHE=0 ;;
--enable-sm100) ENABLE_SM100=1 ;;
--lite)
ENABLE_DEEPEP=0
ENABLE_NIXL=0
Expand Down Expand Up @@ -78,13 +81,16 @@ done
# - Other combos: composed from enabled feature names
if [[ -z "${IMAGE_TAG}" ]]; then
tag_parts=()
if [[ "${ENABLE_SM100}" -eq 1 ]]; then
tag_parts+=("sm100")
fi
if [[ "${ENABLE_NIXL}" -eq 1 ]]; then
tag_parts+=("nixl")
fi
if [[ "${ENABLE_DEEPEP}" -eq 1 ]]; then
tag_parts+=("deepep")
fi
if [[ "${ENABLE_NIXL}" -eq 1 && "${ENABLE_DEEPEP}" -eq 1 && "${ENABLE_CACHE}" -eq 1 ]]; then
if [[ "${ENABLE_SM100}" -eq 0 && "${ENABLE_NIXL}" -eq 1 && "${ENABLE_DEEPEP}" -eq 1 && "${ENABLE_CACHE}" -eq 1 ]]; then
IMAGE_TAG="cuda${CUDA_VERSION}"
else
prefix=""
Expand All @@ -100,6 +106,6 @@ DOCKER_BUILDKIT=1 docker build -f docker/Dockerfile \
--build-arg ENABLE_DEEPEP="${ENABLE_DEEPEP}" \
--build-arg ENABLE_NIXL="${ENABLE_NIXL}" \
--build-arg ENABLE_CACHE="${ENABLE_CACHE}" \
--build-arg ENABLE_SM100="${ENABLE_SM100}" \
--progress=plain \
-t "${IMAGE_PREFIX}:${IMAGE_TAG}" .

1 change: 1 addition & 0 deletions lightllm/common/basemodel/attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .triton.int8kv import Int8kvTritonAttBackend
from .triton.mla import MlaTritonAttBackend
from .fa3.fp import Fa3AttBackend
from .fa4.fp import Fa4AttBackend
from .fa3.fp8 import Fp8Fa3AttBackend
from .fa3.mla import MlaFa3AttBackend
from .flashinfer.fp8 import Fp8FlashInferAttBackend
Expand Down
20 changes: 14 additions & 6 deletions lightllm/common/basemodel/attention/create_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Attention backend selection utilities."""
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.utils.envs_utils import get_env_start_args, get_page_size
from lightllm.utils.log_utils import init_logger
from lightllm.utils.backend_validator import validate
from typing import Dict
Expand All @@ -9,22 +9,30 @@
from .triton.int8kv import Int8kvTritonAttBackend
from .triton.mla import MlaTritonAttBackend
from .fa3.fp import Fa3AttBackend
from .fa4.fp import Fa4AttBackend
from .fa3.fp8 import Fp8Fa3AttBackend
from .fa3.mla import MlaFa3AttBackend
from .paged_fa3.fp import PagedFa3AttBackend
from .paged_fa3.mla import PagedMlaFa3AttBackend
from .flashinfer.fp8 import Fp8FlashInferAttBackend
from .flashinfer.fp import FlashInferAttBackend
from .flashinfer.mla import MlaFlashInferAttBackend
from .nsa.flashmla_sparse import NsaFlashMlaSparseAttBackend
from .nsa.fp8_flashmla_sparse import NsaFlashMlaFp8SparseAttBackend
from .paged_flashinfer.fp import PagedFlashInferAttBackend
from .paged_flashinfer.mla import PagedMlaFlashInferAttBackend

logger = init_logger(__name__)

_PAGE_ENABLED = get_page_size() > 1

# Backend class mappings by data type
data_type_to_backend = {
"None": {
"triton": TritonAttBackend,
"fa3": Fa3AttBackend,
"flashinfer": FlashInferAttBackend,
"triton": TritonAttBackend, # triton backend supports arbitrary page size
"fa3": PagedFa3AttBackend if _PAGE_ENABLED else Fa3AttBackend,
"fa4": Fa4AttBackend,
"flashinfer": PagedFlashInferAttBackend if _PAGE_ENABLED else FlashInferAttBackend,
},
"int4kv": {
"triton": Int4kvTritonAttBackend,
Expand All @@ -47,8 +55,8 @@
mla_data_type_to_backend = {
"None": {
"triton": MlaTritonAttBackend,
"fa3": MlaFa3AttBackend,
"flashinfer": MlaFlashInferAttBackend,
"fa3": PagedMlaFa3AttBackend if _PAGE_ENABLED else MlaFa3AttBackend,
"flashinfer": PagedMlaFlashInferAttBackend if _PAGE_ENABLED else MlaFlashInferAttBackend,
},
}

Expand Down
4 changes: 2 additions & 2 deletions lightllm/common/basemodel/attention/fa3/fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,15 @@ def prefill_att(
alloc_func=torch.empty,
) -> torch.Tensor:
assert att_control.use_alibi is False
return self._nomarl_prefill_att(
return self._normal_prefill_att(
q=q,
k=k,
v=v,
att_control=att_control,
alloc_func=alloc_func,
)

def _nomarl_prefill_att(
def _normal_prefill_att(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, att_control: AttControl, alloc_func=torch.empty
) -> torch.Tensor:
self.backend: Fa3AttBackend = self.backend # for typing
Expand Down
Empty file.
139 changes: 139 additions & 0 deletions lightllm/common/basemodel/attention/fa4/fp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import dataclasses
import torch

from ..base_att import AttControl
from ..paged_fa3.fp import PagedFa3AttBackend, PagedFa3PrefillAttState, PagedFa3DecodeAttState
from lightllm.utils.fa4_utils import (
ensure_fa4_available,
ensure_fa4_supported_gpu,
flash_attn_varlen_func,
sm90_fa4_paged_kv_tile_n,
unwrap_fa4_output,
)


class Fa4AttBackend(PagedFa3AttBackend):
def __init__(self, model):
ensure_fa4_available()
ensure_fa4_supported_gpu()
super().__init__(model=model)

def create_att_prefill_state(self, infer_state) -> "Fa4PrefillAttState":
return Fa4PrefillAttState(backend=self, infer_state=infer_state)

def create_att_decode_state(self, infer_state) -> "Fa4DecodeAttState":
return Fa4DecodeAttState(backend=self, infer_state=infer_state)


def _sm90_fa4_paged_kv_tile_n(
head_dim: int,
head_dim_v: int,
window_size: tuple[int, int],
) -> int | None:
return sm90_fa4_paged_kv_tile_n(head_dim=head_dim, head_dim_v=head_dim_v, window_size=window_size)


def _ensure_fa4_paged_kv_supported(
head_dim: int,
head_dim_v: int,
window_size: tuple[int, int],
page_size: int,
) -> None:
tile_n = _sm90_fa4_paged_kv_tile_n(head_dim, head_dim_v, window_size)
if tile_n is None or page_size == tile_n or tile_n >= 128:
return

raise RuntimeError(
"FA4 SM90 paged KV requires page_size == tile_n for this shape; "
f"current page_size={page_size}, required_page_size={tile_n}, "
f"head_dim={head_dim}, head_dim_v={head_dim_v}, window_size={window_size}. "
"LightLLM's current FA4 wrapper uses token-granular KV pages, so this shape would need "
"the removed repack fallback to run. Please set the FA4 KV cache page size to "
f"{tile_n} tokens for this model/shape, or switch --llm_prefill_att_backend/"
"--llm_decode_att_backend to another backend."
)


@dataclasses.dataclass
class Fa4PrefillAttState(PagedFa3PrefillAttState):
def _normal_prefill_att(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, att_control: AttControl, alloc_func=torch.empty
) -> torch.Tensor:
import triton

if att_control.use_sliding_window:
window_size = att_control.sliding_window
else:
window_size = (-1, -1)

if att_control.use_att_sink:
sink_weight = att_control.sink_weight
else:
sink_weight = None

head_dim = q.shape[-1]
head_dim_v = v.shape[-1]
softmax_scale = 1.0 / (head_dim ** 0.5)
_ensure_fa4_paged_kv_supported(head_dim, head_dim_v, window_size, page_size=self.backend.page_size)

out = flash_attn_varlen_func(
q=q,
k=k.view(-1, self.backend.page_size, k.shape[1], k.shape[2]),
v=v.view(-1, self.backend.page_size, v.shape[1], v.shape[2]),
cu_seqlens_q=self.cu_seqlens_q,
seqused_k=self.infer_state.b_seq_len.int(),
max_seqlen_q=self.infer_state.max_q_seq_len,
max_seqlen_k=triton.cdiv(self.infer_state.max_kv_seq_len, self.backend.page_size) * self.backend.page_size,
page_table=self.page_table,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
learnable_sink=sink_weight,
softcap=0.0,
return_lse=False,
)
return unwrap_fa4_output(out)


@dataclasses.dataclass
class Fa4DecodeAttState(PagedFa3DecodeAttState):
def _normal_decode_att(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
att_control: AttControl,
alloc_func=torch.empty,
):
if att_control.use_sliding_window:
window_size = att_control.sliding_window
else:
window_size = (-1, -1)

if att_control.use_att_sink:
sink_weight = att_control.sink_weight
else:
sink_weight = None

head_dim = q.shape[-1]
head_dim_v = v.shape[-1]
softmax_scale = 1.0 / (head_dim ** 0.5)
_ensure_fa4_paged_kv_supported(head_dim, head_dim_v, window_size, page_size=self.backend.page_size)

out = flash_attn_varlen_func(
q=q,
k=k.view(-1, self.backend.page_size, k.shape[1], k.shape[2]),
v=v.view(-1, self.backend.page_size, v.shape[1], v.shape[2]),
cu_seqlens_q=self.cu_seqlens_q,
seqused_k=self.b_att_seq_len.int(),
max_seqlen_q=self.decode_max_q_seq_len,
max_seqlen_k=self.infer_state.max_kv_seq_len,
page_table=self.page_table,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
learnable_sink=sink_weight,
softcap=0.0,
return_lse=False,
)
return unwrap_fa4_output(out)
4 changes: 2 additions & 2 deletions lightllm/common/basemodel/attention/flashinfer/fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,14 @@ def prefill_att(
and att_control.use_sliding_window is False
and att_control.use_att_sink is False
)
return self._nomarl_prefill_att(
return self._normal_prefill_att(
q=q,
k=k,
v=v,
alloc_func=alloc_func,
)

def _nomarl_prefill_att(
def _normal_prefill_att(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, alloc_func=torch.empty
) -> torch.Tensor:
self.backend: FlashInferAttBackend = self.backend # for typing
Expand Down
Empty file.
Loading
Loading