From 32d5a9d5a73e17afb5deef4cc74dbe3d93f42e20 Mon Sep 17 00:00:00 2001 From: eligotts <78387377+eligotts@users.noreply.github.com> Date: Thu, 18 Jun 2026 07:03:49 +0000 Subject: [PATCH 1/2] Support raw image refs for multimodal rendering --- renderers/base.py | 21 +-- renderers/client.py | 168 +++++++++++--------- renderers/configs.py | 87 ++++++++-- renderers/mm_store.py | 222 ++++++++++++++++++++++++++ renderers/qwen35.py | 78 ++------- renderers/qwen3_vl.py | 360 +++++++++++++++++++++++++++++++++--------- tests/test_client.py | 229 ++++++++++++++++++++++----- 7 files changed, 892 insertions(+), 273 deletions(-) create mode 100644 renderers/mm_store.py diff --git a/renderers/base.py b/renderers/base.py index 8f722d7..4dbb4f4 100644 --- a/renderers/base.py +++ b/renderers/base.py @@ -204,11 +204,10 @@ class MultiModalData: """Multimodal sidecar produced alongside the token stream. Renderer output is framework-agnostic: ``mm_items[modality][i]`` is a - plain ``dict`` mirroring the per-item output of a HuggingFace processor - (e.g. ``{"pixel_values": Tensor, "image_grid_thw": Tensor}`` for - Qwen3-VL images). Translation to engine-specific wire formats — vLLM's - ``MultiModalKwargsItem``, SGLang's payload, etc. — happens in the - inference glue layer (see ``renderers.client``). + plain descriptor dict (e.g. ``{"image_grid_thw": [[1, h, w]]}`` for + Qwen-VL images). Translation to engine-specific wire formats — vLLM image + refs, SGLang payloads, etc. — happens in the inference glue layer (see + ``renderers.client``). """ mm_hashes: dict[str, list[str]] = field(default_factory=dict) @@ -761,8 +760,8 @@ def bridge_to_next_turn( Text-only renderers return :class:`RenderedTokens` with ``multi_modal_data=None``. Multimodal renderers (see :class:`MultimodalRenderer`) populate ``multi_modal_data`` so - the caller can recover placeholder offsets + per-item processed - tensors for the new full prompt; they also accept a + the caller can recover placeholder offsets + per-item image + descriptors for the new full prompt; they also accept a ``previous_multi_modal_data`` kwarg via the :class:`MultimodalRenderer` Protocol override. @@ -818,8 +817,8 @@ def bridge_to_next_turn( the combined token sequence and silently falls back to hash-cache lookup (or errors) - returns :class:`RenderedTokens` (not ``list[int]``) so the - caller can recover the placeholder offsets + per-item - processed tensors for the new full prompt + caller can recover the placeholder offsets + per-item image + descriptors for the new full prompt """ ... @@ -967,6 +966,10 @@ def bridge_to_next_turn(self, *args: Any, **kwargs: Any) -> "RenderedTokens | No with self.checkout() as r: return r.bridge_to_next_turn(*args, **kwargs) + def materialize_image_refs(self, *args: Any, **kwargs: Any) -> "MultiModalData": + with self.checkout() as r: + return r.materialize_image_refs(*args, **kwargs) + # ``mm_token_type_id_map`` (the MultimodalRenderer protocol attribute) # is set in ``__init__`` only for pools wrapping multimodal renderers; # see the comment there for why this isn't a class-level property. diff --git a/renderers/client.py b/renderers/client.py index 0c63c0e..de9df0b 100644 --- a/renderers/client.py +++ b/renderers/client.py @@ -15,6 +15,7 @@ import json import logging from collections.abc import Mapping +from dataclasses import replace from typing import Any, cast import httpx @@ -156,6 +157,7 @@ async def generate( priority: int | None = None, extra_headers: dict[str, str] | None = None, max_prompt_len: int | None = None, + materialize_all_image_refs: bool = False, ) -> dict[str, Any]: """Tokenize messages, call vLLM /inference/v1/generate, parse the response. @@ -173,8 +175,9 @@ async def generate( For multimodal renderers (e.g. ``Qwen3VLRenderer``), the call goes through ``renderer.render(...)`` to recover the ``multi_modal_data`` sidecar, then serializes it to vLLM's ``features`` schema (mm_hashes, - mm_placeholders, kwargs_data) before POSTing. The serializer imports - ``vllm.*`` lazily so text-only consumers never pay for the import. + mm_placeholders, kwargs_data) before POSTing. Qwen-family image + ``kwargs_data`` slots are either ``None`` (cache lookup for a prior + image) or run image refs (new/current images that vLLM should process). ``max_prompt_len`` controls the pre-flight overflow check. When the rendered prompt is strictly longer than the cap, the request is never @@ -248,11 +251,23 @@ def _prepare(): "token_ids": prompt_ids, "sampling_params": sp, } - features = ( - _build_mm_features(renderer, mm_data) - if mm_data and not mm_data.is_empty() - else None - ) + + def _features_and_descriptor_mm() -> tuple[dict[str, Any] | None, MultiModalData | None]: + if mm_data is None or mm_data.is_empty(): + return None, mm_data + build_mm = mm_data + if materialize_all_image_refs: + materialize = getattr(renderer, "materialize_image_refs", None) + if materialize is None: + raise NotImplementedError( + f"{type(renderer).__name__} cannot materialize image refs for retry." + ) + build_mm = materialize(mm_data, messages) + return _build_vllm_mm_features(renderer, build_mm), _descriptor_only_mm_data(mm_data) + + features, out_mm_data = await _maybe_offload(renderer, _features_and_descriptor_mm) + if prompt_attr is not None and getattr(prompt_attr, "multi_modal_data", None) is not None: + prompt_attr = replace(prompt_attr, multi_modal_data=out_mm_data) if features is not None: body["features"] = features if cache_salt is not None: @@ -322,7 +337,7 @@ def _prepare(): # The mm sidecar consumed on the request side, surfaced back so # callers can persist it on the trajectory step for downstream # multi-turn bridging and training-sample construction. - "multi_modal_data": mm_data, + "multi_modal_data": out_mm_data, # The renderer's per-token attribution for the prompt — either # the RenderedTokens computed here via renderer.render(...) or # the one threaded in by the caller alongside prompt_ids (the @@ -334,7 +349,31 @@ def _prepare(): } -def _build_mm_features( +def _descriptor_only_mm_data(mm_data: MultiModalData) -> MultiModalData: + """Drop one-request image-ref fields before callers persist mm_data.""" + from renderers.mm_store import IMAGE_REF_PAYLOAD_KEY + + new_items: dict[str, list[dict[str, Any]]] = {} + for modality, items in mm_data.mm_items.items(): + new_items[modality] = [ + { + key: value + for key, value in item.items() + if key + not in { + "pixel_values", + "raw_uri", + "raw_image_id", + "image_layout_fingerprint", + IMAGE_REF_PAYLOAD_KEY, + } + } + for item in items + ] + return replace(mm_data, mm_items=new_items) + + +def _build_vllm_mm_features( renderer: Renderer | RendererPool, mm_data: MultiModalData, ) -> dict[str, Any] | None: @@ -342,22 +381,9 @@ def _build_mm_features( vLLM's ``MultiModalFeatures`` carries three things: hashes (for cache lookup), placeholder positions (so the engine knows where in the - token stream each item lives), and per-item ``MultiModalKwargsItem`` - base64-encoded. The encoding requires vLLM-side type info — what - fields belong to each modality, how they batch — and is currently - model-family specific. For now we dispatch on the renderer class; - extend the dispatch table as more multimodal renderers land. - - NOTE — future engine pluggability: this encoder is vLLM 0.20-specific - (uses ``vllm.multimodal.inputs.MultiModalKwargsItems``, - ``vllm.entrypoints.serve.disagg.mm_serde.encode_mm_kwargs_item``, and - ``_create_qwen2vl_field_factory``). When a second inference engine - arrives (SGLang, MAX, ...) the renderer client should be parameterized - on engine: either (a) move the encoder onto the renderer as - ``encode_mm_for_(mm_data)`` methods, or (b) accept an - ``Encoder`` strategy at the ``generate(...)`` call site. The data type - (``MultiModalData``) is already framework-agnostic and does not need - to change. Don't pre-build the abstraction with one engine in tree. + token stream each item lives), and per-item payload selectors. For + Qwen images, payload selectors are ``None`` for cache-only prior images + or run image refs for images vLLM should process. """ from renderers.qwen3_vl import Qwen3VLRenderer from renderers.qwen35 import Qwen35Renderer @@ -369,43 +395,27 @@ def _build_mm_features( renderer.renderer_cls if isinstance(renderer, RendererPool) else type(renderer) ) - # Qwen3-VL and Qwen3.5 both ship ``pixel_values`` + ``image_grid_thw`` - # via the shared Qwen2-VL field factory. ``spatial_merge_size=2`` is - # the family default and matches every Qwen-VL processor in tree. if issubclass(renderer_cls, (Qwen3VLRenderer, Qwen35Renderer)): - return _build_qwen_vl_features(mm_data, spatial_merge_size=2) + return _build_qwen_vl_image_ref_features(mm_data) raise NotImplementedError( f"Multimodal serialization not implemented for {renderer_cls.__name__}. " - "Add a dispatch branch in renderers.client._build_mm_features." + "Add a dispatch branch in renderers.client._build_vllm_mm_features." ) -def _build_qwen_vl_features( - mm_data: MultiModalData, *, spatial_merge_size: int -) -> dict[str, Any]: - """vLLM features payload for the Qwen-VL family (Qwen2-VL / Qwen3-VL). - - Stacks per-image processor outputs back into a batched ``BatchFeature``, - runs the Qwen2-VL field factory (shared across the family), wraps as - ``MultiModalKwargsItems``, base64-encodes each item, and assembles a - JSON-serializable dict matching vLLM's ``MultiModalFeatures`` schema. +def _build_qwen_vl_image_ref_features(mm_data: MultiModalData) -> dict[str, Any]: + """vLLM features payload for Qwen-VL image refs. Returns ``None`` semantics live one level up — this helper assumes the caller already verified ``mm_data`` is non-empty. """ - try: - import torch - from transformers.feature_extraction_utils import BatchFeature - from vllm.entrypoints.serve.disagg.mm_serde import encode_mm_kwargs_item - from vllm.model_executor.models.qwen2_vl import _create_qwen2vl_field_factory - from vllm.multimodal.inputs import MultiModalKwargsItems - except ImportError as exc: - raise RuntimeError( - "Multimodal generate via /inference/v1/generate requires `vllm` " - "and `torch` to encode the features payload. Install vLLM in this " - "environment, or pre-build features upstream." - ) from exc + from renderers.mm_store import ( + IMAGE_REF_PAYLOAD_KEY, + IMAGE_REF_PAYLOAD_VALUE, + current_run_id, + image_ref, + ) out: dict[str, Any] = { "mm_hashes": {}, @@ -415,32 +425,44 @@ def _build_qwen_vl_features( image_items = mm_data.mm_items.get("image") or [] if image_items: - # mm_items now ship numpy arrays (the renderer is torch-free); - # convert at this vLLM-glue boundary where torch is already a - # hard dependency. - pixel_values = torch.cat( - [torch.as_tensor(it["pixel_values"]) for it in image_items], dim=0 - ) - image_grid_thw = torch.cat( - [torch.as_tensor(it["image_grid_thw"]) for it in image_items], dim=0 - ) - hf_inputs = BatchFeature( - data={"pixel_values": pixel_values, "image_grid_thw": image_grid_thw} - ) - config = _create_qwen2vl_field_factory(spatial_merge_size)(hf_inputs) - kwargs_items = MultiModalKwargsItems.from_hf_inputs(hf_inputs, config) - encoded = [encode_mm_kwargs_item(it) for it in kwargs_items["image"]] + mm_hashes = list(mm_data.mm_hashes.get("image") or []) + placeholders = list(mm_data.mm_placeholders.get("image") or []) + if len(mm_hashes) != len(image_items) or len(placeholders) != len(image_items): + raise ValueError( + "Qwen-VL mm sidecar length mismatch: " + f"items={len(image_items)} hashes={len(mm_hashes)} placeholders={len(placeholders)}" + ) + + encoded: list[Any] = [None] * len(image_items) + run_id = current_run_id() + for idx, item in enumerate(image_items): + if item.get(IMAGE_REF_PAYLOAD_KEY) != IMAGE_REF_PAYLOAD_VALUE: + continue + raw_image_id = item.get("raw_image_id") + grid_thw = item.get("image_grid_thw") + fingerprint = item.get("image_layout_fingerprint") + if not isinstance(raw_image_id, str) or not raw_image_id: + raise ValueError("image-ref multimodal item is missing raw_image_id") + if grid_thw is None: + raise ValueError("image-ref multimodal item is missing image_grid_thw") + if not isinstance(fingerprint, str) or not fingerprint: + raise ValueError("image-ref multimodal item is missing image_layout_fingerprint") + encoded[idx] = image_ref( + run_id=run_id, + fingerprint=fingerprint, + modality="image", + mm_hash=mm_hashes[idx], + raw_image_id=raw_image_id, + grid_thw=grid_thw, + ) + out["kwargs_data"]["image"] = encoded - out["mm_hashes"]["image"] = list(mm_data.mm_hashes.get("image") or []) + out["mm_hashes"]["image"] = mm_hashes out["mm_placeholders"]["image"] = [ - {"offset": p.offset, "length": p.length} - for p in mm_data.mm_placeholders.get("image") or [] + {"offset": p.offset, "length": p.length} for p in placeholders ] - # If kwargs_data is empty across all modalities, drop the key so vLLM - # falls back to the hash-only (cache-hit) path. Otherwise hand it the - # full payload. - if not any(out["kwargs_data"].values()): + if not any(item is not None for values in out["kwargs_data"].values() for item in values): out["kwargs_data"] = None return out diff --git a/renderers/configs.py b/renderers/configs.py index d500f8e..b07d97e 100644 --- a/renderers/configs.py +++ b/renderers/configs.py @@ -25,6 +25,12 @@ from pydantic import ConfigDict, Field from pydantic_config import BaseConfig +QWEN_VL_IMAGE_PATCH_SIZE = 16 +QWEN_VL_IMAGE_TEMPORAL_PATCH_SIZE = 2 +QWEN_VL_IMAGE_MERGE_SIZE = 2 +QWEN_VL_IMAGE_MIN_PIXELS = 65536 +QWEN_VL_IMAGE_MAX_PIXELS = 16777216 + class BaseRendererConfig(BaseConfig): """Shared fields and config for every renderer config variant. @@ -148,11 +154,30 @@ class Qwen35RendererConfig(BaseRendererConfig): running across the entire conversation. Mirrors the chat template's ``add_vision_id`` toggle.""" - image_cache_max: int = 256 - """FIFO bound on the per-renderer image processor cache. Renderer- - internal — not a Jinja chat-template kwarg.""" + image_patch_size: int = QWEN_VL_IMAGE_PATCH_SIZE + """Qwen image patch size used to compute placeholder layout.""" - _internal_fields = frozenset({"image_cache_max"}) + image_temporal_patch_size: int = QWEN_VL_IMAGE_TEMPORAL_PATCH_SIZE + """Qwen temporal patch size used in the image layout fingerprint.""" + + image_merge_size: int = QWEN_VL_IMAGE_MERGE_SIZE + """Qwen spatial merge size used to compute image pad-token counts.""" + + image_min_pixels: int = QWEN_VL_IMAGE_MIN_PIXELS + """Minimum resized image area used by Qwen smart-resize layout math.""" + + image_max_pixels: int = QWEN_VL_IMAGE_MAX_PIXELS + """Maximum resized image area used by Qwen smart-resize layout math.""" + + _internal_fields = frozenset( + { + "image_patch_size", + "image_temporal_patch_size", + "image_merge_size", + "image_min_pixels", + "image_max_pixels", + } + ) class Qwen36RendererConfig(BaseRendererConfig): @@ -166,10 +191,30 @@ class Qwen36RendererConfig(BaseRendererConfig): add_vision_id: bool = False """See :class:`Qwen35RendererConfig.add_vision_id`.""" - image_cache_max: int = 256 - """See :class:`Qwen35RendererConfig.image_cache_max`.""" + image_patch_size: int = QWEN_VL_IMAGE_PATCH_SIZE + """See :class:`Qwen35RendererConfig.image_patch_size`.""" - _internal_fields = frozenset({"image_cache_max"}) + image_temporal_patch_size: int = QWEN_VL_IMAGE_TEMPORAL_PATCH_SIZE + """See :class:`Qwen35RendererConfig.image_temporal_patch_size`.""" + + image_merge_size: int = QWEN_VL_IMAGE_MERGE_SIZE + """See :class:`Qwen35RendererConfig.image_merge_size`.""" + + image_min_pixels: int = QWEN_VL_IMAGE_MIN_PIXELS + """See :class:`Qwen35RendererConfig.image_min_pixels`.""" + + image_max_pixels: int = QWEN_VL_IMAGE_MAX_PIXELS + """See :class:`Qwen35RendererConfig.image_max_pixels`.""" + + _internal_fields = frozenset( + { + "image_patch_size", + "image_temporal_patch_size", + "image_merge_size", + "image_min_pixels", + "image_max_pixels", + } + ) class Qwen3VLRendererConfig(BaseRendererConfig): @@ -180,10 +225,30 @@ class Qwen3VLRendererConfig(BaseRendererConfig): add_vision_id: bool = False """See :class:`Qwen35RendererConfig.add_vision_id`.""" - image_cache_max: int = 256 - """See :class:`Qwen35RendererConfig.image_cache_max`.""" + image_patch_size: int = QWEN_VL_IMAGE_PATCH_SIZE + """See :class:`Qwen35RendererConfig.image_patch_size`.""" - _internal_fields = frozenset({"image_cache_max"}) + image_temporal_patch_size: int = QWEN_VL_IMAGE_TEMPORAL_PATCH_SIZE + """See :class:`Qwen35RendererConfig.image_temporal_patch_size`.""" + + image_merge_size: int = QWEN_VL_IMAGE_MERGE_SIZE + """See :class:`Qwen35RendererConfig.image_merge_size`.""" + + image_min_pixels: int = QWEN_VL_IMAGE_MIN_PIXELS + """See :class:`Qwen35RendererConfig.image_min_pixels`.""" + + image_max_pixels: int = QWEN_VL_IMAGE_MAX_PIXELS + """See :class:`Qwen35RendererConfig.image_max_pixels`.""" + + _internal_fields = frozenset( + { + "image_patch_size", + "image_temporal_patch_size", + "image_merge_size", + "image_min_pixels", + "image_max_pixels", + } + ) class GLM5RendererConfig(BaseRendererConfig): @@ -295,7 +360,7 @@ class KimiK25RendererConfig(BaseRendererConfig): template's native variable name.""" image_cache_max: int = 256 - """See :class:`Qwen35RendererConfig.image_cache_max`.""" + """FIFO bound on Kimi's per-renderer image processor cache.""" _internal_fields = frozenset({"image_cache_max"}) diff --git a/renderers/mm_store.py b/renderers/mm_store.py new file mode 100644 index 0000000..8cabd14 --- /dev/null +++ b/renderers/mm_store.py @@ -0,0 +1,222 @@ +"""Run-scoped image asset helpers for multimodal rendering. + +The renderer stack does not ship processed multimodal features. Images are +written once into the run output tree and messages carry ``file://`` URLs to +those files. Renderers then emit lightweight image refs for vLLM only when the +engine needs to process an image. +""" + +from __future__ import annotations + +import base64 +import hashlib +import os +import re +import threading +from pathlib import Path + +RUN_OUTPUT_ROOT = Path("/data/outputs") + +IMAGE_OFFLOAD_DIR_ENV = "VF_RENDERER_IMAGE_OFFLOAD_DIR" +RUN_DIR_ENV = "PRIME_RL_RUN_DIR" +RUN_ID_ENV = "RUN_ID" + +IMAGE_ASSET_SUBDIR = Path("assets/images") +IMAGE_REF_PREFIX = "mmraw:v1" +IMAGE_REF_PAYLOAD_KEY = "_prime_rl_image_ref" +IMAGE_REF_PAYLOAD_VALUE = "raw_image" + +_SAFE_RUN_ID_RE = re.compile(r"^[A-Za-z0-9_.-]+$") +_SAFE_FINGERPRINT_RE = re.compile(r"^[a-f0-9]{16,64}$") +_SAFE_MM_HASH_RE = re.compile(r"^[a-f0-9]{16,128}$") +_SAFE_IMAGE_ID_RE = re.compile(r"^[A-Za-z0-9_.-]+$") +_SAFE_GRID_THW_RE = re.compile(r"^[0-9]+x[0-9]+x[0-9]+$") + +_MEDIA_TYPE_EXT = {"jpeg": ".jpg", "jpg": ".jpg", "png": ".png", "webp": ".webp", "gif": ".gif"} + + +def normalize_run_id(run_id: str) -> str: + """Return the canonical run id, without the directory's ``run_`` prefix.""" + value = run_id.strip() + if value.startswith("run_"): + value = value[len("run_") :] + if not value or not _SAFE_RUN_ID_RE.fullmatch(value): + raise ValueError(f"Invalid run id: {run_id!r}") + return value + + +def run_dir_name(run_id: str) -> str: + return f"run_{normalize_run_id(run_id)}" + + +def current_run_id() -> str: + """Best-effort run id for refs emitted by this process.""" + raw = os.getenv(RUN_ID_ENV, "").strip() + if raw: + return normalize_run_id(raw) + + run_dir = os.getenv(RUN_DIR_ENV, "").strip() + if run_dir: + return normalize_run_id(Path(run_dir).name) + + image_dir = os.getenv(IMAGE_OFFLOAD_DIR_ENV, "").strip() + if image_dir: + # Expected shape is /assets/images. If callers pass another + # explicit directory, the ref's run segment is only a stable label; the + # path resolver will use the explicit directory in every process. + path = Path(image_dir).resolve() + if path.name == "images" and path.parent.name == "assets": + try: + return normalize_run_id(path.parent.parent.name) + except ValueError: + pass + return "explicit" + + raise RuntimeError( + f"Set {IMAGE_OFFLOAD_DIR_ENV}, {RUN_DIR_ENV}, or {RUN_ID_ENV} before emitting image refs." + ) + + +def run_dir(run_id: str | None = None) -> Path: + """Resolve the run output directory. + + Resolution order: + 1. ``PRIME_RL_RUN_DIR`` as an exact run directory. + 2. ``RUN_ID`` or explicit ``run_id`` under ``/data/outputs/run_``. + """ + explicit = os.getenv(RUN_DIR_ENV, "").strip() + if explicit: + return Path(explicit).resolve() + + value = run_id or os.getenv(RUN_ID_ENV, "").strip() + if not value: + raise RuntimeError(f"Set {RUN_DIR_ENV} or {RUN_ID_ENV} before resolving a run directory.") + return (RUN_OUTPUT_ROOT / run_dir_name(value)).resolve() + + +def run_image_dir(run_id: str | None = None) -> Path: + """Resolve the directory for raw image assets for a run.""" + explicit = os.getenv(IMAGE_OFFLOAD_DIR_ENV, "").strip() + if explicit: + return Path(explicit).resolve() + return (run_dir(run_id) / IMAGE_ASSET_SUBDIR).resolve() + + +def image_asset_dir(run_id: str | None = None) -> Path: + """Alias for callers that already use the assets terminology.""" + return run_image_dir(run_id) + + +def _media_type_ext(media_type: str) -> str: + subtype = media_type.split("/", 1)[-1].split(";", 1)[0].strip().lower() + return _MEDIA_TYPE_EXT.get(subtype, ".img") + + +def offload_image_to_run_assets(url: object, image_dir: Path | None = None) -> tuple[str, int] | None: + """Decode a base64 data image into the run image assets directory. + + Returns ``(file_url, byte_count)`` when ``url`` was rewritten and ``None`` + for non-data-image values. Writes are content-addressed and atomic. + """ + if not isinstance(url, str) or not url.startswith("data:image/"): + return None + marker = ";base64," + if marker not in url: + return None + + header, b64 = url.split(marker, 1) + try: + raw = base64.b64decode(b64) + except Exception: + return None + + root = (image_dir or run_image_dir()).resolve() + root.mkdir(parents=True, exist_ok=True) + digest = hashlib.sha256(raw).hexdigest()[:16] + path = root / f"{digest}{_media_type_ext(header[len('data:') :])}" + if not path.exists(): + tmp = path.with_name(f".{path.name}.{os.getpid()}.{threading.get_ident()}.tmp") + tmp.write_bytes(raw) + os.replace(tmp, path) + else: + try: + path.touch() + except OSError: + pass + return path.as_uri(), len(raw) + + +def raw_image_path(*, run_id: str, raw_image_id: str) -> Path: + if not _SAFE_IMAGE_ID_RE.fullmatch(raw_image_id): + raise ValueError(f"Invalid raw image id: {raw_image_id!r}") + root = run_image_dir(run_id) + path = (root / raw_image_id).resolve() + if not path.is_relative_to(root): + raise ValueError(f"Raw image path escaped root: {path}") + return path + + +def image_layout_fingerprint( + *, + family: str, + patch_size: int, + merge_size: int, + temporal_patch_size: int, + min_pixels: int, + max_pixels: int, +) -> str: + raw = ( + f"image-layout:v1:{family}:{int(patch_size)}:{int(merge_size)}:" + f"{int(temporal_patch_size)}:{int(min_pixels)}:{int(max_pixels)}" + ).encode("utf-8") + return hashlib.sha256(raw).hexdigest()[:32] + + +def _grid_to_ref(grid_thw: object) -> str: + data = grid_thw.tolist() if hasattr(grid_thw, "tolist") else grid_thw + if isinstance(data, list) and data and isinstance(data[0], list): + data = data[0] + if not isinstance(data, (list, tuple)) or len(data) != 3: + raise ValueError(f"Invalid image grid_thw for image ref: {grid_thw!r}") + return "x".join(str(int(v)) for v in data) + + +def _grid_from_ref(value: str) -> list[int]: + if not _SAFE_GRID_THW_RE.fullmatch(value): + raise ValueError(f"Invalid image grid_thw ref segment: {value!r}") + return [int(v) for v in value.split("x")] + + +def image_ref( + *, + run_id: str, + fingerprint: str, + modality: str, + mm_hash: str, + raw_image_id: str, + grid_thw: object, +) -> str: + run_id = normalize_run_id(run_id) + if not _SAFE_FINGERPRINT_RE.fullmatch(fingerprint): + raise ValueError(f"Invalid image layout fingerprint: {fingerprint!r}") + if modality != "image": + raise ValueError(f"Unsupported image ref modality: {modality!r}") + if not _SAFE_MM_HASH_RE.fullmatch(mm_hash): + raise ValueError(f"Invalid image hash: {mm_hash!r}") + raw_image_path(run_id=run_id, raw_image_id=raw_image_id) + return f"{IMAGE_REF_PREFIX}:{run_id}:{fingerprint}:{modality}:{mm_hash}:{raw_image_id}:{_grid_to_ref(grid_thw)}" + + +def split_image_ref(ref: str) -> tuple[str, str, str, str, str, list[int]]: + parts = ref.split(":") + if parts[:2] != ["mmraw", "v1"] or len(parts) != 8: + raise ValueError(f"Invalid image ref shape: {ref!r}") + return normalize_run_id(parts[2]), parts[3], parts[4], parts[5], parts[6], _grid_from_ref(parts[7]) + + +# Backwards-compatible names for consumers that already speak the mmraw wire format. +MMRAW_PREFIX = IMAGE_REF_PREFIX +MM_RAW_PAYLOAD_KEY = IMAGE_REF_PAYLOAD_KEY +MM_RAW_PAYLOAD_VALUE = IMAGE_REF_PAYLOAD_VALUE +mmraw_ref = image_ref +split_mmraw_ref = split_image_ref diff --git a/renderers/qwen35.py b/renderers/qwen35.py index cdb8ee1..c0b76d6 100644 --- a/renderers/qwen35.py +++ b/renderers/qwen35.py @@ -7,9 +7,9 @@ processor class ``Qwen3VLProcessor``). When a user/tool message carries an ``ImagePart``, the renderer emits the same ``<|vision_start|>``+N×``<|image_pad|>`` +``<|vision_end|>`` expansion as the HF chat template (``N = -image_grid_thw.prod() // merge_size**2``) and ships processed pixel_values via -``RenderedTokens.multi_modal_data``. Text-only inputs take the original fast -path and remain byte-identical to ``apply_chat_template``. +image_grid_thw.prod() // merge_size**2``) using renderer-declared image layout +metadata. It does not call the HF image processor; vLLM receives run image refs +for images it must process. """ from __future__ import annotations @@ -35,10 +35,10 @@ from renderers.configs import Qwen35RendererConfig from renderers.parsing import parse_qwen35 from renderers.qwen3_vl import ( - _image_hash, _is_image_part, _is_video_part, - _load_pil_image, + materialize_image_refs, + qwen_image_item_for_render, ) # --------------------------------------------------------------------------- @@ -120,7 +120,7 @@ def __init__( processor: Any = None, ): self._tokenizer = tokenizer - self._processor = processor + _ = processor cfg = config or type(self)._config_cls() # ``enable_thinking=None`` defers to the model's known default (see # ``_ENABLE_THINKING_DEFAULTS``). Materialise here so downstream reads @@ -147,11 +147,6 @@ def __init__( self._image_pad = self._token_id("<|image_pad|>") self._video_pad = self._token_id("<|video_pad|>") - # Per-instance image-processor cache; see Qwen3VLRenderer for the - # rationale (FIFO-bounded; same image seen across rollouts / - # bridge re-renders). - self._image_cache: dict[str, tuple[Any, int]] = {} - @property def mm_token_type_id_map(self) -> dict[int, int]: """Token-id → modality marker (1 = image, 2 = video) used by the @@ -160,45 +155,10 @@ def mm_token_type_id_map(self) -> dict[int, int]: """ return {self._image_pad: 1, self._video_pad: 2} - def _get_processor(self): - if self._processor is not None: - return self._processor - from transformers import AutoProcessor - - name = getattr(self._tokenizer, "name_or_path", None) - if not name: - raise RuntimeError( - "Qwen35Renderer needs a processor to render image / video parts. " - "Pass `processor=AutoProcessor.from_pretrained(...)` to the " - "constructor, or load the tokenizer with a known name_or_path " - "so the processor can be auto-loaded." - ) - self._processor = AutoProcessor.from_pretrained(name) - return self._processor - - def _process_image(self, part: dict[str, Any]): - """Resolve, process, and characterize a single image part. - - Returns ``(pil, processor_out, num_image_tokens, image_hash)``. - Mirrors ``Qwen3VLRenderer._process_image``: hashes the loaded PIL, - consults ``self._image_cache``, runs the HF image processor on - miss, FIFO-evicts on overflow. - """ - pil = _load_pil_image(part) - h = _image_hash(pil) - cached = self._image_cache.get(h) - if cached is not None: - out, num_image_tokens = cached - return pil, out, num_image_tokens, h - proc = self._get_processor() - out = proc.image_processor(images=[pil], return_tensors="np") - grid_thw = out["image_grid_thw"][0] - merge_size = proc.image_processor.merge_size - num_image_tokens = int(grid_thw.prod()) // (merge_size * merge_size) - if len(self._image_cache) >= self.config.image_cache_max: - self._image_cache.pop(next(iter(self._image_cache))) - self._image_cache[h] = (out, num_image_tokens) - return pil, out, num_image_tokens, h + def materialize_image_refs( + self, mm_data: MultiModalData, messages: list[Message] + ) -> MultiModalData: + return materialize_image_refs(self, mm_data, messages) @staticmethod def _content_has_media(content: Any) -> bool: @@ -364,7 +324,7 @@ def emit_image(part: dict[str, Any], msg_idx: int) -> None: # image data, so they ARE body content (is_content=True); # the surrounding ``<|vision_start|>`` / ``<|vision_end|>`` # specials are template scaffold. - _, out, n, h = self._process_image(part) + n, h, mm_item = qwen_image_item_for_render(self, part) vision_counts["image"] += 1 if self.config.add_vision_id: emit_text( @@ -386,12 +346,7 @@ def emit_image(part: dict[str, Any], msg_idx: int) -> None: mm_placeholders.setdefault("image", []).append( PlaceholderRange(offset=offset, length=n) ) - mm_items.setdefault("image", []).append( - { - "pixel_values": out["pixel_values"], - "image_grid_thw": out["image_grid_thw"], - } - ) + mm_items.setdefault("image", []).append(mm_item) def emit_user_with_media(content_list: list[Any], msg_idx: int) -> None: """Emit a user message whose content list contains image parts. @@ -715,7 +670,7 @@ def emit_text_segments( content_mask.append(is_content) def emit_image(part: dict[str, Any], msg_idx: int = -1) -> None: - _, out, n, h = self._process_image(part) + n, h, mm_item = qwen_image_item_for_render(self, part) vision_counts["image"] += 1 if self.config.add_vision_id: emit_text(f"Picture {vision_counts['image']}: ", msg_idx) @@ -728,12 +683,7 @@ def emit_image(part: dict[str, Any], msg_idx: int = -1) -> None: new_placeholders.setdefault("image", []).append( PlaceholderRange(offset=offset, length=n) ) - new_items.setdefault("image", []).append( - { - "pixel_values": out["pixel_values"], - "image_grid_thw": out["image_grid_thw"], - } - ) + new_items.setdefault("image", []).append(mm_item) def emit_user_with_media(content_list: list[Any], msg_idx: int) -> None: emit_special(self._im_start, msg_idx) diff --git a/renderers/qwen3_vl.py b/renderers/qwen3_vl.py index 7b82d7e..9b865d0 100644 --- a/renderers/qwen3_vl.py +++ b/renderers/qwen3_vl.py @@ -6,14 +6,11 @@ for image inputs as the HF processor (``N = image_grid_thw.prod() // merge_size**2``). -Image data is shipped to the inference engine via -``RenderedTokens.multi_modal_data``: ``mm_placeholders`` records the -``(offset, length)`` span of each image's placeholder tokens in the -prompt, ``mm_items`` carries the per-image processor output -(``pixel_values``, ``image_grid_thw``), and ``mm_hashes`` carries a -stable identifier for cache lookup. The wire-format conversion to -vLLM's ``/inference/v1/generate`` ``features`` field lives in -``renderers.client``. +Image data is shipped to the inference engine via run image refs, not +processed image-processor payloads. ``RenderedTokens.multi_modal_data`` +records placeholder spans, stable image hashes, and Qwen layout metadata +(``image_grid_thw``) so vLLM can cache-match prior images and process new +image refs itself. BPE boundary discipline: text runs that the chat template emits contiguously (e.g. ``"user\\n" + content_text``) must be encoded as a @@ -30,8 +27,11 @@ import hashlib import io import json +import math +from dataclasses import dataclass, replace +from pathlib import Path from typing import Any -from urllib.parse import urlparse +from urllib.parse import unquote, urlparse from transformers.tokenization_utils import PreTrainedTokenizer @@ -48,6 +48,11 @@ trim_to_turn_close, ) from renderers.configs import Qwen3VLRendererConfig +from renderers.mm_store import ( + IMAGE_REF_PAYLOAD_KEY, + IMAGE_REF_PAYLOAD_VALUE, + image_layout_fingerprint, +) from renderers.parsing import parse_qwen3 _TOOLS_HEADER = ( @@ -163,6 +168,261 @@ def _image_hash(pil_image) -> str: return h.hexdigest()[:32] +@dataclass(frozen=True) +class QwenImageLayoutConfig: + patch_size: int + temporal_patch_size: int + merge_size: int + min_pixels: int + max_pixels: int + + +@dataclass(frozen=True) +class QwenImageLayoutDescriptor: + mm_hash: str + image_grid_thw: list[list[int]] + num_image_tokens: int + fingerprint: str + raw_uri: str | None = None + raw_image_id: str | None = None + + +def qwen_image_layout_config_for_renderer(renderer: Any) -> QwenImageLayoutConfig: + config = renderer.config + values = { + "patch_size": getattr(config, "image_patch_size", None), + "temporal_patch_size": getattr(config, "image_temporal_patch_size", None), + "merge_size": getattr(config, "image_merge_size", None), + "min_pixels": getattr(config, "image_min_pixels", None), + "max_pixels": getattr(config, "image_max_pixels", None), + } + missing = [name for name, value in values.items() if value is None] + if missing: + raise RuntimeError( + "Qwen image layout must be declared on the renderer config; missing " + + ", ".join(missing) + ) + return QwenImageLayoutConfig( + patch_size=int(values["patch_size"]), + temporal_patch_size=int(values["temporal_patch_size"]), + merge_size=int(values["merge_size"]), + min_pixels=int(values["min_pixels"]), + max_pixels=int(values["max_pixels"]), + ) + + +def _smart_resize( + height: int, + width: int, + *, + factor: int, + min_pixels: int, + max_pixels: int, +) -> tuple[int, int]: + """Qwen image resize math without materializing resized pixels.""" + if height <= 0 or width <= 0: + raise ValueError(f"image dimensions must be positive, got {height}x{width}") + if max(height, width) / min(height, width) > 200: + raise ValueError( + "absolute aspect ratio must be smaller than 200, got " + f"{max(height, width) / min(height, width)}" + ) + h_bar = round(height / factor) * factor + w_bar = round(width / factor) * factor + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = max(factor, math.floor(height / beta / factor) * factor) + w_bar = max(factor, math.floor(width / beta / factor) * factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = math.ceil(height * beta / factor) * factor + w_bar = math.ceil(width * beta / factor) * factor + return h_bar, w_bar + + +def _image_source(item: dict[str, Any]) -> Any: + if "image" in item: + return item["image"] + if "image_url" in item: + image_url = item.get("image_url") + return image_url.get("url") if isinstance(image_url, dict) else image_url + return item.get("url") or item.get("path") + + +def _file_path_from_source(source: Any) -> Path | None: + if not isinstance(source, str): + return None + parsed = urlparse(source) + if parsed.scheme == "file": + return Path(unquote(parsed.path)).resolve() + if parsed.scheme == "": + return Path(source).resolve() + return None + + +def _image_dimensions(source: Any) -> tuple[int, int]: + try: + from PIL import Image + except ImportError as exc: + raise RuntimeError( + "Pillow is required to read image dimensions for multimodal rendering." + ) from exc + + path = _file_path_from_source(source) + if path is not None: + with Image.open(path) as image: + return image.height, image.width + + image = _load_pil_image({"image": source}) + return image.height, image.width + + +def _image_content_hash(source: Any) -> str: + path = _file_path_from_source(source) + if path is not None: + return hashlib.sha256(path.read_bytes()).hexdigest()[:32] + return _image_hash(_load_pil_image({"image": source})) + + +def _raw_uri_and_id(source: Any) -> tuple[str | None, str | None]: + path = _file_path_from_source(source) + if path is None: + return None, None + return path.as_uri(), path.name + + +def describe_qwen_image_layout(renderer: Any, part: dict[str, Any]) -> QwenImageLayoutDescriptor: + """Return Qwen image layout metadata without invoking an image processor.""" + source = _image_source(part) + height, width = _image_dimensions(source) + layout = qwen_image_layout_config_for_renderer(renderer) + resized_h, resized_w = _smart_resize( + height, + width, + factor=layout.patch_size * layout.merge_size, + min_pixels=layout.min_pixels, + max_pixels=layout.max_pixels, + ) + grid_t = 1 + grid_h = resized_h // layout.patch_size + grid_w = resized_w // layout.patch_size + num_image_tokens = grid_t * grid_h * grid_w // (layout.merge_size * layout.merge_size) + fingerprint = image_layout_fingerprint( + family="qwen_vl", + patch_size=layout.patch_size, + merge_size=layout.merge_size, + temporal_patch_size=layout.temporal_patch_size, + min_pixels=layout.min_pixels, + max_pixels=layout.max_pixels, + ) + raw_uri, raw_image_id = _raw_uri_and_id(source) + return QwenImageLayoutDescriptor( + mm_hash=_image_content_hash(source), + image_grid_thw=[[grid_t, grid_h, grid_w]], + num_image_tokens=num_image_tokens, + fingerprint=fingerprint, + raw_uri=raw_uri, + raw_image_id=raw_image_id, + ) + + +def qwen_image_item_for_render(renderer: Any, part: dict[str, Any]) -> tuple[int, str, dict[str, Any]]: + desc = describe_qwen_image_layout(renderer, part) + item: dict[str, Any] = {"image_grid_thw": desc.image_grid_thw} + if desc.raw_uri is not None and desc.raw_image_id is not None: + item.update( + { + "raw_uri": desc.raw_uri, + "raw_image_id": desc.raw_image_id, + "image_layout_fingerprint": desc.fingerprint, + IMAGE_REF_PAYLOAD_KEY: IMAGE_REF_PAYLOAD_VALUE, + } + ) + return desc.num_image_tokens, desc.mm_hash, item + + +def _iter_image_parts(messages: list[Any]): + for msg in messages or []: + content = msg.get("content") if isinstance(msg, dict) else None + if not isinstance(content, list): + continue + for item in content: + if isinstance(item, dict) and _is_image_part(item): + yield item + + +def _grids_equal(a: Any, b: Any) -> bool: + if a is None or b is None: + return False + al = a.tolist() if hasattr(a, "tolist") else list(a) + bl = b.tolist() if hasattr(b, "tolist") else list(b) + return al == bl + + +def materialize_image_refs(renderer: Any, mm_data: MultiModalData, messages: list[Message]) -> MultiModalData: + """Attach run-image refs to every Qwen image descriptor that can be found.""" + image_items = mm_data.mm_items.get("image") or [] + if not image_items: + return mm_data + hashes = mm_data.mm_hashes.get("image") or [] + if len(hashes) != len(image_items): + raise ValueError( + "materialize_image_refs: mm_hashes/mm_items length mismatch " + f"({len(hashes)} vs {len(image_items)})" + ) + + missing = set(hashes) + resolved: dict[str, QwenImageLayoutDescriptor] = {} + for part in _iter_image_parts(messages): + if not missing: + break + desc = describe_qwen_image_layout(renderer, part) + if desc.mm_hash in missing: + resolved[desc.mm_hash] = desc + missing.discard(desc.mm_hash) + if missing: + raise ValueError( + f"materialize_image_refs: {len(missing)} image hash(es) not found in messages" + ) + + new_image_items: list[dict[str, Any]] = [] + for i, item in enumerate(image_items): + desc = resolved[hashes[i]] + if desc.raw_uri is None or desc.raw_image_id is None: + raise ValueError("materialize_image_refs requires file-backed image URLs") + item_grid = item.get("image_grid_thw") + if item_grid is not None and not _grids_equal(desc.image_grid_thw, item_grid): + raise ValueError( + "materialize_image_refs: reconstructed image_grid_thw " + f"{desc.image_grid_thw!r} != descriptor {item_grid!r}" + ) + new_item = { + k: v + for k, v in item.items() + if k + not in { + "raw_uri", + "raw_image_id", + "image_layout_fingerprint", + IMAGE_REF_PAYLOAD_KEY, + } + } + new_item.update( + { + "image_grid_thw": item_grid if item_grid is not None else desc.image_grid_thw, + "raw_uri": desc.raw_uri, + "raw_image_id": desc.raw_image_id, + "image_layout_fingerprint": desc.fingerprint, + IMAGE_REF_PAYLOAD_KEY: IMAGE_REF_PAYLOAD_VALUE, + } + ) + new_image_items.append(new_item) + + new_items = dict(mm_data.mm_items) + new_items["image"] = new_image_items + return replace(mm_data, mm_items=new_items) + + class _Emitter: """Token-stream builder with BPE-safe text buffering. @@ -296,11 +556,9 @@ class Qwen3VLRenderer: config: Typed renderer config (see :class:`renderers.Qwen3VLRendererConfig`). Defaults to a blank config with template defaults. - processor: Optional ``Qwen3VLProcessor``. Required when rendering - messages that contain image / video parts. If not supplied, - the renderer lazy-loads it via ``AutoProcessor.from_pretrained`` - keyed off ``tokenizer.name_or_path`` the first time a - multimodal part is seen. + processor: Deprecated and ignored. Image layout is declared by the + renderer config; the renderer never loads or calls an HF image + processor. ``preserve_all_thinking`` / ``preserve_thinking_between_tool_calls`` on the config are no-ops here — the chat template drops past @@ -315,7 +573,7 @@ def __init__( processor: Any = None, ): self._tokenizer = tokenizer - self._processor = processor + _ = processor self.config = config or Qwen3VLRendererConfig() self._im_start = self._token_id("<|im_start|>") @@ -331,16 +589,6 @@ def __init__( self._image_pad = self._token_id("<|image_pad|>") self._video_pad = self._token_id("<|video_pad|>") - # Per-instance image-processor cache. The HF image processor is the - # most expensive step on the renderer hot path (~tens of ms per - # image for typical grid_thw). The same image gets re-seen across - # ``rollouts_per_example`` rollouts of one example and (for - # multi-turn) across turn boundaries when the bridge re-renders - # rather than extends. Cache keyed by content hash — values are - # tuples of ``(processor_out, num_image_tokens)`` — bounded to - # avoid unbounded growth on long-lived pools. - self._image_cache: dict[str, tuple[Any, int]] = {} - def _token_id(self, token: str) -> int: tid = self._tokenizer.convert_tokens_to_ids(token) assert isinstance(tid, int) and tid != self._tokenizer.unk_token_id, ( @@ -366,22 +614,6 @@ def _encode(self, text: str) -> list[int]: return [] return self._tokenizer.encode(text, add_special_tokens=False) - def _get_processor(self): - if self._processor is not None: - return self._processor - from transformers import AutoProcessor - - name = getattr(self._tokenizer, "name_or_path", None) - if not name: - raise RuntimeError( - "Qwen3VLRenderer needs a processor to render image / video parts. " - "Pass `processor=AutoProcessor.from_pretrained(...)` to the " - "constructor, or load the tokenizer with a known name_or_path " - "so the processor can be auto-loaded." - ) - self._processor = AutoProcessor.from_pretrained(name) - return self._processor - @staticmethod def _render_text_content(content: Any) -> str: """Flatten a content list to a single text string, dropping media parts. @@ -410,30 +642,10 @@ def _render_text_content(content: Any) -> str: return "".join(parts) raise TypeError(f"Unexpected content type: {type(content)}") - def _process_image(self, part: dict[str, Any]): - """Resolve, process, and characterize a single image part. - - Returns ``(pil, processor_out, num_image_tokens, image_hash)``. - Hashes the loaded PIL first and consults ``self._image_cache``; - on hit the HF image-processor call is skipped entirely. - """ - pil = _load_pil_image(part) - h = _image_hash(pil) - cached = self._image_cache.get(h) - if cached is not None: - out, num_image_tokens = cached - return pil, out, num_image_tokens, h - proc = self._get_processor() - out = proc.image_processor(images=[pil], return_tensors="np") - grid_thw = out["image_grid_thw"][0] - merge_size = proc.image_processor.merge_size - num_image_tokens = int(grid_thw.prod()) // (merge_size * merge_size) - if len(self._image_cache) >= self.config.image_cache_max: - # FIFO eviction — Python dicts preserve insertion order, so - # ``next(iter(...))`` is the oldest key. - self._image_cache.pop(next(iter(self._image_cache))) - self._image_cache[h] = (out, num_image_tokens) - return pil, out, num_image_tokens, h + def materialize_image_refs( + self, mm_data: MultiModalData, messages: list[Message] + ) -> MultiModalData: + return materialize_image_refs(self, mm_data, messages) def render( self, @@ -464,7 +676,7 @@ def emit_image(part: dict[str, Any]) -> None: # image data, so they ARE body content (is_content=True); # the surrounding ``<|vision_start|>`` / ``<|vision_end|>`` # markers are renderer-emitted scaffold. - _, out, n, h = self._process_image(part) + n, h, mm_item = qwen_image_item_for_render(self, part) vision_counts["image"] += 1 if self.config.add_vision_id: em.text( @@ -481,12 +693,7 @@ def emit_image(part: dict[str, Any]) -> None: mm_placeholders.setdefault("image", []).append( PlaceholderRange(offset=offset, length=n) ) - mm_items.setdefault("image", []).append( - { - "pixel_values": out["pixel_values"], - "image_grid_thw": out["image_grid_thw"], - } - ) + mm_items.setdefault("image", []).append(mm_item) def render_media_content(content: Any) -> None: """Emit a user/tool content list with media handled inline. @@ -730,7 +937,7 @@ def bridge_to_next_turn( vision_counts = {"image": prev_image_count, "video": prev_video_count} def emit_image(part: dict[str, Any]) -> None: - _, out, n, h = self._process_image(part) + n, h, mm_item = qwen_image_item_for_render(self, part) vision_counts["image"] += 1 if self.config.add_vision_id: em.text( @@ -747,12 +954,7 @@ def emit_image(part: dict[str, Any]) -> None: new_placeholders.setdefault("image", []).append( PlaceholderRange(offset=offset, length=n) ) - new_items.setdefault("image", []).append( - { - "pixel_values": out["pixel_values"], - "image_grid_thw": out["image_grid_thw"], - } - ) + new_items.setdefault("image", []).append(mm_item) def render_media_content(content: Any) -> None: if isinstance(content, str): diff --git a/tests/test_client.py b/tests/test_client.py index 1cc1000..c1c7aaf 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,5 +1,6 @@ import asyncio import base64 +import hashlib import json import httpx @@ -101,6 +102,141 @@ async def post(self, path, *, cast_to=dict, body=None, options=None): ) +def test_run_image_dir_resolution_prefers_explicit_image_dir(tmp_path, monkeypatch): + from renderers.mm_store import run_image_dir + + image_dir = tmp_path / "custom-images" + monkeypatch.setenv("VF_RENDERER_IMAGE_OFFLOAD_DIR", str(image_dir)) + monkeypatch.setenv("PRIME_RL_RUN_DIR", str(tmp_path / "run_other")) + monkeypatch.setenv("RUN_ID", "other") + + assert run_image_dir() == image_dir.resolve() + + +def test_run_image_dir_resolution_owns_run_prefix(monkeypatch): + from renderers.mm_store import run_image_dir + + monkeypatch.delenv("VF_RENDERER_IMAGE_OFFLOAD_DIR", raising=False) + monkeypatch.delenv("PRIME_RL_RUN_DIR", raising=False) + monkeypatch.setenv("RUN_ID", "run_abc") + + assert run_image_dir().as_posix() == "/data/outputs/run_abc/assets/images" + + +class _TinyQwenTokenizer: + unk_token_id = -1 + _specials = { + "<|im_start|>": 1, + "<|im_end|>": 2, + "<|endoftext|>": 3, + "": 4, + "": 5, + "": 6, + "": 7, + "": 8, + "<|vision_start|>": 9, + "<|vision_end|>": 10, + "<|image_pad|>": 11, + "<|video_pad|>": 12, + } + + def convert_tokens_to_ids(self, token): + return self._specials.get(token, self.unk_token_id) + + def encode(self, text, add_special_tokens=False): + return [100 + ord(ch) % 50 for ch in text] + + +def test_qwen3_vl_render_emits_image_descriptor_without_processor(tmp_path): + pytest.importorskip("PIL") + from PIL import Image + from renderers.mm_store import IMAGE_REF_PAYLOAD_KEY, IMAGE_REF_PAYLOAD_VALUE + from renderers.qwen3_vl import Qwen3VLRenderer + + image_path = tmp_path / "image.png" + Image.new("RGB", (32, 32), color=(255, 0, 0)).save(image_path) + renderer = Qwen3VLRenderer(_TinyQwenTokenizer()) + + rendered = renderer.render( + [ + { + "role": "user", + "content": [{"type": "image_url", "image_url": {"url": image_path.as_uri()}}], + } + ], + add_generation_prompt=True, + ) + + item = rendered.multi_modal_data.mm_items["image"][0] + assert "pixel_values" not in item + assert item["image_grid_thw"] == [[1, 16, 16]] + assert item["raw_image_id"] == "image.png" + assert item[IMAGE_REF_PAYLOAD_KEY] == IMAGE_REF_PAYLOAD_VALUE + assert rendered.multi_modal_data.mm_placeholders["image"][0].length == 64 + + +def test_generate_materialize_all_image_refs_rehydrates_descriptor_slots(tmp_path, monkeypatch): + pytest.importorskip("PIL") + from PIL import Image + + from renderers.base import MultiModalData, ParsedResponse, PlaceholderRange + from renderers.mm_store import split_image_ref + from renderers.qwen3_vl import Qwen3VLRenderer + + class _RetryRenderer(Qwen3VLRenderer): + supports_tools = True + + def get_stop_token_ids(self): + return [99] + + def parse_response(self, completion_ids, *, tools=None): + return ParsedResponse(content="done") + + image_dir = tmp_path / "run_retry" / "assets" / "images" + image_dir.mkdir(parents=True) + image_path = image_dir / "image.png" + Image.new("RGB", (32, 32), color=(0, 255, 0)).save(image_path) + monkeypatch.setenv("VF_RENDERER_IMAGE_OFFLOAD_DIR", str(image_dir)) + monkeypatch.setenv("RUN_ID", "retry") + + mm_hash = hashlib.sha256(image_path.read_bytes()).hexdigest()[:32] + mm_data = MultiModalData( + mm_hashes={"image": [mm_hash]}, + mm_placeholders={"image": [PlaceholderRange(offset=5, length=64)]}, + mm_items={"image": [{"image_grid_thw": [[1, 16, 16]]}]}, + ) + renderer = _RetryRenderer(_TinyQwenTokenizer()) + client = _FakeClient() + + asyncio.run( + generate( + client=client, + renderer=renderer, + messages=[ + { + "role": "user", + "content": [{"type": "image_url", "image_url": {"url": image_path.as_uri()}}], + } + ], + model="qwen3-vl", + prompt_ids=list(range(20)), + multi_modal_data=mm_data, + sampling_params={"max_tokens": 4}, + materialize_all_image_refs=True, + ) + ) + + ref = client.calls[0]["body"]["features"]["kwargs_data"]["image"][0] + run_id, _fingerprint, modality, parsed_hash, raw_image_id, grid = split_image_ref(ref) + assert (run_id, modality, parsed_hash, raw_image_id, grid) == ( + "retry", + "image", + mm_hash, + "image.png", + [1, 16, 16], + ) + + def test_generate_builds_request_body_and_parses_response(): client = _FakeClient() renderer = _FakeRenderer() @@ -281,47 +417,63 @@ def test_generate_threads_prompt_attribution_through_prebuilt_prompt_path(): @pytest.mark.parametrize( - "model_id,renderer_class_path", + "renderer_class_path", [ - ("Qwen/Qwen3-VL-4B-Instruct", "renderers.qwen3_vl:Qwen3VLRenderer"), - ("Qwen/Qwen3.5-2B", "renderers.qwen35:Qwen35Renderer"), + "renderers.qwen3_vl:Qwen3VLRenderer", + "renderers.qwen35:Qwen35Renderer", ], ids=["qwen3_vl", "qwen35"], ) -def test_generate_serializes_multimodal_features_for_qwen_vl_family( - model_id, renderer_class_path +def test_generate_serializes_image_refs_for_qwen_vl_family( + tmp_path, monkeypatch, renderer_class_path ): """When the renderer emits ``MultiModalData``, ``generate`` translates it into vLLM's ``features`` payload (mm_hashes + mm_placeholders + - base64-encoded kwargs_data) and sticks it in the request body. Covers - every renderer routed through ``_build_qwen_vl_features``.""" + image-ref kwargs_data) and sticks it in the request body. Descriptor-only + images stay ``None`` so vLLM can resolve them from its cache.""" import importlib - pytest.importorskip("torch") - pytest.importorskip("vllm", reason="vllm needed for features serialization") - - import torch as _torch from renderers.base import ( MultiModalData, + ParsedResponse, PlaceholderRange, - load_tokenizer, + ) + from renderers.mm_store import ( + IMAGE_REF_PAYLOAD_KEY, + IMAGE_REF_PAYLOAD_VALUE, + image_layout_fingerprint, + split_image_ref, ) mod_name, cls_name = renderer_class_path.split(":") renderer_cls = getattr(importlib.import_module(mod_name), cls_name) - # Build a minimal real renderer so type dispatch in - # _build_mm_features hits the qwen branch. The tokenizer is only - # touched in __init__ to grab special-token ids; render() / etc. - # aren't called here because we pre-supply prompt_ids + mm_data. - tokenizer = load_tokenizer(model_id) - renderer = renderer_cls(tokenizer) + class _BareRenderer(renderer_cls): + supports_tools = True + + def get_stop_token_ids(self): + return [99] + + def parse_response(self, completion_ids, *, tools=None): + return ParsedResponse(content="done") + + renderer = _BareRenderer.__new__(_BareRenderer) + image_dir = tmp_path / "run_rawtest" / "assets" / "images" + image_dir.mkdir(parents=True) + (image_dir / "image.png").write_bytes(b"image-bytes") + monkeypatch.setenv("VF_RENDERER_IMAGE_OFFLOAD_DIR", str(image_dir)) + monkeypatch.setenv("RUN_ID", "rawtest") + fingerprint = image_layout_fingerprint( + family="qwen_vl", + patch_size=16, + merge_size=2, + temporal_patch_size=2, + min_pixels=65536, + max_pixels=16777216, + ) - # Two synthetic 1×2×2 images. Field factory expects pixel_values - # shape ``(sum_HW, embed_dim)`` and grid_thw shape ``(N, 3)``; the - # values themselves don't matter for the encoding round-trip. mm_data = MultiModalData( - mm_hashes={"image": ["aaa", "bbb"]}, + mm_hashes={"image": ["a" * 32, "b" * 32]}, mm_placeholders={ "image": [ PlaceholderRange(offset=5, length=1), @@ -331,19 +483,18 @@ def test_generate_serializes_multimodal_features_for_qwen_vl_family( mm_items={ "image": [ { - "pixel_values": _torch.zeros(4, 8, dtype=_torch.float32), - "image_grid_thw": _torch.tensor([[1, 2, 2]], dtype=_torch.int64), - }, - { - "pixel_values": _torch.zeros(4, 8, dtype=_torch.float32), - "image_grid_thw": _torch.tensor([[1, 2, 2]], dtype=_torch.int64), + "image_grid_thw": [[1, 2, 2]], + "raw_image_id": "image.png", + "image_layout_fingerprint": fingerprint, + IMAGE_REF_PAYLOAD_KEY: IMAGE_REF_PAYLOAD_VALUE, }, + {"image_grid_thw": [[1, 2, 2]]}, ], }, ) client = _FakeClient() - asyncio.run( + result = asyncio.run( generate( client=client, renderer=renderer, @@ -358,17 +509,21 @@ def test_generate_serializes_multimodal_features_for_qwen_vl_family( body = client.calls[0]["body"] assert "features" in body, "multimodal call should attach features" features = body["features"] - assert features["mm_hashes"] == {"image": ["aaa", "bbb"]} + assert features["mm_hashes"] == {"image": ["a" * 32, "b" * 32]} assert features["mm_placeholders"] == { "image": [{"offset": 5, "length": 1}, {"offset": 10, "length": 1}], } - assert "kwargs_data" in features - assert features["kwargs_data"] is not None - assert "image" in features["kwargs_data"] - assert len(features["kwargs_data"]["image"]) == 2 - # Items are base64 strings (encode_mm_kwargs_item output). - for item in features["kwargs_data"]["image"]: - assert isinstance(item, str) and len(item) > 0 + items = features["kwargs_data"]["image"] + assert items[1] is None + assert split_image_ref(items[0]) == ( + "rawtest", + fingerprint, + "image", + "a" * 32, + "image.png", + [1, 2, 2], + ) + assert "raw_image_id" not in result["multi_modal_data"].mm_items["image"][0] # --------------------------------------------------------------------------- From 4bc1766c024da0acdb8f0c2481631fe8b184d43c Mon Sep 17 00:00:00 2001 From: eligotts <78387377+eligotts@users.noreply.github.com> Date: Sat, 20 Jun 2026 07:41:16 +0000 Subject: [PATCH 2/2] Emit generic raw multimodal refs --- renderers/base.py | 6 +- renderers/client.py | 105 ++++++++---------- renderers/configs.py | 42 ++++++- renderers/kimi_k25.py | 249 +++++++++++++++++++++++++++++++++++++++--- renderers/mm_store.py | 151 ++++++++++++++++++------- renderers/qwen3_vl.py | 94 +++++++++++----- tests/test_client.py | 45 +++++--- 7 files changed, 522 insertions(+), 170 deletions(-) diff --git a/renderers/base.py b/renderers/base.py index 4dbb4f4..b6cc1ca 100644 --- a/renderers/base.py +++ b/renderers/base.py @@ -204,9 +204,9 @@ class MultiModalData: """Multimodal sidecar produced alongside the token stream. Renderer output is framework-agnostic: ``mm_items[modality][i]`` is a - plain descriptor dict (e.g. ``{"image_grid_thw": [[1, h, w]]}`` for - Qwen-VL images). Translation to engine-specific wire formats — vLLM image - refs, SGLang payloads, etc. — happens in the inference glue layer (see + plain raw descriptor envelope with a model-family key and an adapter-owned + payload. Translation to engine-specific wire formats — vLLM image refs, + SGLang payloads, etc. — happens in the inference glue layer (see ``renderers.client``). """ diff --git a/renderers/client.py b/renderers/client.py index de9df0b..df3dffb 100644 --- a/renderers/client.py +++ b/renderers/client.py @@ -172,12 +172,12 @@ async def generate( attribution (``is_content`` / ``sampled_mask`` / ``message_indices`` / ``message_roles``) into the result without re-rendering. - For multimodal renderers (e.g. ``Qwen3VLRenderer``), the call goes + For multimodal renderers, the call goes through ``renderer.render(...)`` to recover the ``multi_modal_data`` sidecar, then serializes it to vLLM's ``features`` schema (mm_hashes, - mm_placeholders, kwargs_data) before POSTing. Qwen-family image - ``kwargs_data`` slots are either ``None`` (cache lookup for a prior - image) or run image refs (new/current images that vLLM should process). + mm_placeholders, kwargs_data) before POSTing. Raw image ``kwargs_data`` + slots are either ``None`` (cache lookup for a prior image) or descriptor + refs (new/current images that vLLM should process). ``max_prompt_len`` controls the pre-flight overflow check. When the rendered prompt is strictly longer than the cap, the request is never @@ -364,7 +364,6 @@ def _descriptor_only_mm_data(mm_data: MultiModalData) -> MultiModalData: "pixel_values", "raw_uri", "raw_image_id", - "image_layout_fingerprint", IMAGE_REF_PAYLOAD_KEY, } } @@ -380,41 +379,18 @@ def _build_vllm_mm_features( """Serialize ``MultiModalData`` to vLLM's ``/inference/v1/generate`` features payload. vLLM's ``MultiModalFeatures`` carries three things: hashes (for cache - lookup), placeholder positions (so the engine knows where in the - token stream each item lives), and per-item payload selectors. For - Qwen images, payload selectors are ``None`` for cache-only prior images - or run image refs for images vLLM should process. - """ - from renderers.qwen3_vl import Qwen3VLRenderer - from renderers.qwen35 import Qwen35Renderer - - # Type dispatch only needs the renderer class. Pools expose - # ``renderer_cls`` as a snapshot attribute, so we don't have to check - # out a slot just to read ``type(r)``. - renderer_cls = ( - renderer.renderer_cls if isinstance(renderer, RendererPool) else type(renderer) - ) - - if issubclass(renderer_cls, (Qwen3VLRenderer, Qwen35Renderer)): - return _build_qwen_vl_image_ref_features(mm_data) - - raise NotImplementedError( - f"Multimodal serialization not implemented for {renderer_cls.__name__}. " - "Add a dispatch branch in renderers.client._build_vllm_mm_features." - ) - - -def _build_qwen_vl_image_ref_features(mm_data: MultiModalData) -> dict[str, Any]: - """vLLM features payload for Qwen-VL image refs. - - Returns ``None`` semantics live one level up — this helper assumes - the caller already verified ``mm_data`` is non-empty. + lookup), placeholder positions (so the engine knows where in the token + stream each item lives), and per-item payload selectors. Raw multimodal + descriptors use the common envelope emitted by renderers; family-specific + geometry stays inside the descriptor payload and is interpreted downstream + by prime-rl/vLLM adapters. """ from renderers.mm_store import ( IMAGE_REF_PAYLOAD_KEY, IMAGE_REF_PAYLOAD_VALUE, + RAW_MM_ITEM_KIND, current_run_id, - image_ref, + raw_mm_ref, ) out: dict[str, Any] = { @@ -423,45 +399,54 @@ def _build_qwen_vl_image_ref_features(mm_data: MultiModalData) -> dict[str, Any] "kwargs_data": {}, } - image_items = mm_data.mm_items.get("image") or [] - if image_items: - mm_hashes = list(mm_data.mm_hashes.get("image") or []) - placeholders = list(mm_data.mm_placeholders.get("image") or []) - if len(mm_hashes) != len(image_items) or len(placeholders) != len(image_items): + run_id = current_run_id() + for source_modality, items in mm_data.mm_items.items(): + if not items: + continue + mm_hashes = list(mm_data.mm_hashes.get(source_modality) or []) + placeholders = list(mm_data.mm_placeholders.get(source_modality) or []) + if len(mm_hashes) != len(items) or len(placeholders) != len(items): raise ValueError( - "Qwen-VL mm sidecar length mismatch: " - f"items={len(image_items)} hashes={len(mm_hashes)} placeholders={len(placeholders)}" + "Multimodal sidecar length mismatch: " + f"modality={source_modality} items={len(items)} " + f"hashes={len(mm_hashes)} placeholders={len(placeholders)}" ) - encoded: list[Any] = [None] * len(image_items) - run_id = current_run_id() - for idx, item in enumerate(image_items): + for idx, item in enumerate(items): + if item.get("kind") != RAW_MM_ITEM_KIND: + raise NotImplementedError( + "Multimodal serialization requires raw descriptor envelopes; " + f"got item keys {sorted(item)} for modality {source_modality!r}." + ) + feature_modality = item.get("vllm_modality") or source_modality + if not isinstance(feature_modality, str) or not feature_modality: + raise ValueError("raw multimodal item has invalid vllm_modality") + out["mm_hashes"].setdefault(feature_modality, []).append(mm_hashes[idx]) + out["mm_placeholders"].setdefault(feature_modality, []).append( + {"offset": placeholders[idx].offset, "length": placeholders[idx].length} + ) + out["kwargs_data"].setdefault(feature_modality, []).append(None) if item.get(IMAGE_REF_PAYLOAD_KEY) != IMAGE_REF_PAYLOAD_VALUE: continue raw_image_id = item.get("raw_image_id") - grid_thw = item.get("image_grid_thw") - fingerprint = item.get("image_layout_fingerprint") + family = item.get("family") + fingerprint = item.get("layout_fingerprint") if not isinstance(raw_image_id, str) or not raw_image_id: - raise ValueError("image-ref multimodal item is missing raw_image_id") - if grid_thw is None: - raise ValueError("image-ref multimodal item is missing image_grid_thw") + raise ValueError("raw multimodal item is missing raw_image_id") + if not isinstance(family, str) or not family: + raise ValueError("raw multimodal item is missing family") if not isinstance(fingerprint, str) or not fingerprint: - raise ValueError("image-ref multimodal item is missing image_layout_fingerprint") - encoded[idx] = image_ref( + raise ValueError("raw multimodal item is missing layout_fingerprint") + out["kwargs_data"][feature_modality][-1] = raw_mm_ref( run_id=run_id, + family=family, fingerprint=fingerprint, - modality="image", + modality=feature_modality, mm_hash=mm_hashes[idx], raw_image_id=raw_image_id, - grid_thw=grid_thw, + payload=item.get("payload") or {}, ) - out["kwargs_data"]["image"] = encoded - out["mm_hashes"]["image"] = mm_hashes - out["mm_placeholders"]["image"] = [ - {"offset": p.offset, "length": p.length} for p in placeholders - ] - if not any(item is not None for values in out["kwargs_data"].values() for item in values): out["kwargs_data"] = None diff --git a/renderers/configs.py b/renderers/configs.py index b07d97e..54ac342 100644 --- a/renderers/configs.py +++ b/renderers/configs.py @@ -31,6 +31,14 @@ QWEN_VL_IMAGE_MIN_PIXELS = 65536 QWEN_VL_IMAGE_MAX_PIXELS = 16777216 +KIMI_K25_IMAGE_PATCH_SIZE = 14 +KIMI_K25_IMAGE_MERGE_KERNEL_SIZE = 2 +KIMI_K25_IMAGE_IN_PATCH_LIMIT = 16384 +KIMI_K25_IMAGE_PATCH_LIMIT_ON_ONE_SIDE = 512 +KIMI_K25_IMAGE_FIXED_OUTPUT_TOKENS: int | None = None +KIMI_K25_IMAGE_MEAN = (0.5, 0.5, 0.5) +KIMI_K25_IMAGE_STD = (0.5, 0.5, 0.5) + class BaseRendererConfig(BaseConfig): """Shared fields and config for every renderer config variant. @@ -362,7 +370,39 @@ class KimiK25RendererConfig(BaseRendererConfig): image_cache_max: int = 256 """FIFO bound on Kimi's per-renderer image processor cache.""" - _internal_fields = frozenset({"image_cache_max"}) + image_patch_size: int = KIMI_K25_IMAGE_PATCH_SIZE + """Kimi MoonViT patch size used to compute raw image layout descriptors.""" + + image_merge_kernel_size: int = KIMI_K25_IMAGE_MERGE_KERNEL_SIZE + """Kimi spatial merge kernel used to compute output media-token layout.""" + + image_in_patch_limit: int = KIMI_K25_IMAGE_IN_PATCH_LIMIT + """Kimi NavIT input patch budget used by image resize layout math.""" + + image_patch_limit_on_one_side: int = KIMI_K25_IMAGE_PATCH_LIMIT_ON_ONE_SIDE + """Kimi per-side patch cap used by image resize layout math.""" + + image_fixed_output_tokens: int | None = KIMI_K25_IMAGE_FIXED_OUTPUT_TOKENS + """Optional fixed Kimi output token count. Current K2.5/K2.6 configs use ``None``.""" + + image_mean: tuple[float, float, float] = KIMI_K25_IMAGE_MEAN + """Kimi image normalization mean, included in processor fingerprints.""" + + image_std: tuple[float, float, float] = KIMI_K25_IMAGE_STD + """Kimi image normalization std, included in processor fingerprints.""" + + _internal_fields = frozenset( + { + "image_cache_max", + "image_patch_size", + "image_merge_kernel_size", + "image_in_patch_limit", + "image_patch_limit_on_one_side", + "image_fixed_output_tokens", + "image_mean", + "image_std", + } + ) class LagunaXS2RendererConfig(BaseRendererConfig): diff --git a/renderers/kimi_k25.py b/renderers/kimi_k25.py index bca4464..a9bbf4a 100644 --- a/renderers/kimi_k25.py +++ b/renderers/kimi_k25.py @@ -22,7 +22,9 @@ from __future__ import annotations import json +import math import re +from dataclasses import dataclass from typing import Any from transformers.tokenization_utils import PreTrainedTokenizer @@ -44,11 +46,16 @@ from renderers.configs import KimiK25RendererConfig from renderers.parsing import _reasoning_end_token_index, parse_kimi_k2_section from renderers.qwen3_vl import ( + _image_content_hash, + _image_dimensions, _image_hash, + _image_source, _is_image_part, _is_video_part, _load_pil_image, + _raw_uri_and_id, ) +from renderers.mm_store import image_layout_fingerprint, raw_mm_item # --------------------------------------------------------------------------- # Constants @@ -56,6 +63,9 @@ _DEFAULT_SYSTEM_PROMPT = "You are Kimi, an AI assistant created by Moonshot AI." +KIMI_K25_FAMILY = "kimi_k25" +KIMI_K25_VLLM_MODALITY = "vision_chunk" + # --------------------------------------------------------------------------- # TypeScript-style tool declaration # --------------------------------------------------------------------------- @@ -401,6 +411,218 @@ def _encode_tools_typescript(tools: list[ToolSpec]) -> str: return "# Tools\n\n## functions\nnamespace functions {\n" + functions_str + "\n}\n" +@dataclass(frozen=True) +class KimiImageLayoutConfig: + patch_size: int + merge_kernel_size: int + in_patch_limit: int + patch_limit_on_one_side: int + fixed_output_tokens: int | None + image_mean: tuple[float, ...] + image_std: tuple[float, ...] + + +@dataclass(frozen=True) +class KimiImageLayoutDescriptor: + mm_hash: str + grid_thws: list[list[int]] + num_media_tokens: int + fingerprint: str + raw_uri: str | None = None + raw_image_id: str | None = None + + +def kimi_image_layout_config_for_renderer(renderer: Any) -> KimiImageLayoutConfig: + config = renderer.config + values = { + "patch_size": getattr(config, "image_patch_size", None), + "merge_kernel_size": getattr(config, "image_merge_kernel_size", None), + "in_patch_limit": getattr(config, "image_in_patch_limit", None), + "patch_limit_on_one_side": getattr(config, "image_patch_limit_on_one_side", None), + "fixed_output_tokens": getattr(config, "image_fixed_output_tokens", None), + "image_mean": getattr(config, "image_mean", None), + "image_std": getattr(config, "image_std", None), + } + missing = [ + name + for name, value in values.items() + if value is None and name != "fixed_output_tokens" + ] + if missing: + raise RuntimeError( + "Kimi image layout must be declared on the renderer config; missing " + + ", ".join(missing) + ) + return KimiImageLayoutConfig( + patch_size=int(values["patch_size"]), + merge_kernel_size=int(values["merge_kernel_size"]), + in_patch_limit=int(values["in_patch_limit"]), + patch_limit_on_one_side=int(values["patch_limit_on_one_side"]), + fixed_output_tokens=( + None if values["fixed_output_tokens"] is None else int(values["fixed_output_tokens"]) + ), + image_mean=tuple(float(v) for v in values["image_mean"]), + image_std=tuple(float(v) for v in values["image_std"]), + ) + + +def _ceil_to_factor(value: int, factor: int) -> int: + return max(factor, math.ceil(value / factor) * factor) + + +def _kimi_resize_config(width: int, height: int, layout: KimiImageLayoutConfig) -> tuple[int, int, int]: + """Kimi MoonViT/NavIT image resize layout without materializing pixels.""" + if height <= 0 or width <= 0: + raise ValueError(f"image dimensions must be positive, got {height}x{width}") + patch_size = layout.patch_size + patch_limit_pixels = layout.patch_limit_on_one_side * patch_size + s1 = math.sqrt( + layout.in_patch_limit + / ( + max(1.0, width // patch_size) + * max(1.0, height // patch_size) + ) + ) + s2 = patch_limit_pixels / width + s3 = patch_limit_pixels / height + scale = min(1.0, s1, s2, s3) + resized_w = min(max(1, int(width * scale)), patch_limit_pixels) + resized_h = min(max(1, int(height * scale)), patch_limit_pixels) + + factor = layout.merge_kernel_size * patch_size + padded_w = _ceil_to_factor(resized_w, factor) + padded_h = _ceil_to_factor(resized_h, factor) + if layout.fixed_output_tokens is not None: + num_tokens = layout.fixed_output_tokens + else: + num_tokens = (padded_h // factor) * (padded_w // factor) + return padded_w, padded_h, int(num_tokens) + + +def describe_kimi_image_layout(renderer: Any, part: dict[str, Any]) -> KimiImageLayoutDescriptor: + source = _image_source(part) + height, width = _image_dimensions(source) + layout = kimi_image_layout_config_for_renderer(renderer) + padded_w, padded_h, num_media_tokens = _kimi_resize_config(width, height, layout) + grid_thws = [[1, padded_h // layout.patch_size, padded_w // layout.patch_size]] + fingerprint = image_layout_fingerprint( + family=KIMI_K25_FAMILY, + patch_size=layout.patch_size, + merge_kernel_size=layout.merge_kernel_size, + in_patch_limit=layout.in_patch_limit, + patch_limit_on_one_side=layout.patch_limit_on_one_side, + fixed_output_tokens=layout.fixed_output_tokens, + image_mean=list(layout.image_mean), + image_std=list(layout.image_std), + ) + raw_uri, raw_image_id = _raw_uri_and_id(source) + return KimiImageLayoutDescriptor( + mm_hash=_image_content_hash(source), + grid_thws=grid_thws, + num_media_tokens=num_media_tokens, + fingerprint=fingerprint, + raw_uri=raw_uri, + raw_image_id=raw_image_id, + ) + + +def kimi_image_item_for_render(renderer: Any, part: dict[str, Any]) -> tuple[int, str, dict[str, Any]]: + desc = describe_kimi_image_layout(renderer, part) + item = raw_mm_item( + modality="image", + family=KIMI_K25_FAMILY, + layout_fingerprint=desc.fingerprint, + payload={ + "grid_thws": desc.grid_thws, + "num_media_tokens": desc.num_media_tokens, + }, + raw_uri=desc.raw_uri, + raw_image_id=desc.raw_image_id, + vllm_modality=KIMI_K25_VLLM_MODALITY, + ) + return 1, desc.mm_hash, item + + +def _kimi_grid_from_item(item: dict[str, Any]) -> Any: + payload = item.get("payload") + if isinstance(payload, dict) and payload.get("grid_thws") is not None: + return payload["grid_thws"] + return item.get("grid_thws") + + +def _kimi_grids_equal(a: Any, b: Any) -> bool: + if a is None or b is None: + return False + al = a.tolist() if hasattr(a, "tolist") else a + bl = b.tolist() if hasattr(b, "tolist") else b + return al == bl + + +def materialize_kimi_image_refs(renderer: Any, mm_data: MultiModalData, messages: list[Message]) -> MultiModalData: + """Attach run-image refs to every Kimi image descriptor that can be found.""" + from dataclasses import replace + + image_items = mm_data.mm_items.get("image") or [] + if not image_items: + return mm_data + hashes = mm_data.mm_hashes.get("image") or [] + if len(hashes) != len(image_items): + raise ValueError( + "materialize_kimi_image_refs: mm_hashes/mm_items length mismatch " + f"({len(hashes)} vs {len(image_items)})" + ) + + missing = set(hashes) + resolved: dict[str, KimiImageLayoutDescriptor] = {} + for msg in messages or []: + content = msg.get("content") if isinstance(msg, dict) else None + if not isinstance(content, list): + continue + for part in content: + if not missing: + break + if not (isinstance(part, dict) and _is_image_part(part)): + continue + desc = describe_kimi_image_layout(renderer, part) + if desc.mm_hash in missing: + resolved[desc.mm_hash] = desc + missing.discard(desc.mm_hash) + if missing: + raise ValueError( + f"materialize_kimi_image_refs: {len(missing)} image hash(es) not found in messages" + ) + + new_image_items: list[dict[str, Any]] = [] + for i, item in enumerate(image_items): + desc = resolved[hashes[i]] + if desc.raw_uri is None or desc.raw_image_id is None: + raise ValueError("materialize_kimi_image_refs requires file-backed image URLs") + item_grid = _kimi_grid_from_item(item) + if item_grid is not None and not _kimi_grids_equal(desc.grid_thws, item_grid): + raise ValueError( + "materialize_kimi_image_refs: reconstructed grid_thws " + f"{desc.grid_thws!r} != descriptor {item_grid!r}" + ) + new_image_items.append( + raw_mm_item( + modality="image", + family=KIMI_K25_FAMILY, + layout_fingerprint=desc.fingerprint, + payload={ + "grid_thws": item_grid if item_grid is not None else desc.grid_thws, + "num_media_tokens": desc.num_media_tokens, + }, + raw_uri=desc.raw_uri, + raw_image_id=desc.raw_image_id, + vllm_modality=KIMI_K25_VLLM_MODALITY, + ) + ) + + new_items = dict(mm_data.mm_items) + new_items["image"] = new_image_items + return replace(mm_data, mm_items=new_items) + + # --------------------------------------------------------------------------- # Kimi K2.5 response parsing (mirrors K2 format, same token structure) # --------------------------------------------------------------------------- @@ -647,6 +869,11 @@ def mm_token_type_id_map(self) -> dict[int, int]: internally from ``pixel_values``.""" return {self._media_pad: 1} + def materialize_image_refs( + self, mm_data: MultiModalData, messages: list[Message] + ) -> MultiModalData: + return materialize_kimi_image_refs(self, mm_data, messages) + def _get_processor(self): if self._processor is not None: return self._processor @@ -815,7 +1042,7 @@ def emit_image( ``<|media_content|>``, ``<|media_end|>``, the trailing ``\\n``) are template-injected scaffold. """ - _, out, _num_patches, h = self._process_image(part) + _placeholder_len, h, mm_item = kimi_image_item_for_render(self, part) emit_special( self._media_begin, msg_idx, is_sampled=is_sampled, is_content=False ) @@ -838,16 +1065,7 @@ def emit_image( mm_placeholders.setdefault("image", []).append( PlaceholderRange(offset=offset, length=1) ) - # ``grid_thws`` (Kimi) is the per-image equivalent of Qwen-VL's - # ``image_grid_thw``. Ship under Kimi's native key so the - # orchestrator's generic ``torch.cat``-based packer routes it - # directly into the model's forward kwargs. - mm_items.setdefault("image", []).append( - { - "pixel_values": out["pixel_values"], - "grid_thws": out["grid_thws"], - } - ) + mm_items.setdefault("image", []).append(mm_item) # ── Tool declaration prefix (comes first) ── # K2.5/K2.6's tokenizer auto-computes ``tools_ts_str`` and threads @@ -1110,7 +1328,7 @@ def emit_image( is_sampled: bool = False, is_content: bool = False, ) -> None: - _, out, _num_patches, h = self._process_image(part) + _placeholder_len, h, mm_item = kimi_image_item_for_render(self, part) emit_special(self._media_begin, msg_idx) emit_text("image", msg_idx) emit_special(self._media_content, msg_idx) @@ -1124,12 +1342,7 @@ def emit_image( new_placeholders.setdefault("image", []).append( PlaceholderRange(offset=offset, length=1) ) - new_items.setdefault("image", []).append( - { - "pixel_values": out["pixel_values"], - "grid_thws": out["grid_thws"], - } - ) + new_items.setdefault("image", []).append(mm_item) # Bridge handles user/system/tool only (reject_assistant_in_extension # blocks assistants), so no hist/suffix split needed. diff --git a/renderers/mm_store.py b/renderers/mm_store.py index 8cabd14..a27ea8a 100644 --- a/renderers/mm_store.py +++ b/renderers/mm_store.py @@ -10,9 +10,11 @@ import base64 import hashlib +import json import os import re import threading +from dataclasses import dataclass from pathlib import Path RUN_OUTPUT_ROOT = Path("/data/outputs") @@ -22,15 +24,19 @@ RUN_ID_ENV = "RUN_ID" IMAGE_ASSET_SUBDIR = Path("assets/images") -IMAGE_REF_PREFIX = "mmraw:v1" +IMAGE_REF_PREFIX = "mmraw:v2" IMAGE_REF_PAYLOAD_KEY = "_prime_rl_image_ref" IMAGE_REF_PAYLOAD_VALUE = "raw_image" +RAW_MM_ITEM_KIND = "prime_raw_mm_item" +RAW_MM_ITEM_VERSION = 1 _SAFE_RUN_ID_RE = re.compile(r"^[A-Za-z0-9_.-]+$") +_SAFE_FAMILY_RE = re.compile(r"^[A-Za-z0-9_.-]+$") +_SAFE_MODALITY_RE = re.compile(r"^[A-Za-z0-9_.-]+$") _SAFE_FINGERPRINT_RE = re.compile(r"^[a-f0-9]{16,64}$") _SAFE_MM_HASH_RE = re.compile(r"^[a-f0-9]{16,128}$") _SAFE_IMAGE_ID_RE = re.compile(r"^[A-Za-z0-9_.-]+$") -_SAFE_GRID_THW_RE = re.compile(r"^[0-9]+x[0-9]+x[0-9]+$") +_SAFE_REF_PAYLOAD_RE = re.compile(r"^[A-Za-z0-9_-]*$") _MEDIA_TYPE_EXT = {"jpeg": ".jpg", "jpg": ".jpg", "png": ".png", "webp": ".webp", "gif": ".gif"} @@ -156,67 +162,132 @@ def raw_image_path(*, run_id: str, raw_image_id: str) -> Path: return path -def image_layout_fingerprint( - *, - family: str, - patch_size: int, - merge_size: int, - temporal_patch_size: int, - min_pixels: int, - max_pixels: int, -) -> str: - raw = ( - f"image-layout:v1:{family}:{int(patch_size)}:{int(merge_size)}:" - f"{int(temporal_patch_size)}:{int(min_pixels)}:{int(max_pixels)}" - ).encode("utf-8") - return hashlib.sha256(raw).hexdigest()[:32] +def _json_fingerprint_value(value: object) -> str: + return json.dumps(value, sort_keys=True, separators=(",", ":"), default=str) -def _grid_to_ref(grid_thw: object) -> str: - data = grid_thw.tolist() if hasattr(grid_thw, "tolist") else grid_thw - if isinstance(data, list) and data and isinstance(data[0], list): - data = data[0] - if not isinstance(data, (list, tuple)) or len(data) != 3: - raise ValueError(f"Invalid image grid_thw for image ref: {grid_thw!r}") - return "x".join(str(int(v)) for v in data) - - -def _grid_from_ref(value: str) -> list[int]: - if not _SAFE_GRID_THW_RE.fullmatch(value): - raise ValueError(f"Invalid image grid_thw ref segment: {value!r}") - return [int(v) for v in value.split("x")] +def image_layout_fingerprint(*, family: str, **values: object) -> str: + """Stable adapter-owned fingerprint for raw multimodal layout contracts.""" + if not _SAFE_FAMILY_RE.fullmatch(family): + raise ValueError(f"Invalid multimodal family: {family!r}") + encoded_values = ":".join(f"{key}={_json_fingerprint_value(values[key])}" for key in sorted(values)) + raw = f"image-layout:v1:{family}:{encoded_values}".encode("utf-8") + return hashlib.sha256(raw).hexdigest()[:32] -def image_ref( +def raw_mm_item( + *, + modality: str, + family: str, + layout_fingerprint: str, + payload: dict[str, object], + raw_uri: str | None = None, + raw_image_id: str | None = None, + vllm_modality: str | None = None, +) -> dict[str, object]: + """Build the JSON-safe raw multimodal descriptor envelope. + + ``payload`` is intentionally adapter-owned. Shared consumers may route by + ``family`` and validate the common envelope, but must not inspect adapter + payload keys. + """ + if not _SAFE_FAMILY_RE.fullmatch(family): + raise ValueError(f"Invalid multimodal family: {family!r}") + if not _SAFE_MODALITY_RE.fullmatch(modality): + raise ValueError(f"Invalid raw multimodal modality: {modality!r}") + if not _SAFE_FINGERPRINT_RE.fullmatch(layout_fingerprint): + raise ValueError(f"Invalid image layout fingerprint: {layout_fingerprint!r}") + out: dict[str, object] = { + "kind": RAW_MM_ITEM_KIND, + "version": RAW_MM_ITEM_VERSION, + "modality": modality, + "family": family, + "layout_fingerprint": layout_fingerprint, + "payload": payload, + } + if vllm_modality is not None: + out["vllm_modality"] = vllm_modality + if raw_uri is not None and raw_image_id is not None: + out.update( + { + "raw_uri": raw_uri, + "raw_image_id": raw_image_id, + IMAGE_REF_PAYLOAD_KEY: IMAGE_REF_PAYLOAD_VALUE, + } + ) + return out + + +@dataclass(frozen=True) +class RawMMRef: + run_id: str + family: str + fingerprint: str + modality: str + mm_hash: str + raw_image_id: str + payload: dict[str, object] + + +def raw_mm_ref( *, run_id: str, + family: str, fingerprint: str, modality: str, mm_hash: str, raw_image_id: str, - grid_thw: object, + payload: dict[str, object] | None = None, ) -> str: + """Generic raw multimodal asset ref. + + Adapter-owned details stay in the descriptor payload so refs can serve + future families without baking shape names into the wire id. + """ run_id = normalize_run_id(run_id) + if not _SAFE_FAMILY_RE.fullmatch(family): + raise ValueError(f"Invalid multimodal family: {family!r}") if not _SAFE_FINGERPRINT_RE.fullmatch(fingerprint): raise ValueError(f"Invalid image layout fingerprint: {fingerprint!r}") - if modality != "image": - raise ValueError(f"Unsupported image ref modality: {modality!r}") + if not _SAFE_MODALITY_RE.fullmatch(modality): + raise ValueError(f"Invalid raw multimodal modality: {modality!r}") if not _SAFE_MM_HASH_RE.fullmatch(mm_hash): raise ValueError(f"Invalid image hash: {mm_hash!r}") raw_image_path(run_id=run_id, raw_image_id=raw_image_id) - return f"{IMAGE_REF_PREFIX}:{run_id}:{fingerprint}:{modality}:{mm_hash}:{raw_image_id}:{_grid_to_ref(grid_thw)}" + encoded_payload = base64.urlsafe_b64encode( + json.dumps(payload or {}, sort_keys=True, separators=(",", ":")).encode("utf-8") + ).decode("ascii").rstrip("=") + return ( + f"{IMAGE_REF_PREFIX}:{run_id}:{family}:{fingerprint}:" + f"{modality}:{mm_hash}:{raw_image_id}:{encoded_payload}" + ) -def split_image_ref(ref: str) -> tuple[str, str, str, str, str, list[int]]: +def split_raw_mm_ref(ref: str) -> RawMMRef: parts = ref.split(":") - if parts[:2] != ["mmraw", "v1"] or len(parts) != 8: - raise ValueError(f"Invalid image ref shape: {ref!r}") - return normalize_run_id(parts[2]), parts[3], parts[4], parts[5], parts[6], _grid_from_ref(parts[7]) + if parts[:2] != ["mmraw", "v2"] or len(parts) != 9: + raise ValueError(f"Invalid raw multimodal ref shape: {ref!r}") + run_id, family, fingerprint, modality, mm_hash, raw_image_id, encoded_payload = parts[2:] + if not _SAFE_REF_PAYLOAD_RE.fullmatch(encoded_payload): + raise ValueError("Invalid raw multimodal ref payload segment") + padded = encoded_payload + "=" * (-len(encoded_payload) % 4) + payload = json.loads(base64.urlsafe_b64decode(padded.encode("ascii")).decode("utf-8")) + if not isinstance(payload, dict): + raise ValueError("Raw multimodal ref payload must decode to a dict") + return RawMMRef( + run_id=normalize_run_id(run_id), + family=family, + fingerprint=fingerprint, + modality=modality, + mm_hash=mm_hash, + raw_image_id=raw_image_id, + payload=payload, + ) # Backwards-compatible names for consumers that already speak the mmraw wire format. MMRAW_PREFIX = IMAGE_REF_PREFIX MM_RAW_PAYLOAD_KEY = IMAGE_REF_PAYLOAD_KEY MM_RAW_PAYLOAD_VALUE = IMAGE_REF_PAYLOAD_VALUE -mmraw_ref = image_ref -split_mmraw_ref = split_image_ref +mmraw_ref = raw_mm_ref +split_mmraw_ref = split_raw_mm_ref diff --git a/renderers/qwen3_vl.py b/renderers/qwen3_vl.py index 9b865d0..1cde900 100644 --- a/renderers/qwen3_vl.py +++ b/renderers/qwen3_vl.py @@ -52,6 +52,7 @@ IMAGE_REF_PAYLOAD_KEY, IMAGE_REF_PAYLOAD_VALUE, image_layout_fingerprint, + raw_mm_item, ) from renderers.parsing import parse_qwen3 @@ -328,16 +329,14 @@ def describe_qwen_image_layout(renderer: Any, part: dict[str, Any]) -> QwenImage def qwen_image_item_for_render(renderer: Any, part: dict[str, Any]) -> tuple[int, str, dict[str, Any]]: desc = describe_qwen_image_layout(renderer, part) - item: dict[str, Any] = {"image_grid_thw": desc.image_grid_thw} - if desc.raw_uri is not None and desc.raw_image_id is not None: - item.update( - { - "raw_uri": desc.raw_uri, - "raw_image_id": desc.raw_image_id, - "image_layout_fingerprint": desc.fingerprint, - IMAGE_REF_PAYLOAD_KEY: IMAGE_REF_PAYLOAD_VALUE, - } - ) + item = raw_mm_item( + modality="image", + family="qwen_vl", + layout_fingerprint=desc.fingerprint, + payload={"image_grid_thw": desc.image_grid_thw}, + raw_uri=desc.raw_uri, + raw_image_id=desc.raw_image_id, + ) return desc.num_image_tokens, desc.mm_hash, item @@ -359,6 +358,54 @@ def _grids_equal(a: Any, b: Any) -> bool: return al == bl +def _qwen_grid_from_item(item: dict[str, Any]) -> Any: + payload = item.get("payload") + if isinstance(payload, dict) and payload.get("image_grid_thw") is not None: + return payload["image_grid_thw"] + return item.get("image_grid_thw") + + +def _qwen_item_with_grid_and_ref( + item: dict[str, Any], + *, + image_grid_thw: Any, + fingerprint: str, + raw_uri: str, + raw_image_id: str, +) -> dict[str, Any]: + new_item = { + k: v + for k, v in item.items() + if k + not in { + "raw_uri", + "raw_image_id", + "image_layout_fingerprint", + IMAGE_REF_PAYLOAD_KEY, + } + } + if new_item.get("family") == "qwen_vl" and isinstance(new_item.get("payload"), dict): + payload = dict(new_item["payload"]) + payload["image_grid_thw"] = image_grid_thw + new_item["payload"] = payload + new_item["layout_fingerprint"] = fingerprint + else: + new_item = raw_mm_item( + modality="image", + family="qwen_vl", + layout_fingerprint=fingerprint, + payload={"image_grid_thw": image_grid_thw}, + ) + new_item.update( + { + "raw_uri": raw_uri, + "raw_image_id": raw_image_id, + IMAGE_REF_PAYLOAD_KEY: IMAGE_REF_PAYLOAD_VALUE, + } + ) + return new_item + + def materialize_image_refs(renderer: Any, mm_data: MultiModalData, messages: list[Message]) -> MultiModalData: """Attach run-image refs to every Qwen image descriptor that can be found.""" image_items = mm_data.mm_items.get("image") or [] @@ -390,31 +437,18 @@ def materialize_image_refs(renderer: Any, mm_data: MultiModalData, messages: lis desc = resolved[hashes[i]] if desc.raw_uri is None or desc.raw_image_id is None: raise ValueError("materialize_image_refs requires file-backed image URLs") - item_grid = item.get("image_grid_thw") + item_grid = _qwen_grid_from_item(item) if item_grid is not None and not _grids_equal(desc.image_grid_thw, item_grid): raise ValueError( "materialize_image_refs: reconstructed image_grid_thw " f"{desc.image_grid_thw!r} != descriptor {item_grid!r}" ) - new_item = { - k: v - for k, v in item.items() - if k - not in { - "raw_uri", - "raw_image_id", - "image_layout_fingerprint", - IMAGE_REF_PAYLOAD_KEY, - } - } - new_item.update( - { - "image_grid_thw": item_grid if item_grid is not None else desc.image_grid_thw, - "raw_uri": desc.raw_uri, - "raw_image_id": desc.raw_image_id, - "image_layout_fingerprint": desc.fingerprint, - IMAGE_REF_PAYLOAD_KEY: IMAGE_REF_PAYLOAD_VALUE, - } + new_item = _qwen_item_with_grid_and_ref( + item, + image_grid_thw=item_grid if item_grid is not None else desc.image_grid_thw, + fingerprint=desc.fingerprint, + raw_uri=desc.raw_uri, + raw_image_id=desc.raw_image_id, ) new_image_items.append(new_item) diff --git a/tests/test_client.py b/tests/test_client.py index c1c7aaf..ac0ec16 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -169,7 +169,8 @@ def test_qwen3_vl_render_emits_image_descriptor_without_processor(tmp_path): item = rendered.multi_modal_data.mm_items["image"][0] assert "pixel_values" not in item - assert item["image_grid_thw"] == [[1, 16, 16]] + assert item["family"] == "qwen_vl" + assert item["payload"]["image_grid_thw"] == [[1, 16, 16]] assert item["raw_image_id"] == "image.png" assert item[IMAGE_REF_PAYLOAD_KEY] == IMAGE_REF_PAYLOAD_VALUE assert rendered.multi_modal_data.mm_placeholders["image"][0].length == 64 @@ -180,7 +181,7 @@ def test_generate_materialize_all_image_refs_rehydrates_descriptor_slots(tmp_pat from PIL import Image from renderers.base import MultiModalData, ParsedResponse, PlaceholderRange - from renderers.mm_store import split_image_ref + from renderers.mm_store import split_raw_mm_ref from renderers.qwen3_vl import Qwen3VLRenderer class _RetryRenderer(Qwen3VLRenderer): @@ -226,14 +227,14 @@ def parse_response(self, completion_ids, *, tools=None): ) ) - ref = client.calls[0]["body"]["features"]["kwargs_data"]["image"][0] - run_id, _fingerprint, modality, parsed_hash, raw_image_id, grid = split_image_ref(ref) - assert (run_id, modality, parsed_hash, raw_image_id, grid) == ( + ref_item = client.calls[0]["body"]["features"]["kwargs_data"]["image"][0] + ref = split_raw_mm_ref(ref_item) + assert ref.payload["image_grid_thw"] == [[1, 16, 16]] + assert (ref.run_id, ref.modality, ref.mm_hash, ref.raw_image_id) == ( "retry", "image", mm_hash, "image.png", - [1, 16, 16], ) @@ -439,10 +440,9 @@ def test_generate_serializes_image_refs_for_qwen_vl_family( PlaceholderRange, ) from renderers.mm_store import ( - IMAGE_REF_PAYLOAD_KEY, - IMAGE_REF_PAYLOAD_VALUE, image_layout_fingerprint, - split_image_ref, + raw_mm_item, + split_raw_mm_ref, ) mod_name, cls_name = renderer_class_path.split(":") @@ -482,13 +482,20 @@ def parse_response(self, completion_ids, *, tools=None): }, mm_items={ "image": [ - { - "image_grid_thw": [[1, 2, 2]], - "raw_image_id": "image.png", - "image_layout_fingerprint": fingerprint, - IMAGE_REF_PAYLOAD_KEY: IMAGE_REF_PAYLOAD_VALUE, - }, - {"image_grid_thw": [[1, 2, 2]]}, + raw_mm_item( + modality="image", + family="qwen_vl", + layout_fingerprint=fingerprint, + payload={"image_grid_thw": [[1, 2, 2]]}, + raw_uri=(image_dir / "image.png").as_uri(), + raw_image_id="image.png", + ), + raw_mm_item( + modality="image", + family="qwen_vl", + layout_fingerprint=fingerprint, + payload={"image_grid_thw": [[1, 2, 2]]}, + ), ], }, ) @@ -515,13 +522,15 @@ def parse_response(self, completion_ids, *, tools=None): } items = features["kwargs_data"]["image"] assert items[1] is None - assert split_image_ref(items[0]) == ( + ref = split_raw_mm_ref(items[0]) + assert ref.payload == {"image_grid_thw": [[1, 2, 2]]} + assert (ref.run_id, ref.family, ref.fingerprint, ref.modality, ref.mm_hash, ref.raw_image_id) == ( "rawtest", + "qwen_vl", fingerprint, "image", "a" * 32, "image.png", - [1, 2, 2], ) assert "raw_image_id" not in result["multi_modal_data"].mm_items["image"][0]