Skip to content

[feat] Add opt-in dynamic generation batching#1453

Open
macthecadillac wants to merge 11 commits into
hao-ai-lab:mainfrom
macthecadillac:multimodal-gen-batching
Open

[feat] Add opt-in dynamic generation batching#1453
macthecadillac wants to merge 11 commits into
hao-ai-lab:mainfrom
macthecadillac:multimodal-gen-batching

Conversation

@macthecadillac

Copy link
Copy Markdown
Contributor

Purpose

Adds an opt-in dynamic batching path for compatible text-only generation requests. This lets prompt-file generation and the OpenAI-compatible video server coalesce compatible requests
into a single forward pass while preserving the historical single-request path by default.

Fixes # N/A

Changes

  • Add typed batching config and CLI/API compatibility wiring for batching_mode, batching_max_size, batching_delay_ms, admission config, and batching metrics.
  • Add batching primitives for request compatibility signatures, conservative text-only admission, max batch caps, and model/resolution/cost-based JSON rules.
  • Add VideoGenerator.generate_video_batch() with merge/split behavior, per-request seed/output preservation, prompt-file dynamic batching, and sequential fallback for incompatible
    requests.
  • Make standard input-validation, text-encoding, and denoising stages batch-aware for the conservative text-only path.
  • Add an async OpenAI video VideoBatchScheduler that starts in server lifespan and routes eligible requests through dynamic batching.
  • Add focused unit tests, a Modal L40S parity/benchmark helper, SSIM validation notes, and final validation docs.
  • Added follow-up validation commit: da55c18b [docs]: record SSIM validation follow-up.

Test Plan

pytest fastvideo/tests/batching \
  fastvideo/tests/api/test_compat_translation.py \
  fastvideo/tests/entrypoints/test_video_generator.py \
  fastvideo/tests/entrypoints/test_openai_api.py \
  fastvideo/tests/stages/test_input_validation_batching.py \
  fastvideo/tests/stages/test_text_encoding.py -q

SSIM:

FASTVIDEO_SSIM_MODEL_ID=Wan2.1-T2V-1.3B-Diffusers \ pytest fastvideo/tests/ssim/test_wan_t2v_similarity.py -vs

Also ran changed-file pre-commit run --files ..., attempted pre-commit run --all-files, and ran GPU validation through fastvideo/tests/modal/launch_l40s_job.py.

Test Results

Focused validation

Final changed-file suite on Modal L40S:

  • Modal app: ap-1mFqrE5eCwPkEKnffQcQou
  • Code validation commit: 2304837
  • Result: 119 passed, 14 warnings
  • pre-commit: passed

Focused suites:

  • batching + compat: 27 passed
  • generator batching: 51 passed
  • OpenAI scheduler: 61 passed
  • batch compat fix: 23 passed
  • text padding fix: 28 passed
  • single-text-encode fix: 29 passed
SSIM validation

H100 SSIM attempt:

  • Modal app: ap-KaJr2loSTefvmj8ijYwWOK
  • GPU: H100:2
  • Result: generated both Wan T2V videos, but failed before SSIM comparison
  • Reason: missing H100 reference folders for both FLASH_ATTN and TORCH_SDPA

L40S SSIM fallback:

  • Modal app: ap-iWP6PA1IyZbXHDKtIE1LQH
  • GPU: L40S:2
  • Command:
    FASTVIDEO_SSIM_MODEL_ID=Wan2.1-T2V-1.3B-Diffusers pytest fastvideo/tests/ssim/test_wan_t2v_similarity.py -vs
  • Result: 2 passed, 6 warnings
  • FLASH_ATTN mean SSIM: 0.9786614696
  • TORCH_SDPA mean SSIM: 0.9743387236
Pre-commit validation

Changed-file pre-commit:

  • Passed in prior validation.

Full pre-commit attempt:

  • Modal app: ap-r20n8jCBwqQnh8I5Us1yTN
  • Command: pre-commit run --all-files
  • Result: failed because yapf/ruff rewrote a large number of pre-existing files across the repo.
  • Action taken: did not commit that unrelated formatting churn.

Docs follow-up validation:

  • git diff --check: passed
  • pre-commit run --files .agents/exploration/multimodal-gen-batching-final-report.md .agents/exploration/multimodal-gen-batching-port.md: passed/skipped as expected

Known Limitations

Dynamic batching remains disabled by default and is intentionally limited to compatible text-only requests. Image, video, audio, refine, continuation, and other conditioning-heavy
requests fall back to sequential execution.

Wan latent parity completed successfully but is not near bit-identical to sequential denoising in the current implementation:

  • aggregate max abs diff: 0.1457520127
  • aggregate mean abs diff: 0.0122398026
  • allclose(atol=1e-4, rtol=1e-4): false

Benchmarks on small latent-output Wan2.1 T2V 1.3B workloads showed:

  • batch 2, 2 steps: +1.6%
  • batch 4, 2 steps: +0.1%
  • batch 2, 8 steps: +10.7%

Checklist

  • I ran pre-commit run --all-files and fixed all issues
  • I added or updated tests for my changes
  • I updated documentation if needed
  • I considered GPU memory impact of my changes

For model/pipeline changes, also check:

  • I verified targeted Wan T2V SSIM regression tests pass on L40S
  • I updated the support matrix if adding a new model

@github-actions github-actions Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Welcome to FastVideo! Thanks for your first pull request.

How our CI works:

PRs run a two-tier CI system:

  1. Pre-commit — formatting (yapf), linting (ruff), type checking (mypy). Runs immediately on every PR.
  2. Fastcheck — core GPU tests (encoders, VAEs, transformers, kernels, unit tests). Runs automatically via Buildkite on relevant file changes (~10-15 min).
  3. Full Suite — integration tests, training pipelines, SSIM regression. Runs only when a reviewer adds the ready label.

Before your PR is reviewed:

  • pre-commit run --all-files passes locally
  • You've added or updated tests for your changes
  • The PR description explains what and why

If pre-commit fails, a bot comment will explain how to fix it. Fastcheck and Full Suite results appear in the Checks section below.

Useful links:

@mergify mergify Bot added type: feat New feature or capability scope: inference Inference pipeline, serving, CLI scope: infra CI, tests, Docker, build labels Jun 12, 2026
@mergify

mergify Bot commented Jun 12, 2026

Copy link
Copy Markdown
Contributor

Merge Protections

Your pull request matches the following merge protections and will not be merged until they are valid.

🔴 PR merge requirements

Waiting for

  • #approved-reviews-by>=1
  • check-success=fastcheck-passed
  • check-success=full-suite-passed
  • check-success~=pre-commit
This rule is failing.
  • #approved-reviews-by>=1
  • check-success=fastcheck-passed
  • check-success=full-suite-passed
  • check-success~=pre-commit
  • title~=(?i)^\[(feat|feature|bugfix|fix|refactor|perf|ci|doc|docs|misc|chore|kernel|new.?model|skill|skills|infra)\]

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an SGLang-style dynamic generation batching path to FastVideo for compatible text-only requests, covering configuration, admission control, request merging/splitting, batch-aware pipeline stages, and an OpenAI server queue scheduler. The review feedback highlights three key improvements: implementing a tensor padding helper in the text encoding stage to prevent shape mismatches when concatenating variable-length prompt embeddings, using appendleft in the scheduler to preserve strict FIFO queue ordering for incompatible requests, and expanding the _optional_bool helper to robustly support numeric boolean representations like 1 and 0.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +119 to +174
def _encode_prompt_list_individually(
self,
texts: list[str],
fastvideo_args: FastVideoArgs,
*,
encoder_index: list[int],
return_attention_mask: bool,
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
per_prompt_embeds: list[list[torch.Tensor]] = []
per_prompt_masks: list[list[torch.Tensor]] = []
per_prompt_audio_embeds: list[list[torch.Tensor] | None] = []

for text in texts:
embeds, masks = self.encode_text(
text,
fastvideo_args,
encoder_index=encoder_index,
return_attention_mask=return_attention_mask,
)
per_prompt_embeds.append(embeds)
per_prompt_masks.append(masks)
per_prompt_audio_embeds.append(self._last_audio_embeds)

merged_embeds = [
torch.cat([prompt_embeds[encoder_pos] for prompt_embeds in per_prompt_embeds], dim=0)
for encoder_pos in range(len(per_prompt_embeds[0]))
]
merged_masks = [
self._cat_attention_masks([prompt_masks[encoder_pos] for prompt_masks in per_prompt_masks])
for encoder_pos in range(len(per_prompt_masks[0]))
]
if per_prompt_audio_embeds and all(audio_embeds is not None for audio_embeds in per_prompt_audio_embeds):
audio_embed_lists = [audio_embeds for audio_embeds in per_prompt_audio_embeds if audio_embeds is not None]
self._last_audio_embeds = [
torch.cat([audio_embeds[encoder_pos] for audio_embeds in audio_embed_lists], dim=0)
for encoder_pos in range(len(audio_embed_lists[0]))
]
else:
self._last_audio_embeds = None
return merged_embeds, merged_masks

@staticmethod
def _cat_attention_masks(masks: list[torch.Tensor]) -> torch.Tensor:
base_shape = masks[0].shape[1:]
if all(mask.shape[1:] == base_shape for mask in masks):
return torch.cat(masks, dim=0)
if all(mask.ndim == 2 for mask in masks):
max_length = max(mask.shape[1] for mask in masks)
padded_masks = []
for mask in masks:
pad_width = max_length - mask.shape[1]
if pad_width > 0:
mask = torch.nn.functional.pad(mask, (0, pad_width), value=0)
padded_masks.append(mask)
return torch.cat(padded_masks, dim=0)
raise ValueError(f"Cannot concatenate attention masks with shapes: {[list(mask.shape) for mask in masks]}")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When encoding a batch of prompts of different lengths individually, the resulting prompt_embeds (and self._last_audio_embeds) can have different sequence lengths (dimension 1). Attempting to concatenate them directly using torch.cat along dim=0 will raise a RuntimeError due to shape mismatch.

To prevent this, we should introduce a helper method _cat_tensors that pads the sequence dimension of 3D tensors (like prompt_embeds and self._last_audio_embeds) to the maximum sequence length in the batch, similar to how _cat_attention_masks handles 2D masks.

    def _encode_prompt_list_individually(
        self,
        texts: list[str],
        fastvideo_args: FastVideoArgs,
        *,
        encoder_index: list[int],
        return_attention_mask: bool,
    ) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
        per_prompt_embeds: list[list[torch.Tensor]] = []
        per_prompt_masks: list[list[torch.Tensor]] = []
        per_prompt_audio_embeds: list[list[torch.Tensor] | None] = []

        for text in texts:
            embeds, masks = self.encode_text(
                text,
                fastvideo_args,
                encoder_index=encoder_index,
                return_attention_mask=return_attention_mask,
            )
            per_prompt_embeds.append(embeds)
            per_prompt_masks.append(masks)
            per_prompt_audio_embeds.append(self._last_audio_embeds)

        merged_embeds = [
            self._cat_tensors([prompt_embeds[encoder_pos] for prompt_embeds in per_prompt_embeds])
            for encoder_pos in range(len(per_prompt_embeds[0]))
        ]
        merged_masks = [
            self._cat_attention_masks([prompt_masks[encoder_pos] for prompt_masks in per_prompt_masks])
            for encoder_pos in range(len(per_prompt_masks[0]))
        ]
        if per_prompt_audio_embeds and all(audio_embeds is not None for audio_embeds in per_prompt_audio_embeds):
            audio_embed_lists = [audio_embeds for audio_embeds in per_prompt_audio_embeds if audio_embeds is not None]
            self._last_audio_embeds = [
                self._cat_tensors([audio_embeds[encoder_pos] for audio_embeds in audio_embed_lists])
                for encoder_pos in range(len(audio_embed_lists[0]))
            ]
        else:
            self._last_audio_embeds = None
        return merged_embeds, merged_masks

    @staticmethod
    def _cat_tensors(tensors: list[torch.Tensor]) -> torch.Tensor:
        base_shape = tensors[0].shape[1:]
        if all(t.shape[1:] == base_shape for t in tensors):
            return torch.cat(tensors, dim=0)
        if all(t.ndim == 3 for t in tensors):
            max_len = max(t.shape[1] for t in tensors)
            padded_tensors = []
            for t in tensors:
                pad_width = max_len - t.shape[1]
                if pad_width > 0:
                    t = torch.nn.functional.pad(t, (0, 0, 0, pad_width), value=0.0)
                padded_tensors.append(t)
            return torch.cat(padded_tensors, dim=0)
        raise ValueError(f"Cannot concatenate tensors with shapes: {[list(t.shape) for t in tensors]}")

    @staticmethod
    def _cat_attention_masks(masks: list[torch.Tensor]) -> torch.Tensor:
        base_shape = masks[0].shape[1:]
        if all(mask.shape[1:] == base_shape for mask in masks):
            return torch.cat(masks, dim=0)
        if all(mask.ndim == 2 for mask in masks):
            max_length = max(mask.shape[1] for mask in masks)
            padded_masks = []
            for mask in masks:
                pad_width = max_length - mask.shape[1]
                if pad_width > 0:
                    mask = torch.nn.functional.pad(mask, (0, pad_width), value=0)
                padded_masks.append(mask)
            return torch.cat(padded_masks, dim=0)
        raise ValueError(f"Cannot concatenate attention masks with shapes: {[list(mask.shape) for mask in masks]}")

Comment on lines +100 to +104
if self._jobs_are_compatible(batch[0], candidate):
batch.append(candidate)
continue
self._pending.append(candidate)
break

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

When a candidate job is found to be incompatible with the current batch, appending it back to the right side of self._pending (using append) violates the strict FIFO ordering of the queue. This can cause newer jobs to be processed before older ones, potentially leading to unfairness or starvation.

Using self._pending.appendleft(candidate) ensures that the incompatible job remains at the front of the queue to be processed first in the next batch.

Suggested change
if self._jobs_are_compatible(batch[0], candidate):
batch.append(candidate)
continue
self._pending.append(candidate)
break
if self._jobs_are_compatible(batch[0], candidate):
batch.append(candidate)
continue
self._pending.appendleft(candidate)
break

Comment on lines +285 to +293
if isinstance(value, bool):
return value
if isinstance(value, str):
lowered = value.strip().lower()
if lowered in ("1", "true", "yes", "y", "on"):
return True
if lowered in ("0", "false", "no", "n", "off"):
return False
raise ValueError(f"cannot parse boolean batching config value: {value!r}")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _optional_bool helper raises a ValueError if the input is an integer or float (like 1 or 0), which are common representations of booleans in JSON configurations.

We should explicitly support 1 and 0 (or 1.0 and 0.0) to make config parsing more robust and user-friendly.

Suggested change
if isinstance(value, bool):
return value
if isinstance(value, str):
lowered = value.strip().lower()
if lowered in ("1", "true", "yes", "y", "on"):
return True
if lowered in ("0", "false", "no", "n", "off"):
return False
raise ValueError(f"cannot parse boolean batching config value: {value!r}")
if isinstance(value, bool):
return value
if isinstance(value, (int, float)):
if value == 1.0:
return True
if value == 0.0:
return False
if isinstance(value, str):
lowered = value.strip().lower()
if lowered in ("1", "true", "yes", "y", "on"):
return True
if lowered in ("0", "false", "no", "n", "off"):
return False

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

scope: inference Inference pipeline, serving, CLI scope: infra CI, tests, Docker, build type: feat New feature or capability

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant