diff --git a/docker/Dockerfile b/docker/Dockerfile index 439ecddb34..4640e19b8f 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,14 +1,17 @@ -ARG CUDA_VERSION=12.8.0 +ARG CUDA_VERSION=13.0.0 FROM nvidia/cuda:${CUDA_VERSION}-cudnn-devel-ubuntu22.04 ARG PYTHON_VERSION=3.10 ARG MAMBA_VERSION=24.7.1-0 -ARG VLLM_VERSION=0.16.0 +ARG VLLM_VERSION=0.20.2 +ARG NIXL_REF=v1.1.0 ARG FLASH_MLA_REF=47c35a7 +ARG DEEPGEMM_REF=891d57b4db1071624b5c8fa0d1e51cb317fa709f ARG TARGETPLATFORM ARG ENABLE_DEEPEP=1 ARG ENABLE_NIXL=1 ARG ENABLE_CACHE=1 +ARG ENABLE_SM100=0 ENV PATH=/opt/conda/bin:$PATH \ CONDA_PREFIX=/opt/conda @@ -44,13 +47,18 @@ WORKDIR /root COPY ./requirements.txt /lightllm/requirements.txt RUN pip install -U pip -RUN pip install -r /lightllm/requirements.txt --no-cache-dir -RUN pip install --no-cache-dir vllm==${VLLM_VERSION} -RUN git clone https://github.com/deepseek-ai/FlashMLA.git /root/FlashMLA && \ +RUN pip install --no-cache-dir \ + --extra-index-url https://download.pytorch.org/whl/cu130 \ + vllm==${VLLM_VERSION} +RUN pip install -r /lightllm/requirements.txt --no-cache-dir \ + --extra-index-url https://download.pytorch.org/whl/cu130 +RUN export CPATH=/usr/local/cuda/targets/x86_64-linux/include/cccl:/usr/local/cuda/targets/x86_64-linux/include${CPATH:+:${CPATH}} && \ + git clone https://github.com/deepseek-ai/FlashMLA.git /root/FlashMLA && \ cd /root/FlashMLA && \ git checkout ${FLASH_MLA_REF} && \ git submodule update --init --recursive && \ - FLASH_MLA_DISABLE_SM100=1 pip install --no-cache-dir . + FLASH_MLA_DISABLE_SM100="$(if [ "${ENABLE_SM100}" = "1" ]; then echo 0; else echo 1; fi)" \ + pip install --no-cache-dir . RUN apt-get update && apt-get install -y libnuma-dev && rm -rf /var/lib/apt/lists/* @@ -78,27 +86,20 @@ RUN if [ "${ENABLE_NIXL}" = "1" ] || [ "${ENABLE_DEEPEP}" = "1" ]; then \ RUN if [ "${ENABLE_DEEPEP}" = "1" ]; then \ set -e; \ ln -sf /usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so; \ - NVSHMEM_VERSION=3.3.9; \ - CUDA_ARCHS=90; \ - wget https://developer.download.nvidia.com/compute/redist/nvshmem/${NVSHMEM_VERSION}/source/nvshmem_src_cuda12-all-all-${NVSHMEM_VERSION}.tar.gz \ - && tar -xf nvshmem_src_cuda12-all-all-${NVSHMEM_VERSION}.tar.gz && mv nvshmem_src nvshmem \ - && cd nvshmem \ - && rm -f /root/nvshmem_src_cuda12-all-all-${NVSHMEM_VERSION}.tar.gz \ - && NVSHMEM_SHMEM_SUPPORT=0 \ - NVSHMEM_UCX_SUPPORT=0 \ - NVSHMEM_USE_NCCL=0 \ - NVSHMEM_MPI_SUPPORT=0 \ - NVSHMEM_IBGDA_SUPPORT=1 \ - NVSHMEM_PMIX_SUPPORT=0 \ - NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \ - NVSHMEM_USE_GDRCOPY=1 \ - cmake -S . -B build/ -DCMAKE_INSTALL_PREFIX=/root/nvshmem/install -DCMAKE_CUDA_ARCHITECTURES=${CUDA_ARCHS} \ - && cmake --build build --target install -j64; \ - DEEPEP_COMMIT=b6ce310bb0b75079682d09bc2ebc063a074fbd58; \ - cd /root && git clone https://github.com/deepseek-ai/DeepEP.git && cd DeepEP && git checkout ${DEEPEP_COMMIT} && cd ..; \ - cd /root/DeepEP && NVSHMEM_DIR=/root/nvshmem/install python setup.py install; \ + python -m pip install --upgrade --no-deps \ + "nvidia-nccl-cu13==2.30.4" \ + "nvidia-nvshmem-cu13==3.6.5"; \ + cd /root && git clone https://github.com/deepseek-ai/DeepEP.git && cd DeepEP && git checkout b306af06afd412c88e51e71802951606e40b7358; \ + ln -sf /opt/conda/lib/python${PYTHON_VERSION}/site-packages/nvidia/nvshmem/lib/libnvshmem_host.so.3 /opt/conda/lib/python${PYTHON_VERSION}/site-packages/nvidia/nvshmem/lib/libnvshmem_host.so; \ + ln -sf /opt/conda/lib/python${PYTHON_VERSION}/site-packages/nvidia/nccl/lib/libnccl.so.2 /opt/conda/lib/python${PYTHON_VERSION}/site-packages/nvidia/nccl/lib/libnccl.so; \ + pip install --no-build-isolation .; \ fi +RUN cd /root && git clone https://github.com/deepseek-ai/DeepGEMM.git && \ + cd DeepGEMM && git checkout ${DEEPGEMM_REF} && \ + git submodule update --init --recursive && \ + pip install --no-build-isolation . + RUN if [ "${ENABLE_NIXL}" = "1" ]; then \ apt-get update && apt-get install -y cmake automake autotools-dev libtool libz-dev && \ DEBIAN_FRONTEND=noninteractive apt-get -y install --reinstall libibverbs-dev rdma-core ibverbs-utils libibumad-dev; \ @@ -126,7 +127,7 @@ RUN if [ "${ENABLE_NIXL}" = "1" ]; then \ apt-get update && apt-get install -y pkg-config tmux net-tools && \ cd /usr/local/src; \ pip install --upgrade meson pybind11 patchelf; \ - git clone https://github.com/ai-dynamo/nixl.git -b main && \ + git clone https://github.com/ai-dynamo/nixl.git -b ${NIXL_REF} && \ cd nixl && \ rm -rf build && \ mkdir build && \ diff --git a/docker/scripts/build.sh b/docker/scripts/build.sh index 355d6c65b3..bc1fd73da3 100644 --- a/docker/scripts/build.sh +++ b/docker/scripts/build.sh @@ -18,21 +18,23 @@ set -euo pipefail # --no-nixl Disable NIXL (default: enabled) # --no-cache Disable cache (default: enabled) # --lite Disable DEEPEP, NIXL and cache in one shot -# --cuda-version CUDA version (default: 12.8.0) +# --cuda-version CUDA version (default: 13.0.0) # --image-prefix Image prefix (default: lightllm) # --image-tag Image tag (default: generated from enabled features) +# --enable-sm100 Enable SM100 support (default: disabled) # -h / --help Show help ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" cd "${ROOT_DIR}" IMAGE_PREFIX="${IMAGE_PREFIX:-lightllm}" -CUDA_VERSION="${CUDA_VERSION:-12.8.0}" +CUDA_VERSION="${CUDA_VERSION:-13.0.0}" IMAGE_TAG="${IMAGE_TAG:-}" ENABLE_DEEPEP="${ENABLE_DEEPEP:-1}" ENABLE_NIXL="${ENABLE_NIXL:-1}" ENABLE_CACHE="${ENABLE_CACHE:-1}" +ENABLE_SM100="${ENABLE_SM100:-0}" print_help() { sed -n '1,80p' "$0" | sed 's/^# \{0,1\}//' @@ -43,6 +45,7 @@ while [[ $# -gt 0 ]]; do --no-deepep) ENABLE_DEEPEP=0 ;; --no-nixl) ENABLE_NIXL=0 ;; --no-cache) ENABLE_CACHE=0 ;; + --enable-sm100) ENABLE_SM100=1 ;; --lite) ENABLE_DEEPEP=0 ENABLE_NIXL=0 @@ -78,13 +81,16 @@ done # - Other combos: composed from enabled feature names if [[ -z "${IMAGE_TAG}" ]]; then tag_parts=() + if [[ "${ENABLE_SM100}" -eq 1 ]]; then + tag_parts+=("sm100") + fi if [[ "${ENABLE_NIXL}" -eq 1 ]]; then tag_parts+=("nixl") fi if [[ "${ENABLE_DEEPEP}" -eq 1 ]]; then tag_parts+=("deepep") fi - if [[ "${ENABLE_NIXL}" -eq 1 && "${ENABLE_DEEPEP}" -eq 1 && "${ENABLE_CACHE}" -eq 1 ]]; then + if [[ "${ENABLE_SM100}" -eq 0 && "${ENABLE_NIXL}" -eq 1 && "${ENABLE_DEEPEP}" -eq 1 && "${ENABLE_CACHE}" -eq 1 ]]; then IMAGE_TAG="cuda${CUDA_VERSION}" else prefix="" @@ -100,6 +106,6 @@ DOCKER_BUILDKIT=1 docker build -f docker/Dockerfile \ --build-arg ENABLE_DEEPEP="${ENABLE_DEEPEP}" \ --build-arg ENABLE_NIXL="${ENABLE_NIXL}" \ --build-arg ENABLE_CACHE="${ENABLE_CACHE}" \ + --build-arg ENABLE_SM100="${ENABLE_SM100}" \ --progress=plain \ -t "${IMAGE_PREFIX}:${IMAGE_TAG}" . - diff --git a/lightllm/common/basemodel/attention/__init__.py b/lightllm/common/basemodel/attention/__init__.py index 10cd3b0864..896bc63bdd 100644 --- a/lightllm/common/basemodel/attention/__init__.py +++ b/lightllm/common/basemodel/attention/__init__.py @@ -4,6 +4,7 @@ from .triton.int8kv import Int8kvTritonAttBackend from .triton.mla import MlaTritonAttBackend from .fa3.fp import Fa3AttBackend +from .fa4.fp import Fa4AttBackend from .fa3.fp8 import Fp8Fa3AttBackend from .fa3.mla import MlaFa3AttBackend from .flashinfer.fp8 import Fp8FlashInferAttBackend diff --git a/lightllm/common/basemodel/attention/create_utils.py b/lightllm/common/basemodel/attention/create_utils.py index 594e81a9b4..dda6b4cb94 100644 --- a/lightllm/common/basemodel/attention/create_utils.py +++ b/lightllm/common/basemodel/attention/create_utils.py @@ -1,5 +1,5 @@ """Attention backend selection utilities.""" -from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.envs_utils import get_env_start_args, get_page_size from lightllm.utils.log_utils import init_logger from lightllm.utils.backend_validator import validate from typing import Dict @@ -9,22 +9,30 @@ from .triton.int8kv import Int8kvTritonAttBackend from .triton.mla import MlaTritonAttBackend from .fa3.fp import Fa3AttBackend +from .fa4.fp import Fa4AttBackend from .fa3.fp8 import Fp8Fa3AttBackend from .fa3.mla import MlaFa3AttBackend +from .paged_fa3.fp import PagedFa3AttBackend +from .paged_fa3.mla import PagedMlaFa3AttBackend from .flashinfer.fp8 import Fp8FlashInferAttBackend from .flashinfer.fp import FlashInferAttBackend from .flashinfer.mla import MlaFlashInferAttBackend from .nsa.flashmla_sparse import NsaFlashMlaSparseAttBackend from .nsa.fp8_flashmla_sparse import NsaFlashMlaFp8SparseAttBackend +from .paged_flashinfer.fp import PagedFlashInferAttBackend +from .paged_flashinfer.mla import PagedMlaFlashInferAttBackend logger = init_logger(__name__) +_PAGE_ENABLED = get_page_size() > 1 + # Backend class mappings by data type data_type_to_backend = { "None": { - "triton": TritonAttBackend, - "fa3": Fa3AttBackend, - "flashinfer": FlashInferAttBackend, + "triton": TritonAttBackend, # triton backend supports arbitrary page size + "fa3": PagedFa3AttBackend if _PAGE_ENABLED else Fa3AttBackend, + "fa4": Fa4AttBackend, + "flashinfer": PagedFlashInferAttBackend if _PAGE_ENABLED else FlashInferAttBackend, }, "int4kv": { "triton": Int4kvTritonAttBackend, @@ -47,8 +55,8 @@ mla_data_type_to_backend = { "None": { "triton": MlaTritonAttBackend, - "fa3": MlaFa3AttBackend, - "flashinfer": MlaFlashInferAttBackend, + "fa3": PagedMlaFa3AttBackend if _PAGE_ENABLED else MlaFa3AttBackend, + "flashinfer": PagedMlaFlashInferAttBackend if _PAGE_ENABLED else MlaFlashInferAttBackend, }, } diff --git a/lightllm/common/basemodel/attention/fa3/fp.py b/lightllm/common/basemodel/attention/fa3/fp.py index 952bb39d91..9568e4a892 100644 --- a/lightllm/common/basemodel/attention/fa3/fp.py +++ b/lightllm/common/basemodel/attention/fa3/fp.py @@ -66,7 +66,7 @@ def prefill_att( alloc_func=torch.empty, ) -> torch.Tensor: assert att_control.use_alibi is False - return self._nomarl_prefill_att( + return self._normal_prefill_att( q=q, k=k, v=v, @@ -74,7 +74,7 @@ def prefill_att( alloc_func=alloc_func, ) - def _nomarl_prefill_att( + def _normal_prefill_att( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, att_control: AttControl, alloc_func=torch.empty ) -> torch.Tensor: self.backend: Fa3AttBackend = self.backend # for typing diff --git a/lightllm/common/basemodel/attention/fa4/__init__.py b/lightllm/common/basemodel/attention/fa4/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/common/basemodel/attention/fa4/fp.py b/lightllm/common/basemodel/attention/fa4/fp.py new file mode 100644 index 0000000000..91f191dc66 --- /dev/null +++ b/lightllm/common/basemodel/attention/fa4/fp.py @@ -0,0 +1,139 @@ +import dataclasses +import torch + +from ..base_att import AttControl +from ..paged_fa3.fp import PagedFa3AttBackend, PagedFa3PrefillAttState, PagedFa3DecodeAttState +from lightllm.utils.fa4_utils import ( + ensure_fa4_available, + ensure_fa4_supported_gpu, + flash_attn_varlen_func, + sm90_fa4_paged_kv_tile_n, + unwrap_fa4_output, +) + + +class Fa4AttBackend(PagedFa3AttBackend): + def __init__(self, model): + ensure_fa4_available() + ensure_fa4_supported_gpu() + super().__init__(model=model) + + def create_att_prefill_state(self, infer_state) -> "Fa4PrefillAttState": + return Fa4PrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state) -> "Fa4DecodeAttState": + return Fa4DecodeAttState(backend=self, infer_state=infer_state) + + +def _sm90_fa4_paged_kv_tile_n( + head_dim: int, + head_dim_v: int, + window_size: tuple[int, int], +) -> int | None: + return sm90_fa4_paged_kv_tile_n(head_dim=head_dim, head_dim_v=head_dim_v, window_size=window_size) + + +def _ensure_fa4_paged_kv_supported( + head_dim: int, + head_dim_v: int, + window_size: tuple[int, int], + page_size: int, +) -> None: + tile_n = _sm90_fa4_paged_kv_tile_n(head_dim, head_dim_v, window_size) + if tile_n is None or page_size == tile_n or tile_n >= 128: + return + + raise RuntimeError( + "FA4 SM90 paged KV requires page_size == tile_n for this shape; " + f"current page_size={page_size}, required_page_size={tile_n}, " + f"head_dim={head_dim}, head_dim_v={head_dim_v}, window_size={window_size}. " + "LightLLM's current FA4 wrapper uses token-granular KV pages, so this shape would need " + "the removed repack fallback to run. Please set the FA4 KV cache page size to " + f"{tile_n} tokens for this model/shape, or switch --llm_prefill_att_backend/" + "--llm_decode_att_backend to another backend." + ) + + +@dataclasses.dataclass +class Fa4PrefillAttState(PagedFa3PrefillAttState): + def _normal_prefill_att( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, att_control: AttControl, alloc_func=torch.empty + ) -> torch.Tensor: + import triton + + if att_control.use_sliding_window: + window_size = att_control.sliding_window + else: + window_size = (-1, -1) + + if att_control.use_att_sink: + sink_weight = att_control.sink_weight + else: + sink_weight = None + + head_dim = q.shape[-1] + head_dim_v = v.shape[-1] + softmax_scale = 1.0 / (head_dim ** 0.5) + _ensure_fa4_paged_kv_supported(head_dim, head_dim_v, window_size, page_size=self.backend.page_size) + + out = flash_attn_varlen_func( + q=q, + k=k.view(-1, self.backend.page_size, k.shape[1], k.shape[2]), + v=v.view(-1, self.backend.page_size, v.shape[1], v.shape[2]), + cu_seqlens_q=self.cu_seqlens_q, + seqused_k=self.infer_state.b_seq_len.int(), + max_seqlen_q=self.infer_state.max_q_seq_len, + max_seqlen_k=triton.cdiv(self.infer_state.max_kv_seq_len, self.backend.page_size) * self.backend.page_size, + page_table=self.page_table, + softmax_scale=softmax_scale, + causal=True, + window_size=window_size, + learnable_sink=sink_weight, + softcap=0.0, + return_lse=False, + ) + return unwrap_fa4_output(out) + + +@dataclasses.dataclass +class Fa4DecodeAttState(PagedFa3DecodeAttState): + def _normal_decode_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl, + alloc_func=torch.empty, + ): + if att_control.use_sliding_window: + window_size = att_control.sliding_window + else: + window_size = (-1, -1) + + if att_control.use_att_sink: + sink_weight = att_control.sink_weight + else: + sink_weight = None + + head_dim = q.shape[-1] + head_dim_v = v.shape[-1] + softmax_scale = 1.0 / (head_dim ** 0.5) + _ensure_fa4_paged_kv_supported(head_dim, head_dim_v, window_size, page_size=self.backend.page_size) + + out = flash_attn_varlen_func( + q=q, + k=k.view(-1, self.backend.page_size, k.shape[1], k.shape[2]), + v=v.view(-1, self.backend.page_size, v.shape[1], v.shape[2]), + cu_seqlens_q=self.cu_seqlens_q, + seqused_k=self.b_att_seq_len.int(), + max_seqlen_q=self.decode_max_q_seq_len, + max_seqlen_k=self.infer_state.max_kv_seq_len, + page_table=self.page_table, + softmax_scale=softmax_scale, + causal=True, + window_size=window_size, + learnable_sink=sink_weight, + softcap=0.0, + return_lse=False, + ) + return unwrap_fa4_output(out) diff --git a/lightllm/common/basemodel/attention/flashinfer/fp.py b/lightllm/common/basemodel/attention/flashinfer/fp.py index 91a004ec2e..37478be76f 100644 --- a/lightllm/common/basemodel/attention/flashinfer/fp.py +++ b/lightllm/common/basemodel/attention/flashinfer/fp.py @@ -99,14 +99,14 @@ def prefill_att( and att_control.use_sliding_window is False and att_control.use_att_sink is False ) - return self._nomarl_prefill_att( + return self._normal_prefill_att( q=q, k=k, v=v, alloc_func=alloc_func, ) - def _nomarl_prefill_att( + def _normal_prefill_att( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, alloc_func=torch.empty ) -> torch.Tensor: self.backend: FlashInferAttBackend = self.backend # for typing diff --git a/lightllm/common/basemodel/attention/paged_fa3/__init__.py b/lightllm/common/basemodel/attention/paged_fa3/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/common/basemodel/attention/paged_fa3/fp.py b/lightllm/common/basemodel/attention/paged_fa3/fp.py new file mode 100644 index 0000000000..5c01538c42 --- /dev/null +++ b/lightllm/common/basemodel/attention/paged_fa3/fp.py @@ -0,0 +1,188 @@ +import dataclasses +import torch +import triton +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from lightllm.utils.dist_utils import get_current_device_id +from lightllm.utils.sgl_utils import flash_attn_with_kvcache +from lightllm.utils.envs_utils import get_env_start_args, get_page_size +from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy +from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor + + +class PagedFa3AttBackend(BaseAttBackend): + def __init__(self, model, page_size=None): + super().__init__(model=model) + self.page_size = page_size or get_page_size() + self.get_page_table_buffer() + + def get_page_table_buffer(self): + model = self.model + if not hasattr(self, "_shared_page_table_buffer"): + shared_len = model.graph_max_batch_size * triton.cdiv(model.graph_max_len_in_batch, self.page_size) + self._shared_page_table_buffer = [ + torch.empty(shared_len, dtype=torch.int32).to(get_current_device_id()), + torch.empty(shared_len, dtype=torch.int32).to(get_current_device_id()), + ] + return self._shared_page_table_buffer + + def create_att_prefill_state(self, infer_state): + return PagedFa3PrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state): + return PagedFa3DecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class PagedFa3PrefillAttState(BasePrefillAttState): + cu_seqlens_q: torch.Tensor = None + cu_seqlens_k: torch.Tensor = None + page_table: torch.Tensor = None + + def init_state(self): + self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() + self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() + table_len = triton.cdiv(self.infer_state.max_kv_seq_len, self.backend.page_size) + self.page_table = torch.empty( + (self.infer_state.batch_size, table_len), + dtype=torch.int32, + device=self.infer_state.input_ids.device, + ) + page_table_copy( + page_table=self.page_table, + req_to_token_indexs=self.infer_state.req_manager.req_to_token_indexs, + b_req_idx=self.infer_state.b_req_idx, + ) + + def prefill_att(self, q, k, v, att_control: AttControl = AttControl(), alloc_func=torch.empty): + assert att_control.use_alibi is False + return self._normal_prefill_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func) + + def _normal_prefill_att(self, q, k, v, att_control: AttControl, alloc_func=torch.empty): + if att_control.use_sliding_window: + window_size = att_control.sliding_window + else: + window_size = (-1, -1) + + if att_control.use_att_sink: + sink_weight = att_control.sink_weight + else: + sink_weight = None + + sm_scale = 1.0 / (q.shape[-1] ** 0.5) + return flash_attn_with_kvcache( + q=q, + k_cache=k.view(-1, self.backend.page_size, k.shape[1], k.shape[2]), + v_cache=v.view(-1, self.backend.page_size, v.shape[1], v.shape[2]), + page_table=self.page_table, + cache_seqlens=self.infer_state.b_seq_len, + cu_seqlens_q=self.cu_seqlens_q, + cu_seqlens_k_new=self.cu_seqlens_k, + max_seqlen_q=self.infer_state.max_q_seq_len, + softmax_scale=sm_scale, + causal=True, + window_size=window_size, + softcap=0.0, + k_descale=None, + v_descale=None, + return_softmax_lse=False, + sinks=sink_weight, + ) + + +@dataclasses.dataclass +class PagedFa3DecodeAttState(BaseDecodeAttState): + cu_seqlens_q: torch.Tensor = None + cu_seqlens_k: torch.Tensor = None + page_table: torch.Tensor = None + b_att_seq_len: torch.Tensor = None + decode_max_q_seq_len: int = None + + def init_state(self): + args_mtp_step = get_env_start_args().mtp_step + if args_mtp_step > 0: + mtp_size = args_mtp_step + 1 + b_q_seq_len = torch.full( + (self.infer_state.b_seq_len.shape[0] // mtp_size,), + fill_value=mtp_size, + dtype=torch.int32, + device=self.infer_state.b_seq_len.device, + ) + b_kv_seq_len = self.infer_state.b_seq_len[mtp_size - 1 :: mtp_size] + b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor(b_q_seq_len, b_kv_seq_len) + self.cu_seqlens_q = b1_cu_q_seq_len.int() + self.cu_seqlens_k = b1_cu_kv_seq_len.int() + else: + self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() + self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() + + att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1) + assert self.infer_state.batch_size % (args_mtp_step + 1) == 0 + model = self.backend.model + table_len = triton.cdiv(self.infer_state.max_kv_seq_len, self.backend.page_size) + if ( + self.infer_state.batch_size <= model.graph_max_batch_size + and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch + ): + page_buffer = self.backend.get_page_table_buffer() + shared_table_len = triton.cdiv(model.graph_max_len_in_batch, self.backend.page_size) + self.page_table = page_buffer[self.infer_state.microbatch_index][ + : att_batch_size * shared_table_len + ].reshape(att_batch_size, shared_table_len) + else: + self.page_table = torch.empty( + (att_batch_size, table_len), + dtype=torch.int32, + device=self.infer_state.input_ids.device, + ) + + if args_mtp_step > 0: + page_table_copy( + page_table=self.page_table[:, :table_len], + req_to_token_indexs=model.req_manager.req_to_token_indexs, + b_req_idx=self.infer_state.b_req_idx[args_mtp_step :: (args_mtp_step + 1)], + ) + self.b_att_seq_len = self.infer_state.b_seq_len[args_mtp_step :: (args_mtp_step + 1)].contiguous() + self.decode_max_q_seq_len = args_mtp_step + 1 + else: + page_table_copy( + page_table=self.page_table[:, :table_len], + req_to_token_indexs=model.req_manager.req_to_token_indexs, + b_req_idx=self.infer_state.b_req_idx, + ) + self.b_att_seq_len = self.infer_state.b_seq_len + self.decode_max_q_seq_len = 1 + + def decode_att(self, q, k, v, att_control: AttControl = AttControl(), alloc_func=torch.empty): + assert att_control.use_alibi is False + return self._normal_decode_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func) + + def _normal_decode_att(self, q, k, v, att_control: AttControl, alloc_func=torch.empty): + if att_control.use_sliding_window: + window_size = att_control.sliding_window + else: + window_size = (-1, -1) + + if att_control.use_att_sink: + sink_weight = att_control.sink_weight + else: + sink_weight = None + + sm_scale = 1.0 / (q.shape[-1] ** 0.5) + return flash_attn_with_kvcache( + q=q, + k_cache=k.view(-1, self.backend.page_size, k.shape[1], k.shape[2]), + v_cache=v.view(-1, self.backend.page_size, v.shape[1], v.shape[2]), + page_table=self.page_table, + cache_seqlens=self.b_att_seq_len, + cu_seqlens_q=self.cu_seqlens_q, + cu_seqlens_k_new=self.cu_seqlens_k, + max_seqlen_q=self.decode_max_q_seq_len, + softmax_scale=sm_scale, + causal=True, + window_size=window_size, + softcap=0.0, + k_descale=None, + v_descale=None, + return_softmax_lse=False, + sinks=sink_weight, + ) diff --git a/lightllm/common/basemodel/attention/paged_fa3/mla.py b/lightllm/common/basemodel/attention/paged_fa3/mla.py new file mode 100644 index 0000000000..2e33c05409 --- /dev/null +++ b/lightllm/common/basemodel/attention/paged_fa3/mla.py @@ -0,0 +1,174 @@ +import dataclasses +import torch +import triton +from typing import Tuple +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from lightllm.utils.dist_utils import get_current_device_id +from lightllm.utils.sgl_utils import flash_attn_with_kvcache, flash_attn_varlen_func +from lightllm.utils.envs_utils import get_env_start_args, get_page_size +from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy +from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor + + +class PagedMlaFa3AttBackend(BaseAttBackend): + def __init__(self, model, page_size=None): + super().__init__(model=model) + self.page_size = page_size or get_page_size() + self.get_page_table_buffer() + + def get_page_table_buffer(self): + model = self.model + if not hasattr(self, "_shared_page_table_buffer"): + shared_len = model.graph_max_batch_size * triton.cdiv(model.graph_max_len_in_batch, self.page_size) + self._shared_page_table_buffer = [ + torch.empty(shared_len, dtype=torch.int32).to(get_current_device_id()), + torch.empty(shared_len, dtype=torch.int32).to(get_current_device_id()), + ] + return self._shared_page_table_buffer + + def create_att_prefill_state(self, infer_state): + return PagedMlaFa3PrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state): + return PagedMlaFa3DecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class PagedMlaFa3PrefillAttState(BasePrefillAttState): + cu_seqlens_q: torch.Tensor = None + cu_seqlens_k: torch.Tensor = None + + def init_state(self): + self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() + self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() + + def prefill_att( + self, q, k: Tuple[torch.Tensor, torch.Tensor], v, att_control: AttControl = AttControl(), alloc_func=torch.empty + ): + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + return self._mla_prefill_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func) + + def _mla_prefill_att( + self, q, k: Tuple[torch.Tensor, torch.Tensor], v, att_control: AttControl, alloc_func=torch.empty + ): + k_nope, k_rope = k + q_head_num = q.shape[1] + k = torch.cat([k_nope, torch.repeat_interleave(k_rope, q_head_num, dim=-2)], dim=-1) + assert q.ndim == 3 and k.ndim == 3 and v.ndim == 3 + assert att_control.mla_prefill + return flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=self.cu_seqlens_q, + cu_seqlens_k=self.cu_seqlens_k, + max_seqlen_q=self.infer_state.max_q_seq_len, + max_seqlen_k=self.infer_state.max_kv_seq_len, + softmax_scale=att_control.mla_prefill_dict["softmax_scale"], + causal=True, + return_softmax_lse=False, + ) + + +@dataclasses.dataclass +class PagedMlaFa3DecodeAttState(BaseDecodeAttState): + cu_seqlens_q: torch.Tensor = None + cu_seqlens_k: torch.Tensor = None + page_table: torch.Tensor = None + b_att_seq_len: torch.Tensor = None + decode_max_q_seq_len: int = None + + def init_state(self): + args_mtp_step = get_env_start_args().mtp_step + if args_mtp_step > 0: + mtp_size = args_mtp_step + 1 + b_q_seq_len = torch.full( + (self.infer_state.b_seq_len.shape[0] // mtp_size,), + fill_value=mtp_size, + dtype=torch.int32, + device=self.infer_state.b_seq_len.device, + ) + b_kv_seq_len = self.infer_state.b_seq_len[mtp_size - 1 :: mtp_size] + b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor(b_q_seq_len, b_kv_seq_len) + self.cu_seqlens_q = b1_cu_q_seq_len.int() + self.cu_seqlens_k = b1_cu_kv_seq_len.int() + else: + self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() + self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() + + att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1) + assert self.infer_state.batch_size % (args_mtp_step + 1) == 0 + model = self.backend.model + table_len = triton.cdiv(self.infer_state.max_kv_seq_len, self.backend.page_size) + if ( + self.infer_state.batch_size <= model.graph_max_batch_size + and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch + ): + page_buffer = self.backend.get_page_table_buffer() + shared_table_len = triton.cdiv(model.graph_max_len_in_batch, self.backend.page_size) + self.page_table = page_buffer[self.infer_state.microbatch_index][ + : att_batch_size * shared_table_len + ].reshape(att_batch_size, shared_table_len) + else: + self.page_table = torch.empty( + (att_batch_size, table_len), + dtype=torch.int32, + device=self.infer_state.input_ids.device, + ) + + if args_mtp_step > 0: + page_table_copy( + page_table=self.page_table[:, :table_len], + req_to_token_indexs=model.req_manager.req_to_token_indexs, + b_req_idx=self.infer_state.b_req_idx[args_mtp_step :: (args_mtp_step + 1)], + ) + self.b_att_seq_len = self.infer_state.b_seq_len[args_mtp_step :: (args_mtp_step + 1)].contiguous() + self.decode_max_q_seq_len = args_mtp_step + 1 + else: + page_table_copy( + page_table=self.page_table[:, :table_len], + req_to_token_indexs=model.req_manager.req_to_token_indexs, + b_req_idx=self.infer_state.b_req_idx, + ) + self.b_att_seq_len = self.infer_state.b_seq_len + self.decode_max_q_seq_len = 1 + + def decode_att( + self, q: Tuple[torch.Tensor, torch.Tensor], k, v, att_control: AttControl = AttControl(), alloc_func=torch.empty + ): + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + assert v is None + return self._mla_decode_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func) + + def _mla_decode_att( + self, q: Tuple[torch.Tensor, torch.Tensor], k, v, att_control: AttControl, alloc_func=torch.empty + ): + q_nope, q_rope = q + qk_rope_head_dim = 64 + kv_lora_rank = k.shape[-1] - qk_rope_head_dim + return flash_attn_with_kvcache( + q=q_rope, + k_cache=k[:, :, -qk_rope_head_dim:].view(-1, self.backend.page_size, 1, qk_rope_head_dim), + v_cache=k[:, :, :-qk_rope_head_dim].view(-1, self.backend.page_size, 1, kv_lora_rank), + qv=q_nope, + page_table=self.page_table, + cache_seqlens=self.b_att_seq_len, + cu_seqlens_q=self.cu_seqlens_q, + cu_seqlens_k_new=self.cu_seqlens_k, + max_seqlen_q=self.decode_max_q_seq_len, + softmax_scale=att_control.mla_decode_dict["softmax_scale"], + causal=True, + window_size=(-1, -1), + softcap=0.0, + k_descale=None, + v_descale=None, + return_softmax_lse=False, + ) diff --git a/lightllm/common/basemodel/attention/paged_flashinfer/__init__.py b/lightllm/common/basemodel/attention/paged_flashinfer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/common/basemodel/attention/paged_flashinfer/fp.py b/lightllm/common/basemodel/attention/paged_flashinfer/fp.py new file mode 100644 index 0000000000..b1807ca30b --- /dev/null +++ b/lightllm/common/basemodel/attention/paged_flashinfer/fp.py @@ -0,0 +1,193 @@ +import dataclasses +import torch +import triton +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from lightllm.utils.dist_utils import get_dp_world_size, get_current_device_id +from ...triton_kernel.repack_kv_index import paged_repack_kv_index +from lightllm.utils.envs_utils import get_page_size +from ..flashinfer.env_utils import set_flashinfer_envs + + +class PagedFlashInferAttBackend(BaseAttBackend): + def __init__(self, model, page_size=None): + set_flashinfer_envs() + super().__init__(model=model) + self.page_size = page_size or get_page_size() + tp_world_size = get_dp_world_size() + self.tp_q_head_num = model.config["num_attention_heads"] // tp_world_size + self.tp_kv_head_num = max(model.config["num_key_value_heads"] // tp_world_size, 1) + head_dim = model.config["hidden_size"] // model.config["num_attention_heads"] + self.head_dim = model.config.get("head_dim", head_dim) + self.workspace_buffer = torch.empty(512 * 1024 * 1024, dtype=torch.int8, device=get_current_device_id()) + self.max_seq_length = model.max_seq_length + buffer_len = model.graph_max_batch_size * triton.cdiv(self.max_seq_length, self.page_size) + self.kv_indices_buffer = [ + torch.empty(buffer_len, dtype=torch.int32, device=get_current_device_id()), + torch.empty(buffer_len, dtype=torch.int32, device=get_current_device_id()), + ] + self.q_data_type = model.data_type + self.kv_data_type = model.data_type + + def create_att_prefill_state(self, infer_state): + return PagedFlashInferPrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state): + return PagedFlashInferDecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class PagedFlashInferPrefillAttState(BasePrefillAttState): + prefill_wrapper: object = None + + def init_state(self): + self.backend: PagedFlashInferAttBackend = self.backend + import flashinfer + + batch_size = self.infer_state.batch_size + device = self.infer_state.input_ids.device + q_starts = self.infer_state.b1_cu_q_seq_len.int() + kv_starts = self.infer_state.b1_cu_kv_seq_len.int() + b_page_len = triton.cdiv(self.infer_state.b_seq_len, self.backend.page_size) + kv_starts[1:] = b_page_len.cumsum(0) + kv_last_page_len = self.infer_state.b_seq_len - (b_page_len - 1) * self.backend.page_size + kv_indices = torch.empty( + batch_size * triton.cdiv(self.backend.max_seq_length, self.backend.page_size), + dtype=torch.int32, + device=device, + ) + paged_repack_kv_index( + self.infer_state.req_manager.req_to_token_indexs, + self.infer_state.b_req_idx, + b_page_len, + kv_starts[:-1], + triton.cdiv(self.infer_state.max_kv_seq_len, self.backend.page_size), + kv_indices, + ) + self.prefill_wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( + self.backend.workspace_buffer, + qo_indptr_buf=q_starts, + paged_kv_indptr_buf=kv_starts, + paged_kv_indices_buf=kv_indices, + paged_kv_last_page_len_buf=kv_last_page_len, + ) + self.prefill_wrapper.plan( + q_starts, + kv_starts, + kv_indices, + kv_last_page_len, + self.backend.tp_q_head_num, + self.backend.tp_kv_head_num, + self.backend.head_dim, + self.backend.page_size, + causal=True, + pos_encoding_mode="NONE", + logits_soft_cap=0.0, + q_data_type=self.backend.q_data_type, + kv_data_type=self.backend.kv_data_type, + ) + + def prefill_att(self, q, k, v, att_control: AttControl = AttControl(), alloc_func=torch.empty): + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + o_tensor = alloc_func(q.shape, q.dtype, device="cuda") + self.prefill_wrapper.run( + q, + ( + k.view(-1, self.backend.page_size, k.shape[1], k.shape[2]), + v.view(-1, self.backend.page_size, v.shape[1], v.shape[2]), + ), + out=o_tensor, + ) + return o_tensor + + +@dataclasses.dataclass +class PagedFlashInferDecodeAttState(BaseDecodeAttState): + kv_last_page_len_buffer: torch.Tensor = None + kv_indices: torch.Tensor = None + kv_starts: torch.Tensor = None + decode_wrapper: object = None + + def init_state(self): + import flashinfer + + self.backend: PagedFlashInferAttBackend = self.backend + device = self.infer_state.input_ids.device + model = self.backend.model + b_page_len = triton.cdiv(self.infer_state.b_seq_len, self.backend.page_size) + self.kv_last_page_len_buffer = self.infer_state.b_seq_len - (b_page_len - 1) * self.backend.page_size + buffer_len = self.infer_state.batch_size * triton.cdiv(self.backend.max_seq_length, self.backend.page_size) + if ( + self.infer_state.batch_size <= model.graph_max_batch_size + and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch + ): + self.kv_indices = self.backend.kv_indices_buffer[self.infer_state.microbatch_index][:buffer_len] + else: + self.kv_indices = torch.empty(buffer_len, dtype=torch.int32, device=device) + + self.kv_starts = self.infer_state.b1_cu_kv_seq_len.int() + self.kv_starts[1:] = b_page_len.cumsum(0) + paged_repack_kv_index( + self.infer_state.req_manager.req_to_token_indexs, + self.infer_state.b_req_idx, + b_page_len, + self.kv_starts[:-1], + triton.cdiv(self.infer_state.max_kv_seq_len, self.backend.page_size), + self.kv_indices, + ) + self.decode_wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( + self.backend.workspace_buffer, + "NHD", + use_cuda_graph=True, + use_tensor_cores=True, + paged_kv_indptr_buffer=self.kv_starts, + paged_kv_indices_buffer=self.kv_indices, + paged_kv_last_page_len_buffer=self.kv_last_page_len_buffer, + ) + self.decode_wrapper.plan( + self.kv_starts, + self.kv_indices, + self.kv_last_page_len_buffer, + self.backend.tp_q_head_num, + self.backend.tp_kv_head_num, + self.backend.head_dim, + self.backend.page_size, + q_data_type=self.backend.q_data_type, + kv_data_type=self.backend.kv_data_type, + non_blocking=True, + ) + + def copy_for_decode_cuda_graph(self, new_state): + super().copy_for_decode_cuda_graph(new_state) + self.decode_wrapper.plan( + new_state.kv_starts, + new_state.kv_indices, + new_state.kv_last_page_len_buffer, + new_state.backend.tp_q_head_num, + new_state.backend.tp_kv_head_num, + new_state.backend.head_dim, + new_state.backend.page_size, + q_data_type=new_state.backend.q_data_type, + kv_data_type=new_state.backend.kv_data_type, + non_blocking=True, + ) + + def decode_att(self, q, k, v, att_control: AttControl = AttControl(), alloc_func=torch.empty): + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + o_tensor = alloc_func(q.shape, q.dtype) + self.decode_wrapper.run( + q, + ( + k.view(-1, self.backend.page_size, k.shape[1], k.shape[2]), + v.view(-1, self.backend.page_size, v.shape[1], v.shape[2]), + ), + out=o_tensor, + ) + return o_tensor diff --git a/lightllm/common/basemodel/attention/paged_flashinfer/mla.py b/lightllm/common/basemodel/attention/paged_flashinfer/mla.py new file mode 100644 index 0000000000..c9ea38052f --- /dev/null +++ b/lightllm/common/basemodel/attention/paged_flashinfer/mla.py @@ -0,0 +1,184 @@ +import dataclasses +import torch +import triton +from typing import Tuple +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from lightllm.utils.dist_utils import get_dp_world_size, get_current_device_id +from ...triton_kernel.repack_kv_index import paged_repack_kv_index +from lightllm.utils.envs_utils import get_page_size +from ..flashinfer.env_utils import set_flashinfer_envs + + +class PagedMlaFlashInferAttBackend(BaseAttBackend): + def __init__(self, model, page_size=None): + set_flashinfer_envs() + super().__init__(model=model) + self.page_size = page_size or get_page_size() + num_heads = model.config["num_attention_heads"] + self.tp_q_head_num = num_heads // get_dp_world_size() + self.qk_nope_head_dim = model.qk_nope_head_dim + self.qk_rope_head_dim = model.qk_rope_head_dim + self.kv_lora_rank = model.kv_lora_rank + self.v_head_dim = model.v_head_dim + self.q_data_type = model.data_type + self.kv_data_type = model.data_type + self.workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device=get_current_device_id()) + self.max_seq_length = model.max_seq_length + self.softmax_scale = (self.qk_nope_head_dim + self.qk_rope_head_dim) ** (-0.5) + buffer_len = model.graph_max_batch_size * triton.cdiv(self.max_seq_length, self.page_size) + self.kv_indices_buffer = [ + torch.empty(buffer_len, dtype=torch.int32, device=get_current_device_id()), + torch.empty(buffer_len, dtype=torch.int32, device=get_current_device_id()), + ] + + from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale + + if model.config["rope_scaling"] is not None: + rope_scaling = model.config["rope_scaling"] + mscale_all_dim = rope_scaling.get("mscale_all_dim", 0) + scaling_factor = rope_scaling["factor"] + if mscale_all_dim: + mscale = get_deepseek_mscale(scaling_factor, mscale_all_dim) + self.softmax_scale = self.softmax_scale * mscale * mscale + + def create_att_prefill_state(self, infer_state): + return PagedMlaFlashInferPrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state): + return PagedMlaFlashInferDecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class PagedMlaFlashInferPrefillAttState(BasePrefillAttState): + prefill_wrapper: object = None + + def init_state(self): + self.backend: PagedMlaFlashInferAttBackend = self.backend + import flashinfer + + q_starts = self.infer_state.b1_cu_q_seq_len.int() + kv_starts = self.infer_state.b1_cu_kv_seq_len.int() + if self.prefill_wrapper is None: + self.prefill_wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( + self.backend.workspace_buffer, "NHD" + ) + self.prefill_wrapper.plan( + qo_indptr=q_starts, + kv_indptr=kv_starts, + num_qo_heads=self.backend.tp_q_head_num, + num_kv_heads=self.backend.tp_q_head_num, + head_dim_qk=self.backend.qk_nope_head_dim + self.backend.qk_rope_head_dim, + head_dim_vo=self.backend.v_head_dim, + q_data_type=self.backend.q_data_type, + causal=True, + sm_scale=self.backend.softmax_scale, + ) + + def prefill_att( + self, q, k: Tuple[torch.Tensor, torch.Tensor], v, att_control: AttControl = AttControl(), alloc_func=torch.empty + ): + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + k_nope, k_rope = k + o_tensor = alloc_func((q.shape[0], q.shape[1], v.shape[-1]), q.dtype, device="cuda") + q_head_num = q.shape[1] + k = torch.cat([k_nope, torch.repeat_interleave(k_rope, q_head_num, dim=-2)], dim=-1) + self.prefill_wrapper.run(q, k, v, out=o_tensor) + return o_tensor + + +@dataclasses.dataclass +class PagedMlaFlashInferDecodeAttState(BaseDecodeAttState): + kv_indices: torch.Tensor = None + kv_starts: torch.Tensor = None + decode_wrapper: object = None + + def init_state(self): + import flashinfer + + self.backend: PagedMlaFlashInferAttBackend = self.backend + model = self.backend.model + device = self.infer_state.input_ids.device + batch_size = self.infer_state.batch_size + self.kv_starts = self.infer_state.b1_cu_kv_seq_len + self.q_indptr = torch.arange(batch_size + 1, dtype=torch.int32, device="cuda") + buffer_len = batch_size * triton.cdiv(self.backend.max_seq_length, self.backend.page_size) + if batch_size <= model.graph_max_batch_size and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch: + self.kv_indices = self.backend.kv_indices_buffer[self.infer_state.microbatch_index][:buffer_len] + else: + self.kv_indices = torch.empty(buffer_len, dtype=torch.int32, device=device) + + b_page_len = triton.cdiv(self.infer_state.b_seq_len, self.backend.page_size) + self.kv_starts[1:] = b_page_len.cumsum(0) + paged_repack_kv_index( + self.infer_state.req_manager.req_to_token_indexs, + self.infer_state.b_req_idx, + b_page_len, + self.kv_starts[:-1], + triton.cdiv(self.infer_state.max_kv_seq_len, self.backend.page_size), + self.kv_indices, + ) + self.decode_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( + self.backend.workspace_buffer, + use_cuda_graph=True, + qo_indptr=self.q_indptr, + kv_indices=self.kv_indices, + kv_indptr=self.kv_starts, + kv_len_arr=self.infer_state.b_seq_len, + ) + self.decode_wrapper.plan( + self.q_indptr, + self.kv_starts, + self.kv_indices, + self.infer_state.b_seq_len, + self.backend.tp_q_head_num, + self.backend.kv_lora_rank, + self.backend.qk_rope_head_dim, + self.backend.page_size, + False, + self.backend.softmax_scale, + self.backend.q_data_type, + self.backend.kv_data_type, + ) + + def copy_for_decode_cuda_graph(self, new_state): + super().copy_for_decode_cuda_graph(new_state) + self.decode_wrapper.plan( + new_state.q_indptr, + new_state.kv_starts, + new_state.kv_indices, + new_state.infer_state.b_seq_len, + new_state.backend.tp_q_head_num, + new_state.backend.kv_lora_rank, + new_state.backend.qk_rope_head_dim, + new_state.backend.page_size, + False, + new_state.backend.softmax_scale, + new_state.backend.q_data_type, + new_state.backend.kv_data_type, + ) + + def decode_att( + self, q: Tuple[torch.Tensor, torch.Tensor], k, v, att_control: AttControl = AttControl(), alloc_func=torch.empty + ): + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + assert v is None + q_nope, q_rope = q + qk_rope_head_dim = 64 + o_tensor = alloc_func(q_nope.shape, dtype=q_nope.dtype, device=q_nope.device) + self.decode_wrapper.run( + q_nope, + q_rope, + k[:, :, :-qk_rope_head_dim].view(-1, self.backend.page_size, 1, k.shape[-1] - qk_rope_head_dim), + k[:, :, -qk_rope_head_dim:].view(-1, self.backend.page_size, 1, qk_rope_head_dim), + out=o_tensor, + return_lse=False, + ) + return o_tensor diff --git a/lightllm/common/basemodel/attention/triton/fp.py b/lightllm/common/basemodel/attention/triton/fp.py index d29f15ec3b..23bf245af5 100644 --- a/lightllm/common/basemodel/attention/triton/fp.py +++ b/lightllm/common/basemodel/attention/triton/fp.py @@ -30,7 +30,7 @@ def prefill_att( assert att_control.tp_alibi is not None return self._alibi_prefill_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func) else: - return self._nomarl_prefill_att(q=q, k=k, v=v, alloc_func=alloc_func) + return self._normal_prefill_att(q=q, k=k, v=v, alloc_func=alloc_func) def _alibi_prefill_att( self, @@ -59,7 +59,7 @@ def _alibi_prefill_att( ) return out - def _nomarl_prefill_att(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, alloc_func=torch.empty): + def _normal_prefill_att(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, alloc_func=torch.empty): from ...triton_kernel.att.prefill_att.context_flashattention_nopad import context_attention_fwd out = alloc_func(q.shape, q.dtype) diff --git a/lightllm/common/basemodel/attention_vit/create_utils.py b/lightllm/common/basemodel/attention_vit/create_utils.py index 67f830ba0d..c4a56dd4db 100644 --- a/lightllm/common/basemodel/attention_vit/create_utils.py +++ b/lightllm/common/basemodel/attention_vit/create_utils.py @@ -4,6 +4,7 @@ from lightllm.utils.backend_validator import _validate from lightllm.common.basemodel.attention_vit.base_att import BaseVitAttBackend from lightllm.common.basemodel.attention_vit.fa3.fp import Fa3VitAttBackend +from lightllm.common.basemodel.attention_vit.fa4.fp import Fa4VitAttBackend from lightllm.common.basemodel.attention_vit.triton.fp import TritonVitAttBackend from lightllm.common.basemodel.attention_vit.sdpa.fp import SdpaVitAttBackend from lightllm.common.basemodel.attention_vit.xformers.fp import XformersVitAttBackend @@ -15,6 +16,7 @@ "triton": TritonVitAttBackend, "sdpa": SdpaVitAttBackend, "fa3": Fa3VitAttBackend, + "fa4": Fa4VitAttBackend, "xformers": XformersVitAttBackend, } diff --git a/lightllm/common/basemodel/attention_vit/fa4/__init__.py b/lightllm/common/basemodel/attention_vit/fa4/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/common/basemodel/attention_vit/fa4/fp.py b/lightllm/common/basemodel/attention_vit/fa4/fp.py new file mode 100644 index 0000000000..a685fc4972 --- /dev/null +++ b/lightllm/common/basemodel/attention_vit/fa4/fp.py @@ -0,0 +1,38 @@ +import torch + +from lightllm.common.basemodel.attention_vit.base_att import BaseVitAttBackend +from lightllm.utils.fa4_utils import ( + ensure_fa4_available, + ensure_fa4_supported_gpu, + _flash_attn_fwd, +) + + +class Fa4VitAttBackend(BaseVitAttBackend): + def __init__(self): + ensure_fa4_available() + ensure_fa4_supported_gpu() + + @staticmethod + def _vit_att_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + ) -> None: + head_dim = q.shape[-1] + return _flash_attn_fwd( + out=o, + q=q, + k=k, + v=v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + softmax_scale=head_dim ** -0.5, + causal=False, + return_lse=False, + ) diff --git a/lightllm/common/basemodel/layer_infer/cache_tensor_manager.py b/lightllm/common/basemodel/layer_infer/cache_tensor_manager.py index 7889e8090e..8bcf99b992 100644 --- a/lightllm/common/basemodel/layer_infer/cache_tensor_manager.py +++ b/lightllm/common/basemodel/layer_infer/cache_tensor_manager.py @@ -33,6 +33,7 @@ class BufNode: inner_tensor: torch.Tensor shape_key: Tuple[int, torch.dtype] storage_weak_ptr: int + free_use_count_bias: int = 0 shape_to_tensor: Dict[Union[torch.Size, Iterable[int]], torch.Tensor] = field(default_factory=dict) def __del__(self): @@ -99,7 +100,8 @@ def alloc_tensor( # 回收可能消亡的 tensor for ptr in self.changed_ptr: t_buf_node = self.ptr_to_bufnode[ptr] - if self.use_count(ptr) == 1 + len(t_buf_node.shape_to_tensor): + free_use_count = t_buf_node.free_use_count_bias + 1 + len(t_buf_node.shape_to_tensor) + if self.use_count(ptr) <= free_use_count: self.free_shape_dtype_to_bufs[t_buf_node.shape_key].append(t_buf_node) self.changed_ptr.clear() @@ -131,6 +133,7 @@ def alloc_tensor( self.ptr_to_bufnode[storage_weak_ptr] = buf_node if shape not in buf_node.shape_to_tensor: buf_node.shape_to_tensor[shape] = buf_node.inner_tensor.view(shape) + buf_node.free_use_count_bias = self.use_count(storage_weak_ptr) - (1 + len(buf_node.shape_to_tensor)) mark_tensor = buf_node.shape_to_tensor[shape] ans = mark_tensor.data # 返回一个新的引用, 否则引用计数会无法判断 ans.storage_weak_ptr = buf_node.storage_weak_ptr diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py index 8f54e14a72..8efd9e90f0 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py @@ -11,6 +11,7 @@ from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe.impl import select_fuse_moe_impl from lightllm.common.quantization.quantize_method import QuantizationMethod from lightllm.utils.envs_utils import get_redundancy_expert_ids, get_redundancy_expert_num, get_env_start_args +from lightllm.utils.device_utils import is_sm100_gpu from lightllm.utils.dist_utils import get_global_world_size, get_global_rank from lightllm.utils.log_utils import init_logger @@ -48,6 +49,7 @@ def __init__( self.quant_method = quant_method assert num_fused_shared_experts in [0, 1], "num_fused_shared_experts can only support 0 or 1 now." self.enable_ep_moe = get_env_start_args().enable_ep_moe + self.quant_method = self._maybe_upgrade_quant_method_for_ep_moe(self.quant_method) self.n_routed_experts = n_routed_experts self.num_fused_shared_experts = num_fused_shared_experts self._init_config(network_config) @@ -66,6 +68,28 @@ def __init__( self.lock = threading.Lock() self._create_weight() + def _maybe_upgrade_quant_method_for_ep_moe(self, quant_method: QuantizationMethod) -> QuantizationMethod: + if not self.enable_ep_moe: + return quant_method + + target_method = "deepgemm-fp8fp4-b32" if is_sm100_gpu() else "deepgemm-fp8w8a8-b128" + if quant_method.method_name == "none": + from lightllm.common.quantization.registry import QUANTMETHODS + + logger.info( + f"enable_ep_moe requires DeepGEMM MoE expert weights; " + f"auto-upgrading fused_moe quantization from `none` to `{target_method}`." + ) + quant_method = QUANTMETHODS.get(target_method) + + if quant_method.method_name != target_method: + raise ValueError( + f"enable_ep_moe currently requires `{target_method}` for fused_moe on this GPU, " + f"but got `{quant_method.method_name}`." + ) + + return quant_method + def _init_config(self, network_config: Dict[str, Any]): self.n_group = network_config.get("n_group", 0) self.use_grouped_topk = self.n_group > 0 @@ -130,6 +154,7 @@ def experts( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + is_cuda_graph: bool = False, ) -> torch.Tensor: """Backward compatible method that routes to platform-specific implementation.""" return self.fuse_moe_impl( @@ -145,8 +170,12 @@ def experts( topk_group=topk_group, num_expert_group=num_expert_group, is_prefill=is_prefill, + is_cuda_graph=is_cuda_graph, ) + def use_sm100_mega_moe(self) -> bool: + return bool(getattr(self.fuse_moe_impl, "_use_sm100_fp4_moe", lambda: False)()) + def low_latency_dispatch( self, hidden_states: torch.Tensor, diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py index bdd86eb51e..6c64cb388b 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py @@ -4,11 +4,15 @@ from lightllm.distributed import dist_group_manager from lightllm.common.triton_utils.autotuner import Autotuner from lightllm.common.quantization.quantize_method import WeightPack -from lightllm.utils.envs_utils import get_deepep_num_max_dispatch_tokens_per_rank +from lightllm.utils.envs_utils import ( + get_deepep_num_max_dispatch_tokens_per_rank_prefill, + get_deepep_num_max_dispatch_tokens_per_rank_decode, +) from lightllm.common.basemodel.triton_kernel.fused_moe.grouped_fused_moe_ep import ( fused_experts_impl, masked_group_gemm, - _deepgemm_grouped_fp8_nt_contiguous, + deepgemm_grouped_fp8_fp4_nt_contiguous, + deepgemm_grouped_fp8_nt_contiguous, ) from lightllm.common.basemodel.triton_kernel.quantization.fp8act_quant_kernel import ( per_token_group_quant_fp8, @@ -17,9 +21,132 @@ from lightllm.common.basemodel.triton_kernel.fused_moe.deepep_scatter_gather import ep_scatter, ep_gather from lightllm.common.basemodel.triton_kernel.fused_moe.moe_silu_and_mul import silu_and_mul_fwd from lightllm.common.basemodel.triton_kernel.redundancy_topk_ids_repair import redundancy_topk_ids_repair +from lightllm.utils.device_utils import is_sm100_gpu class FuseMoeDeepGEMM(FuseMoeTriton): + def _get_ep_num_sms(self) -> int: + return getattr(dist_group_manager, "ep_num_sms", None) or 0 + + def _use_sm100_fp4_moe(self) -> bool: + return is_sm100_gpu() and self.quant_method.method_name == "deepgemm-fp8fp4-b32" + + def _get_mega_moe_weights(self, w13: WeightPack, w2: WeightPack): + cache_key = ( + w13.weight.data_ptr(), + w13.weight_scale.data_ptr(), + w2.weight.data_ptr(), + w2.weight_scale.data_ptr(), + ) + if getattr(self, "_mega_moe_weight_cache_key", None) != cache_key: + import deep_gemm + + self._mega_moe_weight_cache = deep_gemm.transform_weights_for_mega_moe( + (w13.weight, w13.weight_scale), + (w2.weight, w2.weight_scale), + ) + self._mega_moe_weight_cache_key = cache_key + return self._mega_moe_weight_cache + + def _get_mega_moe_stats(self, num_local_experts: int, device: torch.device): + stats = getattr(self, "_mega_moe_stats", None) + if stats is None or stats.numel() != num_local_experts or stats.device != device: + stats = torch.zeros((num_local_experts,), device=device, dtype=torch.int32) + self._mega_moe_stats = stats + return stats + + def _mega_moe( + self, + hidden_states: torch.Tensor, + w13: WeightPack, + w2: WeightPack, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + ) -> torch.Tensor: + import deep_gemm + from deep_gemm.utils import per_token_cast_to_fp8 + + buffer = getattr(dist_group_manager, "ep_mega_moe_buffer", None) + if buffer is None: + raise RuntimeError("SM100 Mega MoE requires dist_group_manager.ep_mega_moe_buffer to be initialized") + + num_tokens = hidden_states.shape[0] + if num_tokens > buffer.num_max_tokens_per_rank: + raise RuntimeError( + f"Mega MoE got {num_tokens} tokens, exceeding num_max_tokens_per_rank={buffer.num_max_tokens_per_rank}" + ) + + qinput_tensor = per_token_cast_to_fp8( + hidden_states, + use_ue8m0=True, + gran_k=self.quant_method.block_size, + use_packed_ue8m0=True, + ) + l1_weights, l2_weights = self._get_mega_moe_weights(w13, w2) + cumulative_stats = self._get_mega_moe_stats(w13.weight.shape[0], hidden_states.device) + buffer.x[:num_tokens].copy_(qinput_tensor[0]) + buffer.x_sf[:num_tokens].copy_(qinput_tensor[1]) + buffer.topk_idx[:num_tokens].copy_(topk_ids) + buffer.topk_weights[:num_tokens].copy_(topk_weights) + + output = torch.empty_like(hidden_states) + deep_gemm.fp8_fp4_mega_moe( + output, + l1_weights, + l2_weights, + buffer, + cumulative_local_expert_recv_stats=cumulative_stats, + ) + return output + + def _sm100_fp4_cuda_graph_moe( + self, + hidden_states: torch.Tensor, + w13: WeightPack, + w2: WeightPack, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + ) -> torch.Tensor: + from deep_gemm.utils import per_token_cast_to_fp8 + + buffer = getattr(dist_group_manager, "ep_buffer", None) + if buffer is None: + raise RuntimeError("SM100 CUDA graph MoE fallback requires dist_group_manager.ep_buffer") + + num_max_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank_decode() + qinput_tensor = per_token_cast_to_fp8( + hidden_states, + use_ue8m0=True, + gran_k=self.quant_method.block_size, + use_packed_ue8m0=True, + ) + alignment = getattr(dist_group_manager, "ep_expert_alignment", 128) + cumulative_stats = self._get_mega_moe_stats(w13.weight.shape[0], hidden_states.device) + recv_x, recv_topk_idx, recv_topk_weights, handle, _ = buffer.dispatch( + qinput_tensor, + topk_idx=topk_ids, + topk_weights=topk_weights, + cumulative_local_expert_recv_stats=cumulative_stats, + num_experts=self.total_expert_num_contain_redundancy, + num_max_tokens_per_rank=num_max_tokens_per_rank, + expert_alignment=alignment, + do_cpu_sync=False, + do_handle_copy=False, + do_expand=True, + use_tma_aligned_col_major_sf=True, + ) + gemm_out = self.prefilled_group_gemm( + handle.psum_num_recv_tokens_per_expert, + recv_x, + recv_topk_idx, + recv_topk_weights, + w13, + w2, + hidden_states.dtype, + ) + combined_x, _, _ = buffer.combine(gemm_out, handle=handle, topk_weights=None) + return combined_x + def _select_experts( self, input_tensor: torch.Tensor, @@ -68,11 +195,21 @@ def _fused_experts( topk_ids: torch.Tensor, router_logits: Optional[torch.Tensor] = None, is_prefill: Optional[bool] = None, + is_cuda_graph: bool = False, ): w13_weight, w13_scale = w13.weight, w13.weight_scale w2_weight, w2_scale = w2.weight, w2.weight_scale + if self._use_sm100_fp4_moe(): + # DeepGEMM's official Mega MoE example is an eager fused path. For + # decode CUDA graph, use the official ElasticBuffer + grouped GEMM + # baseline instead of capturing Mega MoE's NVLink barrier kernel. + if is_cuda_graph and not is_prefill: + return self._sm100_fp4_cuda_graph_moe(input_tensor, w13, w2, topk_weights, topk_ids.to(torch.long)) + return self._mega_moe(input_tensor, w13, w2, topk_weights, topk_ids.to(torch.long)) + use_fp8_w8a8 = self.quant_method.method_name != "none" + buffer = dist_group_manager.ep_buffer if is_prefill else dist_group_manager.ep_low_latency_buffer output = fused_experts_impl( hidden_states=input_tensor, w1=w13_weight, @@ -80,7 +217,7 @@ def _fused_experts( topk_weights=topk_weights, topk_idx=topk_ids.to(torch.long), num_experts=self.total_expert_num_contain_redundancy, # number of all experts contain redundancy - buffer=dist_group_manager.ep_buffer, + buffer=buffer, is_prefill=is_prefill, use_fp8_w8a8=use_fp8_w8a8, use_fp8_all2all=use_fp8_w8a8, @@ -116,13 +253,13 @@ def low_latency_dispatch( ) topk_idx = topk_idx.to(torch.long) - num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank() + num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank_decode() use_fp8_w8a8 = self.quant_method.method_name != "none" - recv_x, masked_m, handle, event, hook = dist_group_manager.ep_buffer.low_latency_dispatch( - hidden_states, - topk_idx, - num_max_dispatch_tokens_per_rank, - self.total_expert_num_contain_redundancy, + recv_x, masked_m, handle, event, hook = dist_group_manager.ep_low_latency_buffer.low_latency_dispatch( + topk_idx=topk_idx, + x=hidden_states, + num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank, + num_experts=self.total_expert_num_contain_redundancy, use_fp8=use_fp8_w8a8, async_finish=False, return_recv_hook=True, @@ -154,6 +291,17 @@ def select_experts_and_quant_input( scoring_func=scoring_func, ) w13_weight, w13_scale = w13.weight, w13.weight_scale + if self._use_sm100_fp4_moe(): + from deep_gemm.utils import per_token_cast_to_fp8 + + qinput_tensor = per_token_cast_to_fp8( + hidden_states, + use_ue8m0=True, + gran_k=self.quant_method.block_size, + use_packed_ue8m0=True, + ) + return topk_weights, topk_idx.to(torch.long), qinput_tensor + block_size_k = 0 if w13_weight.ndim == 3: block_size_k = w13_weight.shape[2] // w13_scale.shape[2] @@ -169,38 +317,49 @@ def dispatch( overlap_event: Optional[Any] = None, ): buffer = dist_group_manager.ep_buffer - # get_dispatch_layout - ( - num_tokens_per_rank, - num_tokens_per_rdma_rank, - num_tokens_per_expert, - is_token_in_rank, - previous_event, - ) = buffer.get_dispatch_layout( - topk_idx, - self.total_expert_num_contain_redundancy, - previous_event=overlap_event, - async_finish=True, - allocate_on_comm_stream=True, - ) - recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = buffer.dispatch( + num_max_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank_prefill() + if self._use_sm100_fp4_moe(): + recv_x, recv_topk_idx, recv_topk_weights, handle, event = buffer.dispatch( + qinput_tensor, + topk_idx=topk_idx, + topk_weights=topk_weights, + num_experts=self.total_expert_num_contain_redundancy, + num_max_tokens_per_rank=num_max_tokens_per_rank, + expert_alignment=128, + num_sms=self._get_ep_num_sms(), + previous_event=overlap_event, + async_with_compute_stream=True, + allocate_on_comm_stream=True, + do_cpu_sync=False, + do_handle_copy=False, + do_expand=True, + use_tma_aligned_col_major_sf=True, + ) + + def hook(): + event.current_stream_wait() + + return recv_x, recv_topk_idx, recv_topk_weights, handle.psum_num_recv_tokens_per_expert, handle, hook + + recv_x, recv_topk_idx, recv_topk_weights, handle, event = buffer.dispatch( qinput_tensor, topk_idx=topk_idx, topk_weights=topk_weights, - num_tokens_per_rank=num_tokens_per_rank, - num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, - is_token_in_rank=is_token_in_rank, - num_tokens_per_expert=num_tokens_per_expert, - previous_event=previous_event, - async_finish=True, - allocate_on_comm_stream=True, + num_experts=self.total_expert_num_contain_redundancy, + num_max_tokens_per_rank=num_max_tokens_per_rank, expert_alignment=128, + num_sms=self._get_ep_num_sms(), + previous_event=overlap_event, + async_with_compute_stream=True, + allocate_on_comm_stream=True, + do_cpu_sync=True, + do_handle_copy=False, ) def hook(): event.current_stream_wait() - return recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, hook + return recv_x, recv_topk_idx, recv_topk_weights, handle.num_recv_tokens_per_expert_list, handle, hook def masked_group_gemm( self, @@ -233,6 +392,40 @@ def prefilled_group_gemm( _, K = recv_x[0].shape _, N, _ = w13_weight.shape block_size = self.quant_method.block_size + if self._use_sm100_fp4_moe(): + n = recv_x[0].shape[0] + l1_y = torch.empty((n, N), device=device, dtype=hidden_dtype) + deepgemm_grouped_fp8_fp4_nt_contiguous( + recv_x, + (w13_weight, w13_scale), + l1_y, + num_recv_tokens_per_expert_list, + use_psum_layout=True, + ) + silu_out = torch.empty((n, N // 2), device=device, dtype=hidden_dtype) + silu_and_mul_fwd(l1_y.view(-1, N), silu_out) + if recv_topk_weights is not None: + recv_topk_weights = recv_topk_weights.reshape(-1)[:n] + silu_out.mul_(recv_topk_weights.view(-1, 1)) + + from deep_gemm.utils import per_token_cast_to_fp8 + + qsilu_out = per_token_cast_to_fp8( + silu_out, + use_ue8m0=True, + gran_k=block_size, + use_packed_ue8m0=True, + ) + l2_y = torch.empty((n, K), device=device, dtype=hidden_dtype) + deepgemm_grouped_fp8_fp4_nt_contiguous( + qsilu_out, + (w2_weight, w2_scale), + l2_y, + num_recv_tokens_per_expert_list, + use_psum_layout=True, + ) + return l2_y + # scatter all_tokens = sum(num_recv_tokens_per_expert_list) # calcu padding all nums. # gather_out shape [recive_num_tokens, hidden] @@ -272,7 +465,7 @@ def prefilled_group_gemm( # groupgemm (contiguous layout) gemm_out_a = torch.empty((all_tokens, N), device=device, dtype=hidden_dtype) - _deepgemm_grouped_fp8_nt_contiguous(input_tensor, (w13_weight, w13_scale), gemm_out_a, m_indices) + deepgemm_grouped_fp8_nt_contiguous(input_tensor, (w13_weight, w13_scale), gemm_out_a, m_indices) # silu_and_mul_fwd + qaunt # TODO fused kernel @@ -286,7 +479,7 @@ def prefilled_group_gemm( # groupgemm (contiguous layout) gemm_out_b = torch.empty((all_tokens, K), device=device, dtype=hidden_dtype) - _deepgemm_grouped_fp8_nt_contiguous( + deepgemm_grouped_fp8_nt_contiguous( (qsilu_out, qsilu_out_scale), (w2_weight, w2_scale), gemm_out_b, m_indices ) # gather and local reduce @@ -310,7 +503,7 @@ def low_latency_combine( topk_weights: torch.Tensor, handle: Any, ): - combined_x, event_overlap, hook = dist_group_manager.ep_buffer.low_latency_combine( + combined_x, event_overlap, hook = dist_group_manager.ep_low_latency_buffer.low_latency_combine( gemm_out_b, topk_idx, topk_weights, handle, async_finish=False, return_recv_hook=True ) return combined_x, hook @@ -326,8 +519,9 @@ def combine( gemm_out_b, handle, topk_weights=None, - async_finish=True, + num_sms=self._get_ep_num_sms(), previous_event=overlap_event, + async_with_compute_stream=True, allocate_on_comm_stream=True, ) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py index 6391a10800..c8bd8f806d 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py @@ -29,6 +29,7 @@ def _fused_experts( topk_ids: torch.Tensor, router_logits: Optional[torch.Tensor] = None, is_prefill: Optional[bool] = None, + is_cuda_graph: bool = False, ): w1_weight, w1_scale, w1_zero_point = w13.weight, w13.weight_scale, w13.weight_zero_point diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py index d6e923a115..6fb39662c1 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py @@ -91,6 +91,7 @@ def _fused_experts( topk_ids: torch.Tensor, router_logits: Optional[torch.Tensor] = None, is_prefill: bool = False, + is_cuda_graph: bool = False, ): w13_weight, w13_scale = w13.weight, w13.weight_scale w2_weight, w2_scale = w2.weight, w2.weight_scale @@ -125,6 +126,7 @@ def __call__( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + is_cuda_graph: bool = False, ): topk_weights, topk_ids = self._select_experts( input_tensor=input_tensor, @@ -145,5 +147,6 @@ def __call__( topk_ids=topk_ids, router_logits=router_logits, is_prefill=is_prefill, + is_cuda_graph=is_cuda_graph, ) return output diff --git a/lightllm/common/basemodel/triton_kernel/fa3_utils.py b/lightllm/common/basemodel/triton_kernel/fa3_utils.py index 0a524b63b6..f9d1c9e9c6 100644 --- a/lightllm/common/basemodel/triton_kernel/fa3_utils.py +++ b/lightllm/common/basemodel/triton_kernel/fa3_utils.py @@ -1,5 +1,6 @@ import triton import triton.language as tl +from lightllm.utils.envs_utils import get_page_size @triton.jit @@ -37,6 +38,13 @@ def page_table_copy( assert page_table.dim() == 2, "page_table should be 2D" assert req_to_token_indexs.dim() == 2, "req_to_token_indexs should be 2D" + page_size = get_page_size() + if page_size > 1: + max_seq_len_k = page_table.shape[1] * page_size + sampled = req_to_token_indexs[b_req_idx, :max_seq_len_k:page_size] + page_table.copy_(sampled // page_size) + return + max_seq_len_k = page_table.shape[1] batch_size = page_table.size(0) BLOCK_SIZE = 128 diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py index 2c6d013bd5..77705b1755 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py @@ -1,10 +1,7 @@ """Fused MoE kernel.""" -import os import torch import triton -import triton.language as tl from typing import Any, Callable, Dict, Optional, Tuple -import torch.distributed as dist from lightllm.utils.log_utils import init_logger from lightllm.common.basemodel.triton_kernel.fused_moe.moe_silu_and_mul import silu_and_mul_fwd from lightllm.common.basemodel.triton_kernel.fused_moe.moe_silu_and_mul_mix_quant_ep import ( @@ -15,9 +12,11 @@ tma_align_input_scale, ) from lightllm.common.basemodel.triton_kernel.fused_moe.deepep_scatter_gather import ep_scatter, ep_gather -from lightllm.utils.envs_utils import get_deepep_num_max_dispatch_tokens_per_rank +from lightllm.utils.envs_utils import ( + get_deepep_num_max_dispatch_tokens_per_rank_prefill, + get_deepep_num_max_dispatch_tokens_per_rank_decode, +) from lightllm.common.triton_utils.autotuner import Autotuner -import numpy as np logger = init_logger(__name__) @@ -66,14 +65,14 @@ def fused_experts_impl( topk_weights: torch.Tensor, # [M, topk] topk_idx: torch.Tensor, # [M, topk] num_experts: int, - buffer: "Buffer", + buffer: Any, is_prefill: bool, use_fp8_w8a8: bool = False, use_fp8_all2all: bool = False, use_int8_w8a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, - previous_event: Optional["EventOverlap"] = None, + previous_event: Optional[Any] = None, ): # Check constraints. assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" @@ -99,39 +98,27 @@ def fused_experts_impl( combined_x = None if is_prefill: qinput_tensor, input_scale = per_token_group_quant_fp8(hidden_states, block_size_k, dtype=w1.dtype) - - # get_dispatch_layout - ( - num_tokens_per_rank, - num_tokens_per_rdma_rank, - num_tokens_per_expert, - is_token_in_rank, - previous_event, - ) = buffer.get_dispatch_layout( - topk_idx, num_experts, previous_event=previous_event, async_finish=False, allocate_on_comm_stream=False - ) - + allocate_on_comm_stream = previous_event is not None # normal dispatch # recv_x [recive_num_tokens, hidden] recv_x_scale [recive_num_tokens, hidden // block_size] # recv_topk_idx [recive_num_tokens, topk_num] # recv_topk_weights [recive_num_tokens, topk_num] # num_recv_tokens_per_expert_list list [cur_node_expert_num] padding with expert_alignment=128 - recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = buffer.dispatch( + recv_x, recv_topk_idx, recv_topk_weights, handle, _ = buffer.dispatch( (qinput_tensor, input_scale), topk_idx=topk_idx, topk_weights=topk_weights, - num_tokens_per_rank=num_tokens_per_rank, - num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, - is_token_in_rank=is_token_in_rank, - num_tokens_per_expert=num_tokens_per_expert, - previous_event=previous_event, - async_finish=False, - allocate_on_comm_stream=False, + num_experts=num_experts, + num_max_tokens_per_rank=get_deepep_num_max_dispatch_tokens_per_rank_prefill(), expert_alignment=128, + previous_event=previous_event, + allocate_on_comm_stream=allocate_on_comm_stream, + do_cpu_sync=True, + do_handle_copy=False, ) # scatter - all_tokens = sum(num_recv_tokens_per_expert_list) # calcu padding all nums. + all_tokens = sum(handle.num_recv_tokens_per_expert_list) # calcu padding all nums. # gather_out shape [recive_num_tokens, hidden] gather_out = torch.empty_like(recv_x[0], device=hidden_states.device, dtype=hidden_states.dtype) if all_tokens > 0: @@ -149,7 +136,7 @@ def fused_experts_impl( output_index = torch.empty_like(recv_topk_idx) num_recv_tokens_per_expert = torch.tensor( - num_recv_tokens_per_expert_list, dtype=torch.int32, pin_memory=True, device="cpu" + handle.num_recv_tokens_per_expert_list, dtype=torch.int32, pin_memory=True, device="cpu" ).cuda(non_blocking=True) expert_start_loc = torch.empty_like(num_recv_tokens_per_expert) @@ -169,7 +156,7 @@ def fused_experts_impl( # groupgemm (contiguous layout) gemm_out_a = torch.empty((all_tokens, N), device=hidden_states.device, dtype=hidden_states.dtype) input_tensor[1] = tma_align_input_scale(input_tensor[1]) - _deepgemm_grouped_fp8_nt_contiguous(input_tensor, (w1, w1_scale), gemm_out_a, m_indices) + deepgemm_grouped_fp8_nt_contiguous(input_tensor, (w1, w1_scale), gemm_out_a, m_indices) # silu_and_mul_fwd + qaunt # TODO fused kernel @@ -183,7 +170,7 @@ def fused_experts_impl( # groupgemm (contiguous layout) gemm_out_b = torch.empty((all_tokens, K), device=hidden_states.device, dtype=hidden_states.dtype) - _deepgemm_grouped_fp8_nt_contiguous((qsilu_out, qsilu_out_scale), (w2, w2_scale), gemm_out_b, m_indices) + deepgemm_grouped_fp8_nt_contiguous((qsilu_out, qsilu_out_scale), (w2, w2_scale), gemm_out_b, m_indices) # gather and local reduce ep_gather(gemm_out_b, recv_topk_idx, recv_topk_weights, output_index, gather_out) @@ -202,13 +189,12 @@ def fused_experts_impl( gather_out, handle, topk_weights=None, - async_finish=False, previous_event=previous_event, - allocate_on_comm_stream=False, + allocate_on_comm_stream=allocate_on_comm_stream, ) else: # low latency dispatch - num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank() + num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank_decode() expected_m = triton.cdiv(hidden_states.shape[0] * buffer.group_size * topk_idx.shape[1], num_experts) recv_x, masked_m, handle, event, hook = buffer.low_latency_dispatch( hidden_states, @@ -228,7 +214,7 @@ def fused_experts_impl( return combined_x -def _deepgemm_grouped_fp8_nt_contiguous( +def deepgemm_grouped_fp8_nt_contiguous( input_tuple: Tuple[torch.Tensor, torch.Tensor], w_tuple: Tuple[torch.Tensor, torch.Tensor], out: torch.Tensor, @@ -255,3 +241,22 @@ def _deepgemm_grouped_fp8_nt_masked( if hasattr(deep_gemm, "m_grouped_gemm_fp8_fp8_bf16_nt_masked"): return deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(input_tuple, w_tuple, out, masked_m, expected_m) raise RuntimeError("deep_gemm does not provide grouped_gemm_fp8 NT contiguous GEMM kernel in this version") + + +def deepgemm_grouped_fp8_fp4_nt_contiguous( + input_tuple: Tuple[torch.Tensor, torch.Tensor], + w_tuple: Tuple[torch.Tensor, torch.Tensor], + out: torch.Tensor, + grouped_layout: torch.Tensor, + use_psum_layout: bool = False, +): + if HAS_DEEPGEMM and hasattr(deep_gemm, "m_grouped_fp8_fp4_gemm_nt_contiguous"): + return deep_gemm.m_grouped_fp8_fp4_gemm_nt_contiguous( + input_tuple, + w_tuple, + out, + grouped_layout, + use_psum_layout=use_psum_layout, + recipe=(1, 1, 32), + ) + raise RuntimeError("deep_gemm does not provide grouped fp8-fp4 NT contiguous GEMM kernel") diff --git a/lightllm/common/basemodel/triton_kernel/repack_kv_index.py b/lightllm/common/basemodel/triton_kernel/repack_kv_index.py index e86d2e819e..d50a0a230b 100644 --- a/lightllm/common/basemodel/triton_kernel/repack_kv_index.py +++ b/lightllm/common/basemodel/triton_kernel/repack_kv_index.py @@ -2,6 +2,7 @@ import triton import triton.language as tl +from lightllm.utils.envs_utils import get_page_size @triton.jit @@ -33,6 +34,40 @@ def _fwd_kernel_repack_kv_index( return +@triton.jit +def _fwd_kernel_repack_page_kv_index_from_tokens( + req_to_token_indexs, + req_index, + out_kv_index, + seq_len, + start_loc, + page_size, + token_stride_h, + SEQ_BLOCK: tl.constexpr, +): + cur_batch = tl.program_id(0) + start_seq_n = tl.program_id(1) + + cur_batch_seq_len = tl.load(seq_len + cur_batch) + cur_batch_req_idx = tl.load(req_index + cur_batch) + cur_batch_start_loc = tl.load(start_loc + cur_batch) + + offs_seq = (start_seq_n * SEQ_BLOCK + tl.arange(0, SEQ_BLOCK)) * page_size + block_end_loc = tl.minimum((start_seq_n + 1) * SEQ_BLOCK, cur_batch_seq_len) * page_size + token_data = tl.load( + req_to_token_indexs + token_stride_h * cur_batch_req_idx + offs_seq, + mask=offs_seq < block_end_loc, + other=0, + ) + page_data = token_data // page_size + + offs_seq = start_seq_n * SEQ_BLOCK + tl.arange(0, SEQ_BLOCK) + block_end_loc = tl.minimum((start_seq_n + 1) * SEQ_BLOCK, cur_batch_seq_len) + out_kv_index_ptr = out_kv_index + cur_batch_start_loc + offs_seq + tl.store(out_kv_index_ptr, page_data, mask=offs_seq < block_end_loc) + return + + @torch.no_grad() def repack_kv_index(kv_index, req_index, seq_len, start_loc, max_seq_len, out_kv_index): batch_size = req_index.shape[0] @@ -58,6 +93,34 @@ def repack_kv_index(kv_index, req_index, seq_len, start_loc, max_seq_len, out_kv return +@torch.no_grad() +def paged_repack_kv_index(kv_index, req_index, seq_len, start_loc, max_seq_len, out_kv_index): + page_size = get_page_size() + assert page_size > 1 + batch_size = req_index.shape[0] + # flashinfer requires out_kv_index to be zeroed before use + out_kv_index.zero_() + BLOCK = 64 + grid = ( + batch_size, + triton.cdiv(max_seq_len, BLOCK), + ) + + _fwd_kernel_repack_page_kv_index_from_tokens[grid]( + kv_index, + req_index, + out_kv_index, + seq_len, + start_loc, + page_size, + kv_index.stride(0), + SEQ_BLOCK=BLOCK, + num_warps=8, + num_stages=1, + ) + return + + def repack_kv_ref(req_to_token_indexs, b_req_idx, b_seq_len, b_start_loc, output): for b, sl, start in zip(b_req_idx, b_seq_len, b_start_loc): output[start : start + sl] = req_to_token_indexs[b][:sl] diff --git a/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py b/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py index d49c8d7e73..276bbf54bc 100644 --- a/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py @@ -1,6 +1,7 @@ import torch import os import torch.distributed as dist +import triton from lightllm.server.pd_io_struct import KVMoveTask from .mem_manager import MemoryManager from typing import List, Union, Any @@ -10,6 +11,7 @@ from lightllm.distributed.pynccl import PyNcclCommunicator from lightllm.common.kv_trans_kernel.nixl_kv_trans import mla_page_io from .operator import Deepseek2MemOperator +from lightllm.utils.envs_utils import get_page_size logger = init_logger(__name__) @@ -30,7 +32,9 @@ def get_cell_size(self): return self.head_num * self.head_dim * self.layer_num * torch._utils._element_size(self.dtype) def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): - self.kv_buffer = torch.empty((layer_num, size + 1, head_num, head_dim), dtype=dtype, device="cuda") + page_size = get_page_size() + alloc_size = ((size // page_size) + 1) * page_size if page_size > 1 else size + 1 + self.kv_buffer = torch.empty((layer_num, alloc_size, head_num, head_dim), dtype=dtype, device="cuda") def alloc_kv_move_buffer(self, max_req_total_len): self.kv_move_buffer = torch.empty( diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index 0454c86628..2eea0dbbac 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -11,7 +11,7 @@ from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory from lightllm.common.kv_trans_kernel.kv_trans import kv_trans from lightllm.utils.dist_utils import get_current_rank_in_node, get_node_world_size -from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args +from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args, get_page_size from lightllm.distributed.pynccl import PyNcclCommunicator from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.config_utils import get_num_key_value_heads @@ -38,6 +38,9 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False self.dtype = dtype # profile the max total token num if the size is None self.profile_size(mem_fraction) + page_size = get_page_size() + if page_size > 1: + self.size = (self.size // page_size) * page_size self.allocator = KvCacheAllocator(self.size) @@ -87,7 +90,9 @@ def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): # 分配,内部实际也没有管理,这个token是预留来对一些特殊的运行模式,如多dp下,overlap microbatch # 等模式下 padding 一些请求,使推理过程可以正常运行采用的,其索引值为size,存储在HOLD_TOKEN_MEMINDEX # 成员变量中,其与 req_manager 中的HOLD_REQUEST_ID具有类似的作用和意义。 - self.kv_buffer = torch.empty((layer_num, size + 1, 2 * head_num, head_dim), dtype=dtype, device="cuda") + page_size = get_page_size() + alloc_size = ((size // page_size) + 1) * page_size if page_size > 1 else size + 1 + self.kv_buffer = torch.empty((layer_num, alloc_size, 2 * head_num, head_dim), dtype=dtype, device="cuda") def alloc_kv_move_buffer(self, max_req_total_len): """ diff --git a/lightllm/common/linear_att_cache_manager/linear_att_buffer_manager.py b/lightllm/common/linear_att_cache_manager/linear_att_buffer_manager.py index 30dc4d937c..6f04fefd21 100644 --- a/lightllm/common/linear_att_cache_manager/linear_att_buffer_manager.py +++ b/lightllm/common/linear_att_cache_manager/linear_att_buffer_manager.py @@ -26,7 +26,7 @@ def __init__( dtype=self.linear_config.conv_state_dtype, shape=self.linear_config.get_conv_state_shape(), layer_num=self.linear_config.linear_layer_num, - device="cpu", + device="cuda", size_first=True, ) self.ssm_state_cache = LayerCache( @@ -34,7 +34,7 @@ def __init__( dtype=self.linear_config.ssm_state_dtype, shape=self.linear_config.get_ssm_state_shape(), layer_num=self.linear_config.linear_layer_num, - device="cpu", + device="cuda", size_first=True, ) self.clear_to_init_state() diff --git a/lightllm/common/quantization/deepgemm.py b/lightllm/common/quantization/deepgemm.py index 137455a821..680416640e 100644 --- a/lightllm/common/quantization/deepgemm.py +++ b/lightllm/common/quantization/deepgemm.py @@ -126,6 +126,91 @@ def _create_weight( return mm_param, mm_param_list +@QUANTMETHODS.register(["deepgemm-fp8fp4-b32"], platform="cuda") +class DeepGEMMFP8FP4B32QuantizationMethod(DeepGEMMBaseQuantizationMethod): + def __init__(self): + super().__init__() + self.block_size = 32 + self.ue8m0_pack_factor = 4 + self.weight_suffix = "weight" + self.weight_zero_point_suffix = None + self.weight_scale_suffix = None + self.has_weight_scale = True + self.has_weight_zero_point = False + + @property + def method_name(self): + return "deepgemm-fp8fp4-b32" + + def quantize(self, weight: torch.Tensor, output: WeightPack): + from deep_gemm.utils import per_token_cast_to_fp4 + import deep_gemm + + weight = weight.cuda(output.weight.device) + if weight.dim() == 2: + n, k = weight.shape + packed_weight, weight_scale = per_token_cast_to_fp4(weight, use_ue8m0=True, gran_k=self.block_size) + weight_scale = deep_gemm.transform_sf_into_required_layout(weight_scale, n, k, (1, self.block_size), None) + else: + num_groups, n, k = weight.shape + packed_weight = torch.empty((num_groups, n, k // 2), device=weight.device, dtype=torch.int8) + weight_scale = torch.empty((num_groups, n, k // self.block_size), device=weight.device, dtype=torch.float32) + for i in range(num_groups): + packed_weight[i], weight_scale[i] = per_token_cast_to_fp4( + weight[i], use_ue8m0=True, gran_k=self.block_size + ) + weight_scale = deep_gemm.transform_sf_into_required_layout( + weight_scale, n, k, (1, self.block_size), num_groups + ) + output.weight.copy_(packed_weight) + output.weight_scale.copy_(weight_scale) + return + + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: "WeightPack", + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + raise NotImplementedError("deepgemm-fp8fp4-b32 is only implemented for fused MoE expert weights") + + def _create_weight( + self, out_dims: Union[int, List[int]], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> Tuple[WeightPack, List[WeightPack]]: + import deep_gemm + + out_dim = sum(out_dims) if isinstance(out_dims, list) else out_dims + assert in_dim % 2 == 0, "FP4 packed weight requires even input dimension" + assert in_dim % self.block_size == 0, "FP4 scale dimension must be divisible by block_size" + assert ( + in_dim % (self.block_size * self.ue8m0_pack_factor) == 0 + ), "SM100 FP4 scale layout requires input dimension divisible by 128" + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty(expert_prefix + (out_dim, in_dim // 2), dtype=torch.int8).cuda(device_id) + raw_weight_scale = torch.empty(expert_prefix + (out_dim, in_dim // self.block_size), dtype=torch.float32).cuda( + device_id + ) + weight_scale = deep_gemm.transform_sf_into_required_layout( + raw_weight_scale, + out_dim, + in_dim, + (1, self.block_size), + num_experts if num_experts > 1 else None, + ) + mm_param = WeightPack(weight=weight, weight_scale=weight_scale) + mm_param_list = self._split_weight_pack( + mm_param, + weight_out_dims=out_dims, + weight_split_dim=-2, + weight_scale_out_dims=out_dims, + weight_scale_split_dim=-2, + ) + return mm_param, mm_param_list + + def _deepgemm_fp8_nt(a_tuple, b_tuple, out): if HAS_DEEPGEMM: if hasattr(deep_gemm, "gemm_fp8_fp8_bf16_nt"): diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 01e9c4ad35..e6172f0539 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -1,13 +1,13 @@ import torch import collections +import triton from lightllm.common.linear_att_cache_manager.config_objs import LinearAttCacheConfig - from lightllm.utils.log_utils import init_logger from .kv_cache_mem_manager import MemoryManager from typing import List, Optional, TYPE_CHECKING from lightllm.common.basemodel.triton_kernel.gen_sampling_params import token_id_counter from lightllm.common.basemodel.triton_kernel.gen_sampling_params import update_req_to_token_id_counter -from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args +from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args, get_page_size from lightllm.utils.config_utils import get_vocab_size from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager from lightllm.common.linear_att_cache_manager.layer_cache import LayerCache @@ -78,13 +78,32 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager: MemoryMana def alloc(self): return self.req_list.alloc() + def calc_real_need_token_num(self, need_token_num, b_seq_len, b_ready_cache_len=None): + return max(need_token_num, self._get_need_paged_token_num(b_seq_len, b_ready_cache_len)) + + def calc_last_mem_index_in_prefill(self, mem_indices, b_seq_len, b_ready_cache_len=None): + b_token_len = b_seq_len + if b_ready_cache_len is not None: + b_token_len = b_seq_len - b_ready_cache_len + b_token_len_cumsum = torch.cumsum(b_token_len, dim=0) + b_last_mem_index = mem_indices[b_token_len_cumsum - 1] + return b_last_mem_index + + def alloc_mem_indices( + self, need_size, b_seq_len=None, b_ready_cache_len=None, b_last_mem_index=None + ) -> torch.Tensor: + page_size = get_page_size() + if page_size > 1 and b_seq_len is not None: + return self._alloc_paged_mem_indices(page_size, b_seq_len, b_ready_cache_len, b_last_mem_index) + return self.mem_manager.alloc(need_size) + def free(self, free_req_indexes: List[int], free_token_index): for req_index in free_req_indexes: self.req_list.free(req_index) if self.req_list.is_all_free(): logger.debug(f"freed all request size {self.req_list.can_alloc_size}") - self.mem_manager.free(free_token_index) + self.mem_manager.free(self._expand_to_page_mem_indices(free_token_index)) def free_req(self, free_req_index: int): self.req_list.free(free_req_index) @@ -93,13 +112,73 @@ def free_req(self, free_req_index: int): return def free_token(self, free_token_index): - self.mem_manager.free(free_token_index) + self.mem_manager.free(self._expand_to_page_mem_indices(free_token_index)) return def free_all(self): self.req_list = _ReqLinkedList(self.max_request_num) return + def _expand_to_page_mem_indices(self, free_token_index): + page_size = get_page_size() + if page_size > 1: + if isinstance(free_token_index, list): + free_token_index = torch.tensor(free_token_index, dtype=torch.int32) + base_indices = free_token_index[free_token_index % page_size == 0] + if len(base_indices) == 0: + return free_token_index + page_offsets = torch.arange(page_size, dtype=base_indices.dtype, device=base_indices.device) + return (base_indices[:, None] + page_offsets[None, :]).reshape(-1) + + return free_token_index + + def _expand_by_page_size(self, b_token_len, page_size): + b_page_len = triton.cdiv(b_token_len, page_size) + need_pages_num = int(b_page_len.sum().item()) + p_token_len = torch.full((need_pages_num,), page_size, dtype=b_token_len.dtype, device=b_token_len.device) + cumsum_pages = torch.cumsum(b_page_len, dim=0) + last_page_positions = cumsum_pages - 1 + remainders = b_token_len - (b_page_len - 1) * page_size + p_token_len[last_page_positions] = remainders + return need_pages_num, p_token_len + + def _alloc_paged_mem_indices(self, page_size, b_seq_len, b_ready_cache_len, b_last_mem_index): + b_seq_len = b_seq_len.cpu() + if b_ready_cache_len is not None: + b_ready_cache_len = b_ready_cache_len.cpu() + b_token_len = b_seq_len - b_ready_cache_len + total_pages_needed, p_token_len = self._expand_by_page_size(b_token_len, page_size) + paged_token_idxs = self.mem_manager.alloc(total_pages_needed * page_size) + pages = paged_token_idxs.view(-1, page_size) + mask = torch.arange(page_size, device=p_token_len.device) < p_token_len.unsqueeze(1) + return pages[mask] + + assert b_last_mem_index is not None + b_last_mem_index = b_last_mem_index.cpu() + need_new_page_mask = (b_seq_len - 1) % page_size == 0 + new_pages_num = int(need_new_page_mask.sum().item()) + token_idxs = torch.zeros_like(b_seq_len, device=b_seq_len.device) + if new_pages_num > 0: + new_pages_tokens = self.mem_manager.alloc(new_pages_num * page_size) + token_idxs[need_new_page_mask] = new_pages_tokens[::page_size] + mask = ~need_new_page_mask + if mask.any(): + token_idxs[mask] = b_last_mem_index[mask] + 1 + return token_idxs + + def _get_need_paged_token_num(self, b_seq_len, b_ready_cache_len=None): + page_size = get_page_size() + if page_size == 1: + return 0 + + if b_ready_cache_len is not None: + need_tokens_array = b_seq_len - b_ready_cache_len + need_pages_array = triton.cdiv(need_tokens_array, page_size) + need_new_pages = need_pages_array.sum() + else: + need_new_pages = ((b_seq_len - 1) % page_size == 0).sum() + return need_new_pages * page_size + class ReqSamplingParamsManager: """ diff --git a/lightllm/distributed/communication_op.py b/lightllm/distributed/communication_op.py index f01f1c87f7..e9a53762e0 100644 --- a/lightllm/distributed/communication_op.py +++ b/lightllm/distributed/communication_op.py @@ -27,7 +27,8 @@ from lightllm.utils.device_utils import has_nvlink from lightllm.utils.envs_utils import ( get_env_start_args, - get_deepep_num_max_dispatch_tokens_per_rank, + get_deepep_num_max_dispatch_tokens_per_rank_prefill, + get_deepep_num_max_dispatch_tokens_per_rank_decode, get_redundancy_expert_num, ) from lightllm.utils.dist_utils import ( @@ -36,7 +37,7 @@ create_new_group_for_current_dp, create_dp_special_inter_group, ) -from lightllm.utils.device_utils import get_device_sm_count +from lightllm.utils.device_utils import get_device_sm_count, is_sm100_gpu from lightllm.utils.torch_dtype_utils import get_torch_dtype logger = init_logger(__name__) @@ -106,6 +107,11 @@ def all_gather_into_tensor(self, output_: torch.Tensor, input_: torch.Tensor, as class DistributeGroupManager: def __init__(self): self.groups = [] + self.ep_buffer = None + self.ep_low_latency_buffer = None + self.ep_mega_moe_buffer = None + self.ep_num_sms = None + self.ep_expert_alignment = 128 def __len__(self): return len(self.groups) @@ -127,52 +133,107 @@ def get_default_group(self) -> CustomProcessGroup: def get_group(self, group_index: int) -> CustomProcessGroup: return self.groups[group_index] - def new_deepep_group(self, n_routed_experts, hidden_size): + def new_deepep_group( + self, + n_routed_experts, + hidden_size, + num_experts_per_tok: int = 1, + moe_intermediate_size: Optional[int] = None, + ): enable_ep_moe = get_env_start_args().enable_ep_moe - num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank() + prefill_num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank_prefill() + decode_num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank_decode() if not enable_ep_moe: self.ep_buffer = None + self.ep_low_latency_buffer = None + self.ep_mega_moe_buffer = None + self.ep_num_sms = None return assert HAS_DEEPEP, "deep_ep is required for expert parallelism" - self._set_num_sms_for_deep_gemm() global_world_size = get_global_world_size() deepep_group = dist.new_group(list(range(global_world_size))) - low_latency_mode, num_rdma_bytes = True, 0 - if low_latency_mode: - self.ll_num_tokens, self.ll_hidden = num_max_dispatch_tokens_per_rank, hidden_size - self.ll_num_experts = n_routed_experts + get_redundancy_expert_num() * global_world_size - num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint( - self.ll_num_tokens, self.ll_hidden, global_world_size, self.ll_num_experts + self.ll_num_tokens = prefill_num_max_dispatch_tokens_per_rank + self.ll_decode_num_tokens = decode_num_max_dispatch_tokens_per_rank + self.ll_hidden = hidden_size + self.ll_num_experts = n_routed_experts + get_redundancy_expert_num() * global_world_size + self.ep_low_latency_buffer = None + self.ep_mega_moe_buffer = None + if is_sm100_gpu(): + if moe_intermediate_size is None: + raise ValueError("SM100 Mega MoE requires moe_intermediate_size or intermediate_size in model config") + + import deep_gemm + + self.ep_expert_alignment = deep_gemm.get_theoretical_mk_alignment_for_contiguous_layout() + deep_gemm.set_mk_alignment_for_contiguous_layout(self.ep_expert_alignment) + # Mega MoE is the eager fast path, while ElasticBuffer provides the official + # CUDA-graph-compatible baseline for decode. + self.ep_buffer = deep_ep.ElasticBuffer( + deepep_group, + num_max_tokens_per_rank=self.ll_decode_num_tokens, + hidden=self.ll_hidden, + num_topk=num_experts_per_tok, + use_fp8_dispatch=True, + allow_multiple_reduction=False, ) - self.ep_buffer = deep_ep.Buffer( + self.ep_mega_moe_buffer = deep_gemm.get_symm_buffer_for_mega_moe( + deepep_group, + self.ll_num_experts, + self.ll_num_tokens, + num_experts_per_tok, + self.ll_hidden, + moe_intermediate_size, + ) + self._set_num_sms_for_deep_gemm(0) + logger.info("SM100 detected: use Mega MoE for eager path and ElasticBuffer for CUDA graph decode.") + return + + self.ep_buffer = deep_ep.ElasticBuffer( + deepep_group, + num_max_tokens_per_rank=self.ll_num_tokens, + hidden=self.ll_hidden, + num_topk=num_experts_per_tok, + use_fp8_dispatch=True, + allow_multiple_reduction=False, + ) + num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint( + self.ll_decode_num_tokens, self.ll_hidden, global_world_size, self.ll_num_experts + ) + self.ep_low_latency_buffer = deep_ep.Buffer( deepep_group, int(1e9), num_rdma_bytes, - low_latency_mode=low_latency_mode, - num_qps_per_rank=(self.ll_num_experts // global_world_size if low_latency_mode else 1), + low_latency_mode=True, + num_qps_per_rank=(self.ll_num_experts // global_world_size), ) + theoretical_sms = self.ep_buffer.get_theoretical_num_sms(self.ll_num_experts, num_experts_per_tok) + self._set_num_sms_for_deep_gemm(theoretical_sms) - def _set_num_sms_for_deep_gemm(self): + def _set_num_sms_for_deep_gemm(self, deepep_sms: int): try: try: from deep_gemm.jit_kernels.utils import set_num_sms except: from deep_gemm import set_num_sms - deepep_sms = int(os.getenv("DEEPEP_SMS", deep_ep.Buffer.num_sms)) device_sms = get_device_sm_count() - deep_ep.Buffer.set_num_sms(deepep_sms) - set_num_sms(device_sms - deepep_sms) + deepep_sms = max(0, min(deepep_sms, max(device_sms - 2, 0))) + self.ep_num_sms = deepep_sms + if self.ep_low_latency_buffer is not None: + deep_ep.Buffer.set_num_sms(deepep_sms - deepep_sms % 2) + set_num_sms(max(device_sms - deepep_sms, 2)) except BaseException as e: logger.warning(f"set num sms for deep_gemm failed: {e}") def clear_deepep_buffer(self): """ - prefill 之后需要clean 一下,ep buffer 才能正常执行 decode。 + Prefill after using ElasticBuffer may leave the legacy low-latency buffer dirty for decode. """ - if hasattr(self, "ep_buffer") and self.ep_buffer is not None: - self.ep_buffer.clean_low_latency_buffer(self.ll_num_tokens, self.ll_hidden, self.ll_num_experts) + if self.ep_low_latency_buffer is not None: + self.ep_low_latency_buffer.clean_low_latency_buffer( + self.ll_decode_num_tokens, self.ll_hidden, self.ll_num_experts + ) def all_reduce( diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index fa2dee444f..75688b4bd1 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -258,6 +258,7 @@ def _moe_ffn_edp( topk_group=self.topk_group, num_expert_group=self.n_group, is_prefill=infer_state.is_prefill, + is_cuda_graph=infer_state.is_cuda_graph, ) if self.n_shared_experts is not None: @@ -295,7 +296,7 @@ def overlap_tpsp_token_forward( infer_state1: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, ): - if not self.is_moe: + if not self.is_moe or layer_weight.experts.use_sm100_mega_moe(): return super().overlap_tpsp_token_forward( input_embdings, input_embdings1, infer_state, infer_state1, layer_weight ) @@ -421,7 +422,7 @@ def overlap_tpsp_context_forward( infer_state1: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, ): - if not self.is_moe: + if not self.is_moe or layer_weight.experts.use_sm100_mega_moe(): return super().overlap_tpsp_context_forward( input_embdings, input_embdings1, infer_state, infer_state1, layer_weight ) @@ -447,9 +448,9 @@ def overlap_tpsp_context_forward( _0_topk_weight, _0_topk_idx, _0_qinput_tensor = layer_weight.experts.select_experts_and_quant_input( _0_input1, _0_router_logits ) - from deep_ep import Buffer + from deep_ep import ElasticBuffer - _0_overlap_event = Buffer.capture() + _0_overlap_event = ElasticBuffer.capture() # 1 attention _1_input1 = self._att_norm(input_embdings1, infer_state1, layer_weight) @@ -486,8 +487,7 @@ def overlap_tpsp_context_forward( _1_topk_weight, _1_topk_idx, _1_qinput_tensor = layer_weight.experts.select_experts_and_quant_input( _1_input1, _1_router_logits ) - - _1_overlap_event = Buffer.capture() + _1_overlap_event = ElasticBuffer.capture() # 0 shared expert if self.n_shared_experts is not None: @@ -518,7 +518,7 @@ def overlap_tpsp_context_forward( infer_state1.hook() infer_state1.hook = None - _0_combine_event = Buffer.capture() + _0_combine_event = ElasticBuffer.capture() # 0 combine execute _0_ffn_out, _0_hook = layer_weight.experts.combine(_0_moe_out, _0_handle, _0_combine_event) infer_state.hook = _0_hook @@ -533,7 +533,7 @@ def overlap_tpsp_context_forward( infer_state.hook() infer_state.hook = None - _1_combine_event = Buffer.capture() + _1_combine_event = ElasticBuffer.capture() if self.n_shared_experts is not None: _0_ffn_out.add_(_0_shared_output) diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index e596eed97c..ea6620b4e4 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -48,7 +48,12 @@ def _init_some_value(self): def _init_custom(self): self._init_to_get_yarn_rotary() - dist_group_manager.new_deepep_group(self.config["n_routed_experts"], self.config["hidden_size"]) + dist_group_manager.new_deepep_group( + self.config["n_routed_experts"], + self.config["hidden_size"], + self.config.get("num_experts_per_tok", 1), + self.config.get("moe_intermediate_size", self.config.get("intermediate_size")), + ) def _verify_params(self): return super()._verify_params() diff --git a/lightllm/models/glm4_moe_lite/model.py b/lightllm/models/glm4_moe_lite/model.py index a8fe49ac5e..1e31306aea 100644 --- a/lightllm/models/glm4_moe_lite/model.py +++ b/lightllm/models/glm4_moe_lite/model.py @@ -25,7 +25,12 @@ def _init_config(self): def _init_custom(self): self._init_to_get_yarn_rotary() - dist_group_manager.new_deepep_group(self.config["n_routed_experts"], self.config["hidden_size"]) + dist_group_manager.new_deepep_group( + self.config["n_routed_experts"], + self.config["hidden_size"], + self.config.get("num_experts_per_tok", 1), + self.config.get("moe_intermediate_size", self.config.get("intermediate_size")), + ) def _init_to_get_yarn_rotary(self): rope_scaling = self.config.get("rope_scaling") diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py index 54e4373652..a3dcffb86f 100644 --- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py @@ -104,6 +104,7 @@ def _moe_ffn_edp( topk_group=None, num_expert_group=None, is_prefill=infer_state.is_prefill, + is_cuda_graph=infer_state.is_cuda_graph, ) ep_output = ep_output.view(token_num, hidden_dim) @@ -133,7 +134,7 @@ def overlap_tpsp_token_forward( infer_state1: LlamaInferStateInfo, layer_weight: Qwen3MOETransformerLayerWeight, ): - if not self.is_moe: + if not self.is_moe or layer_weight.experts.use_sm100_mega_moe(): return super().overlap_tpsp_token_forward( input_embdings, input_embdings1, infer_state, infer_state1, layer_weight ) @@ -245,7 +246,7 @@ def overlap_tpsp_context_forward( infer_state1: LlamaInferStateInfo, layer_weight: Qwen3MOETransformerLayerWeight, ): - if not self.is_moe: + if not self.is_moe or layer_weight.experts.use_sm100_mega_moe(): return super().overlap_tpsp_context_forward( input_embdings, input_embdings1, infer_state, infer_state1, layer_weight ) @@ -270,9 +271,9 @@ def overlap_tpsp_context_forward( _0_topk_weight, _0_topk_idx, _0_qinput_tensor = layer_weight.experts.select_experts_and_quant_input( _0_input1, _0_router_logits ) - from deep_ep import Buffer + from deep_ep import ElasticBuffer - _0_overlap_event = Buffer.capture() + _0_overlap_event = ElasticBuffer.capture() # 1 attention _1_input1 = self._att_norm(input_embdings1, infer_state1, layer_weight) @@ -308,8 +309,7 @@ def overlap_tpsp_context_forward( _1_topk_weight, _1_topk_idx, _1_qinput_tensor = layer_weight.experts.select_experts_and_quant_input( _1_input1, _1_router_logits ) - - _1_overlap_event = Buffer.capture() + _1_overlap_event = ElasticBuffer.capture() # 0 moe calu _0_moe_out = layer_weight.experts.prefilled_group_gemm( @@ -332,7 +332,7 @@ def overlap_tpsp_context_forward( infer_state1.hook() infer_state1.hook = None - _0_combine_event = Buffer.capture() + _0_combine_event = ElasticBuffer.capture() # 0 combine execute _0_ffn_out, _0_hook = layer_weight.experts.combine(_0_moe_out, _0_handle, _0_combine_event) infer_state.hook = _0_hook @@ -347,7 +347,7 @@ def overlap_tpsp_context_forward( infer_state.hook() infer_state.hook = None - _1_combine_event = Buffer.capture() + _1_combine_event = ElasticBuffer.capture() input_embdings.add_(_0_ffn_out.view(-1, self.embed_dim_)) diff --git a/lightllm/models/qwen3_moe/model.py b/lightllm/models/qwen3_moe/model.py index b71d7f4878..0d4b45bfe6 100644 --- a/lightllm/models/qwen3_moe/model.py +++ b/lightllm/models/qwen3_moe/model.py @@ -27,4 +27,9 @@ def _init_custom(self): super()._init_custom() # Only initialize DeepEP group for MoE models with num_experts if "num_experts" in self.config and self.config["num_experts"] > 0: - dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) + dist_group_manager.new_deepep_group( + self.config["num_experts"], + self.config["hidden_size"], + self.config.get("num_experts_per_tok", 1), + self.config.get("moe_intermediate_size", self.config.get("intermediate_size")), + ) diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py index bb48bfe49c..0b15abe466 100644 --- a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py @@ -156,6 +156,7 @@ def _moe_ffn_edp( topk_group=None, num_expert_group=None, is_prefill=infer_state.is_prefill, + is_cuda_graph=infer_state.is_cuda_graph, ) ep_output = ep_output.view(token_num, hidden_dim) ep_output.add_(shared_expert_out) diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py index 4a8ee80a46..e3c51f3617 100644 --- a/lightllm/models/qwen3next/model.py +++ b/lightllm/models/qwen3next/model.py @@ -12,7 +12,6 @@ ) from lightllm.models.qwen3next.infer_struct import Qwen3NextInferStateInfo from lightllm.utils.log_utils import init_logger -from lightllm.distributed.communication_op import dist_group_manager from lightllm.utils.envs_utils import get_env_start_args from lightllm.common.kv_cache_mem_manager.qwen3next_mem_manager import Qwen3NextMemManager from lightllm.server.core.objs.start_args_type import StartArgs @@ -56,12 +55,6 @@ def _init_config(self): super()._init_config() self.num_kv_heads = max(self.config["num_key_value_heads"] // self.tp_world_size_, 1) - def _init_custom(self): - super()._init_custom() - # Only initialize DeepEP group for MoE models with num_experts - if "num_experts" in self.config and self.config["num_experts"] > 0: - dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) - def _init_mem_manager(self): assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 start_args: StartArgs = get_env_start_args() diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index fe6236cbe0..dcc7ebf4b5 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -388,7 +388,7 @@ def make_argument_parser() -> argparse.ArgumentParser: "--llm_prefill_att_backend", type=str, nargs="+", - choices=["auto", "triton", "fa3", "flashinfer"], + choices=["auto", "triton", "fa3", "fa4", "flashinfer"], default=["auto"], help="""prefill attention kernel used in llm. auto: automatically select best backend based on GPU and available packages @@ -398,7 +398,7 @@ def make_argument_parser() -> argparse.ArgumentParser: "--llm_decode_att_backend", type=str, nargs="+", - choices=["auto", "triton", "fa3", "flashinfer"], + choices=["auto", "triton", "fa3", "fa4", "flashinfer"], default=["auto"], help="""decode attention kernel used in llm. auto: automatically select best backend based on GPU and available packages @@ -408,7 +408,7 @@ def make_argument_parser() -> argparse.ArgumentParser: "--vit_att_backend", type=str, nargs="+", - choices=["auto", "triton", "fa3", "sdpa", "xformers"], + choices=["auto", "triton", "fa3", "fa4", "sdpa", "xformers"], default=["auto"], help="""vit attention kernel used in vlm. auto: automatically select best backend based on GPU and available packages diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index e1182f2f77..e1bc127fd7 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -4,13 +4,14 @@ import uuid import subprocess import signal +import math from lightllm.utils.net_utils import alloc_can_use_network_port, PortLocker from lightllm.utils.start_utils import process_manager, kill_recursive from .metrics.manager import start_metric_manager from .embed_cache.manager import start_cache_manager from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import set_env_start_args, set_unique_server_name, get_unique_server_name -from lightllm.utils.envs_utils import get_lightllm_gunicorn_keep_alive +from lightllm.utils.envs_utils import get_lightllm_gunicorn_keep_alive, get_page_size, set_page_size from .detokenization.manager import start_detokenization_process from .router.manager import start_router_process from lightllm.utils.process_check import is_process_active @@ -24,10 +25,27 @@ auto_set_max_req_total_len, ) from lightllm.utils.dist_check_utils import auto_configure_allreduce_flags_from_args +from lightllm.utils.device_utils import is_sm100_gpu logger = init_logger(__name__) +def _auto_set_fa4_page_size(args, requested_backends): + if "fa4" not in requested_backends or "PAGE_SIZE" in os.environ: + return + + from lightllm.utils.fa4_utils import infer_fa4_page_size + + page_size = infer_fa4_page_size(args.model_dir) + if is_sm100_gpu(): + page_size = 128 + elif page_size is None: + return + + set_page_size(page_size) + logger.info(f"auto set PAGE_SIZE={page_size} for FA4 backend") + + def setup_signal_handlers(http_server_process, process_manager): def signal_handler(sig, frame): if sig == signal.SIGINT: @@ -192,7 +210,7 @@ def normal_or_p_d_start(args): assert args.enable_tpsp_mix_mode and args.dp > 1, "need set --enable_tpsp_mix_mode firstly and --dp > 1" if args.enable_ep_moe: - allowed_ep_att_backends = {"auto", "fa3", "triton"} + allowed_ep_att_backends = {"auto", "fa3", "fa4", "triton"} for backend in args.llm_prefill_att_backend: assert backend in allowed_ep_att_backends, ( "When --enable_ep_moe is enabled, --llm_prefill_att_backend must be one of " @@ -204,14 +222,35 @@ def normal_or_p_d_start(args): f"{sorted(allowed_ep_att_backends)}; flashinfer is not supported." ) + llm_requested_backends = list(args.llm_prefill_att_backend) + list(args.llm_decode_att_backend) + requested_backends = llm_requested_backends + list(args.vit_att_backend) + if "fa4" in requested_backends: + _auto_set_fa4_page_size(args, llm_requested_backends) + # mtp params check if args.mtp_mode is not None: assert args.mtp_draft_model_dir is not None assert args.mtp_step > 0 + assert get_page_size() == 1, "page_size > 1 is not supported with MTP, please set PAGE_SIZE=1" else: assert args.mtp_draft_model_dir is None assert args.mtp_step == 0 + # page_size > 1 compatibility check + if get_page_size() > 1: + assert args.run_mode not in ( + "prefill", + "decode", + ), "page_size > 1 is not supported with RPyC PD split mode, please set PAGE_SIZE=1" + assert args.run_mode not in ( + "nixl_prefill", + "nixl_decode", + ), "page_size > 1 is not supported with NIXL PD split mode, please set PAGE_SIZE=1" + assert ( + not args.enable_dp_prefill_balance + ), "page_size > 1 is not supported with DP prefill balance, please set PAGE_SIZE=1" + assert not args.enable_cpu_cache, "page_size > 1 is not supported with CPU cache, please set PAGE_SIZE=1" + if args.afs_image_embed_dir is not None: os.makedirs(args.afs_image_embed_dir, mode=0o777, exist_ok=True) os.chmod(args.afs_image_embed_dir, 0o777) @@ -288,7 +327,10 @@ def normal_or_p_d_start(args): # linear att cache 参数自动设置 if args.linear_att_cache_size is None: # linear_att_cache_size 只会在 qwen3.5 等混合线性层模型中生效。 - args.linear_att_cache_size = args.running_max_req_size * 2 + default_cache_size = args.running_max_req_size * 2 + dp_size_in_node = max(1, args.dp // args.nnodes) + per_dp_cache_size = max(1, math.ceil(args.running_max_req_size / dp_size_in_node) * 2) + args.linear_att_cache_size = min(default_cache_size, per_dp_cache_size) if args.enable_cpu_cache and is_linear_att_mixed_model(args.model_dir): args.cpu_cache_token_page_size = args.linear_att_hash_page_size * args.linear_att_page_block_num diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index fe9cb6161a..cecbdb9872 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -135,13 +135,13 @@ class StartArgs: vit_quant_type: Optional[str] = field(default=None) vit_quant_cfg: Optional[str] = field(default=None) llm_prefill_att_backend: List[str] = field( - default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "flashinfer"]} + default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "fa4", "flashinfer"]} ) llm_decode_att_backend: List[str] = field( - default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "flashinfer"]} + default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "fa4", "flashinfer"]} ) vit_att_backend: List[str] = field( - default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "sdpa", "xformers"]} + default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "fa4", "sdpa", "xformers"]} ) llm_kv_type: str = field( default="None", metadata={"choices": ["None", "int8kv", "int4kv", "fp8kv_sph", "fp8kv_spt", "fp8kv_dsa"]} diff --git a/lightllm/server/router/dynamic_prompt/paged_radix_cache.py b/lightllm/server/router/dynamic_prompt/paged_radix_cache.py new file mode 100644 index 0000000000..6d49fc083f --- /dev/null +++ b/lightllm/server/router/dynamic_prompt/paged_radix_cache.py @@ -0,0 +1,538 @@ +# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/managers/router/radix_cache.py +import torch +import numpy as np +import collections +from typing import Tuple, Dict, Set, List, Optional, Union +from sortedcontainers import SortedSet +from .shared_arr import SharedArray +from lightllm.utils.envs_utils import get_page_size + + +class UniqueTimeIdGenerator: + def __init__(self): + self.counter = 0 + + def generate_time_id(self): + self.counter += 1 + return self.counter + + +time_gen = UniqueTimeIdGenerator() + + +class TreeNode: + def __init__(self): + self.children: Dict[int, TreeNode] = {} + self.parent: TreeNode = None + self.token_id_key: torch.Tensor = None + self.token_mem_index_value: torch.Tensor = None + self.ref_counter = 0 + self.time_id = time_gen.generate_time_id() + + self.node_value_len = 0 + self.node_prefix_total_len = 0 + self.page_size = get_page_size() + self._page_size_is_power_of_2 = (self.page_size & (self.page_size - 1)) == 0 + self._page_size_mask = self.page_size - 1 if self._page_size_is_power_of_2 else None + + def get_compare_key(self): + return (0 if self.ref_counter == 0 else 1, len(self.children), self.time_id) + + def _compute_key(self, tokens: torch.Tensor): + page_tokens = tokens[: self.page_size] + return page_tokens.item() if self.page_size == 1 else page_tokens.cpu().numpy().tobytes() + + def split_node(self, prefix_len): + split_parent_node = TreeNode() + split_parent_node.parent = self.parent + split_parent_node.parent.children[self._compute_key(self.token_id_key)] = split_parent_node + split_parent_node.token_id_key = self.token_id_key[0:prefix_len] + split_parent_node.token_mem_index_value = self.token_mem_index_value[0:prefix_len] + split_parent_node.children = {} + split_parent_node.children[self._compute_key(self.token_id_key[prefix_len:])] = self + split_parent_node.ref_counter = self.ref_counter + + new_len = len(split_parent_node.token_mem_index_value) + split_parent_node.node_value_len = new_len + split_parent_node.node_prefix_total_len = split_parent_node.parent.node_prefix_total_len + new_len + + self.token_id_key = self.token_id_key[prefix_len:] + self.token_mem_index_value = self.token_mem_index_value[prefix_len:] + self.parent = split_parent_node + new_len = len(self.token_mem_index_value) + self.node_value_len = new_len + self.node_prefix_total_len = self.parent.node_prefix_total_len + new_len + return split_parent_node + + def add_and_return_new_child(self, token_id_key, token_mem_index_value): + child = TreeNode() + child.token_id_key = token_id_key + child.token_mem_index_value = token_mem_index_value + child_key = child._compute_key(child.token_id_key) + assert child_key not in self.children.keys() + self.children[child_key] = child + child.parent = self + + new_len = len(child.token_mem_index_value) + child.node_value_len = new_len + child.node_prefix_total_len = child.parent.node_prefix_total_len + new_len + return child + + def remove_child(self, child_node: "TreeNode"): + del self.children[child_node._compute_key(child_node.token_id_key)] + child_node.parent = None + return + + def update_time(self): + self.time_id = time_gen.generate_time_id() + + def is_leaf(self): + return len(self.children) == 0 + + +def match(t1: torch.Tensor, t2: torch.Tensor) -> int: + t1_flat = t1.flatten() + t2_flat = t2.flatten() + min_len = min(t1_flat.size(0), t2_flat.size(0)) + diff = t1_flat[:min_len] != t2_flat[:min_len] + mismatch_indices = torch.nonzero(diff) + + if mismatch_indices.numel() == 0: + return min_len + else: + return mismatch_indices[0].item() + + +class PagedRadixCache: + def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None): + from lightllm.common.kv_cache_mem_manager import MemoryManager + + self.mem_manager: MemoryManager = mem_manager + self._key_dtype = torch.int64 + self._value_dtype = torch.int64 + self.page_size = get_page_size() + self._page_size_is_power_of_2 = (self.page_size & (self.page_size - 1)) == 0 + self._page_size_mask = self.page_size - 1 if self._page_size_is_power_of_2 else None + + self.root_node = TreeNode() + self.root_node.token_id_key = torch.zeros((0,), device="cpu", dtype=self._key_dtype) + self.root_node.token_mem_index_value = torch.zeros((0,), device="cpu", dtype=self._value_dtype) + self.root_node.ref_counter = 1 + + self.evict_tree_set: Set[TreeNode] = SortedSet(key=lambda x: x.get_compare_key()) + self.evict_tree_set.add(self.root_node) + + self.refed_tokens_num = SharedArray(f"{unique_name}_refed_tokens_num_{rank_in_node}", (1,), dtype=np.int64) + self.refed_tokens_num.arr[0] = 0 + self.tree_total_tokens_num = SharedArray( + f"{unique_name}_tree_total_tokens_num_{rank_in_node}", (1,), dtype=np.int64 + ) + self.tree_total_tokens_num.arr[0] = 0 + + def _align_prefix_len(self, prefix_len: int) -> int: + if self.page_size <= 1: + return prefix_len + if prefix_len % self.page_size == 0: + return prefix_len + if self._page_size_is_power_of_2: + return prefix_len & ~self._page_size_mask + return (prefix_len // self.page_size) * self.page_size + + def _get_page_aligned_key(self, key, value=None, free_truncated=False): + aligned_len = len(key) + if aligned_len == 0: + return None, None + if self.page_size > 1 and aligned_len % self.page_size != 0: + aligned_len = self._align_prefix_len(aligned_len) + if free_truncated and aligned_len < len(key) and self.mem_manager is not None and value is not None: + truncated_value = value[aligned_len:] + if len(truncated_value) > 0: + base = truncated_value[0] - truncated_value[0] % self.page_size + full_page = torch.arange( + base, base + self.page_size, dtype=truncated_value.dtype, device=truncated_value.device + ) + self.mem_manager.free(full_page) + return ( + key[:aligned_len] if aligned_len > 0 else None, + value[:aligned_len] if value is not None and aligned_len > 0 else None, + ) + return key, value + + def insert(self, key, value=None) -> Tuple[int, Optional[TreeNode]]: + if value is None: + value = key + + assert len(key) == len(value) + key, value = self._get_page_aligned_key(key, value, free_truncated=True) + if key is None: + return 0, None + return self._insert_helper(self.root_node, key, value) + + def _insert_helper(self, node: TreeNode, key, value) -> Tuple[int, Optional[TreeNode]]: + handle_stack = collections.deque() + update_list = collections.deque() + handle_stack.append((node, key, value)) + + ans_prefix_len = 0 + ans_node = None + + while len(handle_stack) != 0: + node, key, value = handle_stack.popleft() + ans_tuple = self._insert_helper_no_recursion(node=node, key=key, value=value) + if len(ans_tuple) == 4: + (_prefix_len, new_node, new_key, new_value) = ans_tuple + ans_prefix_len += _prefix_len + handle_stack.append((new_node, new_key, new_value)) + else: + _prefix_len, ans_node = ans_tuple + ans_prefix_len += _prefix_len + + update_list.append(node) + + while len(update_list) != 0: + cur_node: TreeNode = update_list.pop() + cur_node.update_time() + if cur_node.is_leaf(): + self.evict_tree_set.add(cur_node) + + assert ans_node is not None + + return ans_prefix_len, ans_node + + def _insert_helper_no_recursion( + self, node: TreeNode, key: torch.Tensor, value: torch.Tensor + ) -> Union[Tuple[int, Optional[TreeNode]], Tuple[int, TreeNode, torch.Tensor, torch.Tensor]]: + if node.is_leaf(): + self.evict_tree_set.discard(node) + + child_key = node._compute_key(key) + if child_key in node.children.keys(): + child: TreeNode = node.children[child_key] + prefix_len = match(key, child.token_id_key) + prefix_len = self._align_prefix_len(prefix_len) + if prefix_len == 0: + new_node = node.add_and_return_new_child(key, value) + self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value) + if new_node.is_leaf(): + self.evict_tree_set.add(new_node) + return 0, new_node + if prefix_len == len(key): + if prefix_len == len(child.token_id_key): + if child.is_leaf(): + self.evict_tree_set.discard(child) + child.update_time() + if child.is_leaf(): + self.evict_tree_set.add(child) + return prefix_len, child + elif prefix_len < len(child.token_id_key): + if child.is_leaf(): + self.evict_tree_set.discard(child) + + split_parent_node = child.split_node(prefix_len) + + if split_parent_node.is_leaf(): + self.evict_tree_set.add(split_parent_node) + if child.is_leaf(): + self.evict_tree_set.add(child) + + return prefix_len, split_parent_node + else: + assert False, "can not run to here" + + elif prefix_len < len(key) and prefix_len < len(child.token_id_key): + if child.is_leaf(): + self.evict_tree_set.discard(child) + + key = key[prefix_len:] + value = value[prefix_len:] + split_parent_node = child.split_node(prefix_len) + new_node = split_parent_node.add_and_return_new_child(key, value) + self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value) + + if split_parent_node.is_leaf(): + self.evict_tree_set.add(split_parent_node) + if new_node.is_leaf(): + self.evict_tree_set.add(new_node) + + if child.is_leaf(): + self.evict_tree_set.add(child) + return prefix_len, new_node + elif prefix_len < len(key) and prefix_len == len(child.token_id_key): + return (prefix_len, child, key[prefix_len:], value[prefix_len:]) + else: + assert False, "can not run to here" + + else: + new_node = node.add_and_return_new_child(key, value) + self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value) + if new_node.is_leaf(): + self.evict_tree_set.add(new_node) + return 0, new_node + + def match_prefix(self, key, update_refs=False): + assert len(key) != 0 + key, _ = self._get_page_aligned_key(key) + if key is None: + return None, 0, None + ans_value_list = [] + tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) + if tree_node != self.root_node: + if len(ans_value_list) != 0: + value = torch.concat(ans_value_list) + else: + value = torch.zeros((0,), device="cpu", dtype=self._value_dtype) + return tree_node, len(value), value + else: + if update_refs: + self.dec_node_ref_counter(self.root_node) + return None, 0, None + + def _match_prefix_helper( + self, node: TreeNode, key: torch.Tensor, ans_value_list: list, update_refs=False + ) -> TreeNode: + handle_stack = collections.deque() + update_list = collections.deque() + handle_stack.append((node, key)) + + ans_node = None + + while len(handle_stack) != 0: + node, key = handle_stack.popleft() + ans_tuple = self._match_prefix_helper_no_recursion( + node=node, key=key, ans_value_list=ans_value_list, update_refs=update_refs + ) + if isinstance(ans_tuple, tuple): + new_node, new_key = ans_tuple + handle_stack.append((new_node, new_key)) + else: + ans_node = ans_tuple + + update_list.append(node) + + while len(update_list) != 0: + cur_node: TreeNode = update_list.pop() + cur_node.update_time() + if cur_node.is_leaf(): + self.evict_tree_set.add(cur_node) + + return ans_node + + def _match_prefix_helper_no_recursion( + self, node: TreeNode, key: torch.Tensor, ans_value_list: list, update_refs=False + ) -> TreeNode: + if node.is_leaf(): + self.evict_tree_set.discard(node) + + if update_refs: + node.ref_counter += 1 + if node.ref_counter == 1: + self.refed_tokens_num.arr[0] += len(node.token_mem_index_value) + + if len(key) == 0: + return node + + child_key = node._compute_key(key) + if child_key not in node.children.keys(): + return node + else: + child = node.children[child_key] + prefix_len = match(key, child.token_id_key) + prefix_len = self._align_prefix_len(prefix_len) + if prefix_len == 0: + return node + if prefix_len == len(child.token_id_key): + ans_value_list.append(child.token_mem_index_value) + return (child, key[prefix_len:]) + elif prefix_len < len(child.token_id_key): + if child.is_leaf(): + self.evict_tree_set.discard(child) + + split_parent_node = child.split_node(prefix_len) + ans_value_list.append(split_parent_node.token_mem_index_value) + + if update_refs: + split_parent_node.ref_counter += 1 + if split_parent_node.ref_counter == 1: + self.refed_tokens_num.arr[0] += len(split_parent_node.token_mem_index_value) + + if child.is_leaf(): + self.evict_tree_set.add(child) + if split_parent_node.is_leaf(): + self.evict_tree_set.add(split_parent_node) + + return split_parent_node + else: + assert False, "error state" + + def evict(self, need_remove_tokens, evict_callback): + if self.tree_total_tokens_num.arr[0] - self.refed_tokens_num.arr[0] < need_remove_tokens: + assert False, f"""can not free tree tokens {need_remove_tokens}, + tree_total_tokens_num {self.tree_total_tokens_num.arr[0]}, + refed_tokens_num {self.refed_tokens_num.arr[0]}""" + num_evicted = 0 + while num_evicted < need_remove_tokens: + node: TreeNode = self.evict_tree_set.pop(0) + assert ( + node.ref_counter == 0 and len(node.children) == 0 and node != self.root_node + ), "error evict tree node state" + num_evicted += len(node.token_mem_index_value) + evict_callback(node.token_mem_index_value) + self.tree_total_tokens_num.arr[0] -= len(node.token_mem_index_value) + parent_node: TreeNode = node.parent + parent_node.remove_child(node) + if parent_node.is_leaf(): + self.evict_tree_set.add(parent_node) + + return + + def _try_merge(self, child_node: TreeNode) -> Optional[TreeNode]: + parent_node = child_node.parent + if ( + parent_node is None + or parent_node == self.root_node + or parent_node.ref_counter != 0 + or len(parent_node.children) != 1 + or child_node.ref_counter != 0 + ): + return None + + if child_node.is_leaf(): + self.evict_tree_set.discard(child_node) + + child_node.token_id_key = torch.cat([parent_node.token_id_key, child_node.token_id_key]) + child_node.token_mem_index_value = torch.cat( + [parent_node.token_mem_index_value, child_node.token_mem_index_value] + ) + child_node.node_value_len = len(child_node.token_mem_index_value) + child_node.time_id = max(parent_node.time_id, child_node.time_id) + + grandparent_node = parent_node.parent + key_in_grandparent = grandparent_node._compute_key(parent_node.token_id_key) + grandparent_node.children[key_in_grandparent] = child_node + child_node.parent = grandparent_node + + parent_node.parent = None + + if child_node.is_leaf(): + self.evict_tree_set.add(child_node) + + return child_node + + def merge_unreferenced_nodes(self): + worklist = collections.deque( + [ + node + for node in self.evict_tree_set + if node.ref_counter == 0 and node.parent is not None and node.parent != self.root_node + ] + ) + + while worklist: + node = worklist.popleft() + if node.parent is None: + continue + merged_node = self._try_merge(node) + if merged_node: + worklist.append(merged_node) + + def assert_leafs_is_right(self): + for node in self.evict_tree_set: + if node.is_leaf() and node.ref_counter == 0: + a = node.token_mem_index_value.cuda() + assert (self.mem_manager.mem_state[a] == 1).sum().item() == len(a) + + def clear_tree_nodes(self): + while True: + node: TreeNode = self.evict_tree_set.pop(0) + if node != self.root_node: + parent_node: TreeNode = node.parent + parent_node.remove_child(node) + if parent_node.is_leaf(): + self.evict_tree_set.add(parent_node) + else: + break + + self.tree_total_tokens_num.arr[0] = 0 + self.refed_tokens_num.arr[0] = 0 + return + + def dec_node_ref_counter(self, node: TreeNode): + if node is None: + return + old_node = node + if old_node.is_leaf(): + self.evict_tree_set.discard(old_node) + + while node is not None: + if node.ref_counter == 1: + self.refed_tokens_num.arr[0] -= len(node.token_mem_index_value) + node.ref_counter -= 1 + node = node.parent + + if old_node.is_leaf(): + self.evict_tree_set.add(old_node) + return + + def add_node_ref_counter(self, node: TreeNode): + if node is None: + return + old_node = node + if old_node.is_leaf(): + self.evict_tree_set.discard(old_node) + + while node is not None: + if node.ref_counter == 0: + self.refed_tokens_num.arr[0] += len(node.token_mem_index_value) + node.ref_counter += 1 + node = node.parent + + if old_node.is_leaf(): + self.evict_tree_set.add(old_node) + return + + def get_mem_index_value_by_node(self, node: TreeNode) -> Optional[torch.Tensor]: + if node is None: + return None + + ans_list = [] + while node is not None: + ans_list.append(node.token_mem_index_value) + node = node.parent + + ans_list.reverse() + return torch.concat(ans_list, dim=0) + + def get_refed_tokens_num(self): + return self.refed_tokens_num.arr[0] + + def get_tree_total_tokens_num(self): + return self.tree_total_tokens_num.arr[0] + + def print_self(self, indent=0): + self._print_helper(self.root_node, indent) + + def _print_helper(self, node: TreeNode, indent): + print( + " " * indent, + f"k: {node.token_id_key[0:10]} v: {node.token_mem_index_value[0:10]} refs: {node.ref_counter} \ + time_id: {node.time_id} prefix_total_len: {node.node_prefix_total_len} \ + node_value_len: {node.node_value_len}", + ) + for _, child in node.children.items(): + self._print_helper(child, indent=indent + 2) + return + + def free_radix_cache_to_get_enough_token(self, need_token_num): + assert self.mem_manager is not None + if need_token_num > self.mem_manager.allocator.can_use_mem_size: + need_evict_token_num = need_token_num - self.mem_manager.allocator.can_use_mem_size + release_mems = [] + + def release_mem(mem_index): + release_mems.append(mem_index) + return + + self.evict(need_evict_token_num, release_mem) + mem_index = torch.concat(release_mems) + self.mem_manager.free(mem_index) + return diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index f0ec69b2c1..927fc3abfb 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -513,6 +513,7 @@ def __init__( self.shm_index = shm_index self.multimodal_params = multimodal_params self.vocab_size = vocab_size + self.last_kv_mem_index = -1 # 请求需要被暂停 self.wait_pause = False @@ -626,6 +627,7 @@ def _match_radix_cache(self): # 从 cpu 到 gpu 是流内阻塞操作 g_infer_context.req_manager.req_to_token_indexs[self.req_idx, 0:ready_cache_len] = value_tensor self.cur_kv_len = int(ready_cache_len) # 序列化问题, 该对象可能为numpy.int64,用 int(*)转换 + self.last_kv_mem_index = value_tensor[-1].item() if ready_cache_len > 0 else -1 self.shm_req.prompt_cache_len = self.cur_kv_len # 记录 prompt cache 的命中长度 self.shm_req.shm_cur_kv_len = self.cur_kv_len diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index ca982ec0f0..97b835bf73 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -9,6 +9,7 @@ from lightllm.utils.infer_utils import set_random_seed from lightllm.utils.log_utils import init_logger from lightllm.models import get_model +from lightllm.server.router.dynamic_prompt.paged_radix_cache import PagedRadixCache from lightllm.server.router.model_infer.infer_batch import InferReq, InferReqUpdatePack from lightllm.server.router.token_load import TokenLoad from lightllm.common.basemodel.infer_lock import g_infer_state_lock, InferStateLock @@ -31,6 +32,7 @@ from lightllm.utils.dist_utils import get_dp_rank_in_node, create_new_group_for_current_node from lightllm.utils.envs_utils import ( get_env_start_args, + get_page_size, enable_radix_tree_timer_merge, get_radix_tree_merge_update_delta, ) @@ -199,7 +201,8 @@ def init_model(self, kvargs): linear_att_small_page_buffers=self.linear_att_cache_manager, ) else: - self.radix_cache = RadixCache( + radix_cache_class = PagedRadixCache if get_page_size() > 1 else RadixCache + self.radix_cache = radix_cache_class( unique_name=get_unique_server_name(), total_token_num=self.model.mem_manager.size, rank_in_node=self.rank_in_node, diff --git a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py index 03ac4cfb05..f4fa8ae54c 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py @@ -86,8 +86,18 @@ def padded_prepare_prefill_inputs( # dynamic prompt cache 准备 token g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0] - padded_req_num) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0] - padded_req_num) + token_num = g_infer_context.req_manager.calc_real_need_token_num( + input_ids.shape[0] - padded_req_num, b_seq_len[: len(req_objs)], b_ready_cache_len[: len(req_objs)] + ) + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(token_num) + mem_indexes = g_infer_context.req_manager.alloc_mem_indices( + input_ids.shape[0] - padded_req_num, b_seq_len[: len(req_objs)], b_ready_cache_len[: len(req_objs)] + ) + b_last_mem_index = g_infer_context.req_manager.calc_last_mem_index_in_prefill( + mem_indexes, b_seq_len[: len(req_objs)], b_ready_cache_len[: len(req_objs)] + ) + for i, req in enumerate(req_objs): + req.last_kv_mem_index = b_last_mem_index[i].item() g_infer_state_lock.release() if padded_req_num > 0: @@ -140,6 +150,7 @@ def padded_prepare_decode_inputs( b_mtp_index = [] b_seq_len = [] b_q_seq_len = [] + b_last_mem_index = [] args_mtp_step = get_env_start_args().mtp_step batch_multimodal_params = [] for req in req_objs: @@ -152,6 +163,7 @@ def padded_prepare_decode_inputs( total_token_num += seq_len b_mtp_index.append(0) batch_multimodal_params.append(req.multimodal_params) + b_last_mem_index.append(req.last_kv_mem_index) # process the draft tokens. for step in range(req.mtp_step): run_reqs.append(req) @@ -187,13 +199,23 @@ def padded_prepare_decode_inputs( b_req_idx = torch.tensor(b_req_idx, dtype=torch.int32, device="cpu") b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cpu") b_mtp_index = torch.tensor(b_mtp_index, dtype=torch.int32, device="cpu") + b_last_mem_index = torch.tensor(b_last_mem_index, dtype=torch.int32, device="cpu") # dynamic prompt cache 准备 token padded_mem_indexes_num = padded_req_num * (args_mtp_step + 1) g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(b_seq_len.shape[0] - padded_mem_indexes_num) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(b_seq_len.shape[0] - padded_mem_indexes_num) + token_num = g_infer_context.req_manager.calc_real_need_token_num( + b_seq_len.shape[0] - padded_mem_indexes_num, b_seq_len[: len(b_last_mem_index)] + ) + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(token_num) + mem_indexes = g_infer_context.req_manager.alloc_mem_indices( + b_seq_len.shape[0] - padded_mem_indexes_num, + b_seq_len[: len(b_last_mem_index)], + b_last_mem_index=b_last_mem_index, + ) + for i, req in enumerate(req_objs): + req.last_kv_mem_index = mem_indexes[i].item() g_infer_state_lock.release() if padded_mem_indexes_num > 0: diff --git a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py index 4eb8c7e1e6..a915564d78 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py @@ -64,8 +64,16 @@ def prepare_prefill_inputs(req_objs: List[InferReq], is_chuncked_mode: bool) -> # dynamic prompt cache 准备 token g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0]) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0]) + token_num = g_infer_context.req_manager.calc_real_need_token_num( + input_ids.shape[0], b_seq_len, b_ready_cache_len + ) + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(token_num) + mem_indexes = g_infer_context.req_manager.alloc_mem_indices(input_ids.shape[0], b_seq_len, b_ready_cache_len) + b_last_mem_index = g_infer_context.req_manager.calc_last_mem_index_in_prefill( + mem_indexes, b_seq_len, b_ready_cache_len + ) + for i, req in enumerate(req_objs): + req.last_kv_mem_index = b_last_mem_index[i].item() g_infer_state_lock.release() model_input = ModelInput( @@ -97,6 +105,7 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In b_mtp_index = [] b_seq_len = [] b_q_seq_len = [] + b_last_mem_index = [] multimodal_params = [] for req in req_objs: run_reqs.append(req) @@ -108,6 +117,7 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In total_token_num += seq_len b_mtp_index.append(0) multimodal_params.append(req.multimodal_params) + b_last_mem_index.append(req.last_kv_mem_index) # process the draft tokens. for step in range(req.mtp_step): run_reqs.append(req) @@ -125,6 +135,7 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In b_req_idx = torch.tensor(b_req_idx, dtype=torch.int32, device="cpu") b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cpu") b_mtp_index = torch.tensor(b_mtp_index, dtype=torch.int32, device="cpu") + b_last_mem_index = torch.tensor(b_last_mem_index, dtype=torch.int32, device="cpu") if enable_diverse_mode_gqa_decode_fast_kernel(): b_shared_seq_len, b_mark_shared_group = build_diverse_shared_group_infos(run_reqs=run_reqs) @@ -135,8 +146,13 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In # dynamic prompt cache 准备 token g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(b_seq_len.shape[0]) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(b_seq_len.shape[0]) + token_num = g_infer_context.req_manager.calc_real_need_token_num(b_seq_len.shape[0], b_seq_len) + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(token_num) + mem_indexes = g_infer_context.req_manager.alloc_mem_indices( + b_seq_len.shape[0], b_seq_len, b_last_mem_index=b_last_mem_index + ) + for i, req in enumerate(req_objs): + req.last_kv_mem_index = mem_indexes[i].item() g_infer_state_lock.release() model_input = ModelInput( diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl.py b/lightllm/server/router/req_queue/chunked_prefill/impl.py index 884b5930b0..6eec51367a 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/impl.py +++ b/lightllm/server/router/req_queue/chunked_prefill/impl.py @@ -3,6 +3,7 @@ from ...batch import Batch, Req from lightllm.server.router.req_queue.base_queue import BaseQueue from lightllm.common.basemodel.infer_lock import g_router_lock +from lightllm.utils.envs_utils import get_page_size from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -38,9 +39,11 @@ def _can_add_new_req(self, req: Req, is_busy, new_batch_first_router_need_tokens need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max() with g_router_lock.obj: + page_size = get_page_size() + page_remaining = len(self.cache_len_list) * (page_size - 1) if page_size > 1 else 0 ok_token_num = ( need_max_token_num + self.router.shared_token_load.get_frozened_token_count(self.dp_index) - < self.max_total_tokens + < self.max_total_tokens - page_remaining ) ok_req_num = len(self.cache_len_list) <= self.running_max_req_size diff --git a/lightllm/utils/backend_validator.py b/lightllm/utils/backend_validator.py index 6c5fe90309..3a8b766c75 100644 --- a/lightllm/utils/backend_validator.py +++ b/lightllm/utils/backend_validator.py @@ -59,6 +59,47 @@ def _validate_fa3(): return True, None +def _validate_fa4(): + """Validate FA4 with ground truth.""" + from lightllm.utils.fa4_utils import flash_attn_varlen_func, is_fa4_supported_gpu, unwrap_fa4_output + + if not is_fa4_supported_gpu(): + return False, "FA4 requires Hopper/Blackwell-class GPU" + if flash_attn_varlen_func is None: + return False, "flash_attn_varlen_func is None" + + batch, heads, seq, dim = 1, 4, 8, 64 + q = torch.randn(batch, heads, seq, dim, dtype=torch.bfloat16, device="cuda") + k = torch.randn(batch, heads, seq, dim, dtype=torch.bfloat16, device="cuda") + v = torch.randn(batch, heads, seq, dim, dtype=torch.bfloat16, device="cuda") + + expected = _compute_ground_truth(q, k, v) + + q_flat = q.transpose(1, 2).reshape(batch * seq, heads, dim) + k_flat = k.transpose(1, 2).reshape(batch * seq, heads, dim) + v_flat = v.transpose(1, 2).reshape(batch * seq, heads, dim) + cu_seqlens = torch.arange(0, batch * seq + 1, seq, dtype=torch.int32, device="cuda") + + out = flash_attn_varlen_func( + q=q_flat, + k=k_flat, + v=v_flat, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=seq, + max_seqlen_k=seq, + softmax_scale=1.0 / (dim ** 0.5), + causal=True, + return_lse=False, + ) + out = unwrap_fa4_output(out).reshape(batch, seq, heads, dim).transpose(1, 2) + torch.cuda.synchronize() + + if not torch.allclose(out, expected, rtol=1e-2, atol=1e-2): + return False, f"Output mismatch: max diff {(out - expected).abs().max().item():.6f}" + return True, None + + def _validate_flashinfer(): """Validate FlashInfer with ground truth.""" capability = torch.cuda.get_device_capability() @@ -240,6 +281,8 @@ def _run_in_subprocess(backend_name, pipe): try: if backend_name == "fa3": success, err = _validate_fa3() + elif backend_name == "fa4": + success, err = _validate_fa4() elif backend_name == "xformers": success, err = _validate_xformers() elif backend_name == "sdpa": diff --git a/lightllm/utils/device_utils.py b/lightllm/utils/device_utils.py index 43b10ec88b..58bff90560 100644 --- a/lightllm/utils/device_utils.py +++ b/lightllm/utils/device_utils.py @@ -40,6 +40,11 @@ def get_device_sm_count(): return properties["multiprocessor_count"] +@lru_cache(maxsize=None) +def is_sm100_gpu(): + return torch.cuda.get_device_capability()[0] == 10 + + @lru_cache(maxsize=None) def get_device_sm_regs_num(): import triton diff --git a/lightllm/utils/dist_check_utils.py b/lightllm/utils/dist_check_utils.py index e11da07c8c..0f2548be8d 100644 --- a/lightllm/utils/dist_check_utils.py +++ b/lightllm/utils/dist_check_utils.py @@ -17,7 +17,7 @@ logger = init_logger(__name__) _CUSTOM_ALLREDUCE_WORLD_SIZES = (2, 4, 6, 8) -_TWO_GPU_CHECK_TIMEOUT_SECONDS = 60.0 +_TWO_GPU_CHECK_TIMEOUT_SECONDS = 600.0 # 给flashinfer jit编译预留足够时间 def _start_two_gpu_check_timeout_watchdog(backend_name: str) -> threading.Event: @@ -84,6 +84,8 @@ def _flashinfer_two_gpu_check_worker(process_rank: int, init_tcp_port: int) -> N input_tensor = torch.zeros(2, 64, device=cuda_device, dtype=torch.bfloat16) else: input_tensor = torch.ones(2, 64, device=cuda_device, dtype=torch.bfloat16) + if not flashinfer_all_reduce.should_use(input_tensor): + raise RuntimeError("FlashInferAllReduce unsupported for probe tensor") output_tensor = flashinfer_all_reduce.all_reduce(input_tensor) dist.barrier() expected_reduced = torch.ones(2, 64, device=cuda_device, dtype=torch.bfloat16) diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index 350507e897..db8e8c19d0 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -69,9 +69,22 @@ def enable_env_vars(args): @lru_cache(maxsize=None) -def get_deepep_num_max_dispatch_tokens_per_rank(): +def get_deepep_num_max_dispatch_tokens_per_rank_prefill(): + # 该参数需要大于单卡最大batch size,且是8的倍数。该参数与显存占用直接相关,值越大,显存占用越大。 + # 如果未显式配置,则默认至少覆盖当前进程的 `batch_max_tokens`,避免 DeepEP V2 在 autotune + # warmup 或大 prefill batch 时因为 buffer 上界过小而报错。 + configured = os.getenv("NUM_MAX_DISPATCH_TOKENS_PER_RANK_PREFILL", None) + if configured is not None: + return int(configured) + + batch_max_tokens = get_env_start_args().batch_max_tokens or 256 + return ((int(batch_max_tokens) + 7) // 8) * 8 + + +@lru_cache(maxsize=None) +def get_deepep_num_max_dispatch_tokens_per_rank_decode(): # 该参数需要大于单卡最大batch size,且是8的倍数。该参数与显存占用直接相关,值越大,显存占用越大,如果出现显存不足,可以尝试调小该值 - return int(os.getenv("NUM_MAX_DISPATCH_TOKENS_PER_RANK", 256)) + return int(os.getenv("NUM_MAX_DISPATCH_TOKENS_PER_RANK_DECODE", 256)) def get_lightllm_gunicorn_keep_alive(): @@ -152,6 +165,16 @@ def get_triton_autotune_level(): return int(os.getenv("LIGHTLLM_TRITON_AUTOTUNE_LEVEL", 0)) +@lru_cache(maxsize=None) +def get_page_size(): + return int(os.getenv("PAGE_SIZE", 1)) + + +def set_page_size(page_size: int): + os.environ["PAGE_SIZE"] = str(page_size) + get_page_size.cache_clear() + + g_model_init_done = False diff --git a/lightllm/utils/fa4_utils.py b/lightllm/utils/fa4_utils.py new file mode 100644 index 0000000000..6d771df835 --- /dev/null +++ b/lightllm/utils/fa4_utils.py @@ -0,0 +1,82 @@ +import torch + +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + +try: + from flash_attn.cute import flash_attn_varlen_func + from flash_attn.cute.interface import _flash_attn_fwd + + HAS_FA4 = True +except Exception: + flash_attn_varlen_func = None + _flash_attn_fwd = None + HAS_FA4 = False + logger.warning("flash-attn-4 is not installed") + + +def is_fa4_supported_gpu() -> bool: + if not torch.cuda.is_available(): + return False + major, _minor = torch.cuda.get_device_capability() + return major in (9, 10, 11, 12) + + +def ensure_fa4_available() -> None: + if not HAS_FA4: + raise ImportError( + "flash-attn-4 is unavailable. Install it first, e.g. `pip install flash-attn-4`, " + "or install from the local flash-attention repo." + ) + + +def ensure_fa4_supported_gpu() -> None: + if not torch.cuda.is_available(): + raise RuntimeError("FA4 backend requires CUDA, but CUDA is not available.") + major, minor = torch.cuda.get_device_capability() + if major not in (9, 10, 11, 12): + raise RuntimeError( + f"FA4 backend requires Hopper/Blackwell-class GPUs (SM90/SM100/SM110/SM120). " + f"Current device capability is {major}.{minor}." + ) + + +def sm90_fa4_paged_kv_tile_n(head_dim: int, head_dim_v: int, window_size: tuple[int, int] = (-1, -1)) -> int | None: + major, _minor = torch.cuda.get_device_capability() + if major != 9: + return None + + is_local = window_size != (-1, -1) + if head_dim <= 64: + return 128 + if head_dim <= 96: + return 128 if is_local else 144 + if head_dim <= 128: + return 128 + if head_dim <= 192: + return 96 if is_local else (128 if head_dim_v <= 128 else 112) + return 64 if is_local else 80 + + +def infer_fa4_page_size(model_dir: str) -> int | None: + from transformers.configuration_utils import PretrainedConfig + + model_cfg, _ = PretrainedConfig.get_config_dict(model_dir) + llm_config = model_cfg.get("text_config", model_cfg) + + head_dim = llm_config.get("head_dim") + if head_dim is None: + head_dim = llm_config["hidden_size"] // llm_config["num_attention_heads"] + head_dim_v = llm_config.get("v_head_dim", head_dim) + + window_size = (-1, -1) + sliding_window = llm_config.get("sliding_window", None) + if sliding_window is not None and not llm_config.get("full_attention_interval", None): + window_size = (sliding_window - 1, sliding_window - 1) + + return sm90_fa4_paged_kv_tile_n(head_dim=head_dim, head_dim_v=head_dim_v, window_size=window_size) + + +def unwrap_fa4_output(output): + return output[0] if isinstance(output, tuple) else output diff --git a/requirements.txt b/requirements.txt index d37ae05690..20107fcd37 100644 --- a/requirements.txt +++ b/requirements.txt @@ -33,7 +33,7 @@ mpmath==1.3.0 multiprocessing-logging==0.3.4 networkx==3.1 ninja==1.11.1 -numpy==1.25.1 +numpy==2.1.3 packaging==24.2 pip==23.0.1 pluggy==1.2.0 @@ -59,7 +59,7 @@ six==1.16.0 sniffio==1.3.0 sortedcontainers==2.4.0 toolz==0.12.0 -torch==2.9.1 +torch==2.11.0 tqdm==4.65.0 transformers==4.57.1 tokenizers==0.22.1 @@ -71,7 +71,7 @@ zstandard==0.23.0 safetensors==0.4.5 Pillow==10.4.0 tiktoken==0.7.0 -matplotlib==3.8.2 +matplotlib==3.10.0 psutil==5.9.4 prometheus_client==0.20.0 cchardet==2.1.7 @@ -81,19 +81,21 @@ atomics==1.0.3 easydict==1.13 hypercorn==0.18.0 flashinfer-python==0.6.8.post1 -sgl-kernel==0.3.21 +flashinfer-cubin==0.6.8.post1 +sglang-kernel==0.4.2.post1 httpx==0.28.1 librosa==0.11.0 -cuda_bindings==12.9.0 +cuda_bindings==13.2.0 orjson==3.11.2 setproctitle==1.3.6 xxhash==3.6.0 -torchvision==0.24.1 +torchvision==0.26.0 interegular==0.3.3 partial_json_parser==0.2.1.1.post6 websockets==15.0.1 -cupy-cuda12x==13.6.0 -nixl==0.8.0 -xformers==0.0.33.post2 +cupy-cuda13x==14.0.1 +nixl==1.1.0 +xformers==0.0.35 redis==7.3.0 litellm>=1.52.0,<1.85 +flash-attn-4[13]==4.0.0b13 diff --git a/test/benchmark/service/benchmark_client.py b/test/benchmark/service/benchmark_client.py index 09009fc9e1..3f55bcab1e 100644 --- a/test/benchmark/service/benchmark_client.py +++ b/test/benchmark/service/benchmark_client.py @@ -27,6 +27,13 @@ def get_tokenizer( return tokenizer +def normalize_model_name(model_name: str) -> str: + if not model_name: + return model_name + normalized = model_name.rstrip("/\\") + return normalized or model_name + + def get_output_length(input_num: int, output_len: int) -> List[int]: min_len, max_len = 2, output_len * 2 mean = (min_len + max_len) * 0.5 @@ -162,7 +169,7 @@ def main(): return assert args.tokenizer_path is not None - model_name.append(args.tokenizer_path) + model_name.append(normalize_model_name(args.tokenizer_path)) seed_all(args.seed) url = args.url tokenizer = get_tokenizer(args.tokenizer_path) diff --git a/test/benchmark/service/benchmark_multiturn.py b/test/benchmark/service/benchmark_multiturn.py index c8fd9d4de7..e47c38c4cf 100644 --- a/test/benchmark/service/benchmark_multiturn.py +++ b/test/benchmark/service/benchmark_multiturn.py @@ -39,12 +39,24 @@ import os import random import time +import urllib.parse +import urllib.request from typing import Dict, List, Optional, Tuple, Union import aiohttp import numpy as np from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast +_STREAM_READ_BUFSIZE = 1 << 20 +_STREAM_MAX_LINE_SIZE = 1 << 20 +_DEFAULT_TRANSIENT_RETRIES = 2 +_TRANSIENT_STREAM_ERRORS = ( + aiohttp.ServerDisconnectedError, + aiohttp.ClientPayloadError, + aiohttp.ClientOSError, + asyncio.TimeoutError, +) + def seed_all(seed: int) -> None: if not seed: @@ -58,6 +70,85 @@ def get_tokenizer(tokenizer_name: str) -> Union[PreTrainedTokenizer, PreTrainedT return AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True) +def normalize_model_name(model_name: str) -> str: + if not model_name: + return model_name + normalized = model_name.rstrip("/\\") + return normalized or model_name + + +def get_models_url(completions_url: str) -> str: + parsed = urllib.parse.urlsplit(completions_url) + path = parsed.path.rstrip("/") + for suffix in ("/chat/completions", "/completions"): + if path.endswith(suffix): + path = path[: -len(suffix)] + "/models" + return urllib.parse.urlunsplit(parsed._replace(path=path, query="", fragment="")) + return urllib.parse.urlunsplit(parsed._replace(path="/v1/models", query="", fragment="")) + + +def fetch_served_model_names(completions_url: str, timeout_s: int = 10) -> List[str]: + models_url = get_models_url(completions_url) + request = urllib.request.Request(models_url, headers={"Accept": "application/json"}) + with urllib.request.urlopen(request, timeout=timeout_s) as response: + payload = json.loads(response.read().decode("utf-8")) + return [item["id"] for item in payload.get("data", []) if item.get("id")] + + +def resolve_model_name( + completions_url: str, + requested_model_name: str, + explicit_model_name: bool, +) -> Tuple[str, Optional[str]]: + normalized_name = normalize_model_name(requested_model_name) + if normalized_name != requested_model_name: + note = f"Normalized model name from `{requested_model_name}` to `{normalized_name}`." + else: + note = None + + try: + served_model_names = fetch_served_model_names(completions_url) + except Exception as exc: + if note is not None: + note = f"{note} Failed to query served models: {exc}." + return normalized_name, note + + if requested_model_name in served_model_names: + return requested_model_name, note + if normalized_name in served_model_names: + if normalized_name != requested_model_name: + return normalized_name, ( + f"Normalized model name from `{requested_model_name}` to `{normalized_name}` " "to match `/v1/models`." + ) + return normalized_name, note + + requested_basename = os.path.basename(normalized_name) + basename_matches = [ + served_name + for served_name in served_model_names + if os.path.basename(normalize_model_name(served_name)) == requested_basename + ] + if len(basename_matches) == 1: + matched_name = basename_matches[0] + return matched_name, ( + f"Resolved model name `{requested_model_name}` to served model `{matched_name}` " "via `/v1/models`." + ) + + if not explicit_model_name and len(served_model_names) == 1: + matched_name = served_model_names[0] + return matched_name, ( + f"Using the only served model `{matched_name}` returned by `/v1/models` " + f"instead of `{requested_model_name}`." + ) + + if note is not None: + note = ( + f"{note} Available served models: {', '.join(served_model_names) or '(none)'}. " + f"Using `{normalized_name}`." + ) + return normalized_name, note + + def gen_random_token_ids(tokenizer, n: int, rng: random.Random) -> List[int]: vocab = tokenizer.vocab_size return [rng.randint(0, vocab - 1) for _ in range(n)] @@ -104,10 +195,13 @@ def append_turn_input( async def stream_one_turn( session: aiohttp.ClientSession, + tokenizer, url: str, model_name: str, prompt: str, + prompt_token_len: int, max_new_tokens: int, + max_retries: int = _DEFAULT_TRANSIENT_RETRIES, ) -> Optional[Dict]: """Send one streaming completion request, return per-turn stats: { @@ -116,6 +210,8 @@ async def stream_one_turn( "prompt_tokens": int, "completion_tokens": int, "cached_tokens": int, + "cached_tokens_reported": bool, + "usage_estimated": bool, "generated_text": str, } Returns None on failure.""" @@ -130,74 +226,111 @@ async def stream_one_turn( } headers = {"Content-Type": "application/json"} - start_time = time.time() - first_token_time: Optional[float] = None - last_token_time: Optional[float] = None - decode_times: List[float] = [] - generated_text_parts: List[str] = [] - prompt_tokens = 0 - completion_tokens = 0 - cached_tokens = 0 - - try: - async with session.post(url, headers=headers, json=payload) as response: - if response.status != 200: - err = await response.text() - print(f"\n[turn failed] status={response.status} body={err[:200]}") - return None - - async for raw in response.content: - line = raw.strip() - if not line or not line.startswith(b"data:"): - continue - data_str = line[len(b"data:"):].strip() - if data_str == b"[DONE]": - break - try: - chunk = json.loads(data_str) - except Exception: - continue - - # Final usage-only chunk: choices == [] and usage present - usage = chunk.get("usage") - choices = chunk.get("choices") or [] - if usage is not None and not choices: - prompt_tokens = usage.get("prompt_tokens", prompt_tokens) - completion_tokens = usage.get("completion_tokens", completion_tokens) - details = usage.get("prompt_tokens_details") or {} - cached_tokens = details.get("cached_tokens", cached_tokens) - continue - - # Token-bearing chunk - if not choices: - continue - text_piece = choices[0].get("text", "") - if text_piece == "" and choices[0].get("finish_reason") is None: - continue - - now = time.time() - if first_token_time is None: - first_token_time = now - else: - decode_times.append(now - last_token_time) - last_token_time = now - if text_piece: - generated_text_parts.append(text_piece) - except Exception as e: - print(f"\n[turn exception] {e}") - return None - - if first_token_time is None: - return None - - return { - "ttft": first_token_time - start_time, - "decode_times": decode_times, - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - "cached_tokens": cached_tokens, - "generated_text": "".join(generated_text_parts), - } + for attempt in range(max_retries + 1): + start_time = time.time() + first_token_time: Optional[float] = None + last_token_time: Optional[float] = None + decode_times: List[float] = [] + generated_text_parts: List[str] = [] + prompt_tokens = 0 + completion_tokens = 0 + cached_tokens = 0 + cached_tokens_reported = False + + try: + async with session.post(url, headers=headers, json=payload) as response: + if response.status != 200: + err = await response.text() + if response.status >= 500 and attempt < max_retries: + await asyncio.sleep(0.2 * (attempt + 1)) + continue + print(f"\n[turn failed] status={response.status} body={err[:200]}") + return None + + async for raw in response.content: + line = raw.strip() + if not line or not line.startswith(b"data:"): + continue + data_str = line[len(b"data:") :].strip() + if data_str == b"[DONE]": + break + try: + chunk = json.loads(data_str) + except Exception: + continue + + # Final usage-only chunk: choices == [] and usage present + usage = chunk.get("usage") + choices = chunk.get("choices") or [] + if usage is not None and not choices: + prompt_tokens = usage.get("prompt_tokens", prompt_tokens) + completion_tokens = usage.get("completion_tokens", completion_tokens) + details = usage.get("prompt_tokens_details") + if isinstance(details, dict) and details.get("cached_tokens") is not None: + cached_tokens = details["cached_tokens"] + cached_tokens_reported = True + continue + + # Token-bearing chunk + if not choices: + continue + text_piece = choices[0].get("text", "") + if text_piece == "" and choices[0].get("finish_reason") is None: + continue + + now = time.time() + if first_token_time is None: + first_token_time = now + else: + decode_times.append(now - last_token_time) + last_token_time = now + if text_piece: + generated_text_parts.append(text_piece) + except _TRANSIENT_STREAM_ERRORS as e: + if first_token_time is None and attempt < max_retries: + await asyncio.sleep(0.2 * (attempt + 1)) + continue + + if first_token_time is not None: + generated_text = "".join(generated_text_parts) + estimated_completion_tokens = len(tokenizer.encode(generated_text, add_special_tokens=False)) + estimated_completion_tokens = max(estimated_completion_tokens, len(generated_text_parts)) + print(f"\n[turn warning] {e}; keeping partial turn with estimated usage " f"(attempt={attempt + 1})") + return { + "ttft": first_token_time - start_time, + "decode_times": decode_times, + "prompt_tokens": prompt_tokens or prompt_token_len, + "completion_tokens": completion_tokens or estimated_completion_tokens, + "cached_tokens": cached_tokens, + "cached_tokens_reported": cached_tokens_reported, + "usage_estimated": completion_tokens == 0 or prompt_tokens == 0, + "generated_text": generated_text, + } + + print(f"\n[turn exception] {e}") + return None + except Exception as e: + print(f"\n[turn exception] {e}") + return None + + if first_token_time is None: + if attempt < max_retries: + await asyncio.sleep(0.2 * (attempt + 1)) + continue + return None + + return { + "ttft": first_token_time - start_time, + "decode_times": decode_times, + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "cached_tokens": cached_tokens, + "cached_tokens_reported": cached_tokens_reported, + "usage_estimated": False, + "generated_text": "".join(generated_text_parts), + } + + return None async def run_session( @@ -219,17 +352,13 @@ async def run_session( """Run a single multi-turn dialogue session. Returns a list of per-turn stat dicts (same schema as stream_one_turn output).""" rng = random.Random(base_seed + session_id) - prompt, prompt_len = gen_session_initial_prompt( - tokenizer, start_input_len, base_seed + session_id - ) + prompt, prompt_len = gen_session_initial_prompt(tokenizer, start_input_len, base_seed + session_id) per_turn: List[Dict] = [] turn_idx = 0 while turn_idx < max_turns and prompt_len < max_input_len: turn_output_len = rng.randint(min_output_len, output_len) - result = await stream_one_turn( - session, url, model_name, prompt, turn_output_len - ) + result = await stream_one_turn(session, tokenizer, url, model_name, prompt, prompt_len, turn_output_len) if result is None: break per_turn.append(result) @@ -271,7 +400,7 @@ async def run_concurrency_level( ) -> Dict: """Run one concurrency level. Returns the aggregated stats dict.""" timeout = aiohttp.ClientTimeout(total=request_timeout_s) - connector = aiohttp.TCPConnector(limit=max(concurrency * 2, 32)) + connector = aiohttp.TCPConnector(limit=max(concurrency * 2, 32), enable_cleanup_closed=True) progress_state = { "concurrency": concurrency, "finished_turns": 0, @@ -279,7 +408,12 @@ async def run_concurrency_level( } wall_start = time.time() - async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session: + async with aiohttp.ClientSession( + connector=connector, + timeout=timeout, + read_bufsize=_STREAM_READ_BUFSIZE, + max_line_size=_STREAM_MAX_LINE_SIZE, + ) as session: tasks = [ asyncio.create_task( run_session( @@ -345,13 +479,14 @@ def summarize( prompt_tokens = sum(t["prompt_tokens"] for t in turns) completion_tokens = sum(t["completion_tokens"] for t in turns) cached_tokens = sum(t["cached_tokens"] for t in turns) + cached_tokens_reported_turns = sum(1 for t in turns if t.get("cached_tokens_reported")) + usage_estimated_turns = sum(1 for t in turns if t.get("usage_estimated")) total_tokens = prompt_tokens + completion_tokens qps = len(turns) / wall_time tpm_total = total_tokens / wall_time * 60.0 tpm_prompt = prompt_tokens / wall_time * 60.0 tpm_completion = completion_tokens / wall_time * 60.0 - cache_hit_ratio = cached_tokens / prompt_tokens if prompt_tokens else 0.0 out["QPS"] = round(qps, 4) out["TPM_total"] = round(tpm_total, 2) @@ -360,7 +495,18 @@ def summarize( out["total_prompt_tokens"] = prompt_tokens out["total_completion_tokens"] = completion_tokens out["total_cached_prompt_tokens"] = cached_tokens - out["cache_hit_ratio"] = round(cache_hit_ratio, 6) + out["cached_tokens_reported_turns"] = cached_tokens_reported_turns + out["usage_estimated_turns"] = usage_estimated_turns + if cached_tokens_reported_turns > 0: + cache_hit_ratio = cached_tokens / prompt_tokens if prompt_tokens else 0.0 + out["cache_hit_ratio"] = round(cache_hit_ratio, 6) + else: + out["cache_hit_ratio"] = None + out["cache_hit_ratio_note"] = ( + "Server did not return usage.prompt_tokens_details.cached_tokens. " + "For vLLM OpenAI-compatible APIs, start the server with " + "--enable-prompt-tokens-details to expose cache-hit stats." + ) out["avg_prompt_tokens_per_turn"] = round(prompt_tokens / len(turns), 2) out["avg_completion_tokens_per_turn"] = round(completion_tokens / len(turns), 2) @@ -382,8 +528,10 @@ def summarize( def print_summary(summary: Dict) -> None: print("=" * 80) - print(f"Concurrency = {summary['concurrency']} sessions = {summary['num_sessions']} " - f"total_turns = {summary['total_turns']} wall_time = {summary['wall_time_s']}s") + print( + f"Concurrency = {summary['concurrency']} sessions = {summary['num_sessions']} " + f"total_turns = {summary['total_turns']} wall_time = {summary['wall_time_s']}s" + ) if "error" in summary: print(f" ERROR: {summary['error']}") return @@ -391,19 +539,31 @@ def print_summary(summary: Dict) -> None: print(f" TPM (total) : {summary['TPM_total']}") print(f" TPM (prompt) : {summary['TPM_prompt']}") print(f" TPM (completion) : {summary['TPM_completion']}") - print(f" Cache hit ratio : {summary['cache_hit_ratio'] * 100:.2f}% " - f"({summary['total_cached_prompt_tokens']} / {summary['total_prompt_tokens']})") + if summary["cache_hit_ratio"] is None: + print(" Cache hit ratio : n/a") + print(f" Cache hit note : {summary['cache_hit_ratio_note']}") + else: + print( + f" Cache hit ratio : {summary['cache_hit_ratio'] * 100:.2f}% " + f"({summary['total_cached_prompt_tokens']} / {summary['total_prompt_tokens']})" + ) + if summary.get("usage_estimated_turns"): + print(f" Usage estimated : {summary['usage_estimated_turns']} turns") print(f" Avg prompt tokens : {summary['avg_prompt_tokens_per_turn']}") print(f" Avg output tokens : {summary['avg_completion_tokens_per_turn']}") ttft = summary["TTFT_ms"] tpot = summary["TPOT_ms"] - print(f" TTFT ms mean={ttft['mean']} P50={ttft.get('P50')} P90={ttft.get('P90')} " - f"P95={ttft.get('P95')} P99={ttft.get('P99')}") + print( + f" TTFT ms mean={ttft['mean']} P50={ttft.get('P50')} P90={ttft.get('P90')} " + f"P95={ttft.get('P95')} P99={ttft.get('P99')}" + ) if tpot.get("mean") is None: print(f" TPOT ms (n/a: {tpot.get('note')})") else: - print(f" TPOT ms mean={tpot['mean']} P50={tpot.get('P50')} P90={tpot.get('P90')} " - f"P95={tpot.get('P95')} P99={tpot.get('P99')}") + print( + f" TPOT ms mean={tpot['mean']} P50={tpot.get('P50')} P90={tpot.get('P90')} " + f"P95={tpot.get('P95')} P99={tpot.get('P99')}" + ) def main() -> None: @@ -411,9 +571,9 @@ def main() -> None: parser.add_argument( "--url", type=str, - default="http://127.0.0.1:8088/v1/completions", + default="http://127.0.0.1:8000/v1/completions", help="Streaming OpenAI completion endpoint. The benchmark relies on " - "the final SSE `usage` chunk to obtain cached_tokens.", + "the final SSE `usage` chunk to obtain cached_tokens.", ) parser.add_argument("--tokenizer_path", type=str, required=True) parser.add_argument( @@ -428,22 +588,29 @@ def main() -> None: default="1,4,8,16,32,64,128,256", help="Comma-separated list of concurrency levels to sweep.", ) - parser.add_argument("--start_input_len", type=int, default=32768, - help="Initial prompt length in tokens per session.") - parser.add_argument("--max_input_len", type=int, default=163840, - help="Stop a session when its prompt exceeds this length.") - parser.add_argument("--turn_input_increment", type=int, default=2048, - help="Maximum new 'user' tokens sampled after each turn, on top " - "of the model's generated text.") - parser.add_argument("--min_turn_input_increment", type=int, default=512, - help="Minimum new 'user' tokens sampled after each turn.") - parser.add_argument("--output_len", type=int, default=512, - help="Maximum max_new_tokens sampled per turn.") - parser.add_argument("--min_output_len", type=int, default=128, - help="Minimum max_new_tokens sampled per turn.") - parser.add_argument("--max_turns", type=int, default=64, - help="Hard cap on turns per session. The session also stops once " - "prompt length reaches --max_input_len.") + parser.add_argument( + "--start_input_len", type=int, default=32768, help="Initial prompt length in tokens per session." + ) + parser.add_argument( + "--max_input_len", type=int, default=163840, help="Stop a session when its prompt exceeds this length." + ) + parser.add_argument( + "--turn_input_increment", + type=int, + default=2048, + help="Maximum new 'user' tokens sampled after each turn, on top " "of the model's generated text.", + ) + parser.add_argument( + "--min_turn_input_increment", type=int, default=512, help="Minimum new 'user' tokens sampled after each turn." + ) + parser.add_argument("--output_len", type=int, default=512, help="Maximum max_new_tokens sampled per turn.") + parser.add_argument("--min_output_len", type=int, default=128, help="Minimum max_new_tokens sampled per turn.") + parser.add_argument( + "--max_turns", + type=int, + default=64, + help="Hard cap on turns per session. The session also stops once " "prompt length reaches --max_input_len.", + ) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--request_timeout_s", type=int, default=3600) parser.add_argument( @@ -451,7 +618,7 @@ def main() -> None: type=str, default="", help="If set, append the per-concurrency summary dict to this JSON file. " - "If the file already exists and is non-empty, it is read and printed.", + "If the file already exists and is non-empty, it is read and printed.", ) args = parser.parse_args() @@ -471,12 +638,19 @@ def main() -> None: return seed_all(args.seed) - model_name = args.model_name or args.tokenizer_path + requested_model_name = args.model_name or args.tokenizer_path + model_name, model_name_note = resolve_model_name( + args.url, + requested_model_name, + explicit_model_name=args.model_name is not None, + ) tokenizer = get_tokenizer(args.tokenizer_path) concurrency_levels = [int(x) for x in args.concurrency_levels.split(",") if x.strip()] print(f"URL : {args.url}") print(f"Model : {model_name}") + if model_name_note: + print(f"Model note : {model_name_note}") print(f"Concurrency levels : {concurrency_levels}") print(f"start_input_len : {args.start_input_len}") print(f"max_input_len : {args.max_input_len}") @@ -517,6 +691,7 @@ def main() -> None: "config": { "url": args.url, "model_name": model_name, + "requested_model_name": requested_model_name, "tokenizer_path": args.tokenizer_path, "concurrency_levels": concurrency_levels, "start_input_len": args.start_input_len, diff --git a/test/benchmark/service/benchmark_qps.py b/test/benchmark/service/benchmark_qps.py index 8249ae2c49..3249ebcbda 100644 --- a/test/benchmark/service/benchmark_qps.py +++ b/test/benchmark/service/benchmark_qps.py @@ -31,6 +31,13 @@ def get_tokenizer( return tokenizer +def normalize_model_name(model_name: str) -> str: + if not model_name: + return model_name + normalized = model_name.rstrip("/\\") + return normalized or model_name + + def get_random_length(reqs_num: int, length: int, range_ratio: float) -> List[int]: lens = [] lens = np.random.randint( @@ -429,7 +436,7 @@ def main(): return assert args.tokenizer_path is not None - model_name.append(args.tokenizer_path) + model_name.append(normalize_model_name(args.tokenizer_path)) seed_all(args.seed) url = args.url tokenizer = get_tokenizer(args.tokenizer_path) diff --git a/unit_tests/common/basemodel/triton_kernel/test_repack_kv_index.py b/unit_tests/common/basemodel/triton_kernel/test_repack_kv_index.py index b5184d3caa..0bab0ae540 100644 --- a/unit_tests/common/basemodel/triton_kernel/test_repack_kv_index.py +++ b/unit_tests/common/basemodel/triton_kernel/test_repack_kv_index.py @@ -1,7 +1,8 @@ import torch import pytest from lightllm.utils.log_utils import init_logger -from lightllm.common.basemodel.triton_kernel.repack_kv_index import repack_kv_index +from lightllm.common.basemodel.triton_kernel.repack_kv_index import repack_kv_index, paged_repack_kv_index +from lightllm.utils.envs_utils import get_page_size logger = init_logger(__name__) @@ -41,3 +42,49 @@ def repack_kv_ref(req_to_token_indexs, b_req_idx, b_seq_len, b_start_loc, output repack_kv_ref(req_to_token_indexs, b_req_idx, b_seq_len, b_start_loc, ref) repack_kv_index(req_to_token_indexs, b_req_idx, b_seq_len, b_start_loc, MAX_SEQ_LEN, output) assert torch.allclose(output.float(), ref.float()) + + +@pytest.mark.parametrize( + "batch, max_seq_len, page_size", + [ + (1, 16, 4), + (8, 32, 4), + (16, 128, 8), + ], +) +def test_paged_repack_kv_index(batch, max_seq_len, page_size, monkeypatch): + def repack_page_kv_ref(req_to_token_indexs, b_req_idx, b_page_len, b_start_loc, output, page_size): + for b, sl, start in zip(b_req_idx, b_page_len, b_start_loc): + output[start : start + sl] = req_to_token_indexs[b][: sl * page_size : page_size] // page_size + + BATCH, MAX_SEQ_LEN = batch, max_seq_len + max_page_len = (MAX_SEQ_LEN + page_size - 1) // page_size + total_token_len = 2 * MAX_SEQ_LEN + total_page_len = (total_token_len + page_size - 1) // page_size + + req_to_token_indexs = torch.empty((2 * BATCH, total_token_len), dtype=torch.int32, device="cuda") + page_offsets = torch.arange(page_size, dtype=torch.int32, device="cuda") + for row in range(2 * BATCH): + page_ids = torch.arange(row * total_page_len, (row + 1) * total_page_len, dtype=torch.int32, device="cuda") + req_to_token_indexs[row] = (page_ids[:, None] * page_size + page_offsets[None, :]).reshape(-1)[:total_token_len] + + b_req_idx = torch.randperm(BATCH, device="cuda", dtype=torch.int32) + b_seq_len = torch.randint(1, MAX_SEQ_LEN + 1, (BATCH,), device="cuda", dtype=torch.int32) + b_page_len = (b_seq_len + page_size - 1) // page_size + b_start_loc = torch.cat( + [torch.zeros((1,), dtype=torch.int32, device="cuda"), b_page_len[:-1].cumsum(dim=0, dtype=torch.int32)] + ) + + output = torch.zeros((b_page_len.sum(),), dtype=torch.int32, device="cuda") + ref = torch.zeros((b_page_len.sum(),), dtype=torch.int32, device="cuda") + + monkeypatch.setenv("PAGE_SIZE", str(page_size)) + get_page_size.cache_clear() + try: + repack_page_kv_ref(req_to_token_indexs, b_req_idx, b_page_len, b_start_loc, ref, page_size) + paged_repack_kv_index(req_to_token_indexs, b_req_idx, b_page_len, b_start_loc, max_page_len, output) + finally: + monkeypatch.delenv("PAGE_SIZE", raising=False) + get_page_size.cache_clear() + + assert torch.equal(output, ref)