Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions renderers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,11 +204,10 @@ class MultiModalData:
"""Multimodal sidecar produced alongside the token stream.

Renderer output is framework-agnostic: ``mm_items[modality][i]`` is a
plain ``dict`` mirroring the per-item output of a HuggingFace processor
(e.g. ``{"pixel_values": Tensor, "image_grid_thw": Tensor}`` for
Qwen3-VL images). Translation to engine-specific wire formats — vLLM's
``MultiModalKwargsItem``, SGLang's payload, etc. — happens in the
inference glue layer (see ``renderers.client``).
plain raw descriptor envelope with a model-family key and an adapter-owned
payload. Translation to engine-specific wire formats — vLLM image refs,
SGLang payloads, etc. — happens in the inference glue layer (see
``renderers.client``).
"""

mm_hashes: dict[str, list[str]] = field(default_factory=dict)
Expand Down Expand Up @@ -761,8 +760,8 @@ def bridge_to_next_turn(
Text-only renderers return :class:`RenderedTokens` with
``multi_modal_data=None``. Multimodal renderers (see
:class:`MultimodalRenderer`) populate ``multi_modal_data`` so
the caller can recover placeholder offsets + per-item processed
tensors for the new full prompt; they also accept a
the caller can recover placeholder offsets + per-item image
descriptors for the new full prompt; they also accept a
``previous_multi_modal_data`` kwarg via the
:class:`MultimodalRenderer` Protocol override.

Expand Down Expand Up @@ -818,8 +817,8 @@ def bridge_to_next_turn(
the combined token sequence and silently falls back to
hash-cache lookup (or errors)
- returns :class:`RenderedTokens` (not ``list[int]``) so the
caller can recover the placeholder offsets + per-item
processed tensors for the new full prompt
caller can recover the placeholder offsets + per-item image
descriptors for the new full prompt
"""
...

Expand Down Expand Up @@ -967,6 +966,10 @@ def bridge_to_next_turn(self, *args: Any, **kwargs: Any) -> "RenderedTokens | No
with self.checkout() as r:
return r.bridge_to_next_turn(*args, **kwargs)

def materialize_image_refs(self, *args: Any, **kwargs: Any) -> "MultiModalData":
with self.checkout() as r:
return r.materialize_image_refs(*args, **kwargs)

# ``mm_token_type_id_map`` (the MultimodalRenderer protocol attribute)
# is set in ``__init__`` only for pools wrapping multimodal renderers;
# see the comment there for why this isn't a class-level property.
Expand Down
207 changes: 107 additions & 100 deletions renderers/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import json
import logging
from collections.abc import Mapping
from dataclasses import replace
from typing import Any, cast

import httpx
Expand Down Expand Up @@ -156,6 +157,7 @@ async def generate(
priority: int | None = None,
extra_headers: dict[str, str] | None = None,
max_prompt_len: int | None = None,
materialize_all_image_refs: bool = False,
) -> dict[str, Any]:
"""Tokenize messages, call vLLM /inference/v1/generate, parse the response.

Expand All @@ -170,11 +172,12 @@ async def generate(
attribution (``is_content`` / ``sampled_mask`` / ``message_indices`` /
``message_roles``) into the result without re-rendering.

For multimodal renderers (e.g. ``Qwen3VLRenderer``), the call goes
For multimodal renderers, the call goes
through ``renderer.render(...)`` to recover the ``multi_modal_data``
sidecar, then serializes it to vLLM's ``features`` schema (mm_hashes,
mm_placeholders, kwargs_data) before POSTing. The serializer imports
``vllm.*`` lazily so text-only consumers never pay for the import.
mm_placeholders, kwargs_data) before POSTing. Raw image ``kwargs_data``
slots are either ``None`` (cache lookup for a prior image) or descriptor
refs (new/current images that vLLM should process).

``max_prompt_len`` controls the pre-flight overflow check. When the
rendered prompt is strictly longer than the cap, the request is never
Expand Down Expand Up @@ -248,11 +251,23 @@ def _prepare():
"token_ids": prompt_ids,
"sampling_params": sp,
}
features = (
_build_mm_features(renderer, mm_data)
if mm_data and not mm_data.is_empty()
else None
)

def _features_and_descriptor_mm() -> tuple[dict[str, Any] | None, MultiModalData | None]:
if mm_data is None or mm_data.is_empty():
return None, mm_data
build_mm = mm_data
if materialize_all_image_refs:
materialize = getattr(renderer, "materialize_image_refs", None)
if materialize is None:
raise NotImplementedError(
f"{type(renderer).__name__} cannot materialize image refs for retry."
)
build_mm = materialize(mm_data, messages)
return _build_vllm_mm_features(renderer, build_mm), _descriptor_only_mm_data(mm_data)

features, out_mm_data = await _maybe_offload(renderer, _features_and_descriptor_mm)
if prompt_attr is not None and getattr(prompt_attr, "multi_modal_data", None) is not None:
prompt_attr = replace(prompt_attr, multi_modal_data=out_mm_data)
if features is not None:
body["features"] = features
if cache_salt is not None:
Expand Down Expand Up @@ -322,7 +337,7 @@ def _prepare():
# The mm sidecar consumed on the request side, surfaced back so
# callers can persist it on the trajectory step for downstream
# multi-turn bridging and training-sample construction.
"multi_modal_data": mm_data,
"multi_modal_data": out_mm_data,
# The renderer's per-token attribution for the prompt — either
# the RenderedTokens computed here via renderer.render(...) or
# the one threaded in by the caller alongside prompt_ids (the
Expand All @@ -334,113 +349,105 @@ def _prepare():
}


def _build_mm_features(
def _descriptor_only_mm_data(mm_data: MultiModalData) -> MultiModalData:
"""Drop one-request image-ref fields before callers persist mm_data."""
from renderers.mm_store import IMAGE_REF_PAYLOAD_KEY

new_items: dict[str, list[dict[str, Any]]] = {}
for modality, items in mm_data.mm_items.items():
new_items[modality] = [
{
key: value
for key, value in item.items()
if key
not in {
"pixel_values",
"raw_uri",
"raw_image_id",
IMAGE_REF_PAYLOAD_KEY,
}
}
for item in items
]
return replace(mm_data, mm_items=new_items)


def _build_vllm_mm_features(
renderer: Renderer | RendererPool,
mm_data: MultiModalData,
) -> dict[str, Any] | None:
"""Serialize ``MultiModalData`` to vLLM's ``/inference/v1/generate`` features payload.

vLLM's ``MultiModalFeatures`` carries three things: hashes (for cache
lookup), placeholder positions (so the engine knows where in the
token stream each item lives), and per-item ``MultiModalKwargsItem``
base64-encoded. The encoding requires vLLM-side type info — what
fields belong to each modality, how they batch — and is currently
model-family specific. For now we dispatch on the renderer class;
extend the dispatch table as more multimodal renderers land.

NOTE — future engine pluggability: this encoder is vLLM 0.20-specific
(uses ``vllm.multimodal.inputs.MultiModalKwargsItems``,
``vllm.entrypoints.serve.disagg.mm_serde.encode_mm_kwargs_item``, and
``_create_qwen2vl_field_factory``). When a second inference engine
arrives (SGLang, MAX, ...) the renderer client should be parameterized
on engine: either (a) move the encoder onto the renderer as
``encode_mm_for_<engine>(mm_data)`` methods, or (b) accept an
``Encoder`` strategy at the ``generate(...)`` call site. The data type
(``MultiModalData``) is already framework-agnostic and does not need
to change. Don't pre-build the abstraction with one engine in tree.
lookup), placeholder positions (so the engine knows where in the token
stream each item lives), and per-item payload selectors. Raw multimodal
descriptors use the common envelope emitted by renderers; family-specific
geometry stays inside the descriptor payload and is interpreted downstream
by prime-rl/vLLM adapters.
"""
from renderers.qwen3_vl import Qwen3VLRenderer
from renderers.qwen35 import Qwen35Renderer

# Type dispatch only needs the renderer class. Pools expose
# ``renderer_cls`` as a snapshot attribute, so we don't have to check
# out a slot just to read ``type(r)``.
renderer_cls = (
renderer.renderer_cls if isinstance(renderer, RendererPool) else type(renderer)
)

# Qwen3-VL and Qwen3.5 both ship ``pixel_values`` + ``image_grid_thw``
# via the shared Qwen2-VL field factory. ``spatial_merge_size=2`` is
# the family default and matches every Qwen-VL processor in tree.
if issubclass(renderer_cls, (Qwen3VLRenderer, Qwen35Renderer)):
return _build_qwen_vl_features(mm_data, spatial_merge_size=2)

raise NotImplementedError(
f"Multimodal serialization not implemented for {renderer_cls.__name__}. "
"Add a dispatch branch in renderers.client._build_mm_features."
from renderers.mm_store import (
IMAGE_REF_PAYLOAD_KEY,
IMAGE_REF_PAYLOAD_VALUE,
RAW_MM_ITEM_KIND,
current_run_id,
raw_mm_ref,
)


def _build_qwen_vl_features(
mm_data: MultiModalData, *, spatial_merge_size: int
) -> dict[str, Any]:
"""vLLM features payload for the Qwen-VL family (Qwen2-VL / Qwen3-VL).

Stacks per-image processor outputs back into a batched ``BatchFeature``,
runs the Qwen2-VL field factory (shared across the family), wraps as
``MultiModalKwargsItems``, base64-encodes each item, and assembles a
JSON-serializable dict matching vLLM's ``MultiModalFeatures`` schema.

Returns ``None`` semantics live one level up — this helper assumes
the caller already verified ``mm_data`` is non-empty.
"""
try:
import torch
from transformers.feature_extraction_utils import BatchFeature
from vllm.entrypoints.serve.disagg.mm_serde import encode_mm_kwargs_item
from vllm.model_executor.models.qwen2_vl import _create_qwen2vl_field_factory
from vllm.multimodal.inputs import MultiModalKwargsItems
except ImportError as exc:
raise RuntimeError(
"Multimodal generate via /inference/v1/generate requires `vllm` "
"and `torch` to encode the features payload. Install vLLM in this "
"environment, or pre-build features upstream."
) from exc

out: dict[str, Any] = {
"mm_hashes": {},
"mm_placeholders": {},
"kwargs_data": {},
}

image_items = mm_data.mm_items.get("image") or []
if image_items:
# mm_items now ship numpy arrays (the renderer is torch-free);
# convert at this vLLM-glue boundary where torch is already a
# hard dependency.
pixel_values = torch.cat(
[torch.as_tensor(it["pixel_values"]) for it in image_items], dim=0
)
image_grid_thw = torch.cat(
[torch.as_tensor(it["image_grid_thw"]) for it in image_items], dim=0
)
hf_inputs = BatchFeature(
data={"pixel_values": pixel_values, "image_grid_thw": image_grid_thw}
)
config = _create_qwen2vl_field_factory(spatial_merge_size)(hf_inputs)
kwargs_items = MultiModalKwargsItems.from_hf_inputs(hf_inputs, config)
encoded = [encode_mm_kwargs_item(it) for it in kwargs_items["image"]]
out["kwargs_data"]["image"] = encoded
out["mm_hashes"]["image"] = list(mm_data.mm_hashes.get("image") or [])
out["mm_placeholders"]["image"] = [
{"offset": p.offset, "length": p.length}
for p in mm_data.mm_placeholders.get("image") or []
]
run_id = current_run_id()
for source_modality, items in mm_data.mm_items.items():
if not items:
continue
mm_hashes = list(mm_data.mm_hashes.get(source_modality) or [])
placeholders = list(mm_data.mm_placeholders.get(source_modality) or [])
if len(mm_hashes) != len(items) or len(placeholders) != len(items):
raise ValueError(
"Multimodal sidecar length mismatch: "
f"modality={source_modality} items={len(items)} "
f"hashes={len(mm_hashes)} placeholders={len(placeholders)}"
)

for idx, item in enumerate(items):
if item.get("kind") != RAW_MM_ITEM_KIND:
raise NotImplementedError(
"Multimodal serialization requires raw descriptor envelopes; "
f"got item keys {sorted(item)} for modality {source_modality!r}."
)
feature_modality = item.get("vllm_modality") or source_modality
if not isinstance(feature_modality, str) or not feature_modality:
raise ValueError("raw multimodal item has invalid vllm_modality")
out["mm_hashes"].setdefault(feature_modality, []).append(mm_hashes[idx])
out["mm_placeholders"].setdefault(feature_modality, []).append(
{"offset": placeholders[idx].offset, "length": placeholders[idx].length}
)
out["kwargs_data"].setdefault(feature_modality, []).append(None)
if item.get(IMAGE_REF_PAYLOAD_KEY) != IMAGE_REF_PAYLOAD_VALUE:
continue
raw_image_id = item.get("raw_image_id")
family = item.get("family")
fingerprint = item.get("layout_fingerprint")
if not isinstance(raw_image_id, str) or not raw_image_id:
raise ValueError("raw multimodal item is missing raw_image_id")
if not isinstance(family, str) or not family:
raise ValueError("raw multimodal item is missing family")
if not isinstance(fingerprint, str) or not fingerprint:
raise ValueError("raw multimodal item is missing layout_fingerprint")
out["kwargs_data"][feature_modality][-1] = raw_mm_ref(
run_id=run_id,
family=family,
fingerprint=fingerprint,
modality=feature_modality,
mm_hash=mm_hashes[idx],
raw_image_id=raw_image_id,
payload=item.get("payload") or {},
)

# If kwargs_data is empty across all modalities, drop the key so vLLM
# falls back to the hash-only (cache-hit) path. Otherwise hand it the
# full payload.
if not any(out["kwargs_data"].values()):
if not any(item is not None for values in out["kwargs_data"].values() for item in values):
out["kwargs_data"] = None

return out
Loading
Loading