diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index b4cf2fb4462..a3eb59f4685 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -405,12 +405,15 @@ async def chat_completion_stream_generator( output_speculate_metrics = res["metrics"].get("speculate_metrics", None) delta_message = DeltaMessage( - reasoning_content="", + reasoning_content=output["reasoning_content"], prompt_token_ids=None, - tool_calls=None, + tool_calls=output["tool_calls"], completion_token_ids=None, ) + if output["tool_calls"] is not None: + tool_called[idx] = True + if response_processor.enable_multimodal_content(): delta_message.multimodal_content = output["multipart"] else: @@ -419,15 +422,8 @@ async def chat_completion_stream_generator( if output.get("audio_content", None) is not None: delta_message.audio_content = output["audio_content"] - if not res["finished"] and output["enable_parser"]: - delta_message_output = output["delta_message"] - if delta_message_output is None: - continue - delta_message.content = delta_message_output.content or "" - delta_message.reasoning_content = delta_message_output.reasoning_content or "" - if delta_message_output.tool_calls: - delta_message.tool_calls = delta_message_output.tool_calls - tool_called[idx] = True + if output["skipped"]: + continue choice = ChatCompletionResponseStreamChoice( index=idx, @@ -758,7 +754,7 @@ async def _create_chat_completion_choice( message = ChatMessage( role="assistant", reasoning_content=output.get("reasoning_content"), - tool_calls=output.get("tool_call"), + tool_calls=output.get("tool_calls"), prompt_token_ids=prompt_token_ids if request.return_token_ids else None, completion_token_ids=completion_token_ids if request.return_token_ids else None, prompt_tokens=prompt_tokens if request.return_token_ids else None, @@ -790,7 +786,7 @@ async def _create_chat_completion_choice( finish_reason = "stop" if previous_num_tokens != max_tokens: finish_reason = "stop" - if output.get("tool_call", None): + if output.get("tool_calls"): finish_reason = "tool_calls" else: finish_reason = "length" diff --git a/fastdeploy/entrypoints/openai/serving_completion.py b/fastdeploy/entrypoints/openai/serving_completion.py index 03c074db841..4c4ccad3fac 100644 --- a/fastdeploy/entrypoints/openai/serving_completion.py +++ b/fastdeploy/entrypoints/openai/serving_completion.py @@ -407,7 +407,7 @@ async def _process_echo_logic(self, request, idx, res_outputs): def calc_finish_reason(self, max_tokens, token_num, output, tool_called): if max_tokens is None or token_num != max_tokens: - if tool_called or output.get("tool_call"): + if tool_called or output.get("tool_calls"): return "tool_calls" else: return "stop" @@ -554,9 +554,9 @@ async def completion_stream_generator( text=output["text"], prompt_token_ids=None, completion_token_ids=output.get("token_ids") if request.return_token_ids else None, - tool_calls=None, + tool_calls=output["tool_calls"], completion_tokens=output.get("completion_tokens") if request.return_token_ids else None, - reasoning_content="", + reasoning_content=output["reasoning_content"], arrival_time=arrival_time, logprobs=logprobs_res, prompt_logprobs=( @@ -565,15 +565,12 @@ async def completion_stream_generator( draft_logprobs=draft_logprobs_res, speculate_metrics=output_speculate_metrics, ) - if not res["finished"] and output["enable_parser"]: - delta_message_output = output["delta_message"] - if delta_message_output is None: - continue - delta_message.text = delta_message_output.content or "" - delta_message.reasoning_content = delta_message_output.reasoning_content or "" - if delta_message_output.tool_calls: - delta_message.tool_calls = delta_message_output.tool_calls - tool_called[idx] = True + + if output["tool_calls"] is not None: + tool_called[idx] = True + + if output["skipped"]: + continue choices.append(delta_message) @@ -740,7 +737,7 @@ def request_output_to_completion_response( else None ), reasoning_content=output.get("reasoning_content"), - tool_calls=output.get("tool_call", None), + tool_calls=output.get("tool_calls"), logprobs=aggregated_logprobs, draft_logprobs=aggregated_draft_logprobs, prompt_logprobs=clamp_prompt_logprobs(prompt_logprobs_res), diff --git a/fastdeploy/entrypoints/openai/tool_parsers/__init__.py b/fastdeploy/entrypoints/openai/tool_parsers/__init__.py index a4df47ac99d..c9b8d250f74 100644 --- a/fastdeploy/entrypoints/openai/tool_parsers/__init__.py +++ b/fastdeploy/entrypoints/openai/tool_parsers/__init__.py @@ -14,8 +14,11 @@ # limitations under the License. """ +from fastdeploy.plugins import load_tool_parser_plugins + from .abstract_tool_parser import ToolParser, ToolParserManager from .ernie_45_vl_thinking_tool_parser import Ernie45VLThinkingToolParser from .ernie_x1_tool_parser import ErnieX1ToolParser __all__ = ["ToolParser", "ToolParserManager", "ErnieX1ToolParser", "Ernie45VLThinkingToolParser"] +load_tool_parser_plugins() diff --git a/fastdeploy/input/ernie4_5_processor.py b/fastdeploy/input/ernie4_5_processor.py index eb8bbde6ad5..8a5fa4ce883 100644 --- a/fastdeploy/input/ernie4_5_processor.py +++ b/fastdeploy/input/ernie4_5_processor.py @@ -341,7 +341,7 @@ def process_response_dict_normal(self, response_dict, **kwargs): tool_parser = self.tool_parser_obj(self.tokenizer) tool_call_info = tool_parser.extract_tool_calls(full_text, response_dict) if tool_call_info.tools_called: - response_dict["outputs"]["tool_call"] = tool_call_info.tool_calls + response_dict["outputs"]["tool_calls"] = tool_call_info.tool_calls response_dict["outputs"]["text"] = tool_call_info.content response_dict["outputs"]["completion_tokens"] = full_text data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}") @@ -369,7 +369,11 @@ def process_response_dict_streaming(self, response_dict, **kwargs): if token_ids[-1] == self.tokenizer.eos_token_id: token_ids = token_ids[:-1] delta_text, previous_token_ids, previous_texts = self.ids2tokens(token_ids, req_id) + response_dict["outputs"]["text"] = delta_text response_dict["outputs"]["completion_tokens"] = delta_text + response_dict["outputs"]["skipped"] = False + response_dict["outputs"]["tool_calls"] = None + response_dict["outputs"]["reasoning_content"] = "" if self.reasoning_parser: reasoning_delta_message = self.reasoning_parser.extract_reasoning_content_streaming( previous_texts, @@ -380,19 +384,15 @@ def process_response_dict_streaming(self, response_dict, **kwargs): token_ids, self.model_status_dict[req_id], ) - response_dict["outputs"]["enable_parser"] = True - response_dict["outputs"]["delta_message"] = reasoning_delta_message - reasoning_content = reasoning_delta_message.reasoning_content if reasoning_delta_message else None - reasoning_tokens = self.tokenizer.tokenize(reasoning_content) if reasoning_content else [] - response_dict["outputs"]["reasoning_token_num"] = len(reasoning_tokens) - response_dict["outputs"]["reasoning_content"] = reasoning_content - response_dict["outputs"]["text"] = ( - reasoning_delta_message.content or "" - if reasoning_delta_message and hasattr(reasoning_delta_message, "content") - else "" - ) - else: - response_dict["outputs"]["text"] = delta_text + if reasoning_delta_message: + reasoning_content = reasoning_delta_message.reasoning_content + reasoning_tokens = self.tokenizer.tokenize(reasoning_content) if reasoning_content else [] + response_dict["outputs"]["reasoning_token_num"] = len(reasoning_tokens) + response_dict["outputs"]["reasoning_content"] = reasoning_content or "" + response_dict["outputs"]["text"] = reasoning_delta_message.content or "" + else: + if not is_end: + response_dict["outputs"]["skipped"] = True if self.tool_parser_obj: response_dict["outputs"]["enable_parser"] = True if req_id not in self.tool_parser_dict: @@ -407,8 +407,13 @@ def process_response_dict_streaming(self, response_dict, **kwargs): token_ids, response_dict, ) - if tool_call_delta_message is None or tool_call_delta_message.tool_calls: - response_dict["outputs"]["delta_message"] = tool_call_delta_message + if tool_call_delta_message: + if tool_call_delta_message.tool_calls: + response_dict["outputs"]["text"] = tool_call_delta_message.content + response_dict["outputs"]["tool_calls"] = tool_call_delta_message.tool_calls + else: + if not is_end: + response_dict["outputs"]["skipped"] = True if is_end: data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}") diff --git a/fastdeploy/input/text_processor.py b/fastdeploy/input/text_processor.py index f9ef108da6a..c71e71e8b97 100644 --- a/fastdeploy/input/text_processor.py +++ b/fastdeploy/input/text_processor.py @@ -442,7 +442,7 @@ def process_response_dict_normal(self, response_dict, **kwargs): tool_parser = self.tool_parser_obj(self.tokenizer) tool_call_info = tool_parser.extract_tool_calls(full_text, response_dict) if tool_call_info.tools_called: - response_dict["outputs"]["tool_call"] = tool_call_info.tool_calls + response_dict["outputs"]["tool_calls"] = tool_call_info.tool_calls response_dict["outputs"]["text"] = tool_call_info.content data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}") del self.decode_status[req_id] @@ -469,7 +469,11 @@ def process_response_dict_streaming(self, response_dict, **kwargs): if token_ids[-1] in self.eos_token_ids: token_ids = token_ids[:-1] delta_text, previous_token_ids, previous_texts = self.ids2tokens(token_ids, req_id) + response_dict["outputs"]["text"] = delta_text response_dict["outputs"]["completion_tokens"] = delta_text + response_dict["outputs"]["skipped"] = False + response_dict["outputs"]["tool_calls"] = None + response_dict["outputs"]["reasoning_content"] = "" if self.reasoning_parser: response_dict["outputs"]["enable_parser"] = True reasoning_delta_message = self.reasoning_parser.extract_reasoning_content_streaming( @@ -481,16 +485,21 @@ def process_response_dict_streaming(self, response_dict, **kwargs): token_ids, self.model_status_dict[req_id], ) - response_dict["outputs"]["delta_message"] = reasoning_delta_message - reasoning_content = reasoning_delta_message.reasoning_content if reasoning_delta_message else None - reasoning_tokens = self.tokenizer.tokenize(reasoning_content) if reasoning_content else [] - response_dict["outputs"]["reasoning_token_num"] = len(reasoning_tokens) + if reasoning_delta_message: + reasoning_content = reasoning_delta_message.reasoning_content + reasoning_tokens = self.tokenizer.tokenize(reasoning_content) if reasoning_content else [] + response_dict["outputs"]["reasoning_token_num"] = len(reasoning_tokens) + response_dict["outputs"]["reasoning_content"] = reasoning_content or "" + response_dict["outputs"]["text"] = reasoning_delta_message.content or "" + else: + if not is_end: + response_dict["outputs"]["skipped"] = True if self.tool_parser_obj: response_dict["outputs"]["enable_parser"] = True if req_id not in self.tool_parser_dict: self.tool_parser_dict[req_id] = self.tool_parser_obj(self.tokenizer) tool_parser = self.tool_parser_dict[req_id] - tool_call = tool_parser.extract_tool_calls_streaming( + tool_call_delta_message = tool_parser.extract_tool_calls_streaming( previous_texts, previous_texts + delta_text, delta_text, @@ -499,9 +508,14 @@ def process_response_dict_streaming(self, response_dict, **kwargs): token_ids, response_dict, ) - if tool_call is None or tool_call.tool_calls: - response_dict["outputs"]["delta_message"] = tool_call - response_dict["outputs"]["text"] = delta_text + if tool_call_delta_message: + if tool_call_delta_message.tool_calls: + response_dict["outputs"]["text"] = tool_call_delta_message.content + response_dict["outputs"]["tool_calls"] = tool_call_delta_message.tool_calls + else: + if not is_end: + response_dict["outputs"]["skipped"] = True + if is_end: data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}") del self.decode_status[req_id] diff --git a/fastdeploy/plugins/__init__.py b/fastdeploy/plugins/__init__.py index 08c2922968a..96e30e5a56b 100644 --- a/fastdeploy/plugins/__init__.py +++ b/fastdeploy/plugins/__init__.py @@ -19,6 +19,7 @@ from .model_runner import load_model_runner_plugins from .reasoning_parser import load_reasoning_parser_plugins from .token_processor import load_token_processor_plugins +from .tool_parser import load_tool_parser_plugins __all__ = [ "load_model_register_plugins", @@ -26,4 +27,5 @@ "load_input_processor_plugins", "load_reasoning_parser_plugins", "load_token_processor_plugins", + "load_tool_parser_plugins", ] diff --git a/fastdeploy/plugins/tool_parser/__init__.py b/fastdeploy/plugins/tool_parser/__init__.py new file mode 100644 index 00000000000..19d8f82efe2 --- /dev/null +++ b/fastdeploy/plugins/tool_parser/__init__.py @@ -0,0 +1,34 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from fastdeploy.plugins.utils import load_plugins_by_group + +# make sure one process only loads plugins once +plugins_loaded = False +PLUGINS_GROUP = "fastdeploy.tool_parser_plugins" + + +def load_tool_parser_plugins(): + """load_tool_parser_plugins""" + global plugins_loaded + if plugins_loaded: + return + plugins_loaded = True + + plugins = load_plugins_by_group(group=PLUGINS_GROUP) + # general plugins, we only need to execute the loaded functions + for func in plugins.values(): + func() diff --git a/tests/entrypoints/openai/test_finish_reason.py b/tests/entrypoints/openai/test_finish_reason.py index f5b318b52ac..d7b1c151aae 100644 --- a/tests/entrypoints/openai/test_finish_reason.py +++ b/tests/entrypoints/openai/test_finish_reason.py @@ -1,7 +1,7 @@ """ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # -# Licensed under the Apache License, Version 2.0 (the "License" +# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # @@ -14,954 +14,648 @@ # limitations under the License. """ -import asyncio -import inspect -import itertools -import time -import traceback -import uuid -from collections.abc import Iterable -from typing import List, Optional +import json +from typing import Any, Dict, List +from unittest import IsolatedAsyncioTestCase +from unittest.mock import AsyncMock, MagicMock, Mock, patch import numpy as np -import fastdeploy.envs as envs -import fastdeploy.metrics.trace as tracing -from fastdeploy.engine.request import Request, RequestOutput from fastdeploy.entrypoints.openai.protocol import ( - CompletionLogprobs, + ChatCompletionRequest, CompletionRequest, CompletionResponse, - CompletionResponseChoice, - CompletionResponseStreamChoice, - CompletionStreamResponse, - CompletionTokenUsageInfo, - ErrorInfo, - ErrorResponse, - PromptTokenUsageInfo, UsageInfo, ) -from fastdeploy.trace.constants import LoggingEventName -from fastdeploy.trace.trace_logger import print as trace_print -from fastdeploy.utils import ( - ErrorCode, - ErrorType, - ParameterError, - api_server_logger, - clamp_prompt_logprobs, - get_host_ip, -) -from fastdeploy.worker.output import ( - Logprob, - LogprobsLists, - LogprobsTensors, - PromptLogprobs, -) +from fastdeploy.entrypoints.openai.serving_chat import OpenAIServingChat +from fastdeploy.entrypoints.openai.serving_completion import OpenAIServingCompletion +from fastdeploy.input.ernie4_5_vl_processor import Ernie4_5_VLProcessor +from fastdeploy.utils import data_processor_logger + + +class TestMultiModalProcessorMaxTokens(IsolatedAsyncioTestCase): + async def asyncSetUp(self): + with patch.object(Ernie4_5_VLProcessor, "__init__", return_value=None): + self.multi_modal_processor = Ernie4_5_VLProcessor("model_path") + self.multi_modal_processor.tokenizer = MagicMock() + self.multi_modal_processor.tokenizer.eos_token_id = 102 + self.multi_modal_processor.tokenizer.pad_token_id = 0 + self.multi_modal_processor.eos_token_ids = [102] + self.multi_modal_processor.eos_token_id_len = 1 + self.multi_modal_processor.generation_config = MagicMock() + self.multi_modal_processor.decode_status = {} + self.multi_modal_processor.tool_parser_dict = {} + self.multi_modal_processor.ernie4_5_processor = MagicMock() + self.multi_modal_processor.ernie4_5_processor.request2ids.return_value = { + "input_ids": np.array([101, 9012, 3456, 102]) + } + self.multi_modal_processor.ernie4_5_processor.text2ids.return_value = { + "input_ids": np.array([101, 1234, 5678, 102]) + } + self.multi_modal_processor._apply_default_parameters = lambda x: x + self.multi_modal_processor.update_stop_seq = Mock(return_value=([], [])) + self.multi_modal_processor.update_bad_words = Mock(return_value=[]) + self.multi_modal_processor._check_mm_limits = Mock() + self.multi_modal_processor.append_completion_tokens = Mock() + self.multi_modal_processor.pack_outputs = lambda x: x + self.multi_modal_processor.reasoning_parser = None + self.multi_modal_processor.model_status_dict = {} + + self.engine_client = Mock() + self.engine_client.connection_initialized = False + self.engine_client.connection_manager = AsyncMock() + self.engine_client.semaphore = Mock() + self.engine_client.semaphore.acquire = AsyncMock() + self.engine_client.semaphore.release = Mock() + self.engine_client.is_master = True + self.engine_client.check_model_weight_status = Mock(return_value=False) + self.engine_client.enable_mm = True + self.engine_client.enable_prefix_caching = False + self.engine_client.max_model_len = 20 + self.engine_client.data_processor = self.multi_modal_processor + + async def mock_add_data(current_req_dict): + if current_req_dict.get("max_tokens") is None: + current_req_dict["max_tokens"] = self.engine_client.max_model_len - 1 + current_req_dict["max_tokens"] = min( + self.engine_client.max_model_len - 4, max(0, current_req_dict.get("max_tokens")) + ) + + self.engine_client.add_requests = AsyncMock(side_effect=mock_add_data) + + self.chat_serving = OpenAIServingChat( + engine_client=self.engine_client, + models=None, + pid=123, + ips=None, + max_waiting_time=30, + chat_template="default", + enable_mm_output=True, + tokenizer_base_url=None, + ) + self.completion_serving = OpenAIServingCompletion( + engine_client=self.engine_client, models=None, pid=123, ips=None, max_waiting_time=30 + ) -NONES = itertools.repeat(None) + def _generate_inference_response( + self, request_id: str, output_token_num: int, tool_call: Any = None + ) -> List[Dict]: + outputs = { + "text": "这是一张风景图"[:output_token_num], + "token_ids": list(range(output_token_num)), + "reasoning_content": "推理过程", + "num_image_tokens": 0, + "num_cached_tokens": 0, + "top_logprobs": None, + "draft_top_logprobs": None, + "tool_call": None, + } + if tool_call: + outputs["tool_calls"] = [ + {"index": 0, "type": "function", "function": {"name": tool_call["name"], "arguments": json.dumps({})}} + ] -class OpenAIServingCompletion: - def __init__(self, engine_client, models, pid, ips, max_waiting_time): - self.engine_client = engine_client - self.models = models - self.pid = pid - self.max_waiting_time = max_waiting_time - if ips is not None: - if isinstance(ips, list): - self.master_ip = ips[0] + return [ + { + "request_id": request_id, + "outputs": outputs, + "metrics": {"request_start_time": 0.1}, + "finished": True, + "error_msg": None, + "output_token_ids": output_token_num, + } + ] + + def _generate_stream_inference_response( + self, request_id: str, total_token_num: int, tool_call: Any = None + ) -> List[Dict]: + stream_responses = [] + for i in range(total_token_num): + metrics = {} + if i == 0: + metrics["first_token_time"] = 0.1 + metrics["inference_start_time"] = 0.1 else: - self.master_ip = ips.split(",")[0] - self.is_master_ip = get_host_ip() == self.master_ip - else: - self.master_ip = "0.0.0.0" - self.is_master_ip = True - self._is_process_response_dict_async = None - api_server_logger.info(f"master ip: {self.master_ip}") - - def _check_master(self): - return self.engine_client.is_master or self.is_master_ip - - async def create_completion(self, request: CompletionRequest): - """ - Create a completion for the given prompt. - """ - tracing.trace_set_thread_info("API Server") - if not self._check_master(): - err_msg = ( - f"Only master node can accept completion request, please send request to master node: {self.master_ip}" - ) - api_server_logger.error(err_msg) - return ErrorResponse(error=ErrorInfo(message=err_msg, type=ErrorType.INTERNAL_ERROR)) - if self.models: - is_supported, request.model = self.models.is_supported_model(request.model) - if not is_supported: - err_msg = f"Unsupported model: [{request.model}], support [{', '.join([x.name for x in self.models.model_paths])}] or default" - api_server_logger.error(err_msg) - return ErrorResponse( - error=ErrorInfo(message=err_msg, type=ErrorType.INTERNAL_ERROR, code=ErrorCode.MODEL_NOT_SUPPORT) + metrics["engine_recv_latest_token_time"] = 0.1 * (i + 1) + metrics["first_token_time"] = None + + if i == total_token_num - 1: + metrics["request_start_time"] = 0.1 + + outputs = { + "text": chr(ord("a") + i), + "token_ids": [i + 1], + "top_logprobs": None, + "draft_top_logprobs": None, + "reasoning_token_num": 0, + "skipped": False, + "reasoning_content": "", + "tool_calls": None, + } + + if tool_call and isinstance(tool_call, dict) and i == total_token_num - 2: + outputs["tool_calls"] = [ + { + "index": 0, + "type": "function", + "function": {"name": tool_call["name"], "arguments": json.dumps({})}, + } + ] + + frame = [ + { + "request_id": f"{request_id}_0", + "error_code": 200, + "outputs": outputs, + "metrics": metrics, + "finished": (i == total_token_num - 1), + "error_msg": None, + } + ] + stream_responses.append(frame) + return stream_responses + + @patch.object(data_processor_logger, "info") + @patch("fastdeploy.entrypoints.openai.serving_chat.ChatResponseProcessor") + @patch("fastdeploy.entrypoints.openai.serving_chat.api_server_logger") + async def test_chat_full_max_tokens(self, mock_data_logger, mock_processor_class, mock_api_logger): + test_cases = [ + { + "name": "用户传max_tokens=5,生成数=5→length", + "request": ChatCompletionRequest( + model="ernie4.5-vl", + messages=[{"role": "user", "content": "描述这张图片"}], + stream=False, + max_tokens=5, + return_token_ids=True, + ), + "output_token_num": 5, + "tool_call": [], + "expected_finish_reason": "length", + }, + { + "name": "用户未传max_tokens,生成数=10→stop", + "request": ChatCompletionRequest( + model="ernie4.5-vl", + messages=[{"role": "user", "content": "描述这张图片"}], + stream=False, + return_token_ids=True, + ), + "output_token_num": 10, + "tool_call": [], + "expected_finish_reason": "stop", + }, + { + "name": "用户未传max_tokens,生成数=16→length", + "request": ChatCompletionRequest( + model="ernie4.5-vl", + messages=[{"role": "user", "content": "描述这张图片"}], + stream=False, + return_token_ids=True, + ), + "output_token_num": 16, + "tool_call": [], + "expected_finish_reason": "length", + }, + { + "name": "用户传max_tokens,生成数=10→stop", + "request": ChatCompletionRequest( + model="ernie4.5-vl", + messages=[{"role": "user", "content": "描述这张图片"}], + stream=False, + max_tokens=50, + return_token_ids=True, + ), + "output_token_num": 10, + "tool_call": [], + "expected_finish_reason": "stop", + }, + { + "name": "生成数 0, "prompt_token_ids should not be an empty list" - if isinstance(request.prompt_token_ids[0], list): - request_prompt_ids = request.prompt_token_ids - elif isinstance(request.prompt_token_ids[0], int): - request_prompt_ids = [request.prompt_token_ids] - else: - raise ValueError( - "If prompt_token_ids is provided, its type should be one of: list[int], list[list[int]]" - ) - # reset `prompt_token_ids` to avoid data processor directly using it; let data processor fill it - request.prompt_token_ids = None - else: - if isinstance(request.prompt, str): - request_prompts = [request.prompt] - elif isinstance(request.prompt, list) and all(isinstance(item, int) for item in request.prompt): - request_prompt_ids = [request.prompt] - elif isinstance(request.prompt, list) and all(isinstance(item, str) for item in request.prompt): - request_prompts = request.prompt - elif isinstance(request.prompt, list): - for item in request.prompt: - if isinstance(item, list) and all(isinstance(x, int) for x in item): - continue - else: - raise ValueError("If prompt is a list, each item type must be one of: str, list[int]") - request_prompt_ids = request.prompt - else: - raise ValueError("Prompt type must be one of: str, list[str], list[int], list[list[int]]") - except Exception as e: - error_msg = f"OpenAIServingCompletion create_completion: {e}, {str(traceback.format_exc())}" - api_server_logger.error(error_msg) - return ErrorResponse(error=ErrorInfo(message=error_msg, type=ErrorType.INTERNAL_ERROR)) - - if request_prompt_ids is not None: - request_prompts = request_prompt_ids - - num_choices = len(request_prompts) * (1 if request.n is None else request.n) - api_server_logger.info(f"Start preprocessing request: req_id={request_id}), num_choices={num_choices}") - prompt_batched_token_ids = [] - prompt_tokens_list = [] - max_tokens_list = [] - try: - if self.max_waiting_time < 0: - await self.engine_client.semaphore.acquire() - else: - await asyncio.wait_for(self.engine_client.semaphore.acquire(), timeout=self.max_waiting_time) - except Exception as e: - error_msg = ( - f"OpenAIServingCompletion waiting error: {e}, {str(traceback.format_exc())}, " - f"max waiting time: {self.max_waiting_time}" - ) - api_server_logger.error(error_msg) - return ErrorResponse( - error=ErrorInfo(message=error_msg, code=ErrorCode.TIMEOUT, type=ErrorType.TIMEOUT_ERROR) - ) - try: - try: - for idx, prompt in enumerate(request_prompts): - request_id_idx = f"{request_id}_{idx}" - if not envs.ENABLE_V1_DATA_PROCESSOR: - current_req_dict = request.to_dict_for_infer(request_id_idx, prompt) - else: - current_req_dict = Request.from_generic_request(request, request_id=f"{request_id}_0") - current_req_dict["metrics"]["arrival_time"] = time.time() - prompt_token_ids = await self.engine_client.format_and_add_data(current_req_dict) # tokenize - if isinstance(prompt_token_ids, np.ndarray): - prompt_token_ids = prompt_token_ids.tolist() - prompt_tokens_list.append(current_req_dict.get("prompt_tokens")) - prompt_batched_token_ids.append(prompt_token_ids) - max_tokens_list.append(current_req_dict.get("max_tokens")) - del current_req_dict - except ParameterError as e: - api_server_logger.error(f"OpenAIServingCompletion format error: {e}, {e.message}") - self.engine_client.semaphore.release() - return ErrorResponse( - error=ErrorInfo(code="400", message=str(e.message), type="invalid_request", param=e.param) + result = await self.chat_serving.chat_completion_full_generator( + request=case["request"], + request_id="test_chat", + model_name="ernie4.5-vl", + prompt_token_ids=processed_req["prompt_token_ids"], + prompt_tokens="描述这张图片", + max_tokens=processed_req["max_tokens"], ) - except Exception as e: - error_msg = f"OpenAIServingCompletion format error: {e}, {str(traceback.format_exc())}" - api_server_logger.error(error_msg) - self.engine_client.semaphore.release() - return ErrorResponse( - error=ErrorInfo(message=str(e), code=ErrorCode.INVALID_VALUE, type=ErrorType.INVALID_REQUEST_ERROR) + self.assertEqual( + result.choices[0].finish_reason, case["expected_finish_reason"], f"场景 {case['name']} 失败" ) - if request.stream: - return self.completion_stream_generator( - request=request, - num_choices=num_choices, - request_id=request_id, - created_time=created_time, - model_name=request.model, - prompt_batched_token_ids=prompt_batched_token_ids, - prompt_tokens_list=prompt_tokens_list, - max_tokens_list=max_tokens_list, + @patch.object(data_processor_logger, "info") + @patch("fastdeploy.entrypoints.openai.serving_completion.api_server_logger") + async def test_completion_full_max_tokens(self, mock_api_logger, mock_data_logger): + test_cases = [ + { + "name": "用户传max_tokens=6,生成数=6→length", + "request": CompletionRequest( + request_id="test_completion", + model="ernie4.5-vl", + prompt="描述这张图片:xxx", + stream=False, + max_tokens=6, + return_token_ids=True, + ), + "output_token_num": 6, + "expected_finish_reason": "length", + }, + { + "name": "用户未传max_tokens,生成数=12→stop", + "request": CompletionRequest( + request_id="test_completion", + model="ernie4.5-vl", + prompt="描述这张图片:xxx", + stream=False, + return_token_ids=True, + ), + "output_token_num": 12, + "expected_finish_reason": "stop", + }, + { + "name": "用户传max_tokens=20(修正为16),生成数=16→length", + "request": CompletionRequest( + request_id="test_completion", + model="ernie4.5-vl", + prompt="描述这张图片:xxx", + stream=False, + max_tokens=20, + return_token_ids=True, + ), + "output_token_num": 16, + "expected_finish_reason": "length", + }, + ] + + mock_dealer = Mock() + self.engine_client.connection_manager.get_connection = AsyncMock(return_value=(mock_dealer, AsyncMock())) + + for case in test_cases: + with self.subTest(case=case["name"]): + request_dict = { + "prompt": case["request"].prompt, + "request_id": "test_completion", + "multimodal_data": {"image": ["xxx"]}, + "max_tokens": case["request"].max_tokens, + } + await self.engine_client.add_requests(request_dict) + processed_req = self.multi_modal_processor.process_request_dict( + request_dict, self.engine_client.max_model_len + ) + self.engine_client.data_processor.process_response_dict = ( + lambda data, stream, include_stop_str_in_output: data + ) + mock_response_queue = AsyncMock() + mock_response_queue.get.side_effect = lambda: [ + { + "request_id": "test_completion_0", + "error_code": 200, + "outputs": { + "text": "这是一张风景图"[: case["output_token_num"]], + "token_ids": list(range(case["output_token_num"])), + "top_logprobs": None, + "draft_top_logprobs": None, + }, + "metrics": {"request_start_time": 0.1}, + "finished": True, + "error_msg": None, + "output_token_ids": case["output_token_num"], + } + ] + self.engine_client.connection_manager.get_connection.return_value = (mock_dealer, mock_response_queue) + + result = await self.completion_serving.completion_full_generator( + request=case["request"], + num_choices=1, + request_id="test_completion", + created_time=1699999999, + model_name="ernie4.5-vl", + prompt_batched_token_ids=[processed_req["prompt_token_ids"]], + prompt_tokens_list=[case["request"].prompt], + max_tokens_list=[processed_req["max_tokens"]], ) - else: - try: - return await self.completion_full_generator( - request=request, - num_choices=num_choices, - request_id=request_id, - created_time=created_time, - model_name=request.model, - prompt_batched_token_ids=prompt_batched_token_ids, - prompt_tokens_list=prompt_tokens_list, - max_tokens_list=max_tokens_list, - ) - except Exception as e: - error_msg = ( - f"OpenAIServingCompletion completion_full_generator error: {e}, {str(traceback.format_exc())}" - ) - api_server_logger.error(error_msg) - return ErrorResponse(error=ErrorInfo(message=error_msg, type=ErrorType.INTERNAL_ERROR)) - - except Exception as e: - error_msg = f"OpenAIServingCompletion create_completion error: {e}, {str(traceback.format_exc())}" - api_server_logger.error(error_msg) - return ErrorResponse(error=ErrorInfo(message=error_msg, type=ErrorType.INTERNAL_ERROR)) - - async def completion_full_generator( - self, - request: CompletionRequest, - num_choices: int, - request_id: str, - created_time: int, - model_name: str, - prompt_batched_token_ids: list(), - prompt_tokens_list: list(), - max_tokens_list: list(), - ): - """ - Process the full completion request with multiple choices. - """ - dealer = None - try: - request_ids = [f"{request_id}_{i}" for i in range(num_choices)] - # create dealer - dealer, response_queue = await self.engine_client.connection_manager.get_connection( - request_id, num_choices - ) - for rid in request_ids: - dealer.write([b"", rid.encode("utf-8")]) - - valid_results = [dict()] * num_choices - output_tokens = [0] * num_choices - aggregated_top_logprobs = [[[], [], []] for _ in range(num_choices)] - aggregated_draft_top_logprobs = [[[], [], []] for _ in range(num_choices)] - aggregated_token_ids = [[] for _ in range(num_choices)] - aggregated_prompt_logprobs_tensors = [None] * num_choices - completion_batched_token_ids = [[] for _ in range(num_choices)] - aggregated_speculate_metrics = [None] * num_choices - current_waiting_time = 0 - while num_choices > 0: - if self.engine_client.check_model_weight_status(): - return ErrorResponse( - error=ErrorInfo( - message="Model weight cleared", - code=ErrorCode.INVALID_VALUE, - type=ErrorType.INVALID_REQUEST_ERROR, - ) - ) - try: - response = await asyncio.wait_for(response_queue.get(), timeout=10) - current_waiting_time = 0 - except asyncio.TimeoutError: - current_waiting_time += 10 - if current_waiting_time == 300: - status, msg = self.engine_client.check_health( - time_interval_threashold=envs.FD_WORKER_ALIVE_TIMEOUT - ) - if not status: - raise ValueError(f"Engine is not healthy: {msg}") - else: - current_waiting_time = 0 - await asyncio.sleep(0.1) - continue - - for data in response: - rid = int(data["request_id"].split("_")[-1]) - if data.get("error_code", 200) != 200: - raise ValueError("{}".format(data["error_msg"])) - - output = data["outputs"] - output_top_logprobs = output.get("top_logprobs") or None - output_draft_top_logprobs = output.get("draft_top_logprobs") or None - if output_top_logprobs is not None: - aggregated_top_logprobs[rid][0].extend(output_top_logprobs[0]) - aggregated_top_logprobs[rid][1].extend(output_top_logprobs[1]) - aggregated_top_logprobs[rid][2].extend(output_top_logprobs[2]) - - # draft logprobs - if request.include_draft_logprobs and output_draft_top_logprobs is not None: - aggregated_draft_top_logprobs[rid][0].extend(output_draft_top_logprobs[0]) - aggregated_draft_top_logprobs[rid][1].extend(output_draft_top_logprobs[1]) - aggregated_draft_top_logprobs[rid][2].extend(output_draft_top_logprobs[2]) - - output_prompt_logprobs_tensors = data.get("prompt_logprobs") or None - if output_prompt_logprobs_tensors is not None: - aggregated_prompt_logprobs_tensors[rid] = output_prompt_logprobs_tensors - - aggregated_token_ids[rid].extend(data["outputs"]["token_ids"]) - await self._call_process_response_dict(data, request, stream=False) - output_tokens[rid] += len(data["outputs"]["token_ids"]) - completion_batched_token_ids[rid].extend(data["outputs"]["token_ids"]) - - output_speculate_metrics = data["metrics"].get("speculate_metrics", None) - if output_speculate_metrics is not None: - aggregated_speculate_metrics[rid] = output_speculate_metrics - - if data.get("finished", False): - trace_carrier = data.get("trace_carrier") - if trace_carrier: - tracing.trace_set_proc_propagate_context(request_id, trace_carrier) - start_time = data["metrics"]["engine_recv_latest_token_time"] - tracing.trace_report_span( - tracing.TraceSpanName.POSTPROCESSING, - request_id, - int(start_time * 1e9), - int(time.time() * 1e9), - thread_finish_flag=True, - ) - if "trace_carrier" in data: - del data["trace_carrier"] - data["output_token_ids"] = output_tokens[rid] - data["outputs"]["top_logprobs"] = aggregated_top_logprobs[rid] - data["outputs"]["draft_top_logprobs"] = aggregated_draft_top_logprobs[rid] - data["outputs"]["token_ids"] = aggregated_token_ids[rid] - data["prompt_logprobs_tensors"] = aggregated_prompt_logprobs_tensors[rid] - data["speculate_metrics"] = aggregated_speculate_metrics[rid] - valid_results[rid] = data - num_choices -= 1 - break - res = self.request_output_to_completion_response( - final_res_batch=valid_results, - request=request, - request_id=request_id, - created_time=created_time, - model_name=model_name, - prompt_batched_token_ids=prompt_batched_token_ids, - completion_batched_token_ids=completion_batched_token_ids, - prompt_tokens_list=prompt_tokens_list, - max_tokens_list=max_tokens_list, - ) - api_server_logger.info(f"Completion response: {res.model_dump_json()}") - return res - except Exception as e: - api_server_logger.error(f"Error in completion_full_generator: {e}", exc_info=True) - finally: - tracing.trace_req_finish(request_id) - trace_print(LoggingEventName.POSTPROCESSING_END, request_id, getattr(request, "user", "")) - self.engine_client.semaphore.release() - if dealer is not None: - await self.engine_client.connection_manager.cleanup_request(request_id) - - def _echo_back_prompt(self, request, idx): - """ - The echo pre-process of the smallest unit - """ - if isinstance(request.prompt, str): - prompt_text = request.prompt - elif isinstance(request.prompt, list): - if all(isinstance(item, str) for item in request.prompt): - prompt_text = request.prompt[idx] - elif all(isinstance(item, int) for item in request.prompt): - prompt_text = self.engine_client.data_processor.tokenizer.decode(request.prompt) - else: - prompt_text = self.engine_client.data_processor.tokenizer.decode(request.prompt[idx]) - return prompt_text - - async def _process_echo_logic(self, request, idx, res_outputs): - """ - Process the echo logic and return the modified text. - """ - if request.echo and res_outputs.get("send_idx", -1) == 0: - prompt_text = self._echo_back_prompt(request, idx // (1 if request.n is None else request.n)) - res_outputs["text"] = prompt_text + (res_outputs["text"] or "") - return res_outputs - - def calc_finish_reason(self, max_tokens, token_num, output, tool_called): - if max_tokens is None or token_num != max_tokens: - if tool_called or output.get("tool_call"): - return "tool_calls" - else: - return "stop" - else: - return "length" - - async def completion_stream_generator( - self, - request: CompletionRequest, - num_choices: int, - request_id: str, - created_time: int, - model_name: str, - prompt_batched_token_ids: list(), - prompt_tokens_list: list(), - max_tokens_list: list(), - ): - """ - Process the stream completion request. - """ - try: - dealer, response_queue = await self.engine_client.connection_manager.get_connection( - request_id, num_choices - ) + self.assertIsInstance(result, CompletionResponse) + self.assertEqual(result.choices[0].finish_reason, case["expected_finish_reason"]) + + @patch.object(data_processor_logger, "info") + @patch("fastdeploy.entrypoints.openai.serving_chat.ChatResponseProcessor") + @patch("fastdeploy.entrypoints.openai.serving_chat.api_server_logger") + async def test_chat_stream_max_tokens(self, mock_api_logger, mock_processor_class, mock_data_logger): + test_cases = [ + { + "name": "流式-生成数=8(等于max_tokens)→length", + "request": ChatCompletionRequest( + model="ernie4.5-vl", + messages=[{"role": "user", "content": "描述这张图片"}], + stream=True, + max_tokens=8, + return_token_ids=True, + ), + "total_token_num": 8, + "tool_call": None, + "expected_finish_reason": "length", + }, + { + "name": "流式-生成数=6(小于max_tokens)+tool_call→tool_calls", + "request": ChatCompletionRequest( + model="ernie4.5-vl", + messages=[{"role": "user", "content": "描述这张图片"}], + stream=True, + max_tokens=10, + return_token_ids=True, + ), + "total_token_num": 3, + "tool_call": {"name": "test_tool"}, + "expected_finish_reason": "tool_calls", + }, + { + "name": "流式-生成数=7(小于max_tokens)无tool_call→stop", + "request": ChatCompletionRequest( + model="ernie4.5-vl", + messages=[{"role": "user", "content": "描述这张图片"}], + stream=True, + max_tokens=10, + return_token_ids=True, + ), + "total_token_num": 7, + "tool_call": None, + "expected_finish_reason": "stop", + }, + ] - for i in range(num_choices): - req_id = f"{request_id}_{i}" - dealer.write([b"", req_id.encode("utf-8")]) # 发送多路请求 - output_tokens = [0] * num_choices - num_cache_tokens = [0] * num_choices - num_image_tokens = [0] * num_choices - inference_start_time = [0] * num_choices - reasoning_tokens = [0] * num_choices - first_iteration = [True] * num_choices - tool_called = [False] * num_choices - max_streaming_response_tokens = ( - request.max_streaming_response_tokens - if request.max_streaming_response_tokens is not None - else (request.suffix or {}).get("max_streaming_response_tokens", 1) - ) # dierctly passed & passed in suffix - max_streaming_response_tokens = max(1, max_streaming_response_tokens) - choices = [] - chunk = CompletionStreamResponse( - id=request_id, - created=created_time, - model=model_name, - choices=choices, - ) - current_waiting_time = 0 - while num_choices > 0: - if self.engine_client.check_model_weight_status(): - raise ValueError("Engine is clearing model weight") - try: - response = await asyncio.wait_for(response_queue.get(), timeout=10) - current_waiting_time = 0 - except asyncio.TimeoutError: - current_waiting_time += 10 - if current_waiting_time == 300: - status, msg = self.engine_client.check_health( - time_interval_threashold=envs.FD_WORKER_ALIVE_TIMEOUT - ) - if not status: - raise ValueError(f"Engine is not healthy: {msg}") - else: - current_waiting_time = 0 - await asyncio.sleep(0.1) - continue + mock_dealer = Mock() + self.engine_client.connection_manager.get_connection = AsyncMock(return_value=(mock_dealer, AsyncMock())) + mock_processor_instance = Mock() + mock_processor_instance.enable_multimodal_content.return_value = False + + async def mock_process_response_chat_async(response, stream, include_stop_str_in_output): + if isinstance(response, list): for res in response: - idx = int(res["request_id"].split("_")[-1]) - if res.get("error_code", 200) != 200: - raise ValueError("{}".format(res["error_msg"])) - prompt_logprobs_res: Optional[PromptLogprobs] = None - if first_iteration[idx]: - prompt_logprobs_tensors = res.get("prompt_logprobs", None) - if request.prompt_logprobs is not None and prompt_logprobs_tensors is not None: - num_prompt_logprobs = ( - request.prompt_logprobs - if request.prompt_logprobs != -1 - else self.engine_client.ori_vocab_size - ) - prompt_logprobs_res = self._build_prompt_logprobs( - prompt_logprobs_tensors, num_prompt_logprobs, request.include_logprobs_decode_token - ) - if request.return_token_ids: - chunk = CompletionStreamResponse( - id=request_id, - created=created_time, - model=model_name, - choices=[ - CompletionResponseStreamChoice( - index=idx, - text="", - prompt_token_ids=list( - prompt_batched_token_ids[idx // (1 if request.n is None else request.n)] - ), - prompt_logprobs=clamp_prompt_logprobs(prompt_logprobs_res), - prompt_tokens=prompt_tokens_list[ - idx // (1 if request.n is None else request.n) - ], - completion_token_ids=None, - ) - ], - ) - yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n" - api_server_logger.info( - f"Completion Streaming response send_idx 0: {chunk.model_dump_json()}" - ) - first_iteration[idx] = False - - await self._call_process_response_dict(res, request, stream=True) - if inference_start_time[idx] == 0: - arrival_time = res["metrics"]["first_token_time"] - inference_start_time[idx] = res["metrics"]["inference_start_time"] - else: - arrival_time = res["metrics"]["engine_recv_latest_token_time"] - inference_start_time[idx] - - await self._process_echo_logic(request, idx, res["outputs"]) - output = res["outputs"] - output_top_logprobs = output["top_logprobs"] - output_draft_top_logprobs = output["draft_top_logprobs"] - logprobs_res: Optional[CompletionLogprobs] = None - draft_logprobs_res: Optional[CompletionLogprobs] = None - if request.logprobs is not None and output_top_logprobs is not None: - num_logprobs = ( - request.logprobs if request.logprobs != -1 else self.engine_client.ori_vocab_size - ) - logprobs_res = self._create_completion_logprobs(output_top_logprobs, num_logprobs, 0) - - # draft logprobs - if request.include_draft_logprobs and output_draft_top_logprobs is not None: - draft_logprobs_res = self._create_completion_logprobs( - output_draft_top_logprobs, num_logprobs, 0 - ) - output_tokens[idx] += len(output.get("token_ids", [])) or 0 - num_cache_tokens[idx] += output.get("num_cache_tokens") or 0 - if output.get("num_image_tokens"): - output_tokens[idx] += output.get("num_image_tokens") - num_image_tokens[idx] += output.get("num_image_tokens") - reasoning_tokens[idx] += output.get("reasoning_token_num", 0) - output_speculate_metrics = res["metrics"].get("speculate_metrics", None) - delta_message = CompletionResponseStreamChoice( - index=idx, - text=output["text"], - prompt_token_ids=None, - completion_token_ids=output.get("token_ids") if request.return_token_ids else None, - tool_calls=None, - completion_tokens=output.get("completion_tokens") if request.return_token_ids else None, - reasoning_content="", - arrival_time=arrival_time, - logprobs=logprobs_res, - prompt_logprobs=( - clamp_prompt_logprobs(prompt_logprobs_res) if not request.return_token_ids else None - ), - draft_logprobs=draft_logprobs_res, - speculate_metrics=output_speculate_metrics, - ) - if not res["finished"] and output["enable_parser"]: - delta_message_output = output["delta_message"] - if delta_message_output is None: + yield res + else: + yield response + + mock_processor_instance.process_response_chat = mock_process_response_chat_async + mock_processor_class.return_value = mock_processor_instance + + for case in test_cases: + with self.subTest(case=case["name"]): + request_dict = { + "messages": case["request"].messages, + "chat_template": "default", + "request_id": "test_chat_stream_0", + "max_tokens": case["request"].max_tokens, + } + await self.engine_client.add_requests(request_dict) + processed_req = self.multi_modal_processor.process_request_dict( + request_dict, self.engine_client.max_model_len + ) + + self.engine_client.data_processor.process_response_dict = ( + lambda data, stream, include_stop_str_in_output: data + ) + + mock_response_queue = AsyncMock() + stream_responses = self._generate_stream_inference_response( + request_id="test_chat_stream_0_0", + total_token_num=case["total_token_num"], + tool_call=case["tool_call"], + ) + mock_response_queue.get.side_effect = stream_responses + self.engine_client.connection_manager.get_connection.return_value = (mock_dealer, mock_response_queue) + + generator = self.chat_serving.chat_completion_stream_generator( + request=case["request"], + request_id="test_chat_stream_0", + model_name="ernie4.5-vl", + prompt_token_ids=processed_req["prompt_token_ids"], + prompt_tokens="描述这张图片", + max_tokens=processed_req["max_tokens"], + ) + + final_finish_reason = None + chunks = [] + async for chunk in generator: + chunks.append(chunk) + if "[DONE]" in chunk: + break + + for chunk_str in chunks: + if chunk_str.startswith("data: ") and "[DONE]" not in chunk_str: + try: + json_part = chunk_str.strip().lstrip("data: ").rstrip("\n\n") + chunk_dict = json.loads(json_part) + if chunk_dict.get("choices") and len(chunk_dict["choices"]) > 0: + finish_reason = chunk_dict["choices"][0].get("finish_reason") + if finish_reason: + final_finish_reason = finish_reason + break + except (json.JSONDecodeError, KeyError, IndexError): continue - delta_message.text = delta_message_output.content or "" - delta_message.reasoning_content = delta_message_output.reasoning_content or "" - if delta_message_output.tool_calls: - delta_message.tool_calls = delta_message_output.tool_calls - tool_called[idx] = True - - choices.append(delta_message) - - if res["finished"]: - choices[-1].finish_reason = self.calc_finish_reason( - max_tokens_list[idx // (1 if request.n is None else request.n)], - output_tokens[idx], - output, - tool_called[idx], - ) - inference_start_time[idx] = 0 - - send_idx = output.get("send_idx") - # 只有当 send_idx 明确为 0 时才记录日志 - if send_idx == 0 and not request.return_token_ids: - chunk_temp = chunk - chunk_temp.choices = choices - api_server_logger.info( - f"Completion Streaming response send_idx 0: {chunk_temp.model_dump_json()}" - ) - del chunk_temp - - if len(choices) == max_streaming_response_tokens or res["finished"]: - chunk = CompletionStreamResponse( - id=request_id, - created=created_time, - model=model_name, - choices=choices, - metrics=res["metrics"] if request.collect_metrics else None, - ) - yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n" - choices = [] - - if res["finished"]: - trace_carrier = res.get("trace_carrier") - if trace_carrier: - tracing.trace_set_proc_propagate_context(request_id, trace_carrier) - start_time = res["metrics"]["engine_recv_latest_token_time"] - tracing.trace_report_span( - tracing.TraceSpanName.POSTPROCESSING, - request_id, - int(start_time * 1e9), - int(time.time() * 1e9), - thread_finish_flag=True, - ) - if "trace_carrier" in res: - del res["trace_carrier"] - num_choices -= 1 - if getattr(request, "stream_options", None) and request.stream_options.include_usage: - usage_chunk = CompletionStreamResponse( - id=request_id, - created=created_time, - model=model_name, - choices=[], - usage=UsageInfo( - prompt_tokens=len( - prompt_batched_token_ids[idx // (1 if request.n is None else request.n)] - ), - completion_tokens=output_tokens[idx], - total_tokens=len( - prompt_batched_token_ids[idx // (1 if request.n is None else request.n)] - ) - + output_tokens[idx], - prompt_tokens_details=PromptTokenUsageInfo(cached_tokens=num_cache_tokens[idx]), - completion_tokens_details=CompletionTokenUsageInfo( - image_tokens=num_image_tokens[idx], reasoning_tokens=reasoning_tokens[idx] - ), - ), - metrics=res["metrics"] if request.collect_metrics else None, - ) - yield f"data: {usage_chunk.model_dump_json(exclude_unset=True)}\n\n" - api_server_logger.info(f"Completion Streaming response last send: {chunk.model_dump_json()}") - - except Exception as e: - api_server_logger.error(f"Error in completion_stream_generator: {e}, {str(traceback.format_exc())}") - yield f"data: {ErrorResponse(error=ErrorInfo(message=str(e), code='400', type=ErrorType.INTERNAL_ERROR)).model_dump_json(exclude_unset=True)}\n\n" - finally: - - tracing.trace_req_finish(request_id) - trace_print(LoggingEventName.POSTPROCESSING_END, request_id, getattr(request, "user", "")) - del request - if dealer is not None: - await self.engine_client.connection_manager.cleanup_request(request_id) - self.engine_client.semaphore.release() - yield "data: [DONE]\n\n" - - def request_output_to_completion_response( - self, - final_res_batch: List[RequestOutput], - request: CompletionRequest, - request_id: str, - created_time: int, - model_name: str, - prompt_batched_token_ids: list(), - completion_batched_token_ids: list(), - prompt_tokens_list: list(), - max_tokens_list: list(), - ) -> CompletionResponse: - choices: List[CompletionResponseChoice] = [] - num_prompt_tokens = 0 - num_generated_tokens = 0 - num_cache_tokens = 0 - num_image_tokens = 0 - num_reasoning_tokens = 0 - - for idx in range(len(final_res_batch)): - final_res = final_res_batch[idx] - prompt_token_ids = prompt_batched_token_ids[idx // (1 if request.n is None else request.n)] - assert prompt_token_ids is not None - prompt_text = request.prompt - completion_token_ids = completion_batched_token_ids[idx] - - output = final_res["outputs"] - output_top_logprobs = output.get("top_logprobs") or None - output_draft_top_logprobs = output.get("draft_top_logprobs") or None - - aggregated_logprobs: Optional[CompletionLogprobs] = None - num_logprobs = request.logprobs if request.logprobs != -1 else self.engine_client.ori_vocab_size - if output_top_logprobs is not None: - aggregated_logprobs = self._create_completion_logprobs(output_top_logprobs, num_logprobs, 0) - - aggregated_draft_logprobs: Optional[CompletionLogprobs] = None - if output_draft_top_logprobs is not None: - aggregated_draft_logprobs = self._create_completion_logprobs( - output_draft_top_logprobs, num_logprobs, 0 + + self.assertEqual(final_finish_reason, case["expected_finish_reason"]) + + @patch.object(data_processor_logger, "info") + @patch("fastdeploy.entrypoints.openai.serving_completion.api_server_logger") + async def test_completion_stream_max_tokens(self, mock_api_logger, mock_data_logger): + test_cases = [ + { + "name": "流式-生成数=7(等于max_tokens)→length", + "request": CompletionRequest( + model="ernie4.5-vl", + prompt=["描述这张图片:xxx"], + stream=True, + max_tokens=7, + return_token_ids=True, + ), + "total_token_num": 7, + "expected_finish_reason": "length", + }, + { + "name": "流式-生成数=9(小于max_tokens)→stop", + "request": CompletionRequest( + model="ernie4.5-vl", + prompt=["描述这张图片:xxx"], + stream=True, + max_tokens=15, + return_token_ids=True, + ), + "total_token_num": 9, + "expected_finish_reason": "stop", + }, + ] + + mock_dealer = Mock() + self.engine_client.connection_manager.get_connection = AsyncMock(return_value=(mock_dealer, AsyncMock())) + + for case in test_cases: + with self.subTest(case=case["name"]): + request_dict = { + "prompt": case["request"].prompt, + "multimodal_data": {"image": ["xxx"]}, + "request_id": "test_completion_stream_0", + "max_tokens": case["request"].max_tokens, + } + await self.engine_client.add_requests(request_dict) + processed_req = self.multi_modal_processor.process_request_dict( + request_dict, self.engine_client.max_model_len ) - prompt_logprobs_res: Optional[PromptLogprobs] = None - prompt_logprobs_tensors = final_res.get("prompt_logprobs_tensors", None) - if request.prompt_logprobs is not None and prompt_logprobs_tensors is not None: - num_prompt_logprobs = ( - request.prompt_logprobs if request.prompt_logprobs != -1 else self.engine_client.ori_vocab_size + self.engine_client.data_processor.process_response_dict = ( + lambda data, stream, include_stop_str_in_output: data ) - prompt_logprobs_res = self._build_prompt_logprobs( - prompt_logprobs_tensors, num_prompt_logprobs, request.include_logprobs_decode_token + + mock_response_queue = AsyncMock() + stream_responses = self._generate_stream_inference_response( + request_id="test_completion_stream_0", total_token_num=case["total_token_num"] ) - if request.echo: - prompt_text = self._echo_back_prompt(request, idx // (1 if request.n is None else request.n)) - token_ids = [*prompt_token_ids, *output["token_ids"]] - output_text = prompt_text + output["text"] - else: - token_ids = output["token_ids"] - output_text = output["text"] - finish_reason = self.calc_finish_reason( - max_tokens_list[idx // (1 if request.n is None else request.n)], - final_res["output_token_ids"], - output, - False, - ) + mock_response_queue.get.side_effect = stream_responses + self.engine_client.connection_manager.get_connection.return_value = (mock_dealer, mock_response_queue) + + generator = self.completion_serving.completion_stream_generator( + request=case["request"], + num_choices=1, + created_time=0, + request_id="test_completion_stream", + model_name="ernie4.5-vl", + prompt_batched_token_ids=[processed_req["prompt_token_ids"]], + prompt_tokens_list=case["request"].prompt, + max_tokens_list=[processed_req["max_tokens"]], + ) + + final_finish_reason = None + chunks = [] + async for chunk in generator: + chunks.append(chunk) + if "[DONE]" in chunk: + break + + for chunk_str in chunks: + if chunk_str.startswith("data: ") and "[DONE]" not in chunk_str: + try: + json_part = chunk_str.strip().lstrip("data: ") + chunk_dict = json.loads(json_part) + if chunk_dict["choices"][0].get("finish_reason"): + final_finish_reason = chunk_dict["choices"][0]["finish_reason"] + break + except (json.JSONDecodeError, KeyError, IndexError): + continue - choice_data = CompletionResponseChoice( - token_ids=token_ids, - index=len(choices), - text=output_text, - prompt_token_ids=prompt_token_ids if request.return_token_ids else None, - completion_token_ids=completion_token_ids if request.return_token_ids else None, - completion_tokens=output.get("completion_tokens") if request.return_token_ids else None, - prompt_tokens=( - prompt_tokens_list[idx // (1 if request.n is None else request.n)] - if request.return_token_ids - else None + self.assertEqual(final_finish_reason, case["expected_finish_reason"], f"场景 {case['name']} 失败") + + @patch.object(data_processor_logger, "info") + @patch("fastdeploy.entrypoints.openai.serving_completion.api_server_logger") + async def test_completion_create_max_tokens_list_basic(self, mock_api_logger, mock_data_logger): + test_cases = [ + { + "name": "单prompt → max_tokens_list长度1", + "request": CompletionRequest( + request_id="test_single_prompt", + model="ernie4.5-vl", + prompt="请介绍人工智能的应用", + stream=False, + max_tokens=10, ), - reasoning_content=output.get("reasoning_content"), - tool_calls=output.get("tool_calls", None), - logprobs=aggregated_logprobs, - draft_logprobs=aggregated_draft_logprobs, - prompt_logprobs=clamp_prompt_logprobs(prompt_logprobs_res), - finish_reason=finish_reason, - speculate_metrics=final_res["metrics"].get("speculate_metrics", None), - ) - choices.append(choice_data) - - num_generated_tokens += final_res["output_token_ids"] - - num_prompt_tokens += len(prompt_token_ids) - num_cache_tokens += output.get("num_cache_tokens") or 0 - if output.get("num_image_tokens"): - num_generated_tokens += output.get("num_image_tokens") - num_image_tokens += output.get("num_image_tokens") - - num_reasoning_tokens += output.get("reasoning_token_num", 0) - - num_prompt_tokens = num_prompt_tokens // (1 if request.n is None else request.n) - usage = UsageInfo( - prompt_tokens=num_prompt_tokens, - completion_tokens=num_generated_tokens, - total_tokens=num_prompt_tokens + num_generated_tokens, - prompt_tokens_details=PromptTokenUsageInfo(cached_tokens=num_cache_tokens), - completion_tokens_details=CompletionTokenUsageInfo( - reasoning_tokens=num_reasoning_tokens, image_tokens=num_image_tokens - ), - ) - del request - - return CompletionResponse( - id=request_id, - created=created_time, - model=model_name, - choices=choices, - usage=usage, - ) + "mock_max_tokens": 8, + "expected_max_tokens_list_len": 1, + "expected_max_tokens_list": [8], + }, + { + "name": "多prompt → max_tokens_list长度2", + "request": CompletionRequest( + request_id="test_multi_prompt", + model="ernie4.5-vl", + prompt=["请介绍Python语言", "请说明机器学习的步骤"], + stream=False, + max_tokens=15, + ), + "mock_max_tokens": [12, 10], + "expected_max_tokens_list_len": 2, + "expected_max_tokens_list": [12, 10], + }, + ] + + async def mock_format_and_add_data(current_req_dict): + req_idx = int(current_req_dict["request_id"].split("_")[-1]) + if isinstance(case["mock_max_tokens"], list): + current_req_dict["max_tokens"] = case["mock_max_tokens"][req_idx] + else: + current_req_dict["max_tokens"] = case["mock_max_tokens"] + return [101, 102, 103, 104] - async def _call_process_response_dict(self, res, request, stream): - if self._is_process_response_dict_async is None: - self._is_process_response_dict_async = inspect.iscoroutinefunction( - self.engine_client.data_processor.process_response_dict - ) - if self._is_process_response_dict_async: - await self.engine_client.data_processor.process_response_dict( - res, stream=stream, include_stop_str_in_output=request.include_stop_str_in_output - ) - else: - self.engine_client.data_processor.process_response_dict( - res, stream=stream, include_stop_str_in_output=request.include_stop_str_in_output - ) + self.engine_client.format_and_add_data = AsyncMock(side_effect=mock_format_and_add_data) - def _create_completion_logprobs( - self, - output_top_logprobs, - request_logprobs: Optional[int] = None, - prompt_text_offset: Optional[int] = None, - ) -> Optional[CompletionLogprobs]: - """Create OpenAI-style logprobs for completions.""" - - # Parameter validation - if output_top_logprobs is None or len(output_top_logprobs) < 3 or any(not lst for lst in output_top_logprobs): - return None - - logprobs_res: Optional[CompletionLogprobs] = None - # Iterate over the top-k candidates for each token - for logprob_token_ids, logprobs, sampled_token_ranks in zip( - output_top_logprobs[0], output_top_logprobs[1], output_top_logprobs[2] - ): - top_logprobs = LogprobsLists( - logprob_token_ids=[logprob_token_ids], - logprobs=[logprobs], - sampled_token_ranks=[sampled_token_ranks], + async def intercept_generator(**kwargs): + actual_max_tokens_list = kwargs["max_tokens_list"] + self.assertEqual( + len(actual_max_tokens_list), + case["expected_max_tokens_list_len"], + f"列表长度不匹配:实际{len(actual_max_tokens_list)},预期{case['expected_max_tokens_list_len']}", ) - # Build the logprobs response - step_logprobs_res = self._build_logprobs_response( - response_logprobs=top_logprobs, - request_top_logprobs=request_logprobs, - prompt_text_offset=prompt_text_offset, + self.assertEqual( + actual_max_tokens_list, + case["expected_max_tokens_list"], + f"列表元素不匹配:实际{actual_max_tokens_list},预期{case['expected_max_tokens_list']}", ) - if logprobs_res is None: - logprobs_res = step_logprobs_res - else: - # Append the new tokens to the existing logprobs response - logprobs_res.tokens.extend(step_logprobs_res.tokens) - logprobs_res.token_logprobs.extend(step_logprobs_res.token_logprobs) - logprobs_res.top_logprobs.extend(step_logprobs_res.top_logprobs) - - return logprobs_res - - def _build_logprobs_response( - self, - response_logprobs: Optional[LogprobsLists] = None, - request_top_logprobs: Optional[int] = None, - prompt_text_offset: Optional[int] = None, - ) -> Optional[CompletionLogprobs]: - """ - Construct a logprobs response object in line with the OpenAI style. - Retain the complete top-k candidates and avoid circular references. - """ - - # Parameter validation - if response_logprobs is None or request_top_logprobs is None or request_top_logprobs < 0: - return None - - try: - # The top-k candidates for the current token - topk_token_ids = [] - topk_logprobs = [] - - if response_logprobs.logprob_token_ids and len(response_logprobs.logprob_token_ids) > 0: - topk_token_ids = response_logprobs.logprob_token_ids[0][: request_top_logprobs + 1] - - if response_logprobs.logprobs and len(response_logprobs.logprobs) > 0: - topk_logprobs = response_logprobs.logprobs[0][: request_top_logprobs + 1] - - # Construct the sampled token object (avoid sharing references with top_logprob_entries) - tokens = [] - token_logprobs = [] - top_logprobs = {} - idx = 0 - for tid, lp in zip(topk_token_ids, topk_logprobs): - token_str = self.engine_client.data_processor.process_logprob_response( - [tid], clean_up_tokenization_spaces=False - ) - if "\ufffd" in token_str: - raw_token = self.engine_client.data_processor.tokenizer.convert_ids_to_tokens(tid) - token_bytes = raw_token.encode("utf-8", errors="replace") - token_str = "bytes:" + "".join(f"\\x{byte:02x}" for byte in token_bytes) - if idx == 0: - tokens.append(token_str) - token_logprobs.append(lp) - top_logprobs[token_str] = lp - idx += 1 - - # Construct the sampled token object (avoid sharing references with top_logprob_entries) - # text_offset = prompt_text_offset + len(tokens) - 1 - return CompletionLogprobs( - tokens=tokens, - token_logprobs=token_logprobs, - top_logprobs=[top_logprobs], - # text_offset=[text_offset], + return CompletionResponse( + id=kwargs["request_id"], + object="text_completion", + created=kwargs["created_time"], + model=kwargs["model_name"], + choices=[], + usage=UsageInfo(prompt_tokens=0, completion_tokens=0, total_tokens=0), ) - except Exception as e: - api_server_logger.error(f"Error in _build_logprobs_response: {str(e)}, {str(traceback.format_exc())}") - return None - - def _build_prompt_logprobs( - self, - prompt_logprobs_tensors: LogprobsTensors, - num_prompt_logprobs: int, - include_logprobs_decode_token: bool, - ): - """Update with prompt logprobs from worker. - Args: - prompt_logprobs_tensors: tuple containing the prompt logprobs - tensors. - """ - - token_ids, logprobs, ranks = prompt_logprobs_tensors - - # Detokenize non-incrementally. - # Output is flat: [num_tok, num_lps] -> [num_tok * num_lps] - if include_logprobs_decode_token: - decoded_tokens = [ - self.engine_client.data_processor.process_logprob_response(token_id) - for token_id in token_ids.flatten().tolist() - ] - else: - decoded_tokens = None - - # Recover shapes. - num_prompt_tokens, num_logprobs = logprobs.shape - - # Pythonize the paddle tensors. - prompt_token_ranks = ranks.tolist() - prompt_logprobs = logprobs.tolist() - token_ids = token_ids.tolist() - result: Optional[PromptLogprobs] = [None] - # Make Logprob for each position. - for pos in range(num_prompt_tokens): - # Handle flattening. - offset = pos * num_logprobs - offset_end = offset + num_logprobs - decoded_tokens_for_pos = NONES if decoded_tokens is None else decoded_tokens[offset:offset_end] - - # Update with the Logprob dictionary for this pos. - result.append( - self._make_logprob_dict( - prompt_logprobs[pos], - token_ids[pos], - decoded_tokens_for_pos, - prompt_token_ranks[pos], - num_prompt_logprobs, - ) - ) - return result - - @staticmethod - def _make_logprob_dict( - logprobs: list[float], - logprob_token_ids: list[int], - decoded_tokens: Iterable[str | None], - rank: int, - num_logprobs: int, - ) -> dict[int, Logprob]: - """Make a Logprob dictionary for a position. - Args: - logprobs: list of log probabilities - logprob_token_ids: list of top token ids - decoded_tokens: list of decoded top tokens - rank: rank of the sampled token - num_logprobs: number of logprobs requested - by the user (in addition to sampled logprob) - Returns: - dict[token id, Logprob] - """ - if num_logprobs == -1: - num_logprobs = len(logprobs) - # We do not need a special case for the sampled token - # being in the topk, since inserting duplicated data - # into a dictionary twice is the same as doing it once. - topk_ranks = range(1, num_logprobs + 1) - ranks = itertools.chain((rank,), topk_ranks) - - return { - token_id: Logprob( - logprob=logprob, - rank=rank, - decoded_token=token, - ) - for token_id, logprob, rank, token in zip(logprob_token_ids, logprobs, ranks, decoded_tokens) - } + self.completion_serving.completion_full_generator = AsyncMock(side_effect=intercept_generator) + + for case in test_cases: + with self.subTest(case=case["name"]): + result = await self.completion_serving.create_completion(request=case["request"]) + self.assertIsInstance(result, CompletionResponse) diff --git a/tests/entrypoints/openai/test_max_streaming_tokens.py b/tests/entrypoints/openai/test_max_streaming_tokens.py index 3d380d5a258..c6c5ce7da21 100644 --- a/tests/entrypoints/openai/test_max_streaming_tokens.py +++ b/tests/entrypoints/openai/test_max_streaming_tokens.py @@ -116,7 +116,9 @@ async def test_integration_with_chat_stream_generator(self, mock_processor_class "text": "a", "top_logprobs": None, "draft_top_logprobs": None, - "enable_parser": False, + "reasoning_content": "", + "tool_calls": None, + "skipped": False, }, "metrics": {"first_token_time": 0.1, "inference_start_time": 0.1}, "finished": False, @@ -128,7 +130,9 @@ async def test_integration_with_chat_stream_generator(self, mock_processor_class "text": "b", "top_logprobs": None, "draft_top_logprobs": None, - "enable_parser": False, + "reasoning_content": "", + "tool_calls": None, + "skipped": False, }, "metrics": {"engine_recv_latest_token_time": 0.2, "first_token_time": None}, "finished": False, @@ -140,7 +144,9 @@ async def test_integration_with_chat_stream_generator(self, mock_processor_class "text": "c", "top_logprobs": None, "draft_top_logprobs": None, - "enable_parser": False, + "reasoning_content": "", + "tool_calls": None, + "skipped": False, }, "metrics": {"engine_recv_latest_token_time": 0.3, "first_token_time": None}, "finished": False, @@ -152,7 +158,9 @@ async def test_integration_with_chat_stream_generator(self, mock_processor_class "text": "d", "top_logprobs": None, "draft_top_logprobs": None, - "enable_parser": False, + "reasoning_content": "", + "tool_calls": None, + "skipped": False, }, "metrics": {"engine_recv_latest_token_time": 0.4, "first_token_time": None}, "finished": False, @@ -164,7 +172,9 @@ async def test_integration_with_chat_stream_generator(self, mock_processor_class "text": "e", "top_logprobs": None, "draft_top_logprobs": None, - "enable_parser": False, + "reasoning_content": "", + "tool_calls": None, + "skipped": False, }, "metrics": {"engine_recv_latest_token_time": 0.5, "first_token_time": None}, "finished": False, @@ -176,7 +186,9 @@ async def test_integration_with_chat_stream_generator(self, mock_processor_class "text": "f", "top_logprobs": None, "draft_top_logprobs": None, - "enable_parser": False, + "reasoning_content": "", + "tool_calls": None, + "skipped": False, }, "metrics": {"engine_recv_latest_token_time": 0.6, "first_token_time": None}, "finished": False, @@ -188,7 +200,9 @@ async def test_integration_with_chat_stream_generator(self, mock_processor_class "text": "g", "top_logprobs": None, "draft_top_logprobs": None, - "enable_parser": False, + "reasoning_content": "", + "tool_calls": None, + "skipped": False, }, "metrics": {"engine_recv_latest_token_time": 0.7, "first_token_time": None, "request_start_time": 0.1}, "finished": True, @@ -274,7 +288,9 @@ async def test_integration_with_completion_stream_generator(self, mock_logger): "text": "a", "top_logprobs": None, "draft_top_logprobs": None, - "enable_parser": False, + "reasoning_content": "", + "tool_calls": None, + "skipped": False, }, "metrics": {"first_token_time": 0.1, "inference_start_time": 0.1}, "finished": False, @@ -286,7 +302,9 @@ async def test_integration_with_completion_stream_generator(self, mock_logger): "text": "b", "top_logprobs": None, "draft_top_logprobs": None, - "enable_parser": False, + "reasoning_content": "", + "tool_calls": None, + "skipped": False, }, "metrics": {"engine_recv_latest_token_time": 0.2, "first_token_time": None}, "finished": False, @@ -300,7 +318,9 @@ async def test_integration_with_completion_stream_generator(self, mock_logger): "text": "g", "top_logprobs": None, "draft_top_logprobs": None, - "enable_parser": False, + "reasoning_content": "", + "tool_calls": None, + "skipped": False, }, "metrics": { "engine_recv_latest_token_time": 0.7, @@ -357,7 +377,6 @@ async def test_integration_with_completion_stream_generator(self, mock_logger): self.fail(f"{i + 1} chunk is unexcepted 'data: JSON\\n\\n': {repr(chunk_str)}") self.assertEqual(len(parsed_chunks), 1) for chunk_dict in parsed_chunks: - print(f"======>{chunk_dict}") choices_list = chunk_dict["choices"] self.assertEqual(len(choices_list), 3, f"Chunk {chunk_dict} should has three choices") self.assertEqual( @@ -580,13 +599,29 @@ async def test_chat_stream_usage_fields(self, mock_response_processor, api_serve response_data = [ { "request_id": "test-request-id_0", - "outputs": {"token_ids": [1], "text": "a", "top_logprobs": None, "draft_top_logprobs": None}, + "outputs": { + "token_ids": [1], + "text": "a", + "top_logprobs": None, + "draft_top_logprobs": None, + "reasoning_content": "", + "tool_calls": None, + "skipped": False, + }, "metrics": {"first_token_time": 0.1, "inference_start_time": 0.1, "request_start_time": 0.0}, "finished": False, }, { "request_id": "test-request-id_0", - "outputs": {"token_ids": [2, 3], "text": "bc", "top_logprobs": None, "draft_top_logprobs": None}, + "outputs": { + "token_ids": [2, 3], + "text": "bc", + "top_logprobs": None, + "draft_top_logprobs": None, + "reasoning_content": "", + "tool_calls": None, + "skipped": False, + }, "metrics": {"engine_recv_latest_token_time": 0.3, "first_token_time": None, "request_start_time": 0.0}, "finished": True, }, @@ -708,7 +743,9 @@ async def test_completion_stream_usage_fields(self, mock_logger): "text": "a", "top_logprobs": None, "draft_top_logprobs": None, - "enable_parser": False, + "reasoning_content": "", + "tool_calls": None, + "skipped": False, }, "metrics": { "engine_recv_latest_token_time": 0.3, @@ -727,7 +764,9 @@ async def test_completion_stream_usage_fields(self, mock_logger): "text": "bc", "top_logprobs": None, "draft_top_logprobs": None, - "enable_parser": False, + "reasoning_content": "", + "tool_calls": None, + "skipped": False, }, "metrics": { "engine_recv_latest_token_time": 0.3, @@ -915,7 +954,9 @@ async def test_completion_stream_generator_async_process_response_dict(self, moc "text": "a", "top_logprobs": {"a": 0.98, "b": 0.02}, "draft_top_logprobs": {"a": 0.98, "b": 0.02}, - "enable_parser": False, + "reasoning_content": "", + "tool_calls": None, + "skipped": False, }, "finished": False, "metrics": { @@ -936,7 +977,9 @@ async def test_completion_stream_generator_async_process_response_dict(self, moc "text": "b", "top_logprobs": {"a": 0.98, "b": 0.02}, "draft_top_logprobs": {"a": 0.98, "b": 0.02}, - "enable_parser": False, + "reasoning_content": "", + "tool_calls": None, + "skipped": False, }, "finished": False, "metrics": { @@ -957,7 +1000,9 @@ async def test_completion_stream_generator_async_process_response_dict(self, moc "text": "g", "top_logprobs": {"a": 0.98, "b": 0.02}, "draft_top_logprobs": {"a": 0.98, "b": 0.02}, - "enable_parser": False, + "reasoning_content": "", + "tool_calls": None, + "skipped": False, }, "finished": True, "metrics": { diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 325376abc05..a0e58fd4d46 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -527,13 +527,15 @@ async def test_chat_completion_stream_generator_with_both_logprobs(self): ], "draft_top_logprobs": None, "multipart": [{"type": "text", "text": "Hi"}], + "reasoning_content": "", + "tool_calls": None, + "skipped": False, }, "finished": True, "num_cached_tokens": 0, "num_input_image_tokens": 0, "num_input_video_tokens": 0, } - mock_response = RequestOutput.from_dict(mock_response) mock_response_queue.get.return_value = mock_response @@ -590,6 +592,8 @@ async def mock_async_generator(): # Check for logprobs in subsequent chunks logprobs_found = False for result in results: + print("1") + print(result) # Skip [DONE] message if result.strip() == "data: [DONE]": continue @@ -1170,6 +1174,9 @@ async def mock_async_generator_with_cancel(): "draft_top_logprobs": None, "multipart": [{"type": "text", "text": "Hi"}], "enable_parser": False, + "reasoning_content": "", + "tool_calls": None, + "skipped": False, }, "finished": False, "num_cached_tokens": 0, diff --git a/tests/entrypoints/openai/test_serving_completion.py b/tests/entrypoints/openai/test_serving_completion.py index cfc7f4402c6..874a3c128d2 100644 --- a/tests/entrypoints/openai/test_serving_completion.py +++ b/tests/entrypoints/openai/test_serving_completion.py @@ -83,7 +83,7 @@ def test_calc_finish_reason_tool_calls(self): # 创建一个OpenAIServingCompletion实例 serving_completion = OpenAIServingCompletion(engine_client, None, "pid", "ips", 360) # 创建一个模拟的output,并设置finish_reason为"tool_call" - output = {"tool_call": "tool_call"} + output = {"tool_calls": "tool_call"} # 调用calc_finish_reason方法 result = serving_completion.calc_finish_reason(None, 100, output, False) # 断言结果为"tool_calls" @@ -767,6 +767,9 @@ async def test_completion_stream_generator_without_logprobs(self): "num_cache_tokens": 0, "num_image_tokens": 0, "reasoning_token_num": 0, + "tool_calls": None, + "reasoning_content": "", + "skipped": False, }, "finished": True, } diff --git a/tests/input/test_ernie4_5_processor.py b/tests/input/test_ernie4_5_processor.py index bacda53badb..088fe5f33ec 100644 --- a/tests/input/test_ernie4_5_processor.py +++ b/tests/input/test_ernie4_5_processor.py @@ -96,6 +96,7 @@ def extract_reasoning_content_streaming( class ReasoningDelta: def __init__(self, content): self.reasoning_content = content + self.content = content return ReasoningDelta(delta_text) @@ -227,13 +228,12 @@ def test_process_response_dict_streaming_with_reasoning_and_tool(self): response = { "finished": True, "request_id": "req-1", - "outputs": {"token_ids": [10, 11]}, + "outputs": {"token_ids": [10, 11], "reasoning_content": "", "tool_calls": [1], "skipped": False}, } result = proc.process_response_dict_streaming( response, enable_thinking=False, include_stop_str_in_output=False ) - outputs = result["outputs"] self.assertIn("completion_tokens", outputs) @@ -243,9 +243,7 @@ def test_process_response_dict_streaming_with_reasoning_and_tool(self): self.assertIn("reasoning_token_num", outputs) self.assertGreaterEqual(outputs["reasoning_token_num"], 0) - self.assertIn("delta_message", outputs) - delta_msg = outputs["delta_message"] - self.assertTrue(hasattr(delta_msg, "tool_calls")) + self.assertIn("tool_calls", outputs) self.assertNotIn("req-1", proc.decode_status) self.assertNotIn("req-1", proc.tool_parser_dict) @@ -351,8 +349,8 @@ def test_process_response_dict_normal_with_tool(self): result = proc.process_response_dict_normal(resp, enable_thinking=False, include_stop_str_in_output=False) - self.assertIn("tool_call", result["outputs"]) - self.assertEqual(result["outputs"]["tool_call"][0]["name"], "fake_tool") + self.assertIn("tool_calls", result["outputs"]) + self.assertEqual(result["outputs"]["tool_calls"][0]["name"], "fake_tool") if __name__ == "__main__": diff --git a/tests/input/test_text_processor.py b/tests/input/test_text_processor.py index 4134b130e13..f4a6f2ec7de 100644 --- a/tests/input/test_text_processor.py +++ b/tests/input/test_text_processor.py @@ -19,12 +19,19 @@ import sys import types import unittest +from collections.abc import Sequence from pathlib import Path from types import SimpleNamespace from unittest import mock import numpy as np +from fastdeploy.entrypoints.openai.protocol import ( + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, +) + class DummyTokenizer: bos_token = "" @@ -261,7 +268,7 @@ def __setitem__(self, key, value): class DataProcessorTestCase(unittest.TestCase): @staticmethod - def create_dummy_reasoning(tokenizer, reasoning_content="think"): + def create_dummy_reasoning(tokenizer, reasoning_content="think", content="content"): class DummyReasoning: def __init__(self, tokenizer): self.tokenizer = tokenizer @@ -269,6 +276,18 @@ def __init__(self, tokenizer): def extract_reasoning_content(self, full_text, response_dict, model_status): return reasoning_content, f"{full_text}!" + def extract_reasoning_content_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + model_status: str, + ): + return DeltaMessage(reasoning_content=reasoning_content, content=content) + return DummyReasoning(tokenizer) @staticmethod @@ -278,8 +297,30 @@ def __init__(self, tokenizer): self.tokenizer = tokenizer def extract_tool_calls(self, full_text, response_dict): + # 模拟工具调用解析,返回固定的工具调用数据用于测试 return SimpleNamespace(tools_called=True, tool_calls=["tool"], content=content) + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + model_status: str, + ): + # 模拟流式工具调用解析,返回固定的工具调用数据用于测试 + tool_calls = [ + DeltaToolCall( + index=0, + type="function", + id="text", + function=DeltaFunctionCall(name="test").model_dump(exclude_none=True), + ) + ] + return DeltaMessage(tool_calls=tool_calls, content=content) + return DummyToolParser def setUp(self): @@ -434,6 +475,25 @@ def test_process_response_streaming_clears_state(self): self.assertEqual(result["outputs"]["text"], "7") self.assertNotIn(req_id, processor.decode_status) + def test_process_response_streaming_with_reasoning_and_tools(self): + processor = self.processor + self.processor.model_status_dict["normal"] = "think_start" + processor.reasoning_parser = self.create_dummy_reasoning( + processor.tokenizer, reasoning_content="because", content="tool-text" + ) + processor.tool_parser_obj = self.create_dummy_tool_parser(processor.tokenizer, content="tool-text") + response = { + "finished": True, + "request_id": "normal", + "outputs": {"token_ids": [7, processor.tokenizer.eos_token_id]}, + } + + result = processor.process_response_dict_streaming(response, enable_thinking=True) + self.assertEqual(result["outputs"]["completion_tokens"], "7") + self.assertEqual(result["outputs"]["text"], "tool-text") + self.assertEqual(result["outputs"]["reasoning_content"], "because") + self.assertEqual(result["outputs"]["reasoning_token_num"], 1) + def test_process_response_dict_normal_with_reasoning(self): processor = self.processor processor.model_status_dict = {"normal": "normal"}