Skip to content
Open
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
__pycache__/
.pyc
.codex
build
dist
*.egg-info
Expand Down
30 changes: 30 additions & 0 deletions lightllm/common/basemodel/triton_kernel/mtp_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Optional
import triton
import triton.language as tl
import torch
Expand Down Expand Up @@ -93,10 +94,15 @@ def _fwd_kernel_mtp_scatter_next_token_ids(
req_to_next_token_ids_stride,
all_next_token_ids,
all_next_token_ids_stride,
req_to_next_token_probs,
req_to_next_token_probs_stride,
all_next_token_probs,
all_next_token_probs_stride,
mtp_accept_len,
b_req_mtp_start_loc,
b_req_idx,
mtp_step,
HAS_HAS_NEXT_TOKEN_PROBS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):

Expand All @@ -106,6 +112,17 @@ def _fwd_kernel_mtp_scatter_next_token_ids(
cur_req_idx = tl.load(b_req_idx + req_start_loc)
offset = tl.arange(0, BLOCK_SIZE)

if HAS_HAS_NEXT_TOKEN_PROBS:
cur_next_token_probs = tl.load(
all_next_token_probs + (req_start_loc + accept_len - 1) * all_next_token_probs_stride + offset,
mask=offset < mtp_step,
other=0.0,
)
tl.store(
req_to_next_token_probs + cur_req_idx * req_to_next_token_probs_stride + offset,
cur_next_token_probs,
mask=offset < mtp_step,
)
scatter_next_token_ids = tl.load(
all_next_token_ids + (req_start_loc + accept_len - 1) * all_next_token_ids_stride + offset,
mask=offset < mtp_step,
Expand All @@ -125,23 +142,36 @@ def mtp_scatter_next_token_ids(
all_next_token_ids: torch.Tensor,
b_req_idx: torch.Tensor,
mtp_accept_len: torch.Tensor,
req_to_next_token_probs: Optional[torch.Tensor] = None,
all_next_token_probs: Optional[torch.Tensor] = None,
):
max_mtp_step = req_to_next_token_ids.shape[1]
BLOCK_SIZE = 16
assert max_mtp_step <= BLOCK_SIZE, f"max_mtp_step must be less than {BLOCK_SIZE}"
num_reqs = b_req_mtp_start_loc.shape[0]
mtp_step = all_next_token_ids.shape[1]
if req_to_next_token_probs is not None:
assert all_next_token_probs is not None
assert all_next_token_probs.shape == all_next_token_ids.shape

HAS_HAS_NEXT_TOKEN_PROBS = req_to_next_token_probs is not None

grid = (num_reqs,)
num_warps = 1
_fwd_kernel_mtp_scatter_next_token_ids[grid](
req_to_next_token_ids=req_to_next_token_ids,
req_to_next_token_ids_stride=req_to_next_token_ids.stride(0),
all_next_token_ids=all_next_token_ids,
all_next_token_ids_stride=all_next_token_ids.stride(0),
req_to_next_token_probs=req_to_next_token_probs,
req_to_next_token_probs_stride=req_to_next_token_probs.stride(0),
all_next_token_probs=all_next_token_probs,
all_next_token_probs_stride=all_next_token_probs.stride(0),
mtp_accept_len=mtp_accept_len,
b_req_mtp_start_loc=b_req_mtp_start_loc,
b_req_idx=b_req_idx,
mtp_step=mtp_step,
HAS_HAS_NEXT_TOKEN_PROBS=HAS_HAS_NEXT_TOKEN_PROBS,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
num_stages=1,
Expand Down
14 changes: 13 additions & 1 deletion lightllm/common/req_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from lightllm.common.basemodel.triton_kernel.gen_sampling_params import token_id_counter
from lightllm.common.basemodel.triton_kernel.gen_sampling_params import update_req_to_token_id_counter
from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args
from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args, enable_dynamic_mtp_verify
from lightllm.utils.config_utils import get_vocab_size
from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager

Expand Down Expand Up @@ -116,6 +116,15 @@ def __init__(self, max_request_num):
dtype=torch.int64,
device="cuda",
)
if enable_dynamic_mtp_verify():
self.req_to_next_token_probs = torch.zeros(
(max_request_num + 1, 16),
dtype=torch.float32,
device="cuda",
)
else:
self.req_to_next_token_probs = None

self.req_to_exponential_decay_length_penalty = torch.zeros(
max_request_num + 1, dtype=torch.float32, device="cuda"
)
Expand All @@ -137,6 +146,9 @@ def init_req_sampling_params(self, req):

shm_param = req.sampling_param.shm_param
self.req_to_next_token_ids[req.req_idx][0:1].fill_(req.get_last_gen_token())
if enable_dynamic_mtp_verify():
self.req_to_next_token_probs[req.req_idx].fill_(0.0)
self.req_to_next_token_probs[req.req_idx][0:1].fill_(1.0)
self.req_to_presence_penalty[req.req_idx].fill_(shm_param.presence_penalty)
self.req_to_frequency_penalty[req.req_idx].fill_(shm_param.frequency_penalty)
self.req_to_repetition_penalty[req.req_idx].fill_(shm_param.repetition_penalty)
Expand Down
5 changes: 1 addition & 4 deletions lightllm/server/router/model_infer/infer_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,10 +435,7 @@ def __init__(
# mtp_step 用来记录一个请求 draft模型每步需要生成的token数量
# 正常模式下,这个值为0,在 mtp 模式下,这个值为 draft 模型每步需要生成的token数量
self.mtp_step: int = get_env_start_args().mtp_step
# current_mtp_step 用来记录当前的 MTP 验证长度(<= mtp_step)
# 在启用动态 MTP 验证时,每步会根据 prob 分布重新设置该值
# 静态模式下为 mtp_step,动态模式下为动态计算的 MTP 验证长度
self.current_mtp_step: int = self.mtp_step

if self.mtp_step > 0:
self.decode_need_token_num = self._mtp_decode_need_token_num
else:
Expand Down
19 changes: 12 additions & 7 deletions lightllm/server/router/model_infer/mode_backend/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import time
import threading
import torch.distributed as dist
import collections
from typing import List, Tuple, Callable, Optional
from transformers.configuration_utils import PretrainedConfig
from lightllm.utils.infer_utils import set_random_seed
Expand Down Expand Up @@ -775,14 +776,18 @@ def _update_mtp_accept_ratio(

return

def _update_mtp_verify_token_num(self, decode_reqs: List[InferReq]):
def _update_mtp_verify_token_num(
self, decode_reqs: List[InferReq], dynamic_mtp_run_reqs: Optional[List[InferReq]] = None
):
if self.is_master_in_dp:
for req in decode_reqs:
# 统计发送给主模型验证的 token 数量:1 个主 token + 当前 mtp_size 个 draft token
# 在静态 MTP 模式下,使用固定的 mtp_step;在动态 MTP 模式下,使用动态调整的 current_mtp_step
# current_mtp_step 在静态 MTP 模式下为 mtp_step,在动态 MTP 模式下会在推理过程中动态设置。
assert req.current_mtp_step >= 0
req.update_mtp_verify_token_num(verify_token_num=1 + req.current_mtp_step)
if dynamic_mtp_run_reqs is None:
for req in decode_reqs:
assert req.mtp_step > 0
req.update_mtp_verify_token_num(verify_token_num=1 + req.mtp_step)
else:
counter = collections.Counter([req.req_idx for req in dynamic_mtp_run_reqs])
for req in decode_reqs:
req.update_mtp_verify_token_num(verify_token_num=1 + counter[req.req_idx] - 1)
return

def _gen_argmax_token_ids(self, model_output: ModelOutput):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from lightllm.utils.dist_utils import get_current_device_id
from lightllm.utils.envs_utils import get_env_start_args, enable_dynamic_mtp_verify
from .control_state import ControlState
from lightllm.server.router.model_infer.mode_backend.dynamic_mtp_planner import DynamicMTPPlanner

logger = init_logger(__name__)

Expand All @@ -49,6 +50,9 @@ def __init__(self) -> None:
self.prefill = self.prefill_normal
self.decode = self.decode_normal

if self.enable_dynamic_mtp:
self.dynamic_mtp_planner = DynamicMTPPlanner(mtp_step=get_env_start_args().mtp_step)

self.classed_req_strict_prefill = False
return

Expand Down Expand Up @@ -236,17 +240,32 @@ def decode_mtp(
model_input, run_reqs = prepare_decode_inputs(decode_reqs)

with torch.cuda.stream(g_infer_context.get_overlap_stream()):
b_mtp_index_cpu = model_input.b_mtp_index

if self.enable_dynamic_mtp:
dynamic_batch_size = self.dynamic_mtp_planner.get_dynamic_batch_size(
req_num=len(decode_reqs),
original_batch_size=model_input.batch_size,
)
trans_to_dynamic_model_input = None # TODO: 需要根据实际情况实现 trans_to_dynamic_model_input
model_input, selected_run_reqs = trans_to_dynamic_model_input(model_input, dynamic_batch_size)
# selected_run_reqs 是一个 gpu tensor, 类型为 int, 0, 表示没有选中, 1 表示选中。

selected_run_reqs_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor(
key="selected_run_reqs",
gpu_tensor=selected_run_reqs,
)
trans_dynamic_model_input_event = torch.cuda.Event()
trans_dynamic_model_input_event.record()

start_time_event = torch.cuda.Event(enable_timing=True)
start_time_event.record()

model_output = self.model.forward(model_input)
next_token_ids, next_token_logprobs = sample(model_output.logits, run_reqs, self.eos_id)
# verify the next_token_ids
b_req_mtp_start_loc = [index for index, mtp_index in enumerate(b_mtp_index_cpu) if mtp_index == 0]
b_req_mtp_start_loc = g_pin_mem_manager.gen_from_list(
key="b_req_mtp_start_loc",
data=b_req_mtp_start_loc,
dtype=torch.int32,
).cuda(non_blocking=True)

get_b_req_mtp_start_loc = None # TODO: 需要根据实际情况实现 get_b_req_mtp_start_loc
b_req_mtp_start_loc = get_b_req_mtp_start_loc(model_input.b_mtp_index, req_num=len(decode_reqs))
# b_req_mtp_start_loc 是一个 gpu tensor, 类型为 int, 表示每个请求的 mtp_start_loc, shape 为 len(decode_reqs)
mtp_accept_len, accepted_index = self._verify_mtp_v2(
new_next_token_ids=next_token_ids,
b_req_idx=model_input.b_req_idx,
Expand All @@ -257,7 +276,7 @@ def decode_mtp(
gpu_tensor=accepted_index,
)

verify_event = torch.cuda.Event()
verify_event = torch.cuda.Event(enable_timing=True)
verify_event.record()

if self.enable_dynamic_mtp:
Expand All @@ -277,22 +296,29 @@ def decode_mtp(

# dynamic_sizes_gpu 用于第二阶段更新 req 的 mtp_size
if self.enable_dynamic_mtp:
draft_probs_tensor = torch.cat(draft_probs_list, dim=-1).view(self.mtp_step, b_mtp_index_cpu.shape[0])
draft_probs_tensor = torch.cat(draft_probs_list, dim=-1).view(
self.mtp_step, model_input.b_mtp_index.shape[0]
)
dynamic_sizes_gpu = self._compute_dynamic_mtp_size_gpu_part(draft_probs_tensor=draft_probs_tensor)
# 异步拷贝回 CPU Pin Memory
dynamic_sizes_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor(
key="dynamic_mtp_sizes", gpu_tensor=dynamic_sizes_gpu
)

dynamic_mtp_event = torch.cuda.Event()
dynamic_mtp_event.record()
draft_probs_list = [e.view(-1, 1) for e in draft_probs_list]
draft_probs_list = [torch.ones_like(draft_probs_list[-1])] + draft_probs_list
all_next_token_probs = torch.cat(draft_probs_list, dim=-1) # [batch_size, mtp_step + 1]
else:
all_next_token_probs = None

mtp_scatter_next_token_ids(
req_to_next_token_ids=self.model.req_manager.req_sampling_params_manager.req_to_next_token_ids,
b_req_mtp_start_loc=b_req_mtp_start_loc,
all_next_token_ids=all_next_token_ids,
b_req_idx=model_input.b_req_idx,
mtp_accept_len=mtp_accept_len,
req_to_next_token_probs=self.model.req_manager.req_sampling_params_manager.req_to_next_token_probs,
all_next_token_probs=all_next_token_probs,
)

next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem(
Expand All @@ -315,22 +341,42 @@ def decode_mtp(

# 第二阶段
event_pack.notify_post_handle_and_wait_pre_post_handle()
self._update_mtp_verify_token_num(decode_reqs=decode_reqs)

if self.enable_dynamic_mtp:
trans_dynamic_model_input_event.synchronize()
selected_run_reqs_cpu_numpy = selected_run_reqs_cpu.numpy()
run_reqs = [run_reqs[i] for i in range(len(run_reqs)) if selected_run_reqs_cpu_numpy[i] == 1]

if self.enable_dynamic_mtp:
self._update_mtp_verify_token_num(decode_reqs=decode_reqs, dynamic_mtp_run_reqs=run_reqs)
else:
self._update_mtp_verify_token_num(decode_reqs=decode_reqs)

verify_event.synchronize()
accepted_index_cpu_numpy = accepted_index_cpu.numpy()
verify_ok_reqs = [run_reqs[i] for i in range(len(run_reqs)) if accepted_index_cpu_numpy[i] == 1]
if self.enable_dynamic_mtp:
dynamic_mtp_event.synchronize()
self._update_dynamic_mtp_size_cpu_part(
run_reqs=run_reqs, dynamic_sizes_cpu=dynamic_sizes_cpu, accepted_index_cpu=accepted_index_cpu
)
update_packs = self._pre_post_handle(verify_ok_reqs, is_chuncked_mode=False)

# 第三阶段
event_pack.notify_forward_and_wait_post_handle()
sync_event.synchronize()

if self.enable_dynamic_mtp:
# 更新 动态verify token 数据到 planner 中去
self._update_dynamic_mtp_size_statics(
run_reqs=run_reqs,
dynamic_sizes_cpu=dynamic_sizes_cpu,
accepted_index_cpu=accepted_index_cpu,
)
# 更新单token的速度信息到 planner 中去
per_token_cost_ms = start_time_event.elapsed_time(verify_event) / (mtp_accept_len_cpu.sum().item())

self.dynamic_mtp_planner.update_req_num_speed_statics(
req_num=len(decode_reqs),
dynamic_batch_size=dynamic_batch_size,
per_token_cost_ms=per_token_cost_ms,
)

# 处理需要释放的内存索引
need_free_mem_indexes = model_input.mem_indexes_cpu[accepted_index_cpu == 0]
if additional_mem_indexes_cpu is not None:
Expand Down Expand Up @@ -365,17 +411,22 @@ def _compute_dynamic_mtp_size_gpu_part(
dynamic_mtp_sizes = valid_steps.sum(dim=0)
return dynamic_mtp_sizes

def _update_dynamic_mtp_size_cpu_part(
def _update_dynamic_mtp_size_statics(
self,
run_reqs: List[InferReq],
dynamic_sizes_cpu: torch.Tensor,
accepted_index_cpu: torch.Tensor,
):
id_to_verify_len = {}

assert len(run_reqs) == dynamic_sizes_cpu.shape[0] == accepted_index_cpu.shape[0]
for req, new_size, accepted in zip(run_reqs, dynamic_sizes_cpu.numpy(), accepted_index_cpu.numpy()):
if int(accepted) == 1:
req.current_mtp_step = int(new_size)
assert req.current_mtp_step <= req.mtp_step
assert int(new_size) <= req.mtp_step
id_to_verify_len[req.req_idx] = int(new_size) + 1

self.dynamic_mtp_planner.update_req_verify_len_statics(verify_lens=list(id_to_verify_len.values()))
return

def _draft_prefill_forward(self, model_input: ModelInput, model_output: ModelOutput, next_token_ids: torch.Tensor):
# spec prefill: MTP, 这个地方只是为了填充draft model的 kv, 并不会使用生成的token_id。
Expand Down
Loading