[feat] Add Kandinsky-5 pipeline support#1471
Conversation
Fixed 2 file(s) based on 2 unresolved review comments. Co-authored-by: CodeRabbit <noreply@coderabbit.ai>
|
This PR has merge conflicts with the base branch. Please rebase: git fetch origin main
git rebase origin/main
# Resolve any conflicts, then:
git push --force-with-lease |
Merge ProtectionsYour pull request matches the following merge protections and will not be merged until they are valid. 🔴 PR merge requirementsWaiting for
This rule is failing.
|
There was a problem hiding this comment.
Code Review
This pull request adds support for the Kandinsky-5.0 Lite text-to-video pipeline, introducing the necessary configurations, presets, pipeline stages (latent preparation, denoising, and decoding), and updates to the text encoding and model loading components. Feedback focuses on improving robustness and compatibility: preserving the original tensor dtype in _apply_rotary instead of hardcoding bfloat16, breaking early from the denoising loop on interruption, validating the shape of custom latents, verifying the lengths of text encoder precisions and max lengths, supporting asymmetric patch sizes for divisibility checks, and ensuring prompt_embeds contains at least two elements during input verification.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| def _apply_rotary(x: torch.Tensor, rope: torch.Tensor) -> torch.Tensor: | ||
| orig_dtype = x.dtype | ||
| x_ = x.reshape(*x.shape[:-1], -1, 1, 2).to(torch.float32) | ||
| x_out = (rope * x_).sum(dim=-1) | ||
| return x_out.reshape(*x.shape).to(orig_dtype) | ||
| return x_out.reshape(*x.shape).to(torch.bfloat16) |
There was a problem hiding this comment.
Hardcoding torch.bfloat16 in _apply_rotary breaks compatibility when running the model in other precisions (such as float16 or float32). Additionally, since _apply_rotary is immediately followed by .type_as(query) in Kandinsky5Attention.forward, this hardcoded cast causes redundant casting and precision loss. Preserving the input tensor's original dtype is more generic and correct.
| def _apply_rotary(x: torch.Tensor, rope: torch.Tensor) -> torch.Tensor: | |
| orig_dtype = x.dtype | |
| x_ = x.reshape(*x.shape[:-1], -1, 1, 2).to(torch.float32) | |
| x_out = (rope * x_).sum(dim=-1) | |
| return x_out.reshape(*x.shape).to(orig_dtype) | |
| return x_out.reshape(*x.shape).to(torch.bfloat16) | |
| def _apply_rotary(x: torch.Tensor, rope: torch.Tensor) -> torch.Tensor: | |
| orig_dtype = x.dtype | |
| x_ = x.reshape(*x.shape[:-1], -1, 1, 2).to(torch.float32) | |
| x_out = (rope * x_).sum(dim=-1) | |
| return x_out.reshape(*x.shape).to(orig_dtype) |
| if hasattr(self, "interrupt") and self.interrupt: | ||
| continue |
There was a problem hiding this comment.
Using continue when self.interrupt is set will still iterate through all remaining timesteps in the loop, wasting CPU cycles. Replacing it with break will immediately terminate the denoising loop, which is the expected behavior for an interruption.
| if hasattr(self, "interrupt") and self.interrupt: | |
| continue | |
| if hasattr(self, "interrupt") and self.interrupt: | |
| break |
| else: | ||
| latents = batch.latents.to(device=device, dtype=dtype) |
There was a problem hiding this comment.
When custom or pre-computed latents are provided via batch.latents, it is important to validate that their shape matches the expected latent shape to prevent runtime shape mismatch errors later in the pipeline.
| else: | |
| latents = batch.latents.to(device=device, dtype=dtype) | |
| else: | |
| if list(batch.latents.shape) != list(shape): | |
| raise ValueError(f"Provided latents shape {list(batch.latents.shape)} does not match expected shape {list(shape)}.") | |
| latents = batch.latents.to(device=device, dtype=dtype) |
| def __post_init__(self) -> None: | ||
| if len(self.text_encoder_configs) != 2: | ||
| raise ValueError( | ||
| f"Kandinsky5 pipeline requires exactly 2 text encoders (qwen and clip), " | ||
| f"but got {len(self.text_encoder_configs)} encoder(s)." | ||
| ) |
There was a problem hiding this comment.
In addition to validating the number of text encoder configs, we should also validate that text_encoder_precisions and text_encoder_max_lengths have exactly 2 elements to prevent potential out-of-bounds index errors during loading or encoding.
def __post_init__(self) -> None:
if len(self.text_encoder_configs) != 2:
raise ValueError(
f"Kandinsky5 pipeline requires exactly 2 text encoders (qwen and clip), "
f"but got {len(self.text_encoder_configs)} encoder(s)."
)
if len(self.text_encoder_precisions) != 2:
raise ValueError(
f"Kandinsky5 pipeline requires exactly 2 text encoder precisions, "
f"but got {len(self.text_encoder_precisions)}."
)
if len(self.text_encoder_max_lengths) != 2:
raise ValueError(
f"Kandinsky5 pipeline requires exactly 2 text encoder max lengths, "
f"but got {len(self.text_encoder_max_lengths)}."
)| required_divisor = spatial_ratio * patch_size[1] | ||
| if height % required_divisor != 0 or width % required_divisor != 0: | ||
| raise ValueError(f"Kandinsky5 height/width must be divisible by {required_divisor}; " | ||
| f"got height={height}, width={width}.") |
There was a problem hiding this comment.
Checking only patch_size[1] for both height and width divisibility assumes that height and width patch sizes are always identical. To support potential asymmetric patch sizes correctly, patch_size[1] should be used for height and patch_size[2] for width.
| required_divisor = spatial_ratio * patch_size[1] | |
| if height % required_divisor != 0 or width % required_divisor != 0: | |
| raise ValueError(f"Kandinsky5 height/width must be divisible by {required_divisor}; " | |
| f"got height={height}, width={width}.") | |
| required_divisor_h = spatial_ratio * patch_size[1] | |
| required_divisor_w = spatial_ratio * patch_size[2] | |
| if height % required_divisor_h != 0 or width % required_divisor_w != 0: | |
| raise ValueError(f"Kandinsky5 height must be divisible by {required_divisor_h} and width by {required_divisor_w}; " | |
| f"got height={height}, width={width}.") |
| def verify_input(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult: | ||
| result = VerificationResult() | ||
| result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)]) | ||
| result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty) | ||
| return result |
There was a problem hiding this comment.
Since Kandinsky5DenoisingStage.forward accesses batch.prompt_embeds[0] and batch.prompt_embeds[1], we should validate that batch.prompt_embeds contains at least 2 elements in verify_input to prevent unhandled IndexError crashes.
| def verify_input(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult: | |
| result = VerificationResult() | |
| result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)]) | |
| result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty) | |
| return result | |
| def verify_input(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult: | |
| result = VerificationResult() | |
| result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)]) | |
| result.add_check("prompt_embeds", batch.prompt_embeds, [V.is_list, lambda x: len(x) >= 2]) | |
| return result |
Summary
Adds first-class Kandinsky-5 Lite T2V support through the normal FastVideo model-support path:
basic/kandinsky5composed pipeline wiringkandinskylab/Kandinsky-5.0-T2V-Lite-sft-5s-DiffusersValidation
pre-commit run --files <12 implementation files>: passed locallyuv run pytest tests/local_tests/kandinsky5/test_kandinsky5_lite_transformer_parity.py -q -s -rs: passed on B200 GPU,1 passed, 14 warningsCUDA_VISIBLE_DEVICES=0 uv run python tests/local_tests/kandinsky5/run_kandinsky5_lite_pipeline_smoke.py: passed, one-step latent smoke generated successfullyfuczhqid,768x512, 121 frames, 80 inference steps, guidance scale 5.0, outputoutputs/kandinsky5_validation/kandinsky5_red_motorcycle_best_quality_512x768_121f_80s.mp4