diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index 8be2e785fd9..f6864570874 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -110,7 +110,8 @@ def __init__( # VL model config: if self.enable_mm: - self._init_image_preprocess() + if "ernie" in self.fd_config.model_config.model_type: + self._init_image_preprocess() self.amp_black = [ "reduce_sum", @@ -1001,7 +1002,7 @@ def _init_share_inputs(self, max_num_seqs: int): if self.enable_mm: head_dim = self.model_config.head_dim - if "paddleocr" in self.model_config.model_type: # neox style = True + if "qwen" in self.model_config.model_type or "paddleocr" in self.model_config.model_type: # neox style = True rope_head_dim = head_dim else: # neox style = False rope_head_dim = head_dim // 2 @@ -1697,7 +1698,7 @@ def _preprocess_mm_task(self, one: dict) -> None: image_type_ids = one["image_type_ids"][np.newaxis, :] images = one["images"] image_type_ids = paddle.to_tensor(image_type_ids, dtype=paddle.int64) - images = paddle.to_tensor(images, dtype="uint8") + images = paddle.to_tensor(images, dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16") grid_thw = paddle.to_tensor(one["grid_thw"], dtype="int64") else: image_type_ids = None @@ -1752,6 +1753,22 @@ def extract_vision_features_ernie(self, vision_inputs: dict[str, list[paddle.Ten ) return image_features + def extract_vision_features_qwen(self, vision_inputs: dict[str, list[paddle.Tensor]]) -> paddle.Tensor: + assert len(vision_inputs["images_lst"]) > 0, "at least one image needed" + + grid_thw = paddle.to_tensor(vision_inputs["grid_thw_lst"], dtype=paddle.int64) + images = paddle.concat(vision_inputs["images_lst"]).cast("bfloat16") + with paddle.amp.auto_cast( + True, + custom_black_list=self.amp_black, + custom_white_list=self.amp_white, + level="O2", + dtype=self.model_config.dtype, + ): + image_features = self.model.visual.extract_feature(images, grid_thw) + + return image_features + def extract_vision_features_paddleocr(self, inputs: dict[str, list[paddle.Tensor]]) -> paddle.Tensor: if envs.FD_ENABLE_MAX_PREFILL: inputs["vit_position_ids_lst"] = np.concatenate(inputs["vit_position_ids_lst"]) @@ -1800,9 +1817,8 @@ def extract_vision_features(self, multi_vision_inputs: dict[str, list[paddle.Ten """extract_vision_features""" if "ernie" in self.model_config.model_type: return self.extract_vision_features_ernie(multi_vision_inputs) - # TODO support VL - # elif "qwen" in self.model_config.model_type: - # return self.extract_vision_features_qwen(multi_vision_inputs) + elif "qwen" in self.model_config.model_type: + return self.extract_vision_features_qwen(multi_vision_inputs) elif "paddleocr" in self.model_config.model_type: return self.extract_vision_features_paddleocr(multi_vision_inputs) else: