[feat] QAD 5090: FP8 linear layer inference#1465
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a generic FP8 quantization configuration and method (FP8Config and FP8QuantizeMethod) backed by torch._scaled_mm, supporting both per-tensor and per-channel granularities with a bf16 fallback for older GPUs. The review feedback focuses on robustifying the implementation, including handling empty input tensors, adding fallbacks for unquantized layers, preventing quantization on meta devices, ensuring correct GPU capability detection in multi-GPU setups, and casting weights/scales to the input device and dtype during fallback dequantization.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| def _quantize_tensorwise(x_2d: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | ||
| """Returns ``(x_fp8 [M, K], x_scale [1] float32)``.""" | ||
| x_absmax = x_2d.abs().amax().float() | ||
| x_scale = (x_absmax / FP8_MAX).clamp(min=FP8_MIN_SCALE) | ||
| x_fp8 = (x_2d / x_scale.to(x_2d.dtype)).clamp(-FP8_MAX, FP8_MAX).to(FP8_DTYPE) | ||
| return x_fp8, x_scale.view(1) |
There was a problem hiding this comment.
If the input tensor x_2d is empty (e.g., due to dynamic batching or zero-length sequences), calling .amax() without specifying dimensions will raise a RuntimeError: amax(): Expected reduction dim to be non-empty. Adding a guard for empty tensors prevents runtime crashes.
| def _quantize_tensorwise(x_2d: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
| """Returns ``(x_fp8 [M, K], x_scale [1] float32)``.""" | |
| x_absmax = x_2d.abs().amax().float() | |
| x_scale = (x_absmax / FP8_MAX).clamp(min=FP8_MIN_SCALE) | |
| x_fp8 = (x_2d / x_scale.to(x_2d.dtype)).clamp(-FP8_MAX, FP8_MAX).to(FP8_DTYPE) | |
| return x_fp8, x_scale.view(1) | |
| def _quantize_tensorwise(x_2d: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
| """Returns ``(x_fp8 [M, K], x_scale [1] float32)``.""" | |
| if x_2d.numel() == 0: | |
| return torch.empty_like(x_2d, dtype=FP8_DTYPE), torch.ones(1, dtype=torch.float32, device=x_2d.device) | |
| x_absmax = x_2d.abs().amax().float() | |
| x_scale = (x_absmax / FP8_MAX).clamp(min=FP8_MIN_SCALE) | |
| x_fp8 = (x_2d / x_scale.to(x_2d.dtype)).clamp(-FP8_MAX, FP8_MAX).to(FP8_DTYPE) | |
| return x_fp8, x_scale.view(1) |
| def _quantize_rowwise(x_2d: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | ||
| """Returns ``(x_fp8 [M, K], x_scale [M, 1] float32)``.""" | ||
| x_absmax = x_2d.abs().amax(dim=-1, keepdim=True).float() | ||
| x_scale = (x_absmax / FP8_MAX).clamp(min=FP8_MIN_SCALE) | ||
| x_fp8 = (x_2d / x_scale.to(x_2d.dtype)).clamp(-FP8_MAX, FP8_MAX).to(FP8_DTYPE) | ||
| return x_fp8, x_scale |
There was a problem hiding this comment.
If the input tensor x_2d is empty (e.g., shape [5, 0]), calling .amax(dim=-1) will raise a RuntimeError because the reduction dimension is empty. Adding a guard for empty tensors prevents runtime crashes.
| def _quantize_rowwise(x_2d: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
| """Returns ``(x_fp8 [M, K], x_scale [M, 1] float32)``.""" | |
| x_absmax = x_2d.abs().amax(dim=-1, keepdim=True).float() | |
| x_scale = (x_absmax / FP8_MAX).clamp(min=FP8_MIN_SCALE) | |
| x_fp8 = (x_2d / x_scale.to(x_2d.dtype)).clamp(-FP8_MAX, FP8_MAX).to(FP8_DTYPE) | |
| return x_fp8, x_scale | |
| def _quantize_rowwise(x_2d: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
| """Returns ``(x_fp8 [M, K], x_scale [M, 1] float32)``.""" | |
| if x_2d.numel() == 0: | |
| return torch.empty_like(x_2d, dtype=FP8_DTYPE), torch.ones(x_2d.shape[0], 1, dtype=torch.float32, device=x_2d.device) | |
| x_absmax = x_2d.abs().amax(dim=-1, keepdim=True).float() | |
| x_scale = (x_absmax / FP8_MAX).clamp(min=FP8_MIN_SCALE) | |
| x_fp8 = (x_2d / x_scale.to(x_2d.dtype)).clamp(-FP8_MAX, FP8_MAX).to(FP8_DTYPE) | |
| return x_fp8, x_scale |
| out_dim = layer._fp8_weight.shape[0] | ||
| original_shape = x.shape | ||
|
|
||
| if not _supports_fp8_compute(): | ||
| return self._apply_dequant(layer, x, bias) |
There was a problem hiding this comment.
If the model is run before convert_model_to_fp8 is called (e.g., during initialization, shape tracking, or dry runs), layer._fp8_weight will not exist, causing an immediate AttributeError. Adding a fallback to standard unquantized linear execution when _fp8_weight is missing prevents hard crashes. Additionally, we pass x.device to _supports_fp8_compute to ensure correct capability detection on multi-GPU setups.
| out_dim = layer._fp8_weight.shape[0] | |
| original_shape = x.shape | |
| if not _supports_fp8_compute(): | |
| return self._apply_dequant(layer, x, bias) | |
| if not hasattr(layer, "_fp8_weight"): | |
| return F.linear(x, layer.weight, bias) | |
| out_dim = layer._fp8_weight.shape[0] | |
| original_shape = x.shape | |
| if not _supports_fp8_compute(x.device): | |
| return self._apply_dequant(layer, x, bias) |
| weight = getattr(mod, "weight", None) | ||
| if weight is None: | ||
| continue | ||
| weight_local = weight.to_local() if isinstance(weight, DTensor) else weight # type: ignore[arg-type] |
There was a problem hiding this comment.
If convert_model_to_fp8 is called while the model is still on the meta device (which is common during lazy loading or when using frameworks like Hugging Face Accelerate), trying to compute scales and quantize weights will fail or register meta buffers. Furthermore, popping "weight" from _parameters will prevent the actual weights from being loaded later. Raising a clear error message when weights are on the meta device improves developer experience and prevents silent failures.
weight = getattr(mod, "weight", None)
if weight is None:
continue
if weight.device.type == "meta":
raise ValueError(
"Cannot convert model to FP8 while weights are on the 'meta' device. "
"Please load the model weights before calling convert_model_to_fp8."
)
weight_local = weight.to_local() if isinstance(weight, DTensor) else weight # type: ignore[arg-type]| def _supports_fp8_compute() -> bool: | ||
| """Whether the active device supports FP8 ``_scaled_mm`` (sm89+).""" | ||
| if not torch.cuda.is_available(): | ||
| return False | ||
| cap = torch.cuda.get_device_capability() | ||
| return cap[0] > 8 or (cap[0] == 8 and cap[1] >= 9) |
There was a problem hiding this comment.
In heterogeneous multi-GPU systems, different GPUs may have different compute capabilities (e.g., a mix of RTX 4090 and RTX 3090). Querying torch.cuda.get_device_capability() without arguments returns the capability of the current active device, which might not match the device where the tensor is actually allocated. Passing the target device as an argument ensures correct capability detection.
| def _supports_fp8_compute() -> bool: | |
| """Whether the active device supports FP8 ``_scaled_mm`` (sm89+).""" | |
| if not torch.cuda.is_available(): | |
| return False | |
| cap = torch.cuda.get_device_capability() | |
| return cap[0] > 8 or (cap[0] == 8 and cap[1] >= 9) | |
| def _supports_fp8_compute(device: torch.device | None = None) -> bool: | |
| """Whether the active device supports FP8 ``_scaled_mm`` (sm89+).""" | |
| if not torch.cuda.is_available(): | |
| return False | |
| cap = torch.cuda.get_device_capability(device) | |
| return cap[0] > 8 or (cap[0] == 8 and cap[1] >= 9) |
| w_fp8 = layer._fp8_weight | ||
| w_scale = layer._fp8_weight_scale.to(x.dtype) | ||
| weight = w_fp8.to(x.dtype) * w_scale.unsqueeze(1) |
There was a problem hiding this comment.
To prevent potential device or dtype mismatch errors during fallback dequantization (especially when weights or scales are on CPU or a different GPU), explicitly cast both _fp8_weight and _fp8_weight_scale to the device and dtype of the input tensor x.
| w_fp8 = layer._fp8_weight | |
| w_scale = layer._fp8_weight_scale.to(x.dtype) | |
| weight = w_fp8.to(x.dtype) * w_scale.unsqueeze(1) | |
| w_fp8 = layer._fp8_weight.to(device=x.device, dtype=x.dtype) | |
| w_scale = layer._fp8_weight_scale.to(device=x.device, dtype=x.dtype) | |
| weight = w_fp8 * w_scale.unsqueeze(1) |
dc6e882 to
7d3a3b2
Compare
|
This PR has merge conflicts with the base branch. Please rebase: git fetch origin main
git rebase origin/main
# Resolve any conflicts, then:
git push --force-with-lease |
Merge ProtectionsYour pull request matches the following merge protections and will not be merged until they are valid. 🔴 PR merge requirementsWaiting for
This rule is failing.
|
dffb633 to
10b4bf1
Compare
Pre-commit checks failedHi @kevin314, the pre-commit checks have failed. To fix them locally: # Install pre-commit if you haven't already
uv pip install pre-commit
pre-commit install
# Run all checks and auto-fix what's possible
pre-commit run --all-filesCommon fixes:
After fixing, commit and push the changes. The checks will re-run automatically. For future commits, |
SolitaryThinker
left a comment
There was a problem hiding this comment.
Hi @kevin314 — automated review from Gob, one of @SolitaryThinker's AI reviewers. Findings aren't all human-verified; ping @SolitaryThinker if anything looks off.
Open PR · scope: model · reviewed @ 10b4bf18. The post-rebase diff is clean (2 files, +247/-1); the FP8 QAT-training path is #1464 (already on main) and isn't re-reviewed here.
Verdict: REQUEST_CHANGES
The FP8 inference path as committed is unreachable — the weight-conversion pass is never called, so the first FP8 linear forward raises AttributeError. Also missing SSIM/quality evidence for a model-scope numerics change.
Blocking
-
BLOCKER
fastvideo/layers/quantization/fp8_config.py:121(root cause:207) — the FP8 path is dead:apply()reads a buffer nothing registers.FP8QuantizeMethod.applyopens without_dim = layer._fp8_weight.shape[0](and_apply_dequantat:162), but_fp8_weight/_fp8_weight_scaleare registered only byconvert_model_to_fp8(:229-230), which has zero callers in the whole tree:grep -rn "convert_model_to_fp8\|FP8QuantizeMethod" fastvideo/ --include=*.pymatches only insidefp8_config.py(self-defs +__all__) — no caller, no dynamic dispatch.- The sibling NVFP4 / NVFP4-QAT methods wire their conversion into the loader at
fastvideo/models/loader/fsdp_load.py:31-66(_maybe_convert_model_to_nvfp4, invoked:206from the inference DiT load), dispatching toconvert_model_to_nvfp4/convert_model_to_fp4. There is noFP8QuantizeMethodbranch. - The base seam
QuantizeMethodBase.process_weights_after_loading(base_config.py:41) exists but is never auto-invoked by the loader, so it isn't the implicit hook either. - Wan (the target) builds
to_q/k/v/out+ffnviaReplicatedLinear(..., quant_config=...)(models/dits/wanvideo.py:124-127) and calls them throughReplicatedLinear.forward→quant_method.apply(self, x, bias)(linear.py:296), so the missing buffer crashes on the first projection.
Not a corner case — it's the only runtime path: with
FP8selected, every targeted linear raisesAttributeError: ... has no attribute '_fp8_weight'on the first forward. Fix: add anFP8QuantizeMethodbranch to_maybe_convert_model_to_nvfp4(or a generalized post-load hook) callingconvert_model_to_fp8(model), mirroring the FP4 wiring, plus a runnable example likeexamples/inference/optimizations/nvfp4_qat_wan2_1_1_3b.py. (Gemini flagged the sameAttributeErrorsurface at:125as a conditional edge case — the sharper point is thatconvertis invoked never, so it fails 100% of the time.)
Major
- MAJOR no SSIM / quality evidence for a
scope: modelprecision change. The body's Test Results are smoke-only ("End-to-end video generation completes successfully with fp8 quant") and the template "I verified SSIM regression tests pass" box is unchecked. A dynamic FP8 (e4m3) activation+weight quant on Wan's linears needs an SSIM run attached. Note: that smoke claim can't have exercised this diff given the BLOCKER above — please confirm what was actually run (possibly an earlier local version that wired the convert pass).
Minor
- MINOR
fp8_config.py:98-112,127-132— the prequant plumbing (quantize_input/wants_prequantized_input/pre_quantized) is dead for this PR: Wan never callsquantize_input, and only LTX-2's attention forward uses that path (LTX-2 isn't wired to FP8 here). Wire a consumer or drop the surface until one exists. - MINOR naming — there are now three FP8-family methods (
AbsMaxFP8, newFP8,fp8_qat_train). They're genuinely distinct (AbsMaxFP8= static checkpoint scales;FP8= dynamic scales at load+runtime;fp8_qat_train= STE training), but the bare nameFP8is easy to confuse withAbsMaxFP8— considerfp8_dynamic+ a contrasting docstring.
Nit
- NIT
fp8_config.py:188—get_min_capability()returns 89 but nothing enforces it (the real gate is the runtime_supports_fp8_compute()+ bf16 fallback). Consistent with the sibling configs; documentation-only.
The registry wiring (__init__.py: Literal + import + method_to_config) is internally consistent ✅. The pre-commit CI failure is a pure yapf line-wrap reflow — style is owned by pre-commit, not flagged here. Gemini's robustness nits (empty-tensor amax() guard, multi-GPU get_device_capability(device), meta-device guard, fallback device/dtype cast) are reasonable and not re-litigated.
— Gob (@SolitaryThinker's AI reviewer).
Purpose
Continues the QAD 5090 stack (#1464). Adds a generic FP8 inference quantization path for DiT linear layers (attention projections + MLP), so models can be served in FP8 after training.
Changes
layers/quantization/fp8_config.py: a new FP8Config and FP8QuantizeMethod that matches Wan's to_q/k/v/out + ffn.fc_in/fc_out layers by suffixlayers/quantization/__init__.py: registers FP8 in QuantizationMethods and wires FP8Config into the quant registry.Test Results (sm89 / RTX 4090) and RTX 5090
Test output
Checklist
pre-commit run --all-filesand fixed all issuesFor model/pipeline changes, also check: