diff --git a/.agents/exploration/multimodal-gen-batching-final-report.md b/.agents/exploration/multimodal-gen-batching-final-report.md new file mode 100644 index 000000000..64340dd8b --- /dev/null +++ b/.agents/exploration/multimodal-gen-batching-final-report.md @@ -0,0 +1,188 @@ +# Multimodal Generation Batching Port Report + +Date: 2026-05-30 +Branch: `multimodal-gen-batching` + +## Summary + +This port adds an SGLang-style generation batching path to FastVideo for +compatible text-only generation requests. The implementation covers typed +batching config, request compatibility/admission rules, generator merge/split, +batch-safe standard pipeline stages, and an OpenAI server queue scheduler. + +Initial scope is intentionally conservative: text-only compatible requests can +batch. Requests with image, video, audio, action, continuation, refine, or other +conditioning inputs are rejected by the compatibility signature and continue on +the sequential path. + +## Commits + +- `3d3157fb` - `[feat]: add generation batching primitives` +- `ce758820` - `[misc]: record batching stage 1 state` +- `fe5572f3` - `[feat]: add generator dynamic batching path` +- `39cb1d43` - `[misc]: record batching stage 2 state` +- `4c6a9dc2` - `[feat]: add OpenAI video batching scheduler` +- `1bea1dee` - `[misc]: record batching stage 5 state` +- `9c6eb355` - `[fix]: harden dynamic generation batching` +- `2304837e` - `[docs]: record multimodal batching validation report` + +## Implementation + +Added batching configuration: + +- `fastvideo/api/schema.py`: `BatchingConfig` under `EngineConfig`. +- `fastvideo/fastvideo_args.py`: `batching_mode`, `batching_max_size`, + `batching_delay_ms`, `batching_config`, and `enable_batching_metrics`. +- `fastvideo/api/compat.py`: legacy kwarg translation into typed batching + config and back into `FastVideoArgs`. + +Added batching primitives: + +- `fastvideo/batching/admission.py`: max batch size and JSON-rule admission. +- `fastvideo/batching/signature.py`: request compatibility signatures and + unsupported multimodal/conditioning fields. +- `PipelineConfig.estimate_request_cost()`: default hook for future admission + cost rules. + +Added generator batching: + +- `VideoGenerator.generate_video_batch()` accepts legacy request kwargs. +- Compatible work items are merged into one `ForwardBatch` with prompt lists and + per-request seeds, then split back into per-request result dictionaries. +- Incompatible adjacent requests fall back to existing single-request execution. +- Prompt-file generation uses dynamic batching when enabled. +- `generate_video_batch()` now routes each request through the same compatibility + adapter as `generate_video()`, including pipeline overrides such as + `embedded_cfg_scale`. + +Made standard stages batch-aware: + +- `InputValidationStage` preserves explicit per-request seeds and fans out + prompt-list seeds. +- `TextEncodingStage` repeats a single negative prompt across prompt lists. +- `TextEncodingStage.encode_text()` adds tokenizer padding for direct prompt-list + tokenization. +- `TextEncodingStage.forward()` preserves the existing single-prompt text + encoding path for merged prompt lists, then concatenates postprocessed + embeddings/masks. This reduces text-encoder-induced parity drift. +- `DenoisingStage` no longer asserts batch size 1 on the standard path. + +Added OpenAI server batching: + +- `fastvideo/entrypoints/openai/batching.py`: async FIFO `VideoBatchScheduler`. +- `fastvideo/entrypoints/openai/state.py`: scheduler lifecycle state. +- `fastvideo/entrypoints/openai/api_server.py`: scheduler start/stop in + lifespan. +- `fastvideo/entrypoints/openai/video_api.py`: routes requests through the + scheduler when dynamic batching is enabled. + +Added validation helper: + +- `fastvideo/tests/batching/run_dynamic_batching_parity.py` +- Modes: + - `parity`: sequential `generate_video()` vs dynamic `generate_video_batch()` + latent comparison. + - `sequential`: disabled-batching benchmark baseline. + - `dynamic`: dynamic batching benchmark. + +## Remote Validation + +All validation was run through +`fastvideo/tests/modal/launch_l40s_job.py` on Modal L40S. + +Focused tests and hooks: + +| Stage | Modal app | Command summary | Result | +| --- | --- | --- | --- | +| Config/admission/signature | `ap-q9SRzkYvwXzoPA5oWoaM3i` | `pytest fastvideo/tests/batching fastvideo/tests/api/test_compat_translation.py -q && pre-commit run --files ...` | `27 passed`, pre-commit passed | +| Generator batching | `ap-0LrnrxcMw4Ni3M9YJ6gtZv` | batching, API compat, generator, input validation tests plus pre-commit | `51 passed`, pre-commit passed | +| OpenAI scheduler | `ap-sgV5gRsGJeHE4g9a1Dswk3` | `pytest fastvideo/tests/entrypoints/test_openai_api.py -q && pre-commit run --files ...` | `61 passed`, pre-commit passed | +| Batch compat fix | `ap-gmob40FWTO5knEPA39s3bd` | `pytest fastvideo/tests/entrypoints/test_video_generator.py -q && pre-commit run --files ...` | `23 passed`, pre-commit passed | +| Text padding fix | `ap-IJIJjdLogSGkzeN4sCP46E` | entrypoint and text encoding tests plus pre-commit | `28 passed`, pre-commit passed | +| Single-text-encode fix | `ap-DIIEE6Wy0I728fqsc63C6s` | entrypoint and text encoding tests plus pre-commit | `29 passed`, pre-commit passed | +| Final changed-file suite | `ap-1mFqrE5eCwPkEKnffQcQou` | batching, generator, text encoding, OpenAI API, compat, and input-validation tests plus pre-commit | `119 passed`, pre-commit passed | + +Post-report validation: + +| Check | Modal app | Command summary | Result | +| --- | --- | --- | --- | +| Wan T2V SSIM on H100 | `ap-KaJr2loSTefvmj8ijYwWOK` | `FASTVIDEO_SSIM_MODEL_ID=Wan2.1-T2V-1.3B-Diffusers pytest fastvideo/tests/ssim/test_wan_t2v_similarity.py -vs` on `H100:2` | Generated both videos, but failed before SSIM comparison because `H100_reference_videos` are missing for both `FLASH_ATTN` and `TORCH_SDPA` | +| Wan T2V SSIM on L40S | `ap-iWP6PA1IyZbXHDKtIE1LQH` | same targeted Wan T2V SSIM command on `L40S:2`, `--install-extra none` | `2 passed`, `6 warnings`; mean SSIM `0.9786614696` for `FLASH_ATTN`, `0.9743387236` for `TORCH_SDPA` | +| Full pre-commit attempt | `ap-r20n8jCBwqQnh8I5Us1yTN` | `pre-commit run --all-files` on `L40S:1` | Failed because yapf/ruff rewrote a large set of pre-existing repository files; not taken as PR-local evidence | + +## Parity + +Parity workload: + +- Model: `Wan-AI/Wan2.1-T2V-1.3B-Diffusers` +- GPU: one L40S +- Output: latent tensors, no video save +- Shape: `256x256`, 9 frames +- Steps: 2 +- Batch size: 2 +- Guidance: `guidance_scale=1.0`, `embedded_cfg_scale=6.0` + +Final dynamic parity run: + +- Modal app: `ap-fYT25LrzvFbUZv50JhxmWC` +- Sequential time: `2.8672916360s` +- Dynamic time: `2.1450136500s` +- Speedup in this run: `1.3367x` + +Tensor comparison: + +| Request | Shape | Max abs diff | Mean abs diff | allclose 1e-4 | +| --- | --- | ---: | ---: | --- | +| 0 | `[1, 16, 3, 32, 32]` | `0.0380859375` | `0.0047932155` | false | +| 1 | `[1, 16, 3, 32, 32]` | `0.1457520127` | `0.0196863897` | false | +| Aggregate | - | `0.1457520127` | `0.0122398026` | false | + +Conclusion: dynamic batched denoising is not near bit-identical to sequential +denoising for this Wan latent test. The text-encoding path was adjusted to match +single-request encoding, so the remaining difference is likely from batched +Wan transformer/denoising math. The existing disabled/sequential path remains +available and is the default unless `batching_mode=dynamic` is explicitly set. + +## Benchmarks + +Benchmark environment: + +- Modal L40S +- Wan2.1 T2V 1.3B +- Latent output, no video save +- `256x256`, 9 frames +- `guidance_scale=1.0` +- One warmup per mode + +| Workload | Mode | Runs | Avg seconds | Throughput req/s | +| --- | --- | --- | ---: | ---: | +| batch 2, 2 steps | sequential | 3 | `2.1733781003` | `0.9202264437` | +| batch 2, 2 steps | dynamic | 3 | `2.1396027907` | `0.9347529405` | +| batch 4, 2 steps | sequential | 3 | `4.3423643910` | `0.9211571485` | +| batch 4, 2 steps | dynamic | 3 | `4.3374464800` | `0.9222015807` | +| batch 2, 8 steps | sequential | 2 | `2.6575391855` | `0.7525759210` | +| batch 2, 8 steps | dynamic | 2 | `2.4006625770` | `0.8331033354` | + +Observed throughput changes: + +- Batch 2, 2 steps: about `+1.6%` +- Batch 4, 2 steps: about `+0.1%` +- Batch 2, 8 steps: about `+10.7%` + +Interpretation: with the exact per-prompt text-encoding safeguard, very short +2-step workloads are dominated by text encoding and launch overhead. Dynamic +batching shows clearer benefit once denoising compute is a larger fraction of +the request. + +## Limitations And Follow-Up + +- Dynamic batched Wan denoising is not near bit-identical to sequential Wan + denoising in the current implementation. +- The implementation is text-only for dynamic batching. Multimodal and + conditioning-heavy requests are intentionally routed to sequential execution. +- The benchmarks used small latent-output workloads to keep Modal iteration + time reasonable. Larger production shapes and longer schedules should be + benchmarked before enabling dynamic batching by default. +- A future exact-parity mode could preserve the scheduler/admission queue but + execute denoising per request. That would satisfy strict numerical parity but + would give up most batching speedup. diff --git a/.agents/exploration/multimodal-gen-batching-port.md b/.agents/exploration/multimodal-gen-batching-port.md new file mode 100644 index 000000000..c5d8a9f33 --- /dev/null +++ b/.agents/exploration/multimodal-gen-batching-port.md @@ -0,0 +1,529 @@ +# Exploration Log: Multimodal Generation Batching Port + +## Status: complete with documented dynamic parity limitation + +## Context +User requested a staged plan, approval before implementation, and then a port of +SGLang `python/sglang/multimodal_gen` batching into FastVideo on branch +`multimodal-gen-batching`. + +Upstream reference inspected locally from a sparse clone at: +`/tmp/sglang-multimodal-gen/python/sglang/multimodal_gen`. + +Primary upstream files inspected: +- `runtime/managers/scheduler.py` +- `runtime/managers/dynamic_batch_admission.py` +- `runtime/managers/gpu_worker.py` +- `runtime/pipelines_core/schedule_batch.py` +- `runtime/pipelines_core/composed_pipeline_base.py` +- `runtime/pipelines_core/executors/{pipeline_executor,sync_executor,parallel_executor}.py` +- `runtime/pipelines_core/stages/{base,dedup,latent_preparation,denoising}.py` +- `configs/sample/sampling_params.py` +- `configs/pipeline_configs/base.py` + +FastVideo files inspected: +- `fastvideo/entrypoints/video_generator.py` +- `fastvideo/entrypoints/openai/{api_server,video_api}.py` +- `fastvideo/worker/{executor,multiproc_executor,gpu_worker,worker_base}.py` +- `fastvideo/pipelines/{pipeline_batch_info,composed_pipeline_base}.py` +- `fastvideo/pipelines/stages/{base,text_encoding,latent_preparation,timestep_preparation,denoising,decoding}.py` +- `fastvideo/api/schema.py` +- `fastvideo/configs/pipelines/base.py` +- `fastvideo/tests/modal/launch_l40s_job.py` + +## Progress +- [x] Read repository onboarding, codebase map, relevant AGENTS guidance, and + exploration template. +- [x] Sparse-cloned upstream SGLang multimodal generation source to `/tmp`. +- [x] Identified upstream batching mechanism and corresponding FastVideo gaps. +- [x] Drafted implementation plan for user approval. +- [x] User approved the staged implementation plan on 2026-05-30. +- [x] Stage 1 implementation: typed batching config, FastVideoArgs wiring, + batching admission rules, compatibility signatures, and focused unit + tests have been added locally. +- [x] Stage 1 remote validation passed on Modal L40S. +- [x] Stage 1 commit: `3d3157fb` (`[feat]: add generation batching primitives`). +- [x] Stage 1 state commit: `ce758820` + (`[misc]: record batching stage 1 state`). +- [x] Stage 2 local implementation: + - added generator prepared-work-item merge/split helpers, + - added `generate_video_batch()` for later server queue use, + - enabled prompt-file dynamic batching, + - preserved explicit per-request seeds for prompt-list batches, + - repeated negative prompts for CFG with prompt lists, + - removed the standard denoising stage batch-size-1 assertion, + - added focused CPU-light tests with fake forward execution. +- [x] Stage 2 remote validation passed on Modal L40S. +- [x] Stage 2 commit: `fe5572f3` + (`[feat]: add generator dynamic batching path`). +- [x] Stage 2 state commit: `39cb1d43` + (`[misc]: record batching stage 2 state`). +- [x] Stage 5 local implementation: + - added `VideoBatchScheduler` for OpenAI video requests, + - stores the scheduler in OpenAI server state, + - starts/stops the scheduler in API server lifespan, + - routes `video_api._run_generation()` through the scheduler when + dynamic batching is enabled, + - added async scheduler tests for compatible grouping and incompatible + fallback. +- [x] Stage 5 remote validation passed on Modal L40S. +- [x] Stage 5 commit: `4c6a9dc2` + (`[feat]: add OpenAI video batching scheduler`). +- [x] Added GPU parity/benchmark helper script: + `fastvideo/tests/batching/run_dynamic_batching_parity.py`. +- [x] Remote GPU parity run completed; dynamic batched denoising did not meet + near-bit-identical tolerance and is documented as a limitation. +- [x] Remote benchmark runs completed for batch size 2, batch size 4, and an + 8-step batch size 2 workload. +- [x] Commit GPU helper, parity fixes, and benchmark state: + `9c6eb355` (`[fix]: harden dynamic generation batching`). +- [x] Commit final Markdown write-up with test and benchmark results: + `2304837e` (`[docs]: record multimodal batching validation report`). +- [x] Push branch to origin after final report commit. +- [x] Final changed-file validation passed on Modal L40S. + +## Findings +Upstream SGLang batching is not just a larger prompt list. It has four pieces: + +1. Queue/admission in `runtime/managers/scheduler.py`. + Requests are received into a FIFO queue with enqueue timestamps. Compatible + text-only generation requests are coalesced up to `batching_max_size` or + after `batching_delay_ms`, with metrics and rejection reasons. + +2. Compatibility and admission control. + Compatibility is based on sampling-param signatures with selected fields + excluded via `metadata={"batch_sig_exclude": True}`. Admission applies user + max batch size plus optional JSON rules keyed by model/resolution/memory and + a rough request cost from pipeline config. + +3. Request merge/split. + Compatible requests are deep-copied into one request with `prompt: list[str]` + and per-request seeds/output paths stashed in `extra`. The worker runs one + merged request and the scheduler splits tensor/list/path outputs back into + one output per original request. + +4. Pipeline support. + SGLang added grouped pipeline execution, stage-level dedup hooks, request-local + schedulers, grouped latent preparation that preserves per-request RNG streams, + and denoising code that can process latent batch dimensions greater than one. + +FastVideo currently has: +- `ForwardBatch.prompt` already accepts `str | list[str]`. +- Text encoding already handles list prompts. +- Latent preparation mostly handles batch sizes, but RNG equivalence must be + preserved for grouped requests. +- Standard `DenoisingStage` has an explicit `assert latent_model_input.shape[0] == 1`; + this is a hard blocker for true native batching. +- `VideoGenerator._generate_request_impl` expands prompt lists sequentially. +- OpenAI video API launches each request as an independent background thread + against one global `VideoGenerator`; there is no scheduler queue/admission + layer. + +## Approved Plan +Approved by the user on 2026-05-30. + +Stages: +1. Add batching config and pure admission/signature primitives. +2. Refactor generator/executor/pipeline surfaces for merge/split. +3. Enable true native batching for a conservative text-only path. +4. Integrate the OpenAI server queue scheduler. +5. Run all validation remotely through + `fastvideo/tests/modal/launch_l40s_job.py`. +6. Run parity and before/after benchmarks. +7. Save a final Markdown report, commit it, and push the branch. + +Scope guard: initial dynamic batching only supports compatible text-only +requests. Image/video/audio/refine/continuation/control inputs remain routed +to sequential execution until separately audited. + +## Implementation Log + +### Stage 1: Config And Pure Batching Primitives +Files changed/added: +- `fastvideo/api/schema.py` + - Added `BatchingConfig` under `EngineConfig`. +- `fastvideo/fastvideo_args.py` + - Added `batching_mode`, `batching_max_size`, `batching_delay_ms`, + `batching_config`, and `enable_batching_metrics`. + - Added CLI flags and basic validation. +- `fastvideo/api/compat.py` + - Mapped legacy flat kwargs to `EngineConfig.batching`. + - Emitted typed batching config back to `FastVideoArgs`. +- `fastvideo/entrypoints/video_generator.py` + - Allowed batching kwargs through `from_pretrained` convenience handling. +- `fastvideo/configs/pipelines/base.py` + - Added default `estimate_request_cost()` for admission budgets. +- `fastvideo/batching/` + - Added `admission.py` with SGLang-style rule parsing and cap logic. + - Added `signature.py` with request-local exclusions and safe text-only + compatibility checks. +- `fastvideo/tests/batching/` + - Added focused admission/signature tests. +- `fastvideo/tests/api/test_compat_translation.py` + - Added batching config translation coverage. + +Stage 1 is complete. + +### Stage 2: Generator Merge/Split And Batch-Safe Standard Stages +Files changed/added: +- `fastvideo/entrypoints/video_generator.py` + - Added `_GenerationWorkItem`, forward execution helper, output + postprocessing helper, output split helper, work-item merge/grouping logic, + and `generate_video_batch()`. + - `prompt_txt` / `SamplingParam.prompt_path` now uses dynamic batching when + `batching_mode=dynamic` and `batching_max_size>1`; otherwise it keeps the + prior sequential behavior. +- `fastvideo/pipelines/stages/input_validation.py` + - Preserves `ForwardBatch.seeds` when already supplied by a merged batch. + - Generates one seed per prompt for prompt-list batches. +- `fastvideo/pipelines/stages/text_encoding.py` + - Expands a single negative prompt across a prompt list for CFG. +- `fastvideo/pipelines/stages/denoising.py` + - Removed the explicit standard-path `shape[0] == 1` assertion. +- `fastvideo/tests/entrypoints/test_video_generator.py` + - Added fake-forward tests for merged compatible requests and sequential + fallback on incompatible requests. +- `fastvideo/tests/stages/test_input_validation_batching.py` + - Added seed preservation and prompt-list seed fanout tests. + +Stage 2 is complete. + +Validation attempt: +- Modal app: `ap-vsU8ZgjMvfGgpRBWk4IfCV` +- Result: + - Tests failed before pre-commit: `1 failed, 50 passed, 14 warnings`. + - Failure was an existing entrypoint test using a `SimpleNamespace` test + double without new batching fields. +- Fix applied locally: + - `_dynamic_batching_enabled()` now defaults missing batching fields to + disabled / max size 1. + +Validation rerun: +- Modal app: `ap-Y6B6cRSzW9frq4IIU7Tyqj` +- Result: + - Tests passed: `51 passed, 14 warnings`. + - Pre-commit failed only on Ruff `F841` for an unused `sampling_param` + local in `VideoGenerator._postprocess_generation_output()`. +- Fix applied locally: + - Removed the unused local. + +Validation clean rerun: +- Modal app: `ap-0LrnrxcMw4Ni3M9YJ6gtZv` +- Command: + `pytest fastvideo/tests/batching fastvideo/tests/api/test_compat_translation.py fastvideo/tests/entrypoints/test_video_generator.py fastvideo/tests/stages/test_input_validation_batching.py -q && pre-commit run --files ...` +- Result: + - Tests passed: `51 passed, 14 warnings`. + - Pre-commit passed: yapf, ruff, codespell, mypy, filename spaces, and + suggestion hooks. + +### Stage 5: OpenAI Server Queue Integration +Files changed/added: +- `fastvideo/entrypoints/openai/batching.py` + - Added `VideoBatchScheduler`, an async FIFO queue with batching delay, + compatibility checks, background dispatch, and per-request futures. +- `fastvideo/entrypoints/openai/state.py` + - Added global scheduler storage and accessor. +- `fastvideo/entrypoints/openai/api_server.py` + - Starts the scheduler during lifespan when `batching_mode=dynamic` and + `batching_max_size>1`; stops it before generator shutdown. +- `fastvideo/entrypoints/openai/video_api.py` + - `_run_generation()` submits to the scheduler when enabled; otherwise keeps + the prior direct executor-thread path. +- `fastvideo/tests/entrypoints/test_openai_api.py` + - Added scheduler grouping and incompatible fallback tests using a fake + generator. + +Stage 5 is complete. + +Validation attempt: +- Modal app: `ap-2PJ4b5eKPnxR9HTEb9x3UM` +- Command: + `pytest fastvideo/tests/entrypoints/test_openai_api.py -q && pre-commit run --files ...` +- Result: + - Tests passed: `61 passed, 14 warnings`. + - Pre-commit failed only on mypy for assigning to an `exc` variable outside + an `except` block in `openai/batching.py`. +- Fix applied locally: + - Renamed that local variable to `error`. + +Validation clean rerun: +- Modal app: `ap-sgV5gRsGJeHE4g9a1Dswk3` +- Command: + `pytest fastvideo/tests/entrypoints/test_openai_api.py -q && pre-commit run --files ...` +- Result: + - Tests passed: `61 passed, 14 warnings`. + - Pre-commit passed: yapf, ruff, codespell, mypy, filename spaces, and + suggestion hooks. + +Validation attempt: +- Modal L40S command: + `pytest fastvideo/tests/batching fastvideo/tests/api/test_compat_translation.py -q && pre-commit run --files ...` +- Result: + - Tests passed: `27 passed, 14 warnings`. + - Pre-commit failed only on Ruff `SIM103` in the two new batching helper + files. +- Fix applied locally: + - Simplified the relevant boolean returns in `admission.py` and + `signature.py`. + +Validation rerun: +- Modal app: `ap-q9SRzkYvwXzoPA5oWoaM3i` +- Command: + `pytest fastvideo/tests/batching fastvideo/tests/api/test_compat_translation.py -q && pre-commit run --files ...` +- Result: + - Tests passed: `27 passed, 14 warnings`. + - Pre-commit passed: yapf, ruff, codespell, mypy, filename spaces, and + suggestion hooks. + +## Validation Plan +All validation must run remotely through: +`python -m modal run fastvideo/tests/modal/launch_l40s_job.py ...` + +Planned remote validation after implementation: +- CPU/light import and unit tests on Modal image, not local. +- Focused batching unit tests for compatibility/admission/merge/split. +- GPU parity tests on Modal L40S comparing sequential vs batched outputs with + fixed seeds, prompt sets, model, resolution, steps, backend, and save disabled. +- Existing SSIM regression for affected model(s) if references are available. +- Before/after benchmark suite with dynamic batching disabled vs enabled and + sequential prompt-file baseline vs dynamic server batching. + +GPU validation helper: +- `fastvideo/tests/batching/run_dynamic_batching_parity.py` +- Supports: + - `--mode parity`: sequential `generate_video()` for each request vs one + `generate_video_batch()` call in the same checkout; compares latent tensors. + - `--mode sequential`: benchmark current sequential behavior. + - `--mode dynamic`: benchmark dynamic batching behavior. +- Defaults use Wan2.1 T2V 1.3B, latent output, 256x256, 9 frames, 2 steps, + batch size 2. Final parity/benchmark runs may override these if the model + requires a larger valid shape. + +Benchmark run, batch size 2: +- Modal app: `ap-V0fkFaplHrGYto9hFWvRmc` +- Command: + `python fastvideo/tests/batching/run_dynamic_batching_parity.py --mode sequential --warmup-runs 1 --measurement-runs 3 --output-json /tmp/fastvideo_dynamic_batching/sequential.json && python fastvideo/tests/batching/run_dynamic_batching_parity.py --mode dynamic --warmup-runs 1 --measurement-runs 3 --output-json /tmp/fastvideo_dynamic_batching/dynamic.json` +- Workload: + - Wan2.1 T2V 1.3B, one L40S, latent output, 256x256, 9 frames, 2 denoise + steps, two prompts, `guidance_scale=1.0`. +- Results: + - Sequential baseline times: `[2.1736583650, 2.1744417920, 2.1720341440]` + - Sequential average: `2.1733781003s`; throughput `0.9202264437 req/s` + - Dynamic times: `[2.1407064020, 2.1398537520, 2.1382482180]` + - Dynamic average: `2.1396027907s`; throughput `0.9347529405 req/s` + - Throughput improvement: about `1.6%`. + +Benchmark run, batch size 4: +- Modal app: `ap-K5tj5fWQKByZ0Sl9wn6OGC` +- Command: + `python fastvideo/tests/batching/run_dynamic_batching_parity.py --mode sequential --batch-size 4 --warmup-runs 1 --measurement-runs 3 ... && python fastvideo/tests/batching/run_dynamic_batching_parity.py --mode dynamic --batch-size 4 --warmup-runs 1 --measurement-runs 3 ...` +- Workload: + - Wan2.1 T2V 1.3B, one L40S, latent output, 256x256, 9 frames, 2 denoise + steps, four prompts, `guidance_scale=1.0`. +- Results: + - Sequential baseline times: `[4.3420497740, 4.3434355840, 4.3416078150]` + - Sequential average: `4.3423643910s`; throughput `0.9211571485 req/s` + - Dynamic times: `[4.2617743810, 4.4826704910, 4.2678945680]` + - Dynamic average: `4.3374464800s`; throughput `0.9222015807 req/s` + - Throughput improvement: about `0.1%`. + - Interpretation: with exact per-prompt text encoding and only two denoise + steps, this small synthetic benchmark is dominated by text/launch overhead, + so dynamic denoising has little room to help. + +Benchmark run, batch size 2, 8 denoise steps: +- Modal app: `ap-yPsbXeIc6YbCG7NvAueN4t` +- Command: + `python fastvideo/tests/batching/run_dynamic_batching_parity.py --mode sequential --num-inference-steps 8 --warmup-runs 1 --measurement-runs 2 ... && python fastvideo/tests/batching/run_dynamic_batching_parity.py --mode dynamic --num-inference-steps 8 --warmup-runs 1 --measurement-runs 2 ...` +- Workload: + - Wan2.1 T2V 1.3B, one L40S, latent output, 256x256, 9 frames, 8 denoise + steps, two prompts, `guidance_scale=1.0`. +- Results: + - Sequential baseline times: `[2.7835521810, 2.5315261900]` + - Sequential average: `2.6575391855s`; throughput `0.7525759210 req/s` + - Dynamic times: `[2.4052083240, 2.3961168300]` + - Dynamic average: `2.4006625770s`; throughput `0.8331033354 req/s` + - Throughput improvement: about `10.7%`. + +Final report: +- Created `.agents/exploration/multimodal-gen-batching-final-report.md`. +- Includes implementation summary, commit list, remote validation results, + dynamic parity results, benchmark tables, and limitations/follow-up. + +### Stage 6: GPU Parity And Benchmark Validation +Parity attempt: +- Modal app: `ap-UTIijf9LTsX3ua9Czvm5dq` +- Command: + `python fastvideo/tests/batching/run_dynamic_batching_parity.py --mode parity --output-json /tmp/fastvideo_dynamic_batching/parity.json` +- Result: + - Failed before tensor comparison. + - Sequential `generate_video()` accepted the legacy request kwarg + `embedded_cfg_scale`, but `generate_video_batch()` tried to apply all + request kwargs directly to `SamplingParam.update()` and rejected + `embedded_cfg_scale`. +- Fix applied locally: + - `VideoGenerator.generate_video_batch()` now routes each request through + `legacy_generate_call_to_request()`, `request_to_sampling_param()`, and + `request_to_pipeline_overrides()`, matching `generate_video()`. + - Identical pipeline overrides reuse one resolved `FastVideoArgs` object so + compatible requests can still merge; different overrides remain separated + by the existing object-identity compatibility guard. + - Added a unit test covering `generate_video_batch()` with + `embedded_cfg_scale`. + +Focused validation for the compat fix: +- Modal app: `ap-gmob40FWTO5knEPA39s3bd` +- Command: + `pytest fastvideo/tests/entrypoints/test_video_generator.py -q && pre-commit run --files fastvideo/entrypoints/video_generator.py fastvideo/tests/entrypoints/test_video_generator.py fastvideo/tests/batching/run_dynamic_batching_parity.py .agents/exploration/multimodal-gen-batching-port.md` +- Result: + - Tests passed: `23 passed, 14 warnings`. + - Pre-commit passed: yapf, ruff, codespell, mypy, filename spaces, and + suggestion hooks. + +Parity rerun: +- Modal app: `ap-C8JZ6i4R7mv6NOBTBDCRTc` +- Command: + `python fastvideo/tests/batching/run_dynamic_batching_parity.py --mode parity --output-json /tmp/fastvideo_dynamic_batching/parity.json` +- Result: + - Failed in the batched forward path before tensor comparison. + - The Wan tokenizer received a prompt list with variable token lengths and + `return_tensors="pt"` but no padding, causing Hugging Face tokenization to + reject non-rectangular `input_ids`. +- Fix applied locally: + - `TextEncodingStage.encode_text()` now adds tokenizer `padding=True` when + encoding multiple processed prompts and no explicit padding mode is already + configured. + - Added a unit test to cover default padding insertion for prompt-list text + encoding. + +Focused validation attempt for padding fix: +- Modal app: `ap-MFeCHJF61EuAgsZro5cAb4` +- Command: + `pytest fastvideo/tests/entrypoints/test_video_generator.py fastvideo/tests/stages/test_text_encoding.py -q && pre-commit run --files ...` +- Result: + - Product tests mostly passed, but the new test had a misplaced assertion + referencing `out2` outside its original test. +- Fix applied locally: + - Moved the prompt/negative attention-mask assertions back into + `test_forward_integration_cfg_off_and_on()` and left the padding test + scoped to tokenizer kwargs. + +Focused validation clean rerun: +- Modal app: `ap-IJIJjdLogSGkzeN4sCP46E` +- Command: + `pytest fastvideo/tests/entrypoints/test_video_generator.py fastvideo/tests/stages/test_text_encoding.py -q && pre-commit run --files ...` +- Result: + - Tests passed: `28 passed, 14 warnings`. + - Pre-commit passed: yapf, ruff, codespell, mypy, filename spaces, and + suggestion hooks. + +Parity rerun after padding fix: +- Modal app: `ap-MMEOgUNlqYzeoYaD2blJVh` +- Command: + `python fastvideo/tests/batching/run_dynamic_batching_parity.py --mode parity --output-json /tmp/fastvideo_dynamic_batching/parity.json` +- Result: + - Batched forward completed successfully. + - Tensor parity was not close enough: + - request 0 max abs diff `0.0625`, mean abs diff `0.0048387293` + - request 1 max abs diff `0.1708983183`, mean abs diff `0.0173878949` + - aggregate max abs diff `0.1708983183`, mean abs diff `0.0111133121` + - `torch.allclose(..., atol=1e-4, rtol=1e-4)` failed. +- Follow-up fix applied locally: + - `TextEncodingStage.forward()` now preserves the existing single-prompt text + encoding path for prompt-list batches by encoding each prompt separately + and concatenating postprocessed embeddings/masks. This keeps denoising + batched while removing tokenizer padding/sequence-length drift from the + parity path. + - Added a unit test proving prompt-list `forward()` uses one tokenizer call + per prompt, while direct `encode_text(list)` still supports padded batched + tokenization. + +Focused validation for single-text-encode parity fix: +- Modal app: `ap-DIIEE6Wy0I728fqsc63C6s` +- Command: + `pytest fastvideo/tests/entrypoints/test_video_generator.py fastvideo/tests/stages/test_text_encoding.py -q && pre-commit run --files ...` +- Result: + - Tests passed: `29 passed, 14 warnings`. + - Pre-commit passed: yapf, ruff, codespell, mypy, filename spaces, and + suggestion hooks. + +Parity rerun after single-text-encode fix: +- Modal app: `ap-fYT25LrzvFbUZv50JhxmWC` +- Command: + `python fastvideo/tests/batching/run_dynamic_batching_parity.py --mode parity --output-json /tmp/fastvideo_dynamic_batching/parity.json` +- Result: + - Batched forward completed successfully. + - Tensor parity improved only slightly and is still not near bit-identical: + - request 0 max abs diff `0.0380859375`, mean abs diff + `0.0047932155` + - request 1 max abs diff `0.1457520127`, mean abs diff + `0.0196863897` + - aggregate max abs diff `0.1457520127`, mean abs diff + `0.0122398026` + - `torch.allclose(..., atol=1e-4, rtol=1e-4)` failed. +- Interpretation: + - Since the prompt-list path now reuses the same single-prompt text encoding + calls, the remaining drift is likely from batched Wan denoising/model math + rather than tokenization. + - Keep this as a documented limitation in the final report unless a later + exact-denoising mode is added. + +## Mistakes / Dead Ends +- First GPU parity attempt found the `generate_video_batch()` legacy-compat + gap described in Stage 6. This was a useful pre-parity functional bug, not a + numerical mismatch. +- Second GPU parity attempt found that prompt-list tokenizer calls need padding + for variable-length Wan prompts. +- Third GPU parity attempt completed but exposed non-negligible numerical drift. + The next hypothesis is text-encoder sequence/padding drift; the local fix now + preserves the single-prompt text-encoding path for merged requests. +- Fourth GPU parity attempt still showed non-negligible drift, so dynamic + batched denoising is not near bit-identical to sequential denoising for this + Wan latent test. + +## Proposed Standardization +If this port lands cleanly, create a runtime batching SOP covering: +- compatibility-signature design, +- deterministic grouped latent generation, +- per-stage grouped execution contracts, +- Modal-only parity and benchmark commands. + +## Hand-Off +Current branch: `multimodal-gen-batching`. + +Current local implementation state: +- Stage 1 code is committed as `3d3157fb`; state commit is `ce758820`. +- Stage 2 code is committed as `fe5572f3`. +- Stage 2 state commit is `39cb1d43`; branch is pushed to origin through + Stage 2. +- Stage 5 code is committed as `4c6a9dc2`. +- Stage 5 state commit is `1bea1dee`. +- GPU helper, parity hardening fixes, and benchmark state are committed as + `9c6eb355`. +- Final report is committed as `2304837e`. +- Branch was pushed to origin through `2304837e`. +- Final changed-file validation passed on Modal app + `ap-1mFqrE5eCwPkEKnffQcQou` from commit + `2304837ed1bc0e1cd733d61f864d6cb1e7682b26`: + `119 passed, 14 warnings`, and pre-commit passed. +- Post-report Wan T2V SSIM validation: + - H100 attempt `ap-KaJr2loSTefvmj8ijYwWOK` generated videos but failed + before comparison because H100 reference folders are missing. + - L40S run `ap-iWP6PA1IyZbXHDKtIE1LQH` passed: + `2 passed, 6 warnings`; mean SSIM `0.9786614696` for `FLASH_ATTN` + and `0.9743387236` for `TORCH_SDPA`. +- `pre-commit run --all-files` was attempted on Modal app + `ap-r20n8jCBwqQnh8I5Us1yTN`, but failed because yapf/ruff rewrote a + large set of pre-existing repository files. This was not committed because + it would introduce unrelated formatting churn. +- No code changes remain after the final validation report/state update. + +Important constraints: +- Do not edit unrelated untracked files already present in the worktree. +- User pre-approved tool/command use, commits to current branch, and push to + origin after implementation stages. +- Use Modal L40S jobs for tests; local machine should not be treated as a valid + test environment. + +Next step: +Push `multimodal-gen-batching` to origin if the final docs-only state update is +not already present there. diff --git a/fastvideo/api/compat.py b/fastvideo/api/compat.py index ea3ff10ac..ce558de56 100644 --- a/fastvideo/api/compat.py +++ b/fastvideo/api/compat.py @@ -159,6 +159,16 @@ def legacy_from_pretrained_to_config( preset_refine["guidance_scale"] = value elif key in {"enable_stage_verification", "use_fsdp_inference", "disable_autocast"}: engine[key] = value + elif key == "batching_mode": + engine.setdefault("batching", {})["mode"] = value + elif key == "batching_max_size": + engine.setdefault("batching", {})["max_size"] = value + elif key == "batching_delay_ms": + engine.setdefault("batching", {})["delay_ms"] = value + elif key == "batching_config": + engine.setdefault("batching", {})["config_path"] = value + elif key == "enable_batching_metrics": + engine.setdefault("batching", {})["enable_metrics"] = value elif key == "override_text_encoder_quant": quantization["text_encoder_quant"] = value elif key == "workload_type": @@ -244,6 +254,11 @@ def generator_config_to_fastvideo_args(config: GeneratorConfig | Mapping[str, An "enable_stage_verification": engine.enable_stage_verification, "use_fsdp_inference": engine.use_fsdp_inference, "disable_autocast": engine.disable_autocast, + "batching_mode": engine.batching.mode, + "batching_max_size": engine.batching.max_size, + "batching_delay_ms": engine.batching.delay_ms, + "batching_config": engine.batching.config_path, + "enable_batching_metrics": engine.batching.enable_metrics, } if normalized.pipeline.workload_type is not None: kwargs["workload_type"] = normalized.pipeline.workload_type diff --git a/fastvideo/api/schema.py b/fastvideo/api/schema.py index c4fba6e17..29a904a27 100644 --- a/fastvideo/api/schema.py +++ b/fastvideo/api/schema.py @@ -71,6 +71,15 @@ class QuantizationConfig: transformer_quant: str | None = None +@dataclass +class BatchingConfig: + mode: Literal["disabled", "dynamic"] = "disabled" + max_size: int = 1 + delay_ms: float = 0.0 + config_path: str | None = None + enable_metrics: bool = False + + @dataclass class EngineConfig: num_gpus: int = 1 @@ -82,6 +91,7 @@ class EngineConfig: use_fsdp_inference: bool = False disable_autocast: bool = False quantization: QuantizationConfig | None = None + batching: BatchingConfig = field(default_factory=BatchingConfig) @dataclass @@ -280,6 +290,7 @@ class ServeConfig: __all__ = [ + "BatchingConfig", "CompileConfig", "ComponentConfig", "ContinuationState", diff --git a/fastvideo/batching/__init__.py b/fastvideo/batching/__init__.py new file mode 100644 index 000000000..1b55840dd --- /dev/null +++ b/fastvideo/batching/__init__.py @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Dynamic generation batching helpers.""" + +from fastvideo.batching.admission import ( + AdmissionLimit, + BatchAdmissionController, + BatchingRule, + load_batching_config, +) +from fastvideo.batching.signature import ( + BatchCompatibility, + can_dynamic_batch, + dynamic_batch_signature, + resolution_key, +) + +__all__ = [ + "AdmissionLimit", + "BatchAdmissionController", + "BatchCompatibility", + "BatchingRule", + "can_dynamic_batch", + "dynamic_batch_signature", + "load_batching_config", + "resolution_key", +] diff --git a/fastvideo/batching/admission.py b/fastvideo/batching/admission.py new file mode 100644 index 000000000..f200ad309 --- /dev/null +++ b/fastvideo/batching/admission.py @@ -0,0 +1,292 @@ +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import json +import os +from dataclasses import dataclass +from difflib import get_close_matches +from typing import Any + +from fastvideo.batching.signature import resolution_key +from fastvideo.fastvideo_args import FastVideoArgs +from fastvideo.logger import init_logger + +logger = init_logger(__name__) + +_BYTES_PER_GB = 1024**3 +_BATCHING_RULE_KEYS = frozenset({ + "model", + "model_contains", + "resolution", + "device_memory_gb_min", + "device_memory_gb_max", + "offload", + "max_batch_size", + "max_cost", + "calibration", +}) + + +@dataclass(frozen=True) +class AdmissionLimit: + max_batch_size: int + max_cost: float | None = None + cap_reason: str | None = None + + def reject_reason(self, *, batch_size: int, batch_cost: float) -> str | None: + if batch_size > self.max_batch_size: + return self.cap_reason or f"config_cap:{self.max_batch_size}" + if self.max_cost is not None and batch_cost > self.max_cost: + return f"cost_budget:{batch_cost:.0f}>{self.max_cost:.0f}" + return None + + def stop_reason_for_next_cost(self, next_batch_cost: float) -> str | None: + if self.max_cost is not None and next_batch_cost > self.max_cost: + return f"cost_budget_next:{next_batch_cost:.0f}>{self.max_cost:.0f}" + return None + + +@dataclass(frozen=True) +class BatchingRule: + model: str | None = None + model_contains: str | None = None + resolution: str | None = None + device_memory_gb_min: float | None = None + device_memory_gb_max: float | None = None + offload: bool | None = None + max_batch_size: int = 1 + max_cost: float | None = None + source: str = "user" + + @classmethod + def from_dict(cls, data: dict[str, Any], *, source: str) -> BatchingRule: + if not isinstance(data, dict): + raise ValueError(f"batching config rule from {source} must be an object, got {type(data).__name__}") + _validate_rule_keys(data, source=source) + if "max_batch_size" not in data: + raise ValueError("batching config rule requires max_batch_size") + + rule = cls( + model=_optional_str(data.get("model")), + model_contains=_optional_str(data.get("model_contains")), + resolution=_optional_str(data.get("resolution")), + device_memory_gb_min=_optional_float(data.get("device_memory_gb_min")), + device_memory_gb_max=_optional_float(data.get("device_memory_gb_max")), + offload=_optional_bool(data.get("offload")), + max_batch_size=int(data["max_batch_size"]), + max_cost=_optional_float(data.get("max_cost")), + source=source, + ) + rule.validate() + return rule + + def validate(self) -> None: + if self.model is not None and self.model_contains is not None: + raise ValueError("batching config rule cannot set both model and model_contains") + if self.model is None and self.model_contains is None: + raise ValueError("batching config rule requires model or model_contains") + if self.max_batch_size < 1: + raise ValueError("batching config rule max_batch_size must be >= 1") + if self.max_cost is not None and self.max_cost <= 0.0: + raise ValueError("batching config rule max_cost must be > 0") + if (self.device_memory_gb_min is not None and self.device_memory_gb_max is not None + and self.device_memory_gb_min > self.device_memory_gb_max): + raise ValueError("batching config rule device_memory_gb_min must be <= device_memory_gb_max") + + def matches( + self, + *, + model_path: str, + resolution: str | None, + device_memory_gb: float | None, + offload: bool, + ) -> bool: + if self.model is not None and self.model != model_path: + return False + if self.model_contains is not None and self.model_contains not in model_path: + return False + if self.resolution not in (None, "*") and self.resolution != resolution: + return False + if self.offload is not None and self.offload != offload: + return False + if device_memory_gb is None: + return True + if self.device_memory_gb_min is not None and device_memory_gb < self.device_memory_gb_min: + return False + return not (self.device_memory_gb_max is not None and device_memory_gb > self.device_memory_gb_max) + + +class BatchAdmissionController: + + def __init__(self, fastvideo_args: FastVideoArgs, *, gpu_id: int = 0): + self._mode = fastvideo_args.batching_mode + self._user_max_batch_size = max(1, int(fastvideo_args.batching_max_size)) + self._model_path = fastvideo_args.model_path + self._offload = bool(fastvideo_args.dit_cpu_offload or fastvideo_args.dit_layerwise_offload) + self._device_memory_gb = self._get_device_memory_gb(gpu_id) + self._rules = load_batching_config(fastvideo_args.batching_config) + self._pipeline_config = fastvideo_args.pipeline_config + + if self.enabled: + logger.info( + "Batch admission enabled: user_max=%d, device_memory=%.1fGiB, rules=%d", + self._user_max_batch_size, + self._device_memory_gb or 0.0, + len(self._rules), + ) + + @property + def enabled(self) -> bool: + return self._mode == "dynamic" and self._user_max_batch_size > 1 + + def reject_reason_for_candidate(self, current_requests: list[Any], candidate_request: Any) -> str | None: + if not self.enabled: + return None + proposed = current_requests + [candidate_request] + limit = self.limit_for(proposed[0]) + return limit.reject_reason( + batch_size=len(proposed), + batch_cost=self.estimate_batch_cost(proposed), + ) + + def batch_is_full(self, requests: list[Any]) -> bool: + if not self.enabled or not requests: + return len(requests) >= self._user_max_batch_size + + limit = self.limit_for(requests[0]) + if len(requests) >= limit.max_batch_size: + return True + + next_cost = self.estimate_batch_cost(requests + [requests[0]]) + return limit.max_cost is not None and next_cost > limit.max_cost + + def limit_reason_for_batch(self, requests: list[Any]) -> str | None: + if not self.enabled or not requests: + return None + + limit = self.limit_for(requests[0]) + if len(requests) >= limit.max_batch_size: + return limit.cap_reason or f"config_cap:{limit.max_batch_size}" + + next_cost = self.estimate_batch_cost(requests + [requests[0]]) + return limit.stop_reason_for_next_cost(next_cost) + + def max_admissible_batch_size(self, request: Any) -> int: + return self.limit_for(request).max_batch_size + + def limit_for(self, request: Any) -> AdmissionLimit: + rules = self._matching_rules(request) + if not rules: + return AdmissionLimit(max_batch_size=self._user_max_batch_size) + + config_cap = min(rule.max_batch_size for rule in rules) + max_batch_size = min(self._user_max_batch_size, config_cap) + cap_reason = f"config_cap:{max_batch_size}" if max_batch_size < self._user_max_batch_size else None + costs = [rule.max_cost for rule in rules if rule.max_cost is not None] + return AdmissionLimit( + max_batch_size=max(1, max_batch_size), + max_cost=min(costs) if costs else None, + cap_reason=cap_reason, + ) + + def estimate_batch_cost(self, requests: list[Any]) -> float: + return sum(float(self._pipeline_config.estimate_request_cost(request)) for request in requests) + + def _matching_rules(self, request: Any) -> list[BatchingRule]: + return [ + rule for rule in self._rules if rule.matches( + model_path=self._model_path, + resolution=resolution_key(request), + device_memory_gb=self._device_memory_gb, + offload=self._offload, + ) + ] + + @staticmethod + def _get_device_memory_gb(gpu_id: int) -> float | None: + try: + from fastvideo.platforms import current_platform + + return current_platform.get_device_total_memory(gpu_id) / _BYTES_PER_GB + except Exception: + return None + + +def load_batching_config(path: str | None) -> list[BatchingRule]: + if path is None: + return [] + + with open(path, encoding="utf-8") as f: + payload = json.load(f) + + source = os.path.abspath(path) + entries = _config_entries(payload) + rules = [BatchingRule.from_dict(entry, source=source) for entry in entries] + if not rules: + raise ValueError(f"batching config {source} does not contain any rules") + return rules + + +def _config_entries(payload: Any) -> list[dict[str, Any]]: + if isinstance(payload, dict) and payload.get("schema_version") not in (None, 1): + raise ValueError("batching config schema_version must be 1") + if isinstance(payload, dict) and isinstance(payload.get("rules"), list): + return payload["rules"] + if isinstance(payload, list): + return payload + if isinstance(payload, dict): + entries: list[dict[str, Any]] = [] + for key, value in payload.items(): + if key == "schema_version" or not isinstance(value, dict): + continue + model, _sep, resolution = key.partition("|") + entry = dict(value) + if model: + entry.setdefault("model", model) + if resolution: + entry.setdefault("resolution", resolution) + entries.append(entry) + return entries + raise ValueError("batching config must be a {'schema_version': 1, 'rules': [...]} object, " + "a list of rules, or a mapping keyed by model|resolution") + + +def _validate_rule_keys(data: dict[str, Any], *, source: str) -> None: + unknown = sorted(set(data) - _BATCHING_RULE_KEYS) + if not unknown: + return + + hints = [] + for key in unknown: + matches = get_close_matches(key, _BATCHING_RULE_KEYS, n=1) + if matches: + hints.append(f"{key!r} (did you mean {matches[0]!r}?)") + else: + hints.append(repr(key)) + raise ValueError(f"batching config rule from {source} contains unknown key(s): {', '.join(hints)}") + + +def _optional_str(value: Any) -> str | None: + if value is None: + return None + return str(value) + + +def _optional_float(value: Any) -> float | None: + if value is None: + return None + return float(value) + + +def _optional_bool(value: Any) -> bool | None: + if value is None: + return None + if isinstance(value, bool): + return value + if isinstance(value, str): + lowered = value.strip().lower() + if lowered in ("1", "true", "yes", "y", "on"): + return True + if lowered in ("0", "false", "no", "n", "off"): + return False + raise ValueError(f"cannot parse boolean batching config value: {value!r}") diff --git a/fastvideo/batching/signature.py b/fastvideo/batching/signature.py new file mode 100644 index 000000000..a17e214cc --- /dev/null +++ b/fastvideo/batching/signature.py @@ -0,0 +1,178 @@ +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import dataclasses +from dataclasses import dataclass +from enum import Enum +from typing import Any + +from fastvideo.api.sampling_param import SamplingParam + +_SIGNATURE_EXCLUDED_FIELDS = frozenset({ + "prompt", + "prompt_path", + "output_path", + "output_video_name", + "seed", + "save_video", + "return_frames", +}) + +_UNSUPPORTED_DYNAMIC_BATCH_FIELDS = frozenset({ + "image_path", + "pil_image", + "video_path", + "mouse_cond", + "keyboard_cond", + "grid_sizes", + "pose", + "camera_states", + "camera_trajectory", + "action_list", + "action_speed_list", + "gt_latents", + "conditioning_mask", + "c2ws_plucker_emb", + "refine_from", + "stage1_video", + "trajectory_type", + "movement_distance", + "camera_rotation", + "ltx2_images", + "ltx2_conditioning_latent_stage1", + "ltx2_conditioning_latent_stage2", + "ltx2_video_conditions", + "init_audio", + "inpaint_audio", + "inpaint_mask", + "continuation_state", +}) + +_UNSUPPORTED_EXTRA_KEYS = frozenset({ + "ltx2_audio_latents", + "ltx2_audio_clean_latent", + "ltx2_audio_denoise_mask", + "audio_num_frames", + "video_position_offset_sec", +}) + + +@dataclass(frozen=True) +class BatchCompatibility: + can_batch: bool + reason: str | None = None + + +def resolution_key(request: Any) -> str: + height = _first_scalar(getattr(request, "height", None)) + width = _first_scalar(getattr(request, "width", None)) + num_frames = _first_scalar(getattr(request, "num_frames", None)) + return f"{height}x{width}x{num_frames}" + + +def dynamic_batch_signature( + request: SamplingParam, + *, + extra: dict[str, Any] | None = None, +) -> tuple[tuple[str, Any], ...]: + """Build a hashable compatibility signature for a generation request.""" + signature_items: list[tuple[str, Any]] = [] + for field in dataclasses.fields(request): + if field.name in _SIGNATURE_EXCLUDED_FIELDS: + continue + signature_items.append((field.name, _freeze_signature_value(getattr(request, field.name, None)))) + if extra: + signature_items.append(("extra", _freeze_signature_value(extra))) + return tuple(signature_items) + + +def can_dynamic_batch( + base: SamplingParam, + candidate: SamplingParam, + *, + base_extra: dict[str, Any] | None = None, + candidate_extra: dict[str, Any] | None = None, +) -> BatchCompatibility: + """Return whether two FastVideo generation requests can be merged.""" + base_ready = _request_is_batchable(base, extra=base_extra) + if not base_ready.can_batch: + return base_ready + candidate_ready = _request_is_batchable(candidate, extra=candidate_extra) + if not candidate_ready.can_batch: + return candidate_ready + + base_sig = dynamic_batch_signature(base, extra=base_extra) + candidate_sig = dynamic_batch_signature(candidate, extra=candidate_extra) + if base_sig == candidate_sig: + return BatchCompatibility(can_batch=True) + + mismatch = _first_mismatch(base_sig, candidate_sig) + return BatchCompatibility(can_batch=False, reason=mismatch or "signature_mismatch") + + +def _request_is_batchable( + request: SamplingParam, + *, + extra: dict[str, Any] | None = None, +) -> BatchCompatibility: + if not isinstance(request.prompt, str): + return BatchCompatibility(can_batch=False, reason="prompt_type") + if request.prompt_path is not None: + return BatchCompatibility(can_batch=False, reason="prompt_path") + if request.num_videos_per_prompt != 1: + return BatchCompatibility(can_batch=False, reason="num_videos_per_prompt") + if request.return_continuation_state: + return BatchCompatibility(can_batch=False, reason="return_continuation_state") + + for name in _UNSUPPORTED_DYNAMIC_BATCH_FIELDS: + value = getattr(request, name, None) + if _is_present(value): + return BatchCompatibility(can_batch=False, reason=name) + + if extra: + unsupported = sorted(set(extra) & _UNSUPPORTED_EXTRA_KEYS) + if unsupported: + return BatchCompatibility(can_batch=False, reason=f"extra.{unsupported[0]}") + + return BatchCompatibility(can_batch=True) + + +def _freeze_signature_value(value: Any) -> Any: + if isinstance(value, str | int | float | bool | type(None)): + return value + if isinstance(value, Enum): + return value.value + if isinstance(value, dict): + return tuple( + (str(key), _freeze_signature_value(item)) for key, item in sorted(value.items(), key=lambda kv: str(kv[0]))) + if isinstance(value, list | tuple): + return tuple(_freeze_signature_value(item) for item in value) + return repr(value) + + +def _is_present(value: Any) -> bool: + if value is None: + return False + if value is False: + return False + return not (isinstance(value, list | tuple | dict | set) and not value) + + +def _first_scalar(value: Any) -> Any: + if isinstance(value, list | tuple): + return value[0] if value else None + return value + + +def _first_mismatch( + base_sig: tuple[tuple[str, Any], ...], + candidate_sig: tuple[tuple[str, Any], ...], +) -> str | None: + if len(base_sig) != len(candidate_sig): + return "sampling_params" + for (name, base_value), (candidate_name, candidate_value) in zip(base_sig, candidate_sig, strict=True): + if name != candidate_name: + return "sampling_params" + if base_value != candidate_value: + return f"sampling_params.{name}" + return None diff --git a/fastvideo/configs/pipelines/base.py b/fastvideo/configs/pipelines/base.py index 80ddc04b3..f583accc6 100644 --- a/fastvideo/configs/pipelines/base.py +++ b/fastvideo/configs/pipelines/base.py @@ -255,6 +255,27 @@ def check_pipeline_config(self) -> None: f"Length of text postprocess functions ({len(self.postprocess_text_funcs)}) must be equal to length of text preprocessing functions ({len(self.preprocess_text_funcs)})" ) + def estimate_request_cost(self, request: Any) -> float: + """Estimate relative memory/compute cost for batching admission. + + The default is intentionally simple and model-agnostic: pixel count + times frame count. Pipeline subclasses can override this when they have + calibrated costs. + """ + height = getattr(request, "height", None) + width = getattr(request, "width", None) + num_frames = getattr(request, "num_frames", None) + if isinstance(height, list): + height = height[0] if height else None + if isinstance(width, list): + width = width[0] if width else None + if isinstance(num_frames, list): + num_frames = num_frames[0] if num_frames else None + height = int(height or 1) + width = int(width or 1) + num_frames = int(num_frames or 1) + return float(max(1, height) * max(1, width) * max(1, num_frames)) + def dump_to_json(self, file_path: str): output_dict = shallow_asdict(self) del_keys = [] diff --git a/fastvideo/entrypoints/openai/api_server.py b/fastvideo/entrypoints/openai/api_server.py index 2b64087af..27faf67d4 100644 --- a/fastvideo/entrypoints/openai/api_server.py +++ b/fastvideo/entrypoints/openai/api_server.py @@ -10,6 +10,7 @@ from fastvideo.api.presets import validate_preset_selection from fastvideo.api.schema import GenerationRequest +from fastvideo.entrypoints.openai.batching import VideoBatchScheduler from fastvideo.entrypoints.openai.state import ( DEFAULT_OUTPUT_DIR, clear_state, @@ -59,11 +60,29 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]: generator = VideoGenerator.from_fastvideo_args(args) logger.info("Model loaded successfully.") - set_state(generator, args, output_dir, default_request=default_request) + video_batch_scheduler: VideoBatchScheduler | None = None + if args.batching_mode == "dynamic" and args.batching_max_size > 1: + video_batch_scheduler = VideoBatchScheduler(generator, args) + await video_batch_scheduler.start() + logger.info( + "Started dynamic video batch scheduler: max_size=%d delay_ms=%.2f", + args.batching_max_size, + args.batching_delay_ms, + ) + + set_state( + generator, + args, + output_dir, + default_request=default_request, + video_batch_scheduler=video_batch_scheduler, + ) yield # server is running logger.info("Shutting down — releasing model resources ...") + if video_batch_scheduler is not None: + await video_batch_scheduler.stop() generator.shutdown() clear_state() logger.info("Shutdown complete.") diff --git a/fastvideo/entrypoints/openai/batching.py b/fastvideo/entrypoints/openai/batching.py new file mode 100644 index 000000000..68597b865 --- /dev/null +++ b/fastvideo/entrypoints/openai/batching.py @@ -0,0 +1,170 @@ +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import asyncio +import time +from collections import deque +from dataclasses import dataclass +from typing import Any + +from fastvideo.api.sampling_param import SamplingParam +from fastvideo.batching.signature import can_dynamic_batch +from fastvideo.fastvideo_args import FastVideoArgs +from fastvideo.logger import init_logger + +logger = init_logger(__name__) + + +@dataclass +class _VideoBatchJob: + request_id: str + kwargs: dict[str, Any] + future: asyncio.Future + enqueue_time: float + + +class VideoBatchScheduler: + """Async FIFO scheduler for OpenAI-compatible video generation.""" + + def __init__(self, generator: Any, fastvideo_args: FastVideoArgs) -> None: + self._generator = generator + self._fastvideo_args = fastvideo_args + self._queue: asyncio.Queue[_VideoBatchJob | None] = asyncio.Queue() + self._pending: deque[_VideoBatchJob] = deque() + self._task: asyncio.Task | None = None + self._stopped = False + + @property + def enabled(self) -> bool: + return self._fastvideo_args.batching_mode == "dynamic" and self._fastvideo_args.batching_max_size > 1 + + async def start(self) -> None: + if self._task is not None: + return + self._task = asyncio.create_task(self._run(), name="fastvideo-video-batch-scheduler") + + async def stop(self) -> None: + self._stopped = True + await self._queue.put(None) + if self._task is not None: + await self._task + self._task = None + + async def submit(self, request_id: str, kwargs: dict[str, Any]) -> Any: + loop = asyncio.get_running_loop() + future = loop.create_future() + await self._queue.put( + _VideoBatchJob( + request_id=request_id, + kwargs=dict(kwargs), + future=future, + enqueue_time=time.perf_counter(), + )) + return await future + + async def _run(self) -> None: + while not self._stopped: + job = await self._get_next_job() + if job is None: + break + batch = await self._collect_batch(job) + await self._dispatch(batch) + + while self._pending: + pending = self._pending.popleft() + if not pending.future.done(): + pending.future.set_exception(RuntimeError("Video batch scheduler stopped before dispatch")) + + async def _get_next_job(self) -> _VideoBatchJob | None: + if self._pending: + return self._pending.popleft() + return await self._queue.get() + + async def _collect_batch(self, first: _VideoBatchJob) -> list[_VideoBatchJob]: + batch = [first] + max_size = self._fastvideo_args.batching_max_size + delay_s = max(0.0, self._fastvideo_args.batching_delay_ms / 1000.0) + deadline = first.enqueue_time + delay_s + + while len(batch) < max_size: + timeout = deadline - time.perf_counter() + if timeout <= 0 and delay_s > 0: + break + try: + candidate = await asyncio.wait_for(self._get_next_job(), timeout=max(0.0, timeout)) + except TimeoutError: + break + if candidate is None: + await self._queue.put(None) + break + if self._jobs_are_compatible(batch[0], candidate): + batch.append(candidate) + continue + self._pending.append(candidate) + break + return batch + + async def _dispatch(self, batch: list[_VideoBatchJob]) -> None: + loop = asyncio.get_running_loop() + request_ids = [job.request_id for job in batch] + queue_wait_ms = (time.perf_counter() - min(job.enqueue_time for job in batch)) * 1000.0 + if self._fastvideo_args.enable_batching_metrics: + logger.info( + "Dispatching video batch: request_ids=%s size=%d queue_wait_ms=%.2f", + request_ids, + len(batch), + queue_wait_ms, + ) + + try: + results = await loop.run_in_executor( + None, + lambda: self._generator.generate_video_batch([job.kwargs for job in batch]), + ) + except Exception as exc: + for job in batch: + if not job.future.done(): + job.future.set_exception(exc) + return + + if len(results) != len(batch): + error = RuntimeError(f"Video batch returned {len(results)} results for {len(batch)} requests") + for job in batch: + if not job.future.done(): + job.future.set_exception(error) + return + + for job, result in zip(batch, results, strict=True): + if not job.future.done(): + job.future.set_result(result) + + def _jobs_are_compatible(self, base: _VideoBatchJob, candidate: _VideoBatchJob) -> bool: + try: + base_sampling, base_extra = self._sampling_param_from_kwargs(base.kwargs) + candidate_sampling, candidate_extra = self._sampling_param_from_kwargs(candidate.kwargs) + except Exception: + return False + return can_dynamic_batch( + base_sampling, + candidate_sampling, + base_extra=base_extra, + candidate_extra=candidate_extra, + ).can_batch + + def _sampling_param_from_kwargs(self, kwargs: dict[str, Any]) -> tuple[SamplingParam, dict[str, Any]]: + sampling_param = SamplingParam.from_pretrained(self._fastvideo_args.model_path) + updates = dict(kwargs) + prompt = updates.pop("prompt", None) + extra: dict[str, Any] = {} + for key in ( + "ltx2_audio_latents", + "ltx2_audio_clean_latent", + "ltx2_audio_denoise_mask", + "audio_num_frames", + "video_position_offset_sec", + ): + if key in updates: + extra[key] = updates.pop(key) + sampling_param.update(updates) + sampling_param.prompt = prompt + return sampling_param, extra diff --git a/fastvideo/entrypoints/openai/state.py b/fastvideo/entrypoints/openai/state.py index c95f84400..5630dcc37 100644 --- a/fastvideo/entrypoints/openai/state.py +++ b/fastvideo/entrypoints/openai/state.py @@ -11,6 +11,7 @@ if TYPE_CHECKING: from fastvideo.api.schema import GenerationRequest + from fastvideo.entrypoints.openai.batching import VideoBatchScheduler from fastvideo.entrypoints.video_generator import VideoGenerator from fastvideo.fastvideo_args import FastVideoArgs @@ -20,6 +21,7 @@ _fastvideo_args: FastVideoArgs | None = None _output_dir: str = DEFAULT_OUTPUT_DIR _default_request: GenerationRequest | None = None +_video_batch_scheduler: VideoBatchScheduler | None = None def get_generator() -> VideoGenerator: @@ -44,23 +46,31 @@ def get_default_request() -> GenerationRequest | None: return _default_request +def get_video_batch_scheduler() -> VideoBatchScheduler | None: + """Return the video batch scheduler when dynamic batching is enabled.""" + return _video_batch_scheduler + + def set_state( generator: VideoGenerator, fastvideo_args: FastVideoArgs, output_dir: str, default_request: GenerationRequest | None = None, + video_batch_scheduler: VideoBatchScheduler | None = None, ) -> None: """Set all server state at once (called from lifespan).""" - global _generator, _fastvideo_args, _output_dir, _default_request + global _generator, _fastvideo_args, _output_dir, _default_request, _video_batch_scheduler _generator = generator _fastvideo_args = fastvideo_args _output_dir = output_dir _default_request = default_request + _video_batch_scheduler = video_batch_scheduler def clear_state() -> None: """Clear server state on shutdown.""" - global _generator, _fastvideo_args, _default_request + global _generator, _fastvideo_args, _default_request, _video_batch_scheduler _generator = None _fastvideo_args = None _default_request = None + _video_batch_scheduler = None diff --git a/fastvideo/entrypoints/openai/video_api.py b/fastvideo/entrypoints/openai/video_api.py index 4134175f3..b02a748d2 100644 --- a/fastvideo/entrypoints/openai/video_api.py +++ b/fastvideo/entrypoints/openai/video_api.py @@ -26,6 +26,7 @@ get_generator, get_output_dir, get_server_args, + get_video_batch_scheduler, ) from fastvideo.entrypoints.openai.protocol import ( VideoGenerationsRequest, @@ -151,15 +152,19 @@ async def _run_generation(request_id: str, kwargs: dict[str, Any]) -> None: is synchronous) and update the store on completion or failure. """ generator = get_generator() + scheduler = get_video_batch_scheduler() loop = asyncio.get_running_loop() try: start = time.perf_counter() - result = await loop.run_in_executor( - None, - lambda: generator.generate_video(**kwargs), - ) + if scheduler is not None and scheduler.enabled: + result = await scheduler.submit(request_id, kwargs) + else: + result = await loop.run_in_executor( + None, + lambda: generator.generate_video(**kwargs), + ) elapsed = time.perf_counter() - start update: dict[str, Any] = { diff --git a/fastvideo/entrypoints/video_generator.py b/fastvideo/entrypoints/video_generator.py index d136766e7..12893907f 100644 --- a/fastvideo/entrypoints/video_generator.py +++ b/fastvideo/entrypoints/video_generator.py @@ -18,6 +18,7 @@ from collections.abc import Mapping from contextlib import suppress from copy import deepcopy +from dataclasses import dataclass from typing import Any import imageio @@ -26,6 +27,8 @@ import torchvision from einops import rearrange +from fastvideo.batching.admission import BatchAdmissionController +from fastvideo.batching.signature import can_dynamic_batch from fastvideo.api.compat import ( expand_request_prompt_batch, generator_config_to_fastvideo_args, @@ -94,6 +97,11 @@ "pin_cpu_memory", "enable_torch_compile", "torch_compile_kwargs", + "batching_mode", + "batching_max_size", + "batching_delay_ms", + "batching_config", + "enable_batching_metrics", "output_type", "nvfp4_fa4", }) @@ -112,6 +120,17 @@ def _infer_latent_batch_size(batch: ForwardBatch) -> int: return latent_batch_size +@dataclass +class _GenerationWorkItem: + prompt: str + sampling_param: SamplingParam + fastvideo_args: FastVideoArgs + batch: ForwardBatch + output_path: str + target_height: int + target_width: int + + class VideoGenerator: """ A unified class for generating videos using diffusion models. @@ -439,6 +458,68 @@ def generate_video( if log_queue: self.executor.clear_log_queue() + def generate_video_batch(self, request_kwargs: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Generate multiple legacy video requests, batching compatible items.""" + work_items: list[_GenerationWorkItem] = [] + fastvideo_args_by_pipeline_override: dict[tuple[tuple[str, str], ...], FastVideoArgs] = { + (): self.fastvideo_args + } + for raw_kwargs in request_kwargs: + kwargs = dict(raw_kwargs) + prompt = kwargs.pop("prompt", None) + if prompt is None: + raise ValueError("Each batched generation request must include prompt") + if not isinstance(prompt, str): + raise TypeError(f"`prompt` must be a string, but got {type(prompt)}") + + sampling_param = kwargs.pop("sampling_param", None) + if sampling_param is None: + sampling_param = SamplingParam.from_pretrained(self.fastvideo_args.model_path) + else: + sampling_param = deepcopy(sampling_param) + + extra_overrides: dict[str, Any] = {} + for _ek in _BATCH_EXTRA_PASSTHROUGH_KEYS: + if _ek in kwargs: + extra_overrides[_ek] = kwargs.pop(_ek) + + request = legacy_generate_call_to_request( + prompt, + sampling_param, + legacy_kwargs=kwargs, + ) + if not isinstance(request.prompt, str): + raise TypeError(f"`prompt` must be a string, but got {type(request.prompt)}") + + fastvideo_args = self.fastvideo_args + pipeline_overrides = request_to_pipeline_overrides(request) + if pipeline_overrides: + override_key = tuple((key, repr(value)) for key, value in sorted(pipeline_overrides.items())) + fastvideo_args = fastvideo_args_by_pipeline_override.get(override_key) + if fastvideo_args is None: + fastvideo_args = deepcopy(self.fastvideo_args) + for key, value in pipeline_overrides.items(): + if not hasattr(fastvideo_args.pipeline_config, key): + raise ValueError(f"Request field {key!r} is not supported by pipeline config overrides") + setattr(fastvideo_args.pipeline_config, key, deepcopy(value)) + fastvideo_args_by_pipeline_override[override_key] = fastvideo_args + + resolved_sampling_param = request_to_sampling_param( + request, + model_path=self.fastvideo_args.model_path, + ) + output_path = self._prepare_output_path(resolved_sampling_param.output_path, request.prompt) + work_items.append( + self._prepare_generation_work_item( + prompt=request.prompt, + sampling_param=resolved_sampling_param, + fastvideo_args=fastvideo_args, + output_path=output_path, + _extra_overrides=extra_overrides, + )) + + return self._generate_prepared_work_items(work_items) + def _generate_request_impl( self, request: GenerationRequest, @@ -534,6 +615,26 @@ def _generate_video_impl( logger.info("Found %d prompts in %s", len(prompts), prompt_txt_path) + if self._dynamic_batching_enabled(fastvideo_args): + work_items: list[_GenerationWorkItem] = [] + for batch_prompt in prompts: + item_kwargs = dict(kwargs) + item_kwargs["output_path"] = self._prepare_output_path(sampling_param.output_path, batch_prompt) + work_items.append( + self._prepare_generation_work_item( + prompt=batch_prompt, + sampling_param=sampling_param, + fastvideo_args=fastvideo_args, + **item_kwargs, + )) + + results = self._generate_prepared_work_items(work_items) + for i, (result, batch_prompt) in enumerate(zip(results, prompts, strict=True)): + result["prompt_index"] = i + result["prompt"] = batch_prompt + logger.info("Completed batch processing. Generated %d videos successfully.", len(results)) + return results + results = [] for i, batch_prompt in enumerate(prompts): logger.info("Processing prompt %d/%d: %s...", i + 1, len(prompts), batch_prompt[:100]) @@ -652,6 +753,379 @@ def _sanitize_filename_component(name: str) -> str: counter += 1 return new_output_path + def _dynamic_batching_enabled(self, fastvideo_args: FastVideoArgs) -> bool: + batching_mode = getattr(fastvideo_args, "batching_mode", "disabled") + batching_max_size = getattr(fastvideo_args, "batching_max_size", 1) + return batching_mode == "dynamic" and batching_max_size > 1 + + def _prepare_generation_work_item( + self, + prompt: str | list[str], + sampling_param: SamplingParam, + fastvideo_args: FastVideoArgs, + **kwargs, + ) -> _GenerationWorkItem: + if isinstance(prompt, str): + prompt_for_output = prompt.strip() + prompt_value: str | list[str] = prompt_for_output + elif isinstance(prompt, list) and all(isinstance(item, str) for item in prompt): + prompt_value = [item.strip() for item in prompt] + prompt_for_output = prompt_value[0] if prompt_value else "" + else: + raise TypeError(f"`prompt` must be a string or list of strings, but got {type(prompt)}") + + sampling_param = deepcopy(sampling_param) + output_path = kwargs["output_path"] + sampling_param.prompt = prompt_value + if sampling_param.negative_prompt is not None: + sampling_param.negative_prompt = sampling_param.negative_prompt.strip() + + if sampling_param.height <= 0 or sampling_param.width <= 0 or sampling_param.num_frames <= 0: + raise ValueError(f"Height, width, and num_frames must be positive integers, got " + f"height={sampling_param.height}, width={sampling_param.width}, " + f"num_frames={sampling_param.num_frames}") + + target_height = align_to(sampling_param.height, 16) + target_width = align_to(sampling_param.width, 16) + latents_size = [(sampling_param.num_frames - 1) // 4 + 1, sampling_param.height // 8, sampling_param.width // 8] + n_tokens = latents_size[0] * latents_size[1] * latents_size[2] + + debug_str = f""" + height: {target_height} + width: {target_width} + video_length: {sampling_param.num_frames} + prompt: {sampling_param.prompt} + image_path: {sampling_param.image_path} + neg_prompt: {sampling_param.negative_prompt} + seed: {sampling_param.seed} + infer_steps: {sampling_param.num_inference_steps} + num_videos_per_prompt: {sampling_param.num_videos_per_prompt} + guidance_scale: {sampling_param.guidance_scale} + n_tokens: {n_tokens} + flow_shift: {fastvideo_args.pipeline_config.flow_shift} + embedded_guidance_scale: {fastvideo_args.pipeline_config.embedded_cfg_scale} + save_video: {sampling_param.save_video} + output_path: {output_path} + """ # type: ignore[attr-defined] + logger.info(debug_str) + + batch = ForwardBatch( + **shallow_asdict(sampling_param), + eta=0.0, + n_tokens=n_tokens, + VSA_sparsity=fastvideo_args.VSA_sparsity, + ) + + extra_overrides = kwargs.get("_extra_overrides", {}) + for _ek, _ev in extra_overrides.items(): + batch.extra[_ek] = _ev + + return _GenerationWorkItem( + prompt=prompt_for_output, + sampling_param=sampling_param, + fastvideo_args=fastvideo_args, + batch=batch, + output_path=output_path, + target_height=target_height, + target_width=target_width, + ) + + def _run_forward_batch( + self, + batch: ForwardBatch, + fastvideo_args: FastVideoArgs, + ) -> tuple[ForwardBatch, float, float]: + start_time = time.perf_counter() + result_container = {"output_batch": ForwardBatch(data_type=batch.data_type)} + thread_error: dict[str, BaseException | None] = {"error": None} + thread_error_traceback: dict[str, str] = {"traceback": ""} + + def execute_forward_thread(): + import traceback + try: + result_container["output_batch"] = self.executor.execute_forward(batch, fastvideo_args) + except BaseException as error: # noqa: BLE001 + thread_error["error"] = error + thread_error_traceback["traceback"] = traceback.format_exc() + + thread = threading.Thread(target=execute_forward_thread) + thread.start() + thread.join() + + if thread_error["error"] is not None: + raise RuntimeError("Forward execution thread failed.\n" + f"{thread_error_traceback['traceback']}") from thread_error["error"] + + output_batch = result_container["output_batch"] + if output_batch.output is None: + raise RuntimeError("Forward execution returned no output tensor. " + "This usually means the executor/pipeline failed earlier.") + + gen_time = time.perf_counter() - start_time + logger.info("Generated successfully in %.2f seconds", gen_time) + return output_batch, gen_time, start_time + + def _samples_from_output( + self, + work_item: _GenerationWorkItem, + output_batch: ForwardBatch, + ) -> torch.Tensor: + output = output_batch.output + if output is None: + raise RuntimeError("Forward execution returned no output tensor.") + fastvideo_args = work_item.fastvideo_args + sampling_param = work_item.sampling_param + latent_batch_size = _infer_latent_batch_size(work_item.batch) + skip_pixel_prealloc = fastvideo_args.output_type == "latent" + expected_shape = ( + latent_batch_size, + 3, + sampling_param.num_frames, + sampling_param.height, + sampling_param.width, + ) + if skip_pixel_prealloc: + return output.cpu() + samples = torch.empty(expected_shape, device="cpu", pin_memory=fastvideo_args.pin_cpu_memory) + if output.shape == samples.shape: + samples.copy_(output) + return samples + logger.warning("Output shape %s does not match expected shape %s; use slow path", output.shape, samples.shape) + return output.cpu() + + def _postprocess_generation_output( + self, + work_item: _GenerationWorkItem, + output_batch: ForwardBatch, + gen_time: float, + start_time: float, + ) -> dict[str, Any]: + batch = work_item.batch + fastvideo_args = work_item.fastvideo_args + output_path = work_item.output_path + samples = self._samples_from_output(work_item, output_batch) + logging_info = output_batch.logging_info + + is_latent_output = fastvideo_args.output_type == "latent" + audio_only = bool(output_batch.extra.get("audio_only")) + + postprocess_start = time.perf_counter() + frames: list[np.ndarray] | None + if is_latent_output or audio_only: + frames = None if is_latent_output else [] + else: + videos = rearrange(samples, "b c t h w -> t b c h w") + frames = [] + for x in videos: + x = torchvision.utils.make_grid(x, nrow=6) + x = x.permute(1, 2, 0).squeeze(-1) + x = (x * 255).to(torch.uint8) + frames.append(x.contiguous().cpu().numpy()) + postprocess_time = time.perf_counter() - postprocess_start + logger.info("PostDecodeFrameProcessStage completed in %.3f s", postprocess_time) + if logging_info is not None: + logging_info.add_stage_execution_time("PostDecodeFrameProcessStage", postprocess_time) + + save_to_disk = batch.save_video and not is_latent_output + save_video_time = 0.0 + audio_mux_time = 0.0 + if save_to_disk: + if audio_only: + output_path = self._rewrite_extension(output_path, ".wav") + save_start = time.perf_counter() + self._write_pcm_wav( + output_path, + output_batch.extra["audio"], + int(output_batch.extra["audio_sample_rate"]), + ) + save_video_time = time.perf_counter() - save_start + logger.info("Saved audio to %s", output_path) + elif self._is_image_workload(): + assert frames is not None + save_start = time.perf_counter() + imageio.imwrite(output_path, frames[0]) + save_video_time = time.perf_counter() - save_start + logger.info("Saved image to %s", output_path) + else: + assert frames is not None + audio = output_batch.extra.get("audio") + audio_sample_rate = output_batch.extra.get("audio_sample_rate") + if audio is not None and audio_sample_rate is not None: + save_start = time.perf_counter() + save_ok = self._save_video_with_audio_ffmpeg_pipe( + output_path=output_path, + frames=frames, + fps=batch.fps, + audio=audio, + sample_rate=int(audio_sample_rate), + ) + if not save_ok: + logger.warning("ffmpeg pipe save failed; trying PyAV single-pass save.") + save_ok = self._save_video_with_audio_single_pass( + output_path=output_path, + frames=frames, + fps=batch.fps, + audio=audio, + sample_rate=int(audio_sample_rate), + ) + save_video_time = time.perf_counter() - save_start + if save_ok: + audio_mux_time = 0.0 + else: + logger.warning("Single-pass save failed; falling back to two-step save/mux.") + save_start = time.perf_counter() + imageio.mimsave(output_path, frames, fps=batch.fps, format="mp4") + save_video_time = time.perf_counter() - save_start + mux_start = time.perf_counter() + mux_ok = self._mux_audio(output_path, audio, int(audio_sample_rate)) + audio_mux_time = time.perf_counter() - mux_start + if not mux_ok: + logger.warning("Audio mux failed; saved video without audio.") + else: + save_start = time.perf_counter() + imageio.mimsave(output_path, frames, fps=batch.fps, format="mp4") + save_video_time = time.perf_counter() - save_start + audio_mux_time = 0.0 + logger.info("Saved video to %s", output_path) + + logger.info("VideoSaveStage completed in %.3f s", save_video_time) + if logging_info is not None: + logging_info.add_stage_execution_time("VideoSaveStage", save_video_time) + logger.info("AudioMuxStage completed in %.3f s", audio_mux_time) + if logging_info is not None: + logging_info.add_stage_execution_time("AudioMuxStage", audio_mux_time) + + e2e_time = time.perf_counter() - start_time + logger.info("End-to-end latency: %.2f seconds", e2e_time) + + return { + "prompts": work_item.prompt, + "samples": samples if batch.return_frames else None, + "frames": frames if batch.return_frames else None, + "audio": output_batch.extra.get("audio"), + "audio_sample_rate": output_batch.extra.get("audio_sample_rate"), + "ltx2_audio_latents": output_batch.extra.get("ltx2_audio_latents"), + "size": (work_item.target_height, work_item.target_width, batch.num_frames), + "generation_time": gen_time, + "e2e_latency": e2e_time, + "logging_info": logging_info, + "trajectory": output_batch.trajectory_latents, + "trajectory_timesteps": output_batch.trajectory_timesteps, + "trajectory_decoded": output_batch.trajectory_decoded, + "video_path": output_path if save_to_disk else None, + "peak_memory_mb": output_batch.extra.get("peak_memory_mb"), + } + + def _split_output_batch( + self, + output_batch: ForwardBatch, + *, + index: int, + batch_size: int, + ) -> ForwardBatch: + extra = {} + for key, value in (output_batch.extra or {}).items(): + if torch.is_tensor(value) and value.ndim > 0 and value.shape[0] == batch_size: + extra[key] = value[index:index + 1] + elif isinstance(value, list) and len(value) == batch_size: + extra[key] = value[index] + else: + extra[key] = value + + result = ForwardBatch( + data_type=output_batch.data_type, + output=(output_batch.output[index:index + 1] if output_batch.output is not None else None), + logging_info=output_batch.logging_info, + extra=extra, + ) + if output_batch.trajectory_latents is not None: + result.trajectory_latents = output_batch.trajectory_latents[index:index + 1] + result.trajectory_timesteps = output_batch.trajectory_timesteps + if output_batch.trajectory_decoded is not None: + result.trajectory_decoded = [ + decoded[index:index + 1] if torch.is_tensor(decoded) and decoded.shape[0] == batch_size else decoded + for decoded in output_batch.trajectory_decoded + ] + return result + + def _merge_work_items(self, work_items: list[_GenerationWorkItem]) -> _GenerationWorkItem: + first = work_items[0] + sampling_param = deepcopy(first.sampling_param) + prompts = [item.prompt for item in work_items] + sampling_param.prompt = prompts + sampling_param.seed = work_items[0].sampling_param.seed + + merged = self._prepare_generation_work_item( + prompts, + sampling_param, + first.fastvideo_args, + output_path=first.output_path, + _extra_overrides=first.batch.extra, + ) + merged.batch.seeds = [int(item.sampling_param.seed) for item in work_items] + merged.batch.extra["dynamic_batch_size"] = len(work_items) + merged.batch.extra["dynamic_batch_output_paths"] = [item.output_path for item in work_items] + return merged + + def _can_merge_work_items( + self, + base: _GenerationWorkItem, + candidate: _GenerationWorkItem, + admission: BatchAdmissionController, + current_group: list[_GenerationWorkItem], + ) -> bool: + if candidate.fastvideo_args is not base.fastvideo_args: + return False + compatibility = can_dynamic_batch( + base.sampling_param, + candidate.sampling_param, + base_extra=base.batch.extra, + candidate_extra=candidate.batch.extra, + ) + if not compatibility.can_batch: + return False + current_requests = [item.sampling_param for item in current_group] + return admission.reject_reason_for_candidate(current_requests, candidate.sampling_param) is None + + def _generate_prepared_work_items( + self, + work_items: list[_GenerationWorkItem], + ) -> list[dict[str, Any]]: + if not work_items: + return [] + fastvideo_args = work_items[0].fastvideo_args + if not self._dynamic_batching_enabled(fastvideo_args): + return [self._execute_single_work_item(item) for item in work_items] + + admission = BatchAdmissionController(fastvideo_args) + results: list[dict[str, Any]] = [] + index = 0 + while index < len(work_items): + group = [work_items[index]] + index += 1 + while index < len(work_items) and len(group) < fastvideo_args.batching_max_size: + candidate = work_items[index] + if not self._can_merge_work_items(group[0], candidate, admission, group): + break + group.append(candidate) + index += 1 + + if len(group) == 1: + results.append(self._execute_single_work_item(group[0])) + continue + + merged = self._merge_work_items(group) + output_batch, gen_time, start_time = self._run_forward_batch(merged.batch, merged.fastvideo_args) + batch_size = len(group) + for item_index, item in enumerate(group): + split_batch = self._split_output_batch(output_batch, index=item_index, batch_size=batch_size) + results.append(self._postprocess_generation_output(item, split_batch, gen_time, start_time)) + return results + + def _execute_single_work_item(self, work_item: _GenerationWorkItem) -> dict[str, Any]: + output_batch, gen_time, start_time = self._run_forward_batch(work_item.batch, work_item.fastvideo_args) + return self._postprocess_generation_output(work_item, output_batch, gen_time, start_time) + def _generate_single_video( self, prompt: str, diff --git a/fastvideo/fastvideo_args.py b/fastvideo/fastvideo_args.py index cb1391a84..666b75e8c 100644 --- a/fastvideo/fastvideo_args.py +++ b/fastvideo/fastvideo_args.py @@ -169,6 +169,14 @@ class FastVideoArgs: # Prompt text file for batch processing prompt_txt: str | None = None + # Dynamic multimodal generation batching. Defaults preserve the historical + # one-request-at-a-time execution path. + batching_mode: str = "disabled" + batching_max_size: int = 1 + batching_delay_ms: float = 0.0 + batching_config: str | None = None + enable_batching_metrics: bool = False + # LTX-2 VAE tiling overrides ltx2_vae_tiling: bool | None = None ltx2_vae_spatial_tile_size_in_pixels: int | None = None @@ -440,6 +448,37 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=FastVideoArgs.prompt_txt, help="Path to a text file containing prompts (one per line) for batch processing", ) + parser.add_argument( + "--batching-mode", + type=str, + choices=["disabled", "dynamic"], + default=FastVideoArgs.batching_mode, + help="Request batching mode for inference serving.", + ) + parser.add_argument( + "--batching-max-size", + type=int, + default=FastVideoArgs.batching_max_size, + help="Maximum number of compatible generation requests to execute as one batch.", + ) + parser.add_argument( + "--batching-delay-ms", + type=float, + default=FastVideoArgs.batching_delay_ms, + help="Maximum queue delay in milliseconds before dispatching a dynamic batch.", + ) + parser.add_argument( + "--batching-config", + type=str, + default=FastVideoArgs.batching_config, + help="Optional JSON batching admission rule file.", + ) + parser.add_argument( + "--enable-batching-metrics", + action=StoreBoolean, + default=FastVideoArgs.enable_batching_metrics, + help="Log dynamic batching utilization and rejection metrics.", + ) # LTX-2 VAE tiling overrides parser.add_argument( @@ -747,6 +786,13 @@ def check_fastvideo_args(self) -> None: WorkloadType), f"Workload type must be a WorkloadType enum, got {type(self.workload_type)}" assert self.workload_type in WorkloadType.choices(), f"Invalid workload type: {self.workload_type}" + if self.batching_mode not in {"disabled", "dynamic"}: + raise ValueError(f"batching_mode must be 'disabled' or 'dynamic', got {self.batching_mode!r}") + if self.batching_max_size < 1: + raise ValueError("batching_max_size must be >= 1") + if self.batching_delay_ms < 0: + raise ValueError("batching_delay_ms must be >= 0") + if self.mode in [ExecutionMode.DISTILLATION, ExecutionMode.FINETUNING] and self.inference_mode: logger.warning("Mode is 'training' but inference_mode is True. Setting inference_mode to False.") self.inference_mode = False diff --git a/fastvideo/pipelines/stages/denoising.py b/fastvideo/pipelines/stages/denoising.py index c112f48ca..8ddf92a1d 100644 --- a/fastvideo/pipelines/stages/denoising.py +++ b/fastvideo/pipelines/stages/denoising.py @@ -175,7 +175,6 @@ def forward( boundary_timestep = boundary_ratio * self.scheduler.num_train_timesteps if boundary_ratio is not None else None latent_model_input = latents.to(target_dtype) - assert latent_model_input.shape[0] == 1, "only support batch size 1" if fastvideo_args.pipeline_config.ti2v_task and batch.pil_image is not None: # TI2V directly replaces the first frame of the latent with diff --git a/fastvideo/pipelines/stages/input_validation.py b/fastvideo/pipelines/stages/input_validation.py index 0490954e3..885e1bf30 100644 --- a/fastvideo/pipelines/stages/input_validation.py +++ b/fastvideo/pipelines/stages/input_validation.py @@ -35,7 +35,11 @@ def _generate_seeds(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs): num_videos_per_prompt = batch.num_videos_per_prompt assert seed is not None - seeds = [seed + i for i in range(num_videos_per_prompt)] + if batch.seeds is not None: + seeds = batch.seeds + else: + prompt_count = len(batch.prompt) if isinstance(batch.prompt, list) else 1 + seeds = [seed + i for i in range(prompt_count * num_videos_per_prompt)] batch.seeds = seeds # Peiyuan: using GPU seed will cause A100 and H100 to generate different results... diff --git a/fastvideo/pipelines/stages/text_encoding.py b/fastvideo/pipelines/stages/text_encoding.py index f86fe60d0..54457803b 100644 --- a/fastvideo/pipelines/stages/text_encoding.py +++ b/fastvideo/pipelines/stages/text_encoding.py @@ -61,12 +61,20 @@ def forward( assert batch.prompt is not None prompt_text: str | list[str] = batch.prompt all_indices: list[int] = list(range(len(self.text_encoders))) - prompt_embeds_list, prompt_masks_list = self.encode_text( - prompt_text, - fastvideo_args, - encoder_index=all_indices, - return_attention_mask=True, - ) + if isinstance(prompt_text, list): + prompt_embeds_list, prompt_masks_list = self._encode_prompt_list_individually( + prompt_text, + fastvideo_args, + encoder_index=all_indices, + return_attention_mask=True, + ) + else: + prompt_embeds_list, prompt_masks_list = self.encode_text( + prompt_text, + fastvideo_args, + encoder_index=all_indices, + return_attention_mask=True, + ) if self._last_audio_embeds is not None: batch.extra["ltx2_audio_prompt_embeds"] = self._last_audio_embeds @@ -78,13 +86,24 @@ def forward( # Encode negative prompt if CFG is enabled if batch.do_classifier_free_guidance: - assert isinstance(batch.negative_prompt, str) - neg_embeds_list, neg_masks_list = self.encode_text( - batch.negative_prompt, - fastvideo_args, - encoder_index=all_indices, - return_attention_mask=True, - ) + assert isinstance(batch.negative_prompt, str | list) + negative_prompt: str | list[str] = batch.negative_prompt + if isinstance(batch.prompt, list) and isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * len(batch.prompt) + if isinstance(negative_prompt, list): + neg_embeds_list, neg_masks_list = self._encode_prompt_list_individually( + negative_prompt, + fastvideo_args, + encoder_index=all_indices, + return_attention_mask=True, + ) + else: + neg_embeds_list, neg_masks_list = self.encode_text( + negative_prompt, + fastvideo_args, + encoder_index=all_indices, + return_attention_mask=True, + ) if self._last_audio_embeds is not None: batch.extra["ltx2_audio_negative_embeds"] = self._last_audio_embeds @@ -97,6 +116,63 @@ def forward( return batch + def _encode_prompt_list_individually( + self, + texts: list[str], + fastvideo_args: FastVideoArgs, + *, + encoder_index: list[int], + return_attention_mask: bool, + ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: + per_prompt_embeds: list[list[torch.Tensor]] = [] + per_prompt_masks: list[list[torch.Tensor]] = [] + per_prompt_audio_embeds: list[list[torch.Tensor] | None] = [] + + for text in texts: + embeds, masks = self.encode_text( + text, + fastvideo_args, + encoder_index=encoder_index, + return_attention_mask=return_attention_mask, + ) + per_prompt_embeds.append(embeds) + per_prompt_masks.append(masks) + per_prompt_audio_embeds.append(self._last_audio_embeds) + + merged_embeds = [ + torch.cat([prompt_embeds[encoder_pos] for prompt_embeds in per_prompt_embeds], dim=0) + for encoder_pos in range(len(per_prompt_embeds[0])) + ] + merged_masks = [ + self._cat_attention_masks([prompt_masks[encoder_pos] for prompt_masks in per_prompt_masks]) + for encoder_pos in range(len(per_prompt_masks[0])) + ] + if per_prompt_audio_embeds and all(audio_embeds is not None for audio_embeds in per_prompt_audio_embeds): + audio_embed_lists = [audio_embeds for audio_embeds in per_prompt_audio_embeds if audio_embeds is not None] + self._last_audio_embeds = [ + torch.cat([audio_embeds[encoder_pos] for audio_embeds in audio_embed_lists], dim=0) + for encoder_pos in range(len(audio_embed_lists[0])) + ] + else: + self._last_audio_embeds = None + return merged_embeds, merged_masks + + @staticmethod + def _cat_attention_masks(masks: list[torch.Tensor]) -> torch.Tensor: + base_shape = masks[0].shape[1:] + if all(mask.shape[1:] == base_shape for mask in masks): + return torch.cat(masks, dim=0) + if all(mask.ndim == 2 for mask in masks): + max_length = max(mask.shape[1] for mask in masks) + padded_masks = [] + for mask in masks: + pad_width = max_length - mask.shape[1] + if pad_width > 0: + mask = torch.nn.functional.pad(mask, (0, pad_width), value=0) + padded_masks.append(mask) + return torch.cat(padded_masks, dim=0) + raise ValueError(f"Cannot concatenate attention masks with shapes: {[list(mask.shape) for mask in masks]}") + def verify_input(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult: """Verify text encoding stage inputs.""" result = VerificationResult() @@ -230,6 +306,9 @@ def encode_text( attn_masks_list.append(attention_mask) return self.return_embeds(embeds_list, attn_masks_list, return_type, return_attention_mask, indices) + if len(processed_texts) > 1 and "padding" not in tok_kwargs: + tok_kwargs["padding"] = True + # If tokenizer is a multimodal processor (e.g. Qwen2_5_VLProcessor), # use its inner tokenizer for text-only encoding. tok = getattr(tokenizer, "tokenizer", tokenizer) diff --git a/fastvideo/tests/api/test_compat_translation.py b/fastvideo/tests/api/test_compat_translation.py index 759965136..19c8d56e3 100644 --- a/fastvideo/tests/api/test_compat_translation.py +++ b/fastvideo/tests/api/test_compat_translation.py @@ -9,7 +9,7 @@ generator_config_to_fastvideo_args, legacy_from_pretrained_to_config, ) -from fastvideo.api.schema import CompileConfig, GeneratorConfig +from fastvideo.api.schema import BatchingConfig, CompileConfig, GeneratorConfig class TestLegacyTorchCompileKwargsTranslation: @@ -200,6 +200,48 @@ def test_reverse_unset_skips_key(self, monkeypatch) -> None: assert "enable_torch_compile_text_encoder" not in args.kwargs +class TestBatchingTranslation: + + def test_flat_kwargs_promote_to_engine_batching(self) -> None: + config = legacy_from_pretrained_to_config( + "/models/wan", + { + "batching_mode": "dynamic", + "batching_max_size": 4, + "batching_delay_ms": 25.0, + "batching_config": "/tmp/batching.json", + "enable_batching_metrics": True, + }, + ) + + assert config.engine.batching.mode == "dynamic" + assert config.engine.batching.max_size == 4 + assert config.engine.batching.delay_ms == 25.0 + assert config.engine.batching.config_path == "/tmp/batching.json" + assert config.engine.batching.enable_metrics is True + + def test_typed_batching_emits_fastvideo_args_kwargs(self, monkeypatch) -> None: + _stub_fastvideo_args_from_kwargs(monkeypatch) + config = GeneratorConfig( + model_path="/models/wan", + engine=_engine_with_batching(BatchingConfig( + mode="dynamic", + max_size=3, + delay_ms=10.0, + config_path="/tmp/batching.json", + enable_metrics=True, + )), + ) + + args = generator_config_to_fastvideo_args(config) + + assert args.kwargs["batching_mode"] == "dynamic" + assert args.kwargs["batching_max_size"] == 3 + assert args.kwargs["batching_delay_ms"] == 10.0 + assert args.kwargs["batching_config"] == "/tmp/batching.json" + assert args.kwargs["enable_batching_metrics"] is True + + # ------------------------------------------------------------------- # Helpers # ------------------------------------------------------------------- @@ -213,6 +255,13 @@ def _engine_with_compile(compile_config): return engine +def _engine_with_batching(batching_config): + from fastvideo.api.schema import EngineConfig + engine = EngineConfig() + engine.batching = batching_config + return engine + + def _stub_fastvideo_args_from_kwargs(monkeypatch): """Swap ``FastVideoArgs.from_kwargs`` for a capture-only stub so translation tests don't need to construct a valid FastVideoArgs.""" diff --git a/fastvideo/tests/batching/run_dynamic_batching_parity.py b/fastvideo/tests/batching/run_dynamic_batching_parity.py new file mode 100644 index 000000000..b495285fc --- /dev/null +++ b/fastvideo/tests/batching/run_dynamic_batching_parity.py @@ -0,0 +1,198 @@ +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import argparse +import json +import os +import time +from pathlib import Path +from typing import Any + +import torch + +from fastvideo import VideoGenerator + +DEFAULT_PROMPTS = ( + "A small robot sketches a city skyline at sunrise, cinematic lighting.", + "A glass teapot steams on a wooden table while rain falls outside.", +) + + +def _build_init_kwargs(args: argparse.Namespace, *, dynamic: bool) -> dict[str, Any]: + return { + "num_gpus": args.num_gpus, + "sp_size": args.sp_size, + "tp_size": args.tp_size, + "use_fsdp_inference": args.use_fsdp_inference, + "dit_cpu_offload": False, + "dit_layerwise_offload": False, + "flow_shift": args.flow_shift, + "text_encoder_precisions": ("fp32",), + "output_type": "latent", + "batching_mode": "dynamic" if dynamic else "disabled", + "batching_max_size": args.batch_size if dynamic else 1, + "batching_delay_ms": 0.0, + } + + +def _request_kwargs(args: argparse.Namespace, prompt_index: int) -> dict[str, Any]: + return { + "prompt": args.prompts[prompt_index], + "height": args.height, + "width": args.width, + "num_frames": args.num_frames, + "num_inference_steps": args.num_inference_steps, + "guidance_scale": args.guidance_scale, + "embedded_cfg_scale": args.embedded_cfg_scale, + "seed": args.seed + prompt_index, + "fps": 24, + "save_video": False, + "return_frames": True, + "output_path": str(Path(args.output_dir) / f"request_{prompt_index}.mp4"), + } + + +def _sync() -> None: + if torch.cuda.is_available(): + torch.cuda.synchronize() + + +def _run_sequential(generator: VideoGenerator, args: argparse.Namespace) -> tuple[list[dict[str, Any]], float]: + _sync() + start = time.perf_counter() + results = [] + for index in range(args.batch_size): + kwargs = _request_kwargs(args, index) + prompt = kwargs.pop("prompt") + results.append(generator.generate_video(prompt=prompt, **kwargs)) + _sync() + return results, time.perf_counter() - start + + +def _run_dynamic(generator: VideoGenerator, args: argparse.Namespace) -> tuple[list[dict[str, Any]], float]: + if not hasattr(generator, "generate_video_batch"): + raise RuntimeError("VideoGenerator.generate_video_batch is unavailable in this checkout") + requests = [_request_kwargs(args, index) for index in range(args.batch_size)] + _sync() + start = time.perf_counter() + results = generator.generate_video_batch(requests) + _sync() + return results, time.perf_counter() - start + + +def _tensor_metrics(sequential: list[dict[str, Any]], dynamic: list[dict[str, Any]]) -> dict[str, Any]: + per_request = [] + for index, (seq_result, dyn_result) in enumerate(zip(sequential, dynamic, strict=True)): + seq = seq_result["samples"].detach().cpu().to(torch.float32) + dyn = dyn_result["samples"].detach().cpu().to(torch.float32) + diff = (seq - dyn).abs() + per_request.append({ + "index": index, + "shape": list(seq.shape), + "max_abs_diff": float(diff.max().item()), + "mean_abs_diff": float(diff.mean().item()), + "allclose_atol_1e_5": bool(torch.allclose(seq, dyn, atol=1e-5, rtol=1e-5)), + "allclose_atol_1e_4": bool(torch.allclose(seq, dyn, atol=1e-4, rtol=1e-4)), + }) + return { + "per_request": per_request, + "max_abs_diff": max(item["max_abs_diff"] for item in per_request), + "mean_abs_diff": sum(item["mean_abs_diff"] for item in per_request) / len(per_request), + "allclose_atol_1e_5": all(item["allclose_atol_1e_5"] for item in per_request), + "allclose_atol_1e_4": all(item["allclose_atol_1e_4"] for item in per_request), + } + + +def run_parity(args: argparse.Namespace) -> dict[str, Any]: + generator = VideoGenerator.from_pretrained(args.model_path, **_build_init_kwargs(args, dynamic=True)) + try: + sequential, sequential_s = _run_sequential(generator, args) + dynamic, dynamic_s = _run_dynamic(generator, args) + metrics = _tensor_metrics(sequential, dynamic) + finally: + generator.shutdown() + return { + "mode": "parity", + "model_path": args.model_path, + "num_gpus": args.num_gpus, + "shape": { + "height": args.height, + "width": args.width, + "num_frames": args.num_frames, + "num_inference_steps": args.num_inference_steps, + }, + "batch_size": args.batch_size, + "sequential_time_s": sequential_s, + "dynamic_time_s": dynamic_s, + "speedup": sequential_s / dynamic_s if dynamic_s > 0 else None, + "tensor_metrics": metrics, + } + + +def run_benchmark(args: argparse.Namespace, *, dynamic: bool) -> dict[str, Any]: + generator = VideoGenerator.from_pretrained(args.model_path, **_build_init_kwargs(args, dynamic=dynamic)) + run = _run_dynamic if dynamic else _run_sequential + try: + for _ in range(args.warmup_runs): + run(generator, args) + times = [] + for _ in range(args.measurement_runs): + _results, elapsed = run(generator, args) + times.append(elapsed) + finally: + generator.shutdown() + avg = sum(times) / len(times) + return { + "mode": "dynamic" if dynamic else "sequential", + "model_path": args.model_path, + "num_gpus": args.num_gpus, + "batch_size": args.batch_size, + "measurement_runs": args.measurement_runs, + "times_s": times, + "avg_time_s": avg, + "requests_per_second": args.batch_size / avg if avg > 0 else None, + } + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument("--mode", choices=("parity", "sequential", "dynamic"), required=True) + parser.add_argument("--model-path", default="Wan-AI/Wan2.1-T2V-1.3B-Diffusers") + parser.add_argument("--num-gpus", type=int, default=1) + parser.add_argument("--sp-size", type=int, default=1) + parser.add_argument("--tp-size", type=int, default=1) + parser.add_argument("--use-fsdp-inference", action="store_true") + parser.add_argument("--height", type=int, default=256) + parser.add_argument("--width", type=int, default=256) + parser.add_argument("--num-frames", type=int, default=9) + parser.add_argument("--num-inference-steps", type=int, default=2) + parser.add_argument("--guidance-scale", type=float, default=1.0) + parser.add_argument("--embedded-cfg-scale", type=float, default=6.0) + parser.add_argument("--flow-shift", type=float, default=7.0) + parser.add_argument("--seed", type=int, default=1024) + parser.add_argument("--batch-size", type=int, default=2) + parser.add_argument("--warmup-runs", type=int, default=1) + parser.add_argument("--measurement-runs", type=int, default=3) + parser.add_argument("--output-dir", default="/tmp/fastvideo_dynamic_batching") + parser.add_argument("--output-json", required=True) + parser.add_argument("--prompts", nargs="+", default=list(DEFAULT_PROMPTS)) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + if len(args.prompts) < args.batch_size: + raise ValueError("--prompts must contain at least --batch-size prompts") + os.makedirs(args.output_dir, exist_ok=True) + if args.mode == "parity": + result = run_parity(args) + else: + result = run_benchmark(args, dynamic=args.mode == "dynamic") + output_path = Path(args.output_json) + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text(json.dumps(result, indent=2), encoding="utf-8") + print(json.dumps(result, indent=2)) + + +if __name__ == "__main__": + main() diff --git a/fastvideo/tests/batching/test_admission.py b/fastvideo/tests/batching/test_admission.py new file mode 100644 index 000000000..b3b9a75b0 --- /dev/null +++ b/fastvideo/tests/batching/test_admission.py @@ -0,0 +1,75 @@ +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +from fastvideo.batching.admission import ( + AdmissionLimit, + BatchAdmissionController, + BatchingRule, + load_batching_config, +) +from fastvideo.configs.pipelines.base import PipelineConfig + + +def test_admission_limit_rejects_batch_size_and_cost() -> None: + limit = AdmissionLimit(max_batch_size=2, max_cost=10.0) + + assert limit.reject_reason(batch_size=3, batch_cost=1.0) == "config_cap:2" + assert limit.reject_reason(batch_size=2, batch_cost=11.0) == "cost_budget:11>10" + assert limit.reject_reason(batch_size=2, batch_cost=10.0) is None + + +def test_batching_rule_validates_unknown_keys() -> None: + with pytest.raises(ValueError, match="did you mean 'max_batch_size'"): + BatchingRule.from_dict( + { + "model_contains": "wan", + "max_batch_siz": 2, + }, + source="unit", + ) + + +def test_load_batching_config_supports_mapping_form(tmp_path) -> None: + path = tmp_path / "batching.json" + path.write_text( + '{"schema_version": 1, "wan|720x1280x81": {"max_batch_size": 3, "max_cost": 9}}', + encoding="utf-8", + ) + + rules = load_batching_config(str(path)) + + assert len(rules) == 1 + assert rules[0].model == "wan" + assert rules[0].resolution == "720x1280x81" + assert rules[0].max_batch_size == 3 + assert rules[0].max_cost == 9.0 + + +def test_admission_controller_applies_user_and_config_caps(tmp_path, monkeypatch) -> None: + path = tmp_path / "batching.json" + path.write_text( + '{"rules": [{"model_contains": "wan", "resolution": "720x1280x81", "max_batch_size": 3}]}', + encoding="utf-8", + ) + monkeypatch.setattr(BatchAdmissionController, "_get_device_memory_gb", staticmethod(lambda gpu_id: 48.0)) + + args = SimpleNamespace( + batching_mode="dynamic", + batching_max_size=4, + batching_config=str(path), + model_path="/models/wan", + dit_cpu_offload=False, + dit_layerwise_offload=False, + pipeline_config=PipelineConfig(), + ) + request = SimpleNamespace(height=720, width=1280, num_frames=81) + + controller = BatchAdmissionController(args) + + assert controller.enabled is True + assert controller.max_admissible_batch_size(request) == 3 + assert controller.batch_is_full([request, request, request]) is True diff --git a/fastvideo/tests/batching/test_signature.py b/fastvideo/tests/batching/test_signature.py new file mode 100644 index 000000000..bdf09c9ed --- /dev/null +++ b/fastvideo/tests/batching/test_signature.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +from fastvideo.api.sampling_param import SamplingParam +from fastvideo.batching.signature import ( + can_dynamic_batch, + dynamic_batch_signature, + resolution_key, +) + + +def _request(prompt: str = "a prompt", **overrides) -> SamplingParam: + request = SamplingParam(prompt=prompt, height=256, width=384, num_frames=17, num_inference_steps=4) + for key, value in overrides.items(): + setattr(request, key, value) + return request + + +def test_dynamic_batch_signature_excludes_request_local_fields() -> None: + first = _request(seed=1, output_path="/tmp/a.mp4", save_video=True, return_frames=False) + second = _request(seed=2, output_path="/tmp/b.mp4", save_video=False, return_frames=True) + + assert dynamic_batch_signature(first) == dynamic_batch_signature(second) + + +def test_can_dynamic_batch_accepts_matching_text_requests() -> None: + first = _request("first", seed=1) + second = _request("second", seed=2) + + result = can_dynamic_batch(first, second) + + assert result.can_batch is True + assert result.reason is None + + +def test_can_dynamic_batch_rejects_sampling_mismatch() -> None: + first = _request(guidance_scale=1.0) + second = _request(guidance_scale=3.0) + + result = can_dynamic_batch(first, second) + + assert result.can_batch is False + assert result.reason == "sampling_params.guidance_scale" + + +def test_can_dynamic_batch_rejects_image_conditioning() -> None: + first = _request() + second = _request(image_path="/tmp/image.png") + + result = can_dynamic_batch(first, second) + + assert result.can_batch is False + assert result.reason == "image_path" + + +def test_resolution_key_uses_generation_shape() -> None: + assert resolution_key(_request(height=720, width=1280, num_frames=81)) == "720x1280x81" diff --git a/fastvideo/tests/entrypoints/test_openai_api.py b/fastvideo/tests/entrypoints/test_openai_api.py index e4f7179a6..edb50357d 100644 --- a/fastvideo/tests/entrypoints/test_openai_api.py +++ b/fastvideo/tests/entrypoints/test_openai_api.py @@ -1,10 +1,14 @@ """Unit tests for the OpenAI-compatible API server helpers (no GPU needed).""" +import asyncio import os +from types import SimpleNamespace from unittest.mock import patch import pytest +from fastvideo.configs.pipelines.base import PipelineConfig +from fastvideo.entrypoints.openai.batching import VideoBatchScheduler from fastvideo.api.parser import parse_config from fastvideo.api.schema import GenerationRequest from fastvideo.entrypoints.openai.protocol import ( @@ -21,6 +25,112 @@ parse_size, ) + +class _FakeBatchGenerator: + + def __init__(self): + self.calls = [] + + def generate_video_batch(self, request_kwargs): + self.calls.append([dict(item) for item in request_kwargs]) + return [{"prompts": item["prompt"], "video_path": item["output_path"]} for item in request_kwargs] + + +def _batch_scheduler_args(**overrides): + defaults = dict( + model_path="test-model", + batching_mode="dynamic", + batching_max_size=2, + batching_delay_ms=25.0, + enable_batching_metrics=False, + pipeline_config=PipelineConfig(), + ) + defaults.update(overrides) + return SimpleNamespace(**defaults) + + +def test_video_batch_scheduler_groups_compatible_requests(tmp_path): + async def run(): + generator = _FakeBatchGenerator() + scheduler = VideoBatchScheduler(generator, _batch_scheduler_args()) + await scheduler.start() + try: + first = { + "prompt": "first", + "height": 256, + "width": 256, + "num_frames": 1, + "num_inference_steps": 2, + "seed": 1, + "output_path": str(tmp_path / "first.mp4"), + "save_video": False, + } + second = { + "prompt": "second", + "height": 256, + "width": 256, + "num_frames": 1, + "num_inference_steps": 2, + "seed": 2, + "output_path": str(tmp_path / "second.mp4"), + "save_video": False, + } + results = await asyncio.gather( + scheduler.submit("req-1", first), + scheduler.submit("req-2", second), + ) + finally: + await scheduler.stop() + return generator.calls, results + + calls, results = asyncio.run(run()) + + assert len(calls) == 1 + assert [item["prompt"] for item in calls[0]] == ["first", "second"] + assert [result["prompts"] for result in results] == ["first", "second"] + + +def test_video_batch_scheduler_keeps_incompatible_requests_separate(tmp_path): + async def run(): + generator = _FakeBatchGenerator() + scheduler = VideoBatchScheduler(generator, _batch_scheduler_args()) + await scheduler.start() + try: + text_only = { + "prompt": "first", + "height": 256, + "width": 256, + "num_frames": 1, + "num_inference_steps": 2, + "seed": 1, + "output_path": str(tmp_path / "first.mp4"), + "save_video": False, + } + image_conditioned = { + "prompt": "second", + "height": 256, + "width": 256, + "num_frames": 1, + "num_inference_steps": 2, + "seed": 2, + "image_path": str(tmp_path / "input.png"), + "output_path": str(tmp_path / "second.mp4"), + "save_video": False, + } + results = await asyncio.gather( + scheduler.submit("req-1", text_only), + scheduler.submit("req-2", image_conditioned), + ) + finally: + await scheduler.stop() + return generator.calls, results + + calls, results = asyncio.run(run()) + + assert len(calls) == 2 + assert [[item["prompt"] for item in call] for call in calls] == [["first"], ["second"]] + assert [result["prompts"] for result in results] == ["first", "second"] + # --------------------------------------------------------------------------- # parse_size # --------------------------------------------------------------------------- diff --git a/fastvideo/tests/entrypoints/test_video_generator.py b/fastvideo/tests/entrypoints/test_video_generator.py index 2789f071d..c15f17ce2 100644 --- a/fastvideo/tests/entrypoints/test_video_generator.py +++ b/fastvideo/tests/entrypoints/test_video_generator.py @@ -3,6 +3,7 @@ import warnings import pytest +import torch from fastvideo.api import ( GenerationRequest, @@ -13,8 +14,10 @@ load_run_config, ) from fastvideo.api.sampling_param import SamplingParam +from fastvideo.configs.pipelines.base import PipelineConfig from fastvideo.entrypoints.video_generator import VideoGenerator from fastvideo.fastvideo_args import WorkloadType +from fastvideo.pipelines import ForwardBatch def _new_video_generator() -> VideoGenerator: @@ -37,6 +40,25 @@ def _new_runtime_video_generator() -> VideoGenerator: return generator +def _batching_fastvideo_args(**overrides): + defaults = dict( + model_path="test-model", + prompt_txt=None, + workload_type=SimpleNamespace(value="t2v"), + batching_mode="dynamic", + batching_max_size=4, + batching_config=None, + dit_cpu_offload=False, + dit_layerwise_offload=False, + output_type="latent", + pin_cpu_memory=False, + VSA_sparsity=0.0, + pipeline_config=PipelineConfig(), + ) + defaults.update(overrides) + return SimpleNamespace(**defaults) + + def _patch_from_fastvideo_args(monkeypatch): captured = {} @@ -151,6 +173,117 @@ def test_prepare_output_path_empty_prompt_fallback(tmp_path): assert os.path.basename(result) == "output.mp4" +def test_generate_prepared_work_items_merges_compatible_latent_requests(monkeypatch, tmp_path): + vg = _new_video_generator() + vg.fastvideo_args = _batching_fastvideo_args() + calls = [] + + def fake_device_memory(gpu_id): + return 48.0 + + def fake_run_forward(batch, fastvideo_args): + calls.append(batch) + batch_size = len(batch.prompt) if isinstance(batch.prompt, list) else 1 + output = torch.arange(batch_size * 4, dtype=torch.float32).reshape(batch_size, 4, 1, 1, 1) + return ForwardBatch(data_type=batch.data_type, output=output, extra={"peak_memory_mb": 1.0}), 0.5, 10.0 + + monkeypatch.setattr( + "fastvideo.batching.admission.BatchAdmissionController._get_device_memory_gb", + staticmethod(fake_device_memory), + ) + monkeypatch.setattr(vg, "_run_forward_batch", fake_run_forward) + + first = SamplingParam(prompt="one", height=8, width=8, num_frames=1, seed=11, return_frames=True, save_video=False) + second = SamplingParam(prompt="two", height=8, width=8, num_frames=1, seed=22, return_frames=True, save_video=False) + work_items = [ + vg._prepare_generation_work_item("one", first, vg.fastvideo_args, output_path=str(tmp_path / "one.mp4")), + vg._prepare_generation_work_item("two", second, vg.fastvideo_args, output_path=str(tmp_path / "two.mp4")), + ] + + results = vg._generate_prepared_work_items(work_items) + + assert len(calls) == 1 + assert calls[0].prompt == ["one", "two"] + assert calls[0].seeds == [11, 22] + assert [result["prompts"] for result in results] == ["one", "two"] + assert [result["samples"].shape for result in results] == [(1, 4, 1, 1, 1), (1, 4, 1, 1, 1)] + + +def test_generate_prepared_work_items_falls_back_for_incompatible_requests(monkeypatch, tmp_path): + vg = _new_video_generator() + vg.fastvideo_args = _batching_fastvideo_args() + calls = [] + + def fake_run_forward(batch, fastvideo_args): + calls.append(batch) + output = torch.zeros((1, 4, 1, 1, 1), dtype=torch.float32) + return ForwardBatch(data_type=batch.data_type, output=output), 0.5, 10.0 + + monkeypatch.setattr(vg, "_run_forward_batch", fake_run_forward) + + first = SamplingParam(prompt="one", height=8, width=8, num_frames=1, guidance_scale=1.0, save_video=False) + second = SamplingParam(prompt="two", height=8, width=8, num_frames=1, guidance_scale=3.0, save_video=False) + work_items = [ + vg._prepare_generation_work_item("one", first, vg.fastvideo_args, output_path=str(tmp_path / "one.mp4")), + vg._prepare_generation_work_item("two", second, vg.fastvideo_args, output_path=str(tmp_path / "two.mp4")), + ] + + results = vg._generate_prepared_work_items(work_items) + + assert len(calls) == 2 + assert all(isinstance(call.prompt, str) for call in calls) + assert [result["prompts"] for result in results] == ["one", "two"] + + +def test_generate_video_batch_routes_compat_kwargs(monkeypatch, tmp_path): + vg = _new_video_generator() + vg.fastvideo_args = _batching_fastvideo_args() + calls = [] + + def fake_device_memory(gpu_id): + return 48.0 + + def fake_run_forward(batch, fastvideo_args): + calls.append((batch, fastvideo_args)) + output = torch.zeros((len(batch.prompt), 4, 1, 1, 1), dtype=torch.float32) + return ForwardBatch(data_type=batch.data_type, output=output), 0.5, 10.0 + + monkeypatch.setattr( + "fastvideo.batching.admission.BatchAdmissionController._get_device_memory_gb", + staticmethod(fake_device_memory), + ) + monkeypatch.setattr(vg, "_run_forward_batch", fake_run_forward) + + results = vg.generate_video_batch([ + { + "prompt": "one", + "height": 8, + "width": 8, + "num_frames": 1, + "embedded_cfg_scale": 7.5, + "save_video": False, + "return_frames": True, + "output_path": str(tmp_path / "one.mp4"), + }, + { + "prompt": "two", + "height": 8, + "width": 8, + "num_frames": 1, + "embedded_cfg_scale": 7.5, + "save_video": False, + "return_frames": True, + "output_path": str(tmp_path / "two.mp4"), + }, + ]) + + assert len(calls) == 1 + batch, fastvideo_args = calls[0] + assert batch.prompt == ["one", "two"] + assert fastvideo_args.pipeline_config.embedded_cfg_scale == 7.5 + assert [result["prompts"] for result in results] == ["one", "two"] + + def test_from_config_normalizes_and_translates(monkeypatch): captured = _patch_from_fastvideo_args(monkeypatch) _patch_fastvideo_args_from_kwargs(monkeypatch) diff --git a/fastvideo/tests/stages/test_input_validation_batching.py b/fastvideo/tests/stages/test_input_validation_batching.py new file mode 100644 index 000000000..fe176d63a --- /dev/null +++ b/fastvideo/tests/stages/test_input_validation_batching.py @@ -0,0 +1,40 @@ +# SPDX-License-Identifier: Apache-2.0 +from types import SimpleNamespace + +from fastvideo.pipelines import ForwardBatch +from fastvideo.pipelines.stages.input_validation import InputValidationStage + + +def test_input_validation_preserves_explicit_dynamic_batch_seeds() -> None: + batch = ForwardBatch( + data_type="video", + prompt=["one", "two"], + seed=100, + seeds=[17, 23], + height=8, + width=8, + num_frames=1, + num_inference_steps=1, + ) + + InputValidationStage()._generate_seeds(batch, SimpleNamespace()) + + assert batch.seeds == [17, 23] + assert [generator.initial_seed() for generator in batch.generator] == [17, 23] + + +def test_input_validation_generates_one_seed_per_prompt() -> None: + batch = ForwardBatch( + data_type="video", + prompt=["one", "two"], + seed=100, + height=8, + width=8, + num_frames=1, + num_inference_steps=1, + ) + + InputValidationStage()._generate_seeds(batch, SimpleNamespace()) + + assert batch.seeds == [100, 101] + assert [generator.initial_seed() for generator in batch.generator] == [100, 101] diff --git a/fastvideo/tests/stages/test_text_encoding.py b/fastvideo/tests/stages/test_text_encoding.py index 244a01d3b..1ba38204f 100644 --- a/fastvideo/tests/stages/test_text_encoding.py +++ b/fastvideo/tests/stages/test_text_encoding.py @@ -13,7 +13,13 @@ def to(self, device): return TensorDict({k: v.to(device) for k, v in self.items()}) class FakeTokenizer: + def __init__(self): + self.calls = [] + self.texts = [] + def __call__(self, texts, **kwargs): + self.calls.append(kwargs) + self.texts.append(list(texts)) B = len(texts) seq_len = int(kwargs.get("max_length", 4)) return TensorDict({ @@ -131,6 +137,35 @@ def test_forward_integration_cfg_off_and_on(): assert len(out2.prompt_attention_mask) == 2 assert len(out2.negative_attention_mask) == 2 +def test_encode_text_adds_padding_for_prompt_lists(): + fastvideo_args, hidden = make_args(num_encoders=1, text_len=4, hidden_size=8) + stage = make_stage(num_encoders=1, hidden_size=hidden) + + stage.encode_text(["short", "a longer prompt"], fastvideo_args, encoder_index=[0]) + + assert stage.tokenizers[0].calls[-1]["padding"] is True + + +def test_forward_prompt_list_preserves_single_prompt_text_encoding_path(): + fastvideo_args, hidden = make_args(num_encoders=1, text_len=4, hidden_size=8) + stage = make_stage(num_encoders=1, hidden_size=hidden) + batch = ForwardBatch( + data_type="video", + prompt=["short", "a longer prompt"], + negative_prompt="", + do_classifier_free_guidance=False, + prompt_embeds=[], + negative_prompt_embeds=None, + prompt_attention_mask=[], + negative_attention_mask=None, + ) + + out = stage.forward(batch, fastvideo_args) + + assert stage.tokenizers[0].texts == [["short"], ["a longer prompt"]] + assert out.prompt_embeds[0].shape == (2, hidden) + assert out.prompt_attention_mask[0].shape == (2, 4) + def test_encode_text_hidden_state_flag_follows_encoder_config(): fastvideo_args, hidden = make_args(num_encoders=1, text_len=4, hidden_size=8)