Skip to content

[feat] QAD 5090: FP8 linear layer inference#1465

Open
kevin314 wants to merge 2 commits into
mainfrom
fp8-stacked
Open

[feat] QAD 5090: FP8 linear layer inference#1465
kevin314 wants to merge 2 commits into
mainfrom
fp8-stacked

Conversation

@kevin314

@kevin314 kevin314 commented Jun 16, 2026

Copy link
Copy Markdown
Collaborator

Purpose

Continues the QAD 5090 stack (#1464). Adds a generic FP8 inference quantization path for DiT linear layers (attention projections + MLP), so models can be served in FP8 after training.

Changes

  • layers/quantization/fp8_config.py: a new FP8Config and FP8QuantizeMethod that matches Wan's to_q/k/v/out + ffn.fc_in/fc_out layers by suffix
  • Supports per-tensor (fast, default) and per-channel (higher accuracy) granularity via _scaled_mm
  • Falls back to bf16 dequant on pre-sm89 GPUs
  • layers/quantization/__init__.py: registers FP8 in QuantizationMethods and wires FP8Config into the quant registry.

Test Results (sm89 / RTX 4090) and RTX 5090

Test output
End-to-end video generation completes successfully with fp8 quant

Checklist

  • I ran pre-commit run --all-files and fixed all issues
  • I added or updated tests for my changes
  • I updated documentation if needed
  • I considered GPU memory impact of my changes

For model/pipeline changes, also check:

  • I verified SSIM regression tests pass
  • I updated the support matrix if adding a new model

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a generic FP8 quantization configuration and method (FP8Config and FP8QuantizeMethod) backed by torch._scaled_mm, supporting both per-tensor and per-channel granularities with a bf16 fallback for older GPUs. The review feedback focuses on robustifying the implementation, including handling empty input tensors, adding fallbacks for unquantized layers, preventing quantization on meta devices, ensuring correct GPU capability detection in multi-GPU setups, and casting weights/scales to the input device and dtype during fallback dequantization.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +47 to +52
def _quantize_tensorwise(x_2d: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Returns ``(x_fp8 [M, K], x_scale [1] float32)``."""
x_absmax = x_2d.abs().amax().float()
x_scale = (x_absmax / FP8_MAX).clamp(min=FP8_MIN_SCALE)
x_fp8 = (x_2d / x_scale.to(x_2d.dtype)).clamp(-FP8_MAX, FP8_MAX).to(FP8_DTYPE)
return x_fp8, x_scale.view(1)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If the input tensor x_2d is empty (e.g., due to dynamic batching or zero-length sequences), calling .amax() without specifying dimensions will raise a RuntimeError: amax(): Expected reduction dim to be non-empty. Adding a guard for empty tensors prevents runtime crashes.

Suggested change
def _quantize_tensorwise(x_2d: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Returns ``(x_fp8 [M, K], x_scale [1] float32)``."""
x_absmax = x_2d.abs().amax().float()
x_scale = (x_absmax / FP8_MAX).clamp(min=FP8_MIN_SCALE)
x_fp8 = (x_2d / x_scale.to(x_2d.dtype)).clamp(-FP8_MAX, FP8_MAX).to(FP8_DTYPE)
return x_fp8, x_scale.view(1)
def _quantize_tensorwise(x_2d: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Returns ``(x_fp8 [M, K], x_scale [1] float32)``."""
if x_2d.numel() == 0:
return torch.empty_like(x_2d, dtype=FP8_DTYPE), torch.ones(1, dtype=torch.float32, device=x_2d.device)
x_absmax = x_2d.abs().amax().float()
x_scale = (x_absmax / FP8_MAX).clamp(min=FP8_MIN_SCALE)
x_fp8 = (x_2d / x_scale.to(x_2d.dtype)).clamp(-FP8_MAX, FP8_MAX).to(FP8_DTYPE)
return x_fp8, x_scale.view(1)

Comment on lines +55 to +60
def _quantize_rowwise(x_2d: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Returns ``(x_fp8 [M, K], x_scale [M, 1] float32)``."""
x_absmax = x_2d.abs().amax(dim=-1, keepdim=True).float()
x_scale = (x_absmax / FP8_MAX).clamp(min=FP8_MIN_SCALE)
x_fp8 = (x_2d / x_scale.to(x_2d.dtype)).clamp(-FP8_MAX, FP8_MAX).to(FP8_DTYPE)
return x_fp8, x_scale

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If the input tensor x_2d is empty (e.g., shape [5, 0]), calling .amax(dim=-1) will raise a RuntimeError because the reduction dimension is empty. Adding a guard for empty tensors prevents runtime crashes.

Suggested change
def _quantize_rowwise(x_2d: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Returns ``(x_fp8 [M, K], x_scale [M, 1] float32)``."""
x_absmax = x_2d.abs().amax(dim=-1, keepdim=True).float()
x_scale = (x_absmax / FP8_MAX).clamp(min=FP8_MIN_SCALE)
x_fp8 = (x_2d / x_scale.to(x_2d.dtype)).clamp(-FP8_MAX, FP8_MAX).to(FP8_DTYPE)
return x_fp8, x_scale
def _quantize_rowwise(x_2d: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Returns ``(x_fp8 [M, K], x_scale [M, 1] float32)``."""
if x_2d.numel() == 0:
return torch.empty_like(x_2d, dtype=FP8_DTYPE), torch.ones(x_2d.shape[0], 1, dtype=torch.float32, device=x_2d.device)
x_absmax = x_2d.abs().amax(dim=-1, keepdim=True).float()
x_scale = (x_absmax / FP8_MAX).clamp(min=FP8_MIN_SCALE)
x_fp8 = (x_2d / x_scale.to(x_2d.dtype)).clamp(-FP8_MAX, FP8_MAX).to(FP8_DTYPE)
return x_fp8, x_scale

Comment on lines +121 to +125
out_dim = layer._fp8_weight.shape[0]
original_shape = x.shape

if not _supports_fp8_compute():
return self._apply_dequant(layer, x, bias)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If the model is run before convert_model_to_fp8 is called (e.g., during initialization, shape tracking, or dry runs), layer._fp8_weight will not exist, causing an immediate AttributeError. Adding a fallback to standard unquantized linear execution when _fp8_weight is missing prevents hard crashes. Additionally, we pass x.device to _supports_fp8_compute to ensure correct capability detection on multi-GPU setups.

Suggested change
out_dim = layer._fp8_weight.shape[0]
original_shape = x.shape
if not _supports_fp8_compute():
return self._apply_dequant(layer, x, bias)
if not hasattr(layer, "_fp8_weight"):
return F.linear(x, layer.weight, bias)
out_dim = layer._fp8_weight.shape[0]
original_shape = x.shape
if not _supports_fp8_compute(x.device):
return self._apply_dequant(layer, x, bias)

Comment on lines +217 to +220
weight = getattr(mod, "weight", None)
if weight is None:
continue
weight_local = weight.to_local() if isinstance(weight, DTensor) else weight # type: ignore[arg-type]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If convert_model_to_fp8 is called while the model is still on the meta device (which is common during lazy loading or when using frameworks like Hugging Face Accelerate), trying to compute scales and quantize weights will fail or register meta buffers. Furthermore, popping "weight" from _parameters will prevent the actual weights from being loaded later. Raising a clear error message when weights are on the meta device improves developer experience and prevents silent failures.

            weight = getattr(mod, "weight", None)
            if weight is None:
                continue
            if weight.device.type == "meta":
                raise ValueError(
                    "Cannot convert model to FP8 while weights are on the 'meta' device. "
                    "Please load the model weights before calling convert_model_to_fp8."
                )
            weight_local = weight.to_local() if isinstance(weight, DTensor) else weight  # type: ignore[arg-type]

Comment on lines +39 to +44
def _supports_fp8_compute() -> bool:
"""Whether the active device supports FP8 ``_scaled_mm`` (sm89+)."""
if not torch.cuda.is_available():
return False
cap = torch.cuda.get_device_capability()
return cap[0] > 8 or (cap[0] == 8 and cap[1] >= 9)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In heterogeneous multi-GPU systems, different GPUs may have different compute capabilities (e.g., a mix of RTX 4090 and RTX 3090). Querying torch.cuda.get_device_capability() without arguments returns the capability of the current active device, which might not match the device where the tensor is actually allocated. Passing the target device as an argument ensures correct capability detection.

Suggested change
def _supports_fp8_compute() -> bool:
"""Whether the active device supports FP8 ``_scaled_mm`` (sm89+)."""
if not torch.cuda.is_available():
return False
cap = torch.cuda.get_device_capability()
return cap[0] > 8 or (cap[0] == 8 and cap[1] >= 9)
def _supports_fp8_compute(device: torch.device | None = None) -> bool:
"""Whether the active device supports FP8 ``_scaled_mm`` (sm89+)."""
if not torch.cuda.is_available():
return False
cap = torch.cuda.get_device_capability(device)
return cap[0] > 8 or (cap[0] == 8 and cap[1] >= 9)

Comment on lines +164 to +166
w_fp8 = layer._fp8_weight
w_scale = layer._fp8_weight_scale.to(x.dtype)
weight = w_fp8.to(x.dtype) * w_scale.unsqueeze(1)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To prevent potential device or dtype mismatch errors during fallback dequantization (especially when weights or scales are on CPU or a different GPU), explicitly cast both _fp8_weight and _fp8_weight_scale to the device and dtype of the input tensor x.

Suggested change
w_fp8 = layer._fp8_weight
w_scale = layer._fp8_weight_scale.to(x.dtype)
weight = w_fp8.to(x.dtype) * w_scale.unsqueeze(1)
w_fp8 = layer._fp8_weight.to(device=x.device, dtype=x.dtype)
w_scale = layer._fp8_weight_scale.to(device=x.device, dtype=x.dtype)
weight = w_fp8 * w_scale.unsqueeze(1)

@mergify mergify Bot added type: feat New feature or capability scope: model Model architecture (DiTs, encoders, VAEs) labels Jun 16, 2026
@kevin314 kevin314 marked this pull request as ready for review June 16, 2026 10:31
@mergify

mergify Bot commented Jun 19, 2026

Copy link
Copy Markdown
Contributor

This PR has merge conflicts with the base branch. Please rebase:

git fetch origin main
git rebase origin/main
# Resolve any conflicts, then:
git push --force-with-lease

@mergify mergify Bot added the needs-rebase PR has merge conflicts label Jun 19, 2026
Base automatically changed from pr1225_s14 to main June 19, 2026 08:03
An error occurred while trying to automatically change base from pr1225_s14 to main June 19, 2026 08:03
@mergify

mergify Bot commented Jun 19, 2026

Copy link
Copy Markdown
Contributor

Merge Protections

Your pull request matches the following merge protections and will not be merged until they are valid.

🔴 PR merge requirements

Waiting for

  • #approved-reviews-by>=1
  • check-success=full-suite-passed
  • check-success~=pre-commit
This rule is failing.
  • #approved-reviews-by>=1
  • check-success=full-suite-passed
  • check-success~=pre-commit
  • check-success=fastcheck-passed
  • title~=(?i)^\[(feat|feature|bugfix|fix|refactor|perf|ci|doc|docs|misc|chore|kernel|new.?model|skill|skills|infra)\]

@mergify mergify Bot removed the needs-rebase PR has merge conflicts label Jun 19, 2026
@mergify

mergify Bot commented Jun 19, 2026

Copy link
Copy Markdown
Contributor

Pre-commit checks failed

Hi @kevin314, the pre-commit checks have failed. To fix them locally:

# Install pre-commit if you haven't already
uv pip install pre-commit
pre-commit install

# Run all checks and auto-fix what's possible
pre-commit run --all-files

Common fixes:

  • yapf: yapf -i <file> (formatting)
  • ruff: ruff check --fix <file> (linting)
  • codespell: codespell --write-changes <file> (spelling)

After fixing, commit and push the changes. The checks will re-run automatically.

For future commits, pre-commit will run automatically on changed files before each commit.

@SolitaryThinker SolitaryThinker left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @kevin314 — automated review from Gob, one of @SolitaryThinker's AI reviewers. Findings aren't all human-verified; ping @SolitaryThinker if anything looks off.

Open PR · scope: model · reviewed @ 10b4bf18. The post-rebase diff is clean (2 files, +247/-1); the FP8 QAT-training path is #1464 (already on main) and isn't re-reviewed here.

Verdict: REQUEST_CHANGES

The FP8 inference path as committed is unreachable — the weight-conversion pass is never called, so the first FP8 linear forward raises AttributeError. Also missing SSIM/quality evidence for a model-scope numerics change.

Blocking

  • BLOCKER fastvideo/layers/quantization/fp8_config.py:121 (root cause :207) — the FP8 path is dead: apply() reads a buffer nothing registers. FP8QuantizeMethod.apply opens with out_dim = layer._fp8_weight.shape[0] (and _apply_dequant at :162), but _fp8_weight / _fp8_weight_scale are registered only by convert_model_to_fp8 (:229-230), which has zero callers in the whole tree:

    • grep -rn "convert_model_to_fp8\|FP8QuantizeMethod" fastvideo/ --include=*.py matches only inside fp8_config.py (self-defs + __all__) — no caller, no dynamic dispatch.
    • The sibling NVFP4 / NVFP4-QAT methods wire their conversion into the loader at fastvideo/models/loader/fsdp_load.py:31-66 (_maybe_convert_model_to_nvfp4, invoked :206 from the inference DiT load), dispatching to convert_model_to_nvfp4 / convert_model_to_fp4. There is no FP8QuantizeMethod branch.
    • The base seam QuantizeMethodBase.process_weights_after_loading (base_config.py:41) exists but is never auto-invoked by the loader, so it isn't the implicit hook either.
    • Wan (the target) builds to_q/k/v/out + ffn via ReplicatedLinear(..., quant_config=...) (models/dits/wanvideo.py:124-127) and calls them through ReplicatedLinear.forwardquant_method.apply(self, x, bias) (linear.py:296), so the missing buffer crashes on the first projection.

    Not a corner case — it's the only runtime path: with FP8 selected, every targeted linear raises AttributeError: ... has no attribute '_fp8_weight' on the first forward. Fix: add an FP8QuantizeMethod branch to _maybe_convert_model_to_nvfp4 (or a generalized post-load hook) calling convert_model_to_fp8(model), mirroring the FP4 wiring, plus a runnable example like examples/inference/optimizations/nvfp4_qat_wan2_1_1_3b.py. (Gemini flagged the same AttributeError surface at :125 as a conditional edge case — the sharper point is that convert is invoked never, so it fails 100% of the time.)

Major

  • MAJOR no SSIM / quality evidence for a scope: model precision change. The body's Test Results are smoke-only ("End-to-end video generation completes successfully with fp8 quant") and the template "I verified SSIM regression tests pass" box is unchecked. A dynamic FP8 (e4m3) activation+weight quant on Wan's linears needs an SSIM run attached. Note: that smoke claim can't have exercised this diff given the BLOCKER above — please confirm what was actually run (possibly an earlier local version that wired the convert pass).

Minor

  • MINOR fp8_config.py:98-112,127-132 — the prequant plumbing (quantize_input / wants_prequantized_input / pre_quantized) is dead for this PR: Wan never calls quantize_input, and only LTX-2's attention forward uses that path (LTX-2 isn't wired to FP8 here). Wire a consumer or drop the surface until one exists.
  • MINOR naming — there are now three FP8-family methods (AbsMaxFP8, new FP8, fp8_qat_train). They're genuinely distinct (AbsMaxFP8 = static checkpoint scales; FP8 = dynamic scales at load+runtime; fp8_qat_train = STE training), but the bare name FP8 is easy to confuse with AbsMaxFP8 — consider fp8_dynamic + a contrasting docstring.

Nit

  • NIT fp8_config.py:188get_min_capability() returns 89 but nothing enforces it (the real gate is the runtime _supports_fp8_compute() + bf16 fallback). Consistent with the sibling configs; documentation-only.

The registry wiring (__init__.py: Literal + import + method_to_config) is internally consistent ✅. The pre-commit CI failure is a pure yapf line-wrap reflow — style is owned by pre-commit, not flagged here. Gemini's robustness nits (empty-tensor amax() guard, multi-GPU get_device_capability(device), meta-device guard, fallback device/dtype cast) are reasonable and not re-litigated.

— Gob (@SolitaryThinker's AI reviewer).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

scope: model Model architecture (DiTs, encoders, VAEs) type: feat New feature or capability

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants