Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
21c3eeb
support 31B
WANDY666 Apr 30, 2026
99b790c
fix
WANDY666 May 6, 2026
4c30c73
Merge branch 'main' of https://github.com/ModelTC/LightLLM into suppo…
WANDY666 May 6, 2026
15a5379
support moe
WANDY666 May 7, 2026
83f4983
support e4b (PLE and shared_kv)
WANDY666 May 9, 2026
d969a5f
support visual module
WANDY666 May 11, 2026
08f066d
optimize sliding window
WANDY666 May 12, 2026
7678de8
fix
WANDY666 May 12, 2026
63c658a
simplify
WANDY666 May 13, 2026
300e577
minor improvements
WANDY666 May 13, 2026
50822f0
fix
WANDY666 May 13, 2026
b4b13cc
fix attention cuda graph
WANDY666 May 13, 2026
f19074b
fused gelu gate up
WANDY666 May 14, 2026
5b61450
add out_dtype
WANDY666 May 14, 2026
c0ca212
minor improvements
WANDY666 May 14, 2026
9499a00
fix eos_token_ids
WANDY666 May 14, 2026
de7e220
for HF format
WANDY666 May 14, 2026
bfc59ff
Merge branch 'main' of https://github.com/ModelTC/LightLLM into suppo…
WANDY666 May 14, 2026
109d27c
fix window_size
WANDY666 May 14, 2026
2ea258e
fix window_size
WANDY666 May 14, 2026
b297af5
fix
WANDY666 May 14, 2026
7a81e85
add reasoning_parser for gemma4
WANDY666 May 15, 2026
d619534
[fix]ple support cudagraph
WANDY666 May 16, 2026
c2578c0
fix PLE illegal memory access
WANDY666 May 18, 2026
d744cbc
support sliding_window_right
WANDY666 May 18, 2026
05a0db8
fix notes
WANDY666 May 18, 2026
6f1bd2e
tune in H200
WANDY666 May 19, 2026
90643db
fix
hiworldwzj May 19, 2026
a2b74ab
fix
hiworldwzj May 19, 2026
e606e05
fix
hiworldwzj May 19, 2026
7354da2
fix
hiworldwzj May 20, 2026
afa0194
fix
hiworldwzj May 20, 2026
46ce6af
fix
hiworldwzj May 20, 2026
0188c10
fix
hiworldwzj May 20, 2026
393ec69
fix
hiworldwzj May 20, 2026
e96c2b7
fix
WANDY666 May 20, 2026
f806326
fix
hiworldwzj May 20, 2026
c5b2b81
fix
hiworldwzj May 20, 2026
3bd46d7
fix
hiworldwzj May 20, 2026
91051f0
Merge branch 'support_gemma4' of https://github.com/ModelTC/LightLLM …
WANDY666 May 20, 2026
fb75045
fix
WANDY666 May 20, 2026
7c664c3
fix
hiworldwzj May 20, 2026
0d35e8b
fix
hiworldwzj May 20, 2026
74a4b1f
fix
hiworldwzj May 20, 2026
d2df0a0
fix
hiworldwzj May 20, 2026
3491641
fix
hiworldwzj May 20, 2026
c8812f2
fix
hiworldwzj May 20, 2026
8f160b5
fix
hiworldwzj May 20, 2026
ee92fee
fix
hiworldwzj May 21, 2026
131a163
fix
hiworldwzj May 21, 2026
6d7729f
fix
hiworldwzj May 21, 2026
87da477
fix
WANDY666 May 21, 2026
819497c
Merge branch 'main' of https://github.com/ModelTC/LightLLM into suppo…
WANDY666 May 21, 2026
c57e062
format
WANDY666 May 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions lightllm/common/basemodel/attention/triton/fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ def prefill_att(
att_control: AttControl = AttControl(),
alloc_func=torch.empty,
) -> torch.Tensor:
assert att_control.use_sliding_window is False and att_control.use_att_sink is False
if att_control.use_alibi:
assert att_control.use_sliding_window is False, "alibi + sliding_window not supported"
assert att_control.tp_alibi is not None
return self._alibi_prefill_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func)
else:
return self._nomarl_prefill_att(q=q, k=k, v=v, alloc_func=alloc_func)
return self._nomarl_prefill_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func)

def _alibi_prefill_att(
self,
Expand Down Expand Up @@ -59,9 +59,21 @@ def _alibi_prefill_att(
)
return out

def _nomarl_prefill_att(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, alloc_func=torch.empty):
def _nomarl_prefill_att(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
att_control: AttControl = AttControl(),
alloc_func=torch.empty,
):
from ...triton_kernel.att.prefill_att.context_flashattention_nopad import context_attention_fwd

if att_control.use_sliding_window:
sliding_window = att_control.sliding_window
else:
sliding_window = (-1, -1)

out = alloc_func(q.shape, q.dtype)
context_attention_fwd(
q,
Expand All @@ -74,6 +86,7 @@ def _nomarl_prefill_att(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
self.infer_state.b_ready_cache_len,
self.infer_state.max_q_seq_len,
self.infer_state.req_manager.req_to_token_indexs,
sliding_window=sliding_window,
)
return out

Expand All @@ -94,17 +107,20 @@ def decode_att(
att_control: AttControl = AttControl(),
alloc_func=torch.empty,
):
assert att_control.use_sliding_window is False and att_control.use_att_sink is False
if att_control.use_alibi:
assert att_control.use_sliding_window is False, "alibi + sliding_window not supported"
assert att_control.tp_alibi is not None
return self._alibi_decode_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func)
else:
q_head_num = q.shape[1]
k_head_num = k.shape[1]
if q_head_num == k_head_num:
assert att_control.use_sliding_window is False, "sliding_window not supported in non-gqa attention yet"
return self._normal_decode_flash_decoding_att(q=q, k=k, v=v, alloc_func=alloc_func)
elif q_head_num > k_head_num:
return self._normal_decode_gqa_flash_decoding_att(q=q, k=k, v=v, alloc_func=alloc_func)
return self._normal_decode_gqa_flash_decoding_att(
q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func
)
else:
raise NotImplementedError("error")

Expand Down Expand Up @@ -163,12 +179,18 @@ def _normal_decode_gqa_flash_decoding_att(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
att_control: AttControl = AttControl(),
alloc_func=torch.empty,
):
from ...triton_kernel.att.decode_att.gqa.flash_decoding.gqa_flash_decoding import (
gqa_token_decode_attention_flash_decoding,
)

if att_control.use_sliding_window:
sliding_window = att_control.sliding_window
else:
sliding_window = (-1, -1)

out = alloc_func(q.shape, q.dtype)

gqa_token_decode_attention_flash_decoding(
Expand All @@ -178,6 +200,7 @@ def _normal_decode_gqa_flash_decoding_att(
cache_v=v,
out=out,
alloc_tensor_func=alloc_func,
sliding_window=sliding_window,
)

return out
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,13 @@ def _context_attention_wrapper_run(
) -> torch.Tensor:
if torch.cuda.is_current_stream_capturing():
q = q.contiguous()
cache_kv = cache_kv.contiguous()
_q, _cache_kv = (
tensor_to_no_ref_tensor(q),
tensor_to_no_ref_tensor(cache_kv),
)
# cache_kv is None for layers that own no K/V slot (e.g. gemma4
# KV-shared layers, which read K/V from a prior layer's cache and
# ignore this arg in _context_attention_kernel). Skip the
# graph-input plumbing for it instead of crashing on None.
cache_kv = cache_kv.contiguous() if cache_kv is not None else None
_q = tensor_to_no_ref_tensor(q)
_cache_kv = tensor_to_no_ref_tensor(cache_kv) if cache_kv is not None else None
pre_capture_graph = infer_state.prefill_cuda_graph_get_current_capture_graph()
pre_capture_graph.__exit__(None, None, None)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,16 @@ def __init__(
num_fused_shared_experts: int = 0,
layer_num: int = 0,
network_config: Dict[str, Any] = None,
per_expert_scale_name: str = "",
) -> None:
super().__init__(data_type=data_type)
self.w1_weight_name = gate_proj_name
self.w2_weight_name = down_proj_name
self.w3_weight_name = up_proj_name
self.e_score_correction_bias_name = e_score_correction_bias_name
# gemma4 的专家计算出的值都需要一个 scale 值,每个专家有自己独立的scale参数
# per_expert_scale_name 是专家的scale参数权重的名称, 为 "" 表示没有专家独立的scale参数
self.per_expert_scale_name = per_expert_scale_name
self.weight_prefix = weight_prefix
self.layer_num_ = layer_num
self.global_rank_ = get_global_rank()
Expand Down Expand Up @@ -145,6 +149,7 @@ def experts(
topk_group=topk_group,
num_expert_group=num_expert_group,
is_prefill=is_prefill,
per_expert_scale=self.per_expert_scale,
)

def low_latency_dispatch(
Expand Down Expand Up @@ -261,25 +266,42 @@ def combine(

def load_hf_weights(self, weights):
# Load bias
if self.e_score_correction_bias_name in weights:
self.e_score_correction_bias.copy_(weights[self.e_score_correction_bias_name])
self._load_e_score_correction_bias(weights)
self._load_per_expert_scale(weights)
self._load_weight(self.expert_idx_to_local_idx, weights)
if self.redundancy_expert_num > 0:
self._load_weight(self.redundancy_expert_idx_to_local_idx, weights)

def verify_load(self):
return all(all(_weight_pack.load_ok) for _weight_pack in self.w1_list + self.w2_list + self.w3_list)
weight_load_ok = all(all(_weight_pack.load_ok) for _weight_pack in self.w1_list + self.w2_list + self.w3_list)
per_expert_scale_load_ok = (
True if self.per_expert_scale is None else getattr(self.per_expert_scale, "load_ok", False)
)
e_score_correction_bias_load_ok = (
True if self.e_score_correction_bias is None else getattr(self.e_score_correction_bias, "load_ok", False)
)
return weight_load_ok and per_expert_scale_load_ok and e_score_correction_bias_load_ok

def _create_weight(self):
intermediate_size = self.split_inter_size
self.e_score_correction_bias = None
self.per_expert_scale = None
# Create e_score_correction_bias
if self.e_score_correction_bias_name:
self.e_score_correction_bias = torch.empty(
(self.n_routed_experts,),
dtype=self.data_type_,
device=f"cuda:{self.device_id_}",
)
self.e_score_correction_bias.load_ok = False

if self.per_expert_scale_name:
self.per_expert_scale = torch.empty(
(self.n_routed_experts,),
dtype=torch.float32,
device=f"cuda:{self.device_id_}",
)
self.per_expert_scale.load_ok = False

self.w13, w13_param_list = self.quant_method.create_moe_weight(
out_dims=[intermediate_size, intermediate_size],
Expand All @@ -299,6 +321,16 @@ def _create_weight(self):
self.w3_list: List[WeightPack] = self._get_expert_weight_list(w13_param_list[1])
self.w2_list: List[WeightPack] = self._get_expert_weight_list(self.w2)

def _load_e_score_correction_bias(self, weights: Dict[str, torch.Tensor]):
if self.e_score_correction_bias_name and self.e_score_correction_bias_name in weights:
self.e_score_correction_bias.copy_(weights[self.e_score_correction_bias_name])
self.e_score_correction_bias.load_ok = True

def _load_per_expert_scale(self, weights: Dict[str, torch.Tensor]):
if self.per_expert_scale_name and self.per_expert_scale_name in weights:
self.per_expert_scale.copy_(weights[self.per_expert_scale_name].to(self.per_expert_scale.dtype))
self.per_expert_scale.load_ok = True

def _get_expert_weight_list(self, weight_pack: WeightPack):
weight_list = []
for idx in range(self.local_n_routed_experts):
Expand All @@ -307,7 +339,6 @@ def _get_expert_weight_list(self, weight_pack: WeightPack):
return weight_list

def _load_weight(self, expert_idx_to_local_idx: Dict[int, int], weights: Dict[str, torch.Tensor]):

# Load each expert with TP slicing
for expert_idx, local_expert_idx in expert_idx_to_local_idx.items():
with self.lock:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe.fused_moe_weight import FusedMoeWeight


class Gemma4PackedFusedMoeWeight(FusedMoeWeight):
def load_hf_weights(self, weights):
# 将权重名称的格式对齐基类的统一加载格式。
gate_up_name = f"{self.weight_prefix}.gate_up_proj"
down_name = f"{self.weight_prefix}.down_proj"
assert not self.enable_ep_moe, "Gemma-4 packed MoE currently supports TP mode only."
moe_intermediate_size = self.moe_intermediate_size

if gate_up_name in weights:
gate_up_weight = weights[gate_up_name]

for expert_idx in range(self.n_routed_experts):
expert_gate_weight = gate_up_weight[expert_idx, :moe_intermediate_size, :]
expert_up_weight = gate_up_weight[expert_idx, moe_intermediate_size:, :]

weights[
f"{self.weight_prefix}.{expert_idx}.{self.w1_weight_name}.{self.quant_method.weight_suffix}"
] = expert_gate_weight
weights[
f"{self.weight_prefix}.{expert_idx}.{self.w3_weight_name}.{self.quant_method.weight_suffix}"
] = expert_up_weight

del weights[gate_up_name]

if down_name in weights:
down_weight = weights[down_name]
for expert_idx in range(self.n_routed_experts):
expert_down_weight = down_weight[expert_idx, :, :]
weights[
f"{self.weight_prefix}.{expert_idx}.{self.w2_weight_name}.{self.quant_method.weight_suffix}"
] = expert_down_weight
del weights[down_name]

super().load_hf_weights(weights)
return
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,6 @@ def __call__(
topk_group: int,
num_expert_group: int,
is_prefill: Optional[bool] = None,
per_expert_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def _select_experts(
topk_group: int,
num_expert_group: int,
scoring_func: str,
per_expert_scale: Optional[torch.Tensor] = None,
):
"""Select experts and return topk weights and ids."""
from lightllm.common.basemodel.triton_kernel.fused_moe.topk_select import select_experts
Expand All @@ -48,6 +49,8 @@ def _select_experts(
)
if self.routed_scaling_factor != 1.0:
topk_weights.mul_(self.routed_scaling_factor)
if per_expert_scale is not None:
topk_weights = topk_weights * per_expert_scale[topk_ids.to(torch.long)].to(topk_weights.dtype)
if self.redundancy_expert_num > 0:
redundancy_topk_ids_repair(
topk_ids=topk_ids,
Expand All @@ -69,7 +72,6 @@ def _fused_experts(
router_logits: Optional[torch.Tensor] = None,
is_prefill: Optional[bool] = None,
):

w13_weight, w13_scale = w13.weight, w13.weight_scale
w2_weight, w2_scale = w2.weight, w2.weight_scale
use_fp8_w8a8 = self.quant_method.method_name != "none"
Expand Down Expand Up @@ -214,7 +216,14 @@ def masked_group_gemm(
w13_weight, w13_scale = w13.weight, w13.weight_scale
w2_weight, w2_scale = w2.weight, w2.weight_scale
return masked_group_gemm(
recv_x, masked_m, dtype, w13_weight, w13_scale, w2_weight, w2_scale, expected_m=expected_m
recv_x,
masked_m,
dtype,
w13_weight,
w13_scale,
w2_weight,
w2_scale,
expected_m=expected_m,
)

def prefilled_group_gemm(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
AWQMARLINW4A16QuantizationMethod,
)
from typing import Optional
from lightllm.utils.config_utils import ffn_use_tanh_approximate_gelu


class FuseMoeMarlin(FuseMoeTriton):
Expand Down Expand Up @@ -38,6 +39,8 @@ def _fused_experts(

self.quant_method: AWQMARLINW4A16QuantizationMethod = self.quant_method

activation = "silu" if not ffn_use_tanh_approximate_gelu() else "gelu"

fused_marlin_moe(
input_tensor,
w1_weight,
Expand All @@ -52,6 +55,7 @@ def _fused_experts(
quant_type_id=self.quant_method.vllm_quant_type.id,
apply_router_weight_on_input=False,
global_num_experts=-1,
activation=activation,
expert_map=None,
w1_zeros=w1_zero_point,
w2_zeros=w2_zero_point,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def _select_experts(
topk_group: int,
num_expert_group: int,
scoring_func: str,
per_expert_scale: Optional[torch.Tensor] = None,
):
"""Select experts and return topk weights and ids."""
from lightllm.common.basemodel.triton_kernel.fused_moe.topk_select import select_experts
Expand All @@ -59,6 +60,8 @@ def _select_experts(
)
if self.routed_scaling_factor != 1.0:
topk_weights.mul_(self.routed_scaling_factor)
if per_expert_scale is not None:
topk_weights = topk_weights * per_expert_scale[topk_ids.to(torch.long)].to(topk_weights.dtype)
if self.num_fused_shared_experts > 0:
pad_topk_ids = (
torch.arange(
Expand Down Expand Up @@ -125,6 +128,7 @@ def __call__(
topk_group: int,
num_expert_group: int,
is_prefill: Optional[bool] = None,
per_expert_scale: Optional[torch.Tensor] = None,
):
topk_weights, topk_ids = self._select_experts(
input_tensor=input_tensor,
Expand All @@ -136,6 +140,7 @@ def __call__(
topk_group=topk_group,
num_expert_group=num_expert_group,
scoring_func=scoring_func,
per_expert_scale=per_expert_scale,
)
output = self._fused_experts(
input_tensor=input_tensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@


def gqa_token_decode_attention_flash_decoding(
q: torch.Tensor, infer_state, cache_k: torch.Tensor, cache_v: torch.Tensor, out=None, alloc_tensor_func=torch.empty
q: torch.Tensor,
infer_state,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
out=None,
alloc_tensor_func=torch.empty,
sliding_window=(-1, -1),
):
batch_size = infer_state.batch_size
q_head_num, head_dim = q.shape[1], q.shape[2]
Expand Down Expand Up @@ -39,12 +45,14 @@ def gqa_token_decode_attention_flash_decoding(
mid_out=mid_o,
mid_out_logsumexp=mid_o_logexpsum,
block_seq=BLOCK_SEQ,
sliding_window=sliding_window,
)
flash_decode_stage2(
mid_out=mid_o,
mid_out_logexpsum=mid_o_logexpsum,
B_Seqlen=infer_state.b_seq_len,
out=o_tensor.view(calcu_shape1),
block_seq=BLOCK_SEQ,
sliding_window=sliding_window,
)
return o_tensor
Loading
Loading