-
Notifications
You must be signed in to change notification settings - Fork 607
[PyTorch] GroupedTensor integration
#2600
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Greptile SummaryThis PR integrates the Key Changes
Implementation NotesThe implementation allocates all weight data in a single contiguous buffer, then creates individual parameter views that share the underlying storage. This improves memory locality and enables future optimizations like grouped GEMMs (#2502). Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant GroupedLinear
participant GroupedTensor
participant Quantizer
participant Storage
Note over User,Storage: Initialization Phase
User->>GroupedLinear: __init__(num_gemms, in_features, out_features)
GroupedLinear->>GroupedLinear: register_parameter(weight0...weightN)
GroupedLinear->>GroupedLinear: reset_parameters()
GroupedLinear->>GroupedLinear: make_grouped_weights()
Note over GroupedLinear,Storage: Weight Consolidation
GroupedLinear->>Quantizer: _get_weight_quantizers()
Quantizer-->>GroupedLinear: [quantizer0...quantizerN]
GroupedLinear->>GroupedTensor: make_grouped_tensor(num_tensors, shapes, quantizers)
Note over GroupedTensor,Storage: Allocate Contiguous Storage
GroupedTensor->>GroupedTensor: analyze shape patterns
GroupedTensor->>GroupedTensor: calculate logical_shape, offsets
GroupedTensor->>Storage: allocate contiguous buffers (data, scale_inv, etc)
GroupedTensor->>GroupedTensor: split_into_quantized_tensors()
GroupedTensor-->>GroupedLinear: grouped_weights with quantized_tensors
Note over GroupedLinear: Copy & Re-register Weights
loop for each weight i
GroupedLinear->>GroupedTensor: quantized_tensors[i].copy_(weights[i])
GroupedLinear->>GroupedLinear: register_parameter(weightI, quantized_tensors[i])
end
Note over User,Storage: Forward Pass
User->>GroupedLinear: forward(inp, m_splits)
GroupedLinear->>GroupedLinear: _get_weight_tensors()
GroupedLinear->>GroupedLinear: prepare quantizers
GroupedLinear->>GroupedLinear: _GroupedLinear.apply()
Note over GroupedLinear: All weights share contiguous storage
GroupedLinear->>GroupedLinear: general_grouped_gemm(weights, inputs)
GroupedLinear-->>User: output tensor
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
9 files reviewed, 2 comments
| # TODO(ksivamani): Verify correctness of copy for all recipes. | ||
| with torch.no_grad(): | ||
| for i in range(self.num_gemms): | ||
| grouped_weights.quantized_tensors[i].copy_(weights[i]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: check that the copy operation works correctly for all quantization recipes (FP8, MXFP8, NVFP4, block scaling). the TODO comment on line 771 acknowledges this needs verification.
| # TODO(ksivaman): (Do we need multiple quantizers?) | ||
| # Current implementation assumes all tensors have the different quantizers. | ||
| # instances but effectively the same quantizer. | ||
| rowwise_usage = quantizers[0].rowwise_usage if not no_quantization else True | ||
| columnwise_usage = quantizers[0].columnwise_usage if not no_quantization else False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: check that all quantizers in the group are compatible. the comment acknowledges uncertainty about whether multiple quantizers are needed, but the implementation assumes they're "effectively the same" - mixed quantization schemes could cause issues.
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
2b7ea40 to
40c619e
Compare
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Description
#2388 introduced the
GroupedTensorclass in the core library. This PR partly integrates this functionality to the PyTorch bindings.Type of change
Changes
GroupedTensorclass.GroupedTensorintoGroupedLinearsuch that the parameters are contiguous.grouped_quantizeAPI to python similar to thesplit_quantizewhich returns a quantizedGroupedTensorthat can be directly consumed by the GEMMs ([common] Add support for cuBLASLt GEMM for GroupedTensor #2502).Checklist: