Skip to content

Conversation

@ksivaman
Copy link
Member

Description

#2388 introduced the GroupedTensor class in the core library. This PR partly integrates this functionality to the PyTorch bindings.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Expose a python GroupedTensor class.
  • Integrate GroupedTensor into GroupedLinear such that the parameters are contiguous.
  • Expose a C++ grouped_quantize API to python similar to the split_quantize which returns a quantized GroupedTensor that can be directly consumed by the GEMMs ([common] Add support for cuBLASLt GEMM for GroupedTensor #2502).

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman ksivaman marked this pull request as draft January 15, 2026 14:58
@ksivaman ksivaman requested a review from ptrendx January 15, 2026 14:58
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 15, 2026

Greptile Summary

This PR integrates the GroupedTensor class from #2388 into PyTorch bindings, enabling contiguous memory storage for multiple weight tensors in GroupedLinear.

Key Changes

  • New GroupedTensor class (918 lines): Stores multiple tensors with different shapes in contiguous memory, supporting all quantization recipes (FP8, MXFP8, NVFP4, block scaling)
  • GroupedLinear integration: Added make_grouped_weights() method that converts individual weight parameters into views of a single contiguous GroupedTensor storage
  • Recipe API refactoring: Changed type-checking methods from instance methods to classmethods (isinstanceissubclass) to align with _get_compatible_recipe() returning class types
  • Quantizer enhancements: Added get_columnwise_shape() and get_scale_shape() methods for proper memory layout calculations
  • Comprehensive tests: 430-line test suite verifying contiguous memory layout and quantization correctness across all recipes

Implementation Notes

The implementation allocates all weight data in a single contiguous buffer, then creates individual parameter views that share the underlying storage. This improves memory locality and enables future optimizations like grouped GEMMs (#2502).

Confidence Score: 4/5

  • This PR is safe to merge with minor caveats that should be verified through testing
  • The implementation is well-designed and comprehensive, with extensive tests covering all quantization recipes. The core GroupedTensor logic is sound, and the integration into GroupedLinear follows established patterns. However, there are two acknowledged areas needing verification: (1) the copy operation from regular tensors to quantized tensors in make_grouped_weights() has a TODO comment indicating uncertainty about correctness across all recipes, and (2) the assumption that all quantizers in a group are "effectively the same" is not strongly enforced. The recipe API change from instance methods to classmethods is correct but represents a subtle behavioral change that could affect code calling these methods on instances.
  • Pay close attention to transformer_engine/pytorch/module/grouped_linear.py (copy operation in make_grouped_weights) and transformer_engine/common/recipe/__init__.py (instance method to classmethod change)

Important Files Changed

Filename Overview
transformer_engine/pytorch/tensor/storage/grouped_tensor.py New 918-line file implementing GroupedTensor class for contiguous storage of multiple tensors with different shapes. Comprehensive implementation with quantization support for FP8, MXFP8, NVFP4, and block scaling recipes.
transformer_engine/pytorch/module/grouped_linear.py Added make_grouped_weights() method to convert weight parameters into contiguous GroupedTensor storage. Weights are copied and re-registered to share underlying storage.
transformer_engine/common/recipe/init.py Changed recipe type-checking methods from instance methods to classmethods, using issubclass() instead of isinstance(). This aligns with _get_compatible_recipe() returning class types.
tests/pytorch/test_grouped_tensor.py New comprehensive test file with 430 lines covering GroupedTensor construction, splitting, quantization for all supported recipes, and verification of contiguous memory layout.

Sequence Diagram

sequenceDiagram
    participant User
    participant GroupedLinear
    participant GroupedTensor
    participant Quantizer
    participant Storage

    Note over User,Storage: Initialization Phase
    User->>GroupedLinear: __init__(num_gemms, in_features, out_features)
    GroupedLinear->>GroupedLinear: register_parameter(weight0...weightN)
    GroupedLinear->>GroupedLinear: reset_parameters()
    GroupedLinear->>GroupedLinear: make_grouped_weights()
    
    Note over GroupedLinear,Storage: Weight Consolidation
    GroupedLinear->>Quantizer: _get_weight_quantizers()
    Quantizer-->>GroupedLinear: [quantizer0...quantizerN]
    GroupedLinear->>GroupedTensor: make_grouped_tensor(num_tensors, shapes, quantizers)
    
    Note over GroupedTensor,Storage: Allocate Contiguous Storage
    GroupedTensor->>GroupedTensor: analyze shape patterns
    GroupedTensor->>GroupedTensor: calculate logical_shape, offsets
    GroupedTensor->>Storage: allocate contiguous buffers (data, scale_inv, etc)
    GroupedTensor->>GroupedTensor: split_into_quantized_tensors()
    GroupedTensor-->>GroupedLinear: grouped_weights with quantized_tensors
    
    Note over GroupedLinear: Copy & Re-register Weights
    loop for each weight i
        GroupedLinear->>GroupedTensor: quantized_tensors[i].copy_(weights[i])
        GroupedLinear->>GroupedLinear: register_parameter(weightI, quantized_tensors[i])
    end
    
    Note over User,Storage: Forward Pass
    User->>GroupedLinear: forward(inp, m_splits)
    GroupedLinear->>GroupedLinear: _get_weight_tensors()
    GroupedLinear->>GroupedLinear: prepare quantizers
    GroupedLinear->>GroupedLinear: _GroupedLinear.apply()
    Note over GroupedLinear: All weights share contiguous storage
    GroupedLinear->>GroupedLinear: general_grouped_gemm(weights, inputs)
    GroupedLinear-->>User: output tensor
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

9 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment on lines 771 to 774
# TODO(ksivamani): Verify correctness of copy for all recipes.
with torch.no_grad():
for i in range(self.num_gemms):
grouped_weights.quantized_tensors[i].copy_(weights[i])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: check that the copy operation works correctly for all quantization recipes (FP8, MXFP8, NVFP4, block scaling). the TODO comment on line 771 acknowledges this needs verification.

Comment on lines 382 to 386
# TODO(ksivaman): (Do we need multiple quantizers?)
# Current implementation assumes all tensors have the different quantizers.
# instances but effectively the same quantizer.
rowwise_usage = quantizers[0].rowwise_usage if not no_quantization else True
columnwise_usage = quantizers[0].columnwise_usage if not no_quantization else False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: check that all quantizers in the group are compatible. the comment acknowledges uncertainty about whether multiple quantizers are needed, but the implementation assumes they're "effectively the same" - mixed quantization schemes could cause issues.

Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
@ksivaman ksivaman force-pushed the grouped_tensor_python branch from 2b7ea40 to 40c619e Compare January 16, 2026 07:36
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
shape = [(512, 512) for _ in range(num_tensors)]
quantizers = make_quantizers(quantization, num_tensors, shape)

grouped_tensor = GroupedTensor.make_grouped_tensor(
Copy link
Collaborator

@zhongbozhu zhongbozhu Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not graph safe for MOE activation inputs. Maybe we should make a make_grouped_tensor_graph_safe API?

def __init__(
self,
num_tensors: int,
shape: List[Tuple[int, int]],
Copy link
Collaborator

@zhongbozhu zhongbozhu Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot do this, shape is supposed to be a tensor on device and we should have a tiny kernel that computes the offsets
Edit:
It's okay for weights where shapes are static, but for moe activation, this needs to be changed.


no_quantization = quantizers is None or len(quantizers) == 0 or quantizers[0] is None

# TODO(ksivaman): (Do we need multiple quantizers?)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we shouldn't have multiple quantizers, multiple quantizers only make sense for recipes like delayed scaling, where quantizer holds amax values. For blockwise current scaling recipes like mxfp8 and nvfp4, quantizer only holds configurations like whether we are doing RHT, then only one quantizer is enough, and it's better for reducing CPU overhead since pybind don't need to convert multiple objects.

Having separate quantizers also doesn't make sense for cuda graph, because having multiple quantizers implicitly assumes splitting the input, which then implicitly assumes that host knows about the shape of input for each expert, which will then break cuda graph.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. One customer has reviewed this PR and commented that having multiple quantizers would introduce CPU overheads when calling quantizer.update_usage/set_usage/... if the number of tensors is very large, for example, in FSDP.

Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
if defer_init:
return

weights = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about the case where we do fp8 param gather, each weight in this case is already a Float8Tensor?

/*! \enum Float8BlockScaleTensorFormat
* \brief Data format for an FP8 block-scaled tensor
*/
enum class Float8BlockScaleTensorFormat {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this flag has been deprecated, plus that for grouped tensor, because there is no TP, the format is always GEMM_READY

ksivaman and others added 7 commits February 3, 2026 00:20
* changes for pytoch extension; but everything seems to be broken probably unrelated to my changes

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

* fix the issues

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

* comment nvte API since Oleg's PR is not merged

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

* test for all cases:

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

* tensor attributes should be set later

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

---------

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman ksivaman mentioned this pull request Feb 6, 2026
15 tasks
@ksivaman ksivaman added the MoE label Feb 6, 2026
ksivaman and others added 4 commits February 9, 2026 03:02
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

quantized_tensors = self.split_into_quantized_tensors()
for i in range(self.num_tensors):
self.quantizer.update_quantized(tensors[i], quantized_tensors[i], noop_flag=noop_flag)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be calling the group_quantize from the C side? Why is it still in a loop of num_tensors iterations? Thanks.

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
GetTransformerEngineDType(tensor_offsets.scalar_type()),
getTensorShape(tensor_offsets));
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add the with_gemm_swizzled_scales support here too please?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for grouped tensor, since there is no TP, we can always assume with_gemm_swizzled_scales = True because current multi-tensor-swizzle doesn't work with grouped tensor anyway.

Although I agree that having a toggle would be good.


def split_into_quantized_tensors(
self,
) -> List[Union[QuantizedTensorStorage, torch.Tensor]]:
Copy link
Collaborator

@cyanguwa cyanguwa Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this function be extended to supporting any number n of tensors that satisfies "num_tensors is divisible by n"? This would help with reshaping efforts, for example, in attention.

from .nvfp4_tensor_storage import NVFP4TensorStorage


class GroupedTensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be GroupedTensor or GroupedTensorStorage, like with the other low-precision tensor types?

f"GroupedTensor(num_tensors={self.num_tensors}, "
f"shape={self.shape}, "
f"logical_shape={self.logical_shape}, "
f"dtype={self.get_dtype()})"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add the quantizer information here?

- logical_shape provides the conceptual 2D interpretation
- All data is stored on device in contiguous layout
Note: This structure is used only for combined storage of multiple tensors with the same dtype and scaling mode.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be duplicate to the first sentence in the docstring.

# Offsets for indexing into contiguous 1D layout (OPTIONAL - not needed if all_same_shape())
# tensor_offsets[i] = element offset to start of tensor i (cumulative sum of numel for tensors 0..i-1)
# Usage: tensor_i_ptr = data.data_ptr() + tensor_offsets[i] * element_size
# If None and all_same_shape(): offset[i] = i * M * N (where M, N are common dimensions)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At first glance, these two tensors are confusing. Maybe we could explain the difference as one being the offset on the element level, one at the byte level?

@staticmethod
def make_grouped_tensor_with_shapes(
num_tensors: int,
shape: List[Tuple[int, int]],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be named "shapes" please?

dtype: Optional[torch.dtype] = None,
) -> GroupedTensor:
"""
Create a GroupedTensor for storing multiple weight tensors of the same shape.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why "weight" tensors?

)
result.append(tensor_data)
else:
raise RuntimeError("GroupedTensor has no data to split")
Copy link
Collaborator

@cyanguwa cyanguwa Feb 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The if/elseif/else code seems to be the same for both "offsets being None" or "not being None" cases, so maybe we can combine the two? Similarly, can the "data" part of the split for both "no_quantization" and "quantization" paths be combined, so there's only one "for i in range(self.num_tensors)" loop?

]


def make_quantizer(quantization: str, num_tensors: int, shape: List[Tuple[int, int]]) -> Quantizer:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need "num_tensors" as an argument here anymore, I think, because we assume all tensors in the group use the same kind of quantizer.

Quantize the GroupedTensor inplace.
"""

quantized_tensors = self.split_into_quantized_tensors()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should "quantized_tensors" here be "self.quantized_tensors"?

ksivaman and others added 5 commits February 10, 2026 03:17
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

@pytest.mark.parametrize(
"shape",
[[(256, 512), (512, 512), (768, 512)], [(512, 512), (512, 512), (512, 512)]],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add edge cases to the test:

  1. zero tokens at all
  2. 0 tokens in the beginning
  3. 0 tokens in the end
  4. 0 tokens in the middle

py::reinterpret_borrow<py::object>(quantizer), first_dims, logical_first_dim,
logical_last_dim);

NVTE_SCOPED_GIL_RELEASE({
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's build a switch here for recipe like split_quantize

m.def("dequantize", &transformer_engine::pytorch::dequantize, "Dequantize", py::arg("input"),
py::arg("otype"));

m.def("group_quantize", transformer_engine::pytorch::group_quantize, py::arg("tensor"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for future grouped quantization for weights, let's add a slot for an optional noop tensor for cuda graph

}

if (columnwise_usage) {
columnwise_data = at::empty({total_elements}, uint8_opts);
Copy link
Collaborator

@vthumbe1503 vthumbe1503 Feb 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A lot of the code is duplicated throughout for each and e very quantizer. Maybe we should have common code in the base Quantizer implementation

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or maybe call the utility function from type_conerters.cpp.

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants