Skip to content

[C] NVFP4 quantization for GroupedTensor#2655

Draft
ksivaman wants to merge 1 commit intoNVIDIA:mainfrom
ksivaman:nvfp4_grouped_quantize
Draft

[C] NVFP4 quantization for GroupedTensor#2655
ksivaman wants to merge 1 commit intoNVIDIA:mainfrom
ksivaman:nvfp4_grouped_quantize

Conversation

@ksivaman
Copy link
Member

@ksivaman ksivaman commented Feb 6, 2026

Description

Pieces taken from #2600.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • NVFP4 quantization for grouped tensor.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
@ksivaman ksivaman added the MoE label Feb 6, 2026
@ksivaman ksivaman marked this pull request as draft February 6, 2026 06:38
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 6, 2026

Greptile Overview

Greptile Summary

This PR adds NVFP4 quantization support for GroupedTensor with graph-safe Hadamard transform operations. The implementation introduces new CUDA kernels that support device-managed tensor grouping, enabling CUDA graph capture.

Key Changes

  • New graph-safe APIs: Added nvte_group_hadamard_transform_amax_graph_safe() and nvte_group_hadamard_transform_cast_fusion_graph_safe() that accept NVTEGroupedTensor instead of host-side split info
  • GroupedTensorWrapper class: New C++ wrapper with fluent API for setting/getting tensor parameters (rowwise/columnwise data, scales, amax, etc.)
  • SM100+ kernels: Two new CUDA files implementing TMA-based kernels for graph-safe grouped hadamard transforms with NVFP4 quantization
  • Build integration: Added new source files to CMakeLists.txt arch-specific sources

Implementation Details

The graph-safe variants differ from existing implementations by:

  • Accepting grouped tensor metadata on device (offsets, dimensions) instead of host arrays
  • Using device-side binary search (get_current_tensor_id) to determine tensor boundaries
  • Supporting CUDA graph capture since tensor split info is not host-dependent

The fusion kernel (graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu) combines:

  1. Hadamard transform (RHT)
  2. Row-wise and column-wise NVFP4 quantization
  3. Scale factor computation with optional stochastic rounding

Notes

  • Multiple TODO comments exist in the fusion kernel requesting logic verification (lines 709, 724, 778, 795, 1352)
  • Requires CUDA 12.8+ and SM100+ architecture
  • Only supports constant last dimension currently (checked at runtime)
  • Functionality marked as incomplete in PR checklist

Confidence Score: 3/5

  • This PR is moderately safe to merge with caveats - the functionality is incomplete and contains unverified logic
  • Score reflects incomplete functionality (per PR checklist), multiple unresolved TODOs in critical quantization logic, and lack of tests. The code appears well-structured and follows existing patterns, but the TODOs requesting logic verification in tensor indexing and quantization paths are concerning for correctness.
  • Pay close attention to graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu - contains multiple TODO comments for logic verification in group index calculations and tensor layout code

Important Files Changed

Filename Overview
transformer_engine/common/include/transformer_engine/transformer_engine.h Added GroupedTensorWrapper C++ class with parameter setters/getters and moved deprecated enum documentation
transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu New graph-safe grouped hadamard transform implementation using TMA with CUDA 12.8+ and SM100+ requirements
transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu New graph-safe NVFP4 quantization with hadamard transform fusion, includes multiple TODOs for logic verification

Sequence Diagram

sequenceDiagram
    participant User
    participant API as nvte_group_hadamard_transform_*_graph_safe
    participant Wrapper as GroupedTensorWrapper
    participant Kernel as CUDA Kernel (SM100+)
    participant Device as Device Memory

    User->>Wrapper: Create GroupedTensorWrapper
    Wrapper->>Wrapper: nvte_create_grouped_tensor()
    User->>Wrapper: set_rowwise_data(), set_columnwise_data(), etc.
    Wrapper->>Device: Set device pointers for tensor params
    
    User->>API: nvte_group_hadamard_transform_amax_graph_safe()
    API->>API: Convert NVTEGroupedTensor
    API->>API: Validate num_tensors > 0
    
    alt NVFP4 Quantization with Fusion
        User->>API: nvte_group_hadamard_transform_cast_fusion_graph_safe()
        API->>Kernel: group_row_col_rht_gemm_ntt_w_sfc_graph_safe()
        Kernel->>Device: TMA load input tensors
        Kernel->>Kernel: Hadamard transform (RHT)
        Kernel->>Kernel: NVFP4 quantization (stochastic rounding)
        Kernel->>Device: Store rowwise/columnwise quantized data
        Kernel->>Device: Store scale factors and amax
    else Simple Amax Computation
        API->>Kernel: GraphSafeGroupHadamardAmaxTmaKernel()
        Kernel->>Device: TMA load input tensors
        Kernel->>Kernel: Compute amax (with/without RHT)
        Kernel->>Device: Store rowwise/columnwise amax
    end
    
    Kernel-->>API: Kernel completion
    API-->>User: Return
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +709 to +724
// TODO(zhongbo): double check the logic here
int group_idx = get_current_tensor_id(shape_rep, num_tensors,
(scheduler.tile_n_base() * size<1>(epilogue_tiler)) * M,
packed_N, M, offsets);

// Determine quantization scale factor layouts/output splits for this group
TSFDLayout sfd_layout;
int cur_N = static_cast<int>(first_dims[group_idx]);
if constexpr (kEnableSwizzleSFOutput) {
sfd_layout = tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{});
} else {
sfd_layout = make_layout(make_shape(M, make_shape(Int<SFVecSize>{}, cur_N / SFVecSize)),
make_stride(cur_N / SFVecSize, make_stride(_0{}, _1{})));
}
// Build output tensors for columns and their quant scales
// TODO(zhongbo): double check the logic here
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

multiple TODO comments requesting logic verification in critical group index calculation and tensor layout code - verify group_idx calculation and tensor layout logic are correct before merging

Suggested change
// TODO(zhongbo): double check the logic here
int group_idx = get_current_tensor_id(shape_rep, num_tensors,
(scheduler.tile_n_base() * size<1>(epilogue_tiler)) * M,
packed_N, M, offsets);
// Determine quantization scale factor layouts/output splits for this group
TSFDLayout sfd_layout;
int cur_N = static_cast<int>(first_dims[group_idx]);
if constexpr (kEnableSwizzleSFOutput) {
sfd_layout = tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{});
} else {
sfd_layout = make_layout(make_shape(M, make_shape(Int<SFVecSize>{}, cur_N / SFVecSize)),
make_stride(cur_N / SFVecSize, make_stride(_0{}, _1{})));
}
// Build output tensors for columns and their quant scales
// TODO(zhongbo): double check the logic here
// Determine the current tensor group index based on tile offset
int group_idx = get_current_tensor_id(shape_rep, num_tensors,
(scheduler.tile_n_base() * size<1>(epilogue_tiler)) * M,
packed_N, M, offsets);
// Determine quantization scale factor layouts/output splits for this group
TSFDLayout sfd_layout;
int cur_N = static_cast<int>(first_dims[group_idx]);
if constexpr (kEnableSwizzleSFOutput) {
sfd_layout = tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{});
} else {
sfd_layout = make_layout(make_shape(M, make_shape(Int<SFVecSize>{}, cur_N / SFVecSize)),
make_stride(cur_N / SFVecSize, make_stride(_0{}, _1{})));
}
// Build output tensors for columns and their quant scales
Tensor mD = make_tensor(cute::subbyte_iterator<TD>(reinterpret_cast<TD *>(
reinterpret_cast<char *>(QA_COLWISE) + offsets[group_idx] / 2)),
make_shape(M, cur_N), DStride{}); // (M,packed_N)

Comment on lines +778 to +795
// TODO(zhongbo): double check the logic here
int cur_group_idx = get_current_tensor_id(shape_rep, num_tensors,
global_tile_n_offset * M, packed_N, M, offsets);

if (cur_group_idx != group_idx) {
group_idx = cur_group_idx;
c_global_amax_val = shared_storage.global_d_amax[group_idx];
// update amax
global_encode_scale = c_global_amax_val > 0.0f
? cutlass::minimum_with_nan_propagation<float>{}(
(fp8_max * fp4_max) / c_global_amax_val,
cutlass::platform::numeric_limits<float>::max())
: 1.0f;
global_decode_scale = 1.0f / global_encode_scale;
if constexpr (kUseFastMath) {
global_encode_scale_multiplier = global_encode_scale * fp4_max_inv;
}
// TODO(zhongbo): double check the logic here
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more TODO comments in epilogue loop - verify group index recalculation and amax scaling logic

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

using transformer_engine::detail::ShapeRepresentation;

void *input_base_ptr = reinterpret_cast<void *>(input->data.dptr);
// TODO(zhongbo): add input sanity checks here
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add input sanity checks as noted in TODO

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant