[C] NVFP4 quantization for GroupedTensor#2655
Conversation
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Greptile OverviewGreptile SummaryThis PR adds NVFP4 quantization support for Key Changes
Implementation DetailsThe graph-safe variants differ from existing implementations by:
The fusion kernel (
Notes
Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant API as nvte_group_hadamard_transform_*_graph_safe
participant Wrapper as GroupedTensorWrapper
participant Kernel as CUDA Kernel (SM100+)
participant Device as Device Memory
User->>Wrapper: Create GroupedTensorWrapper
Wrapper->>Wrapper: nvte_create_grouped_tensor()
User->>Wrapper: set_rowwise_data(), set_columnwise_data(), etc.
Wrapper->>Device: Set device pointers for tensor params
User->>API: nvte_group_hadamard_transform_amax_graph_safe()
API->>API: Convert NVTEGroupedTensor
API->>API: Validate num_tensors > 0
alt NVFP4 Quantization with Fusion
User->>API: nvte_group_hadamard_transform_cast_fusion_graph_safe()
API->>Kernel: group_row_col_rht_gemm_ntt_w_sfc_graph_safe()
Kernel->>Device: TMA load input tensors
Kernel->>Kernel: Hadamard transform (RHT)
Kernel->>Kernel: NVFP4 quantization (stochastic rounding)
Kernel->>Device: Store rowwise/columnwise quantized data
Kernel->>Device: Store scale factors and amax
else Simple Amax Computation
API->>Kernel: GraphSafeGroupHadamardAmaxTmaKernel()
Kernel->>Device: TMA load input tensors
Kernel->>Kernel: Compute amax (with/without RHT)
Kernel->>Device: Store rowwise/columnwise amax
end
Kernel-->>API: Kernel completion
API-->>User: Return
|
| // TODO(zhongbo): double check the logic here | ||
| int group_idx = get_current_tensor_id(shape_rep, num_tensors, | ||
| (scheduler.tile_n_base() * size<1>(epilogue_tiler)) * M, | ||
| packed_N, M, offsets); | ||
|
|
||
| // Determine quantization scale factor layouts/output splits for this group | ||
| TSFDLayout sfd_layout; | ||
| int cur_N = static_cast<int>(first_dims[group_idx]); | ||
| if constexpr (kEnableSwizzleSFOutput) { | ||
| sfd_layout = tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{}); | ||
| } else { | ||
| sfd_layout = make_layout(make_shape(M, make_shape(Int<SFVecSize>{}, cur_N / SFVecSize)), | ||
| make_stride(cur_N / SFVecSize, make_stride(_0{}, _1{}))); | ||
| } | ||
| // Build output tensors for columns and their quant scales | ||
| // TODO(zhongbo): double check the logic here |
There was a problem hiding this comment.
multiple TODO comments requesting logic verification in critical group index calculation and tensor layout code - verify group_idx calculation and tensor layout logic are correct before merging
| // TODO(zhongbo): double check the logic here | |
| int group_idx = get_current_tensor_id(shape_rep, num_tensors, | |
| (scheduler.tile_n_base() * size<1>(epilogue_tiler)) * M, | |
| packed_N, M, offsets); | |
| // Determine quantization scale factor layouts/output splits for this group | |
| TSFDLayout sfd_layout; | |
| int cur_N = static_cast<int>(first_dims[group_idx]); | |
| if constexpr (kEnableSwizzleSFOutput) { | |
| sfd_layout = tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{}); | |
| } else { | |
| sfd_layout = make_layout(make_shape(M, make_shape(Int<SFVecSize>{}, cur_N / SFVecSize)), | |
| make_stride(cur_N / SFVecSize, make_stride(_0{}, _1{}))); | |
| } | |
| // Build output tensors for columns and their quant scales | |
| // TODO(zhongbo): double check the logic here | |
| // Determine the current tensor group index based on tile offset | |
| int group_idx = get_current_tensor_id(shape_rep, num_tensors, | |
| (scheduler.tile_n_base() * size<1>(epilogue_tiler)) * M, | |
| packed_N, M, offsets); | |
| // Determine quantization scale factor layouts/output splits for this group | |
| TSFDLayout sfd_layout; | |
| int cur_N = static_cast<int>(first_dims[group_idx]); | |
| if constexpr (kEnableSwizzleSFOutput) { | |
| sfd_layout = tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{}); | |
| } else { | |
| sfd_layout = make_layout(make_shape(M, make_shape(Int<SFVecSize>{}, cur_N / SFVecSize)), | |
| make_stride(cur_N / SFVecSize, make_stride(_0{}, _1{}))); | |
| } | |
| // Build output tensors for columns and their quant scales | |
| Tensor mD = make_tensor(cute::subbyte_iterator<TD>(reinterpret_cast<TD *>( | |
| reinterpret_cast<char *>(QA_COLWISE) + offsets[group_idx] / 2)), | |
| make_shape(M, cur_N), DStride{}); // (M,packed_N) |
| // TODO(zhongbo): double check the logic here | ||
| int cur_group_idx = get_current_tensor_id(shape_rep, num_tensors, | ||
| global_tile_n_offset * M, packed_N, M, offsets); | ||
|
|
||
| if (cur_group_idx != group_idx) { | ||
| group_idx = cur_group_idx; | ||
| c_global_amax_val = shared_storage.global_d_amax[group_idx]; | ||
| // update amax | ||
| global_encode_scale = c_global_amax_val > 0.0f | ||
| ? cutlass::minimum_with_nan_propagation<float>{}( | ||
| (fp8_max * fp4_max) / c_global_amax_val, | ||
| cutlass::platform::numeric_limits<float>::max()) | ||
| : 1.0f; | ||
| global_decode_scale = 1.0f / global_encode_scale; | ||
| if constexpr (kUseFastMath) { | ||
| global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; | ||
| } | ||
| // TODO(zhongbo): double check the logic here |
There was a problem hiding this comment.
more TODO comments in epilogue loop - verify group index recalculation and amax scaling logic
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| using transformer_engine::detail::ShapeRepresentation; | ||
|
|
||
| void *input_base_ptr = reinterpret_cast<void *>(input->data.dptr); | ||
| // TODO(zhongbo): add input sanity checks here |
There was a problem hiding this comment.
add input sanity checks as noted in TODO
Description
Pieces taken from #2600.
Type of change
Changes
Checklist: