Skip to content

Conversation

@jberchtold-nvidia
Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia commented Jan 15, 2026

Description

Depends on #2502

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

phu0ngng and others added 30 commits December 3, 2025 13:07
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
- Add FP8 scale_inv pointer handling in nvte_grouped_gemm for proper FP8 GEMM
- Fix random padding in tests to ensure 16-byte alignment for all dtypes
- Reorder GroupedGemmSetupWorkspace members for natural alignment
- Remove debug prints

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
- Remove unused alignment parameter from GroupedGemmSetupWorkspace::from_buffers
- Simplify select_grouped_operand by removing dead code branches
- Add GroupedOperandSelection.tensor field to avoid passing tensor separately
- Extract set_fp8_scale_pointers and init_matrix_layouts helpers
- Add safety check for FP8 on Hopper column-wise fallback
- Support NULL C tensor when beta=0 (uses D as placeholder)
- Remove unused get_scale_inv() from test
- Add use_null_c test parameter and test case
- Fix documentation: alpha/beta are single element tensors only

Signed-off-by: Piotr Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
- Change alpha/beta from single values to per-matrix arrays
- Validate alpha/beta have exactly num_tensors elements
- Update kernel to index alpha_ptr[idx] and beta_ptr[idx]
- Move alpha/beta validation to validate_grouped_gemm_inputs
- Update tests to use per-matrix alpha/beta arrays
- Update documentation

Signed-off-by: Piotr Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Piotr Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia marked this pull request as draft January 15, 2026 19:52
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 15, 2026

Greptile Summary

This PR adds support for batched einsum operations and grouped GEMM without device-to-host memory copies, enabling efficient Mixture-of-Experts (MoE) implementations with per-expert FP8 quantization in JAX.

Key Changes

  • Grouped GEMM Implementation: New nvte_grouped_gemm C++ API using cuBLAS 13.1+ for batched matrix operations, with GPU-side setup kernel to avoid D2H memcpy overhead
  • Einsum Support: JAX einsum implementation using vmap+dense pattern for MoE workloads, supporting per-expert quantization with QuantizerSet arrays
  • Quantization Updates: Refactored grouped quantization to use batched primitives, changed scale initialization from jnp.empty to jnp.ones to prevent uninitialized memory issues
  • Architecture Requirements: Grouped GEMM requires Blackwell (SM100+) and cuBLAS 13.1+, with proper version checks in place

Testing

Comprehensive test coverage includes:

  • C++ unit tests for grouped GEMM with various shape configurations (uniform, varying first/last dims)
  • JAX tests for einsum with MoE patterns, gradients, and multiple FP8 recipes
  • Validation of forward/backward passes with numerical accuracy checks

Temporary Workarounds

  • quantization.cpp:405-406: Memset to zero uninitialized buffer portions when over-allocated (noted as temporary fix needing investigation)

Confidence Score: 4/5

  • This PR is safe to merge with minor style improvements recommended
  • The implementation is well-tested with comprehensive unit tests for both C++ and JAX layers. Architecture checks ensure SM100+ requirement for grouped GEMM. Main concerns are temporary workarounds (memset for over-allocated buffers) that should be monitored, and some code style improvements. The core logic appears sound with proper validation and error handling.
  • Pay close attention to transformer_engine/jax/csrc/extensions/quantization.cpp for the temporary memset workaround

Important Files Changed

Filename Overview
transformer_engine/jax/csrc/extensions/quantization.cpp Optimized memset for quantization buffer to only zero uninitialized portions, preventing unnecessary overhead
transformer_engine/common/gemm/cublaslt_grouped_gemm.cu New grouped GEMM implementation using cuBLAS 13.1+ with proper SM100+ architecture checks, comprehensive error handling, and GPU-side setup kernel
transformer_engine/jax/einsum.py New einsum implementation using vmap+dense for MoE with per-expert quantization, validates NN layout and single batch dimension
transformer_engine/jax/cpp_extensions/quantization.py Changed scale initialization from empty to ones, refactored batcher to use general batcher implementation
transformer_engine/jax/cpp_extensions/gemm.py Added grouped GEMM support with new primitives, workspace management, and batched quantization integration
tests/jax/test_einsum.py Comprehensive test suite for einsum with MoE operations, gradients, and multiple FP8 recipes

Sequence Diagram

sequenceDiagram
    participant User as JAX User Code
    participant Einsum as einsum()
    participant Dense as dense()
    participant GEMM as GemmPrimitive
    participant Quant as GroupedQuantize
    participant CUDA as cuBLAS/CUDA

    User->>Einsum: einsum("EBCM,EMH->EBCH", x, w, quantizer_sets)
    Einsum->>Einsum: Parse equation & validate NN layout
    Einsum->>Einsum: Stack quantizer_sets into pytree
    
    Einsum->>Dense: vmap(dense_with_quantizer) over batch dim E
    
    loop For each expert (vmapped)
        Dense->>Quant: grouped_quantize(x, quantizer_i)
        Quant->>CUDA: GroupedQuantizeFFI (batched)
        CUDA-->>Quant: quantized tensors + scales
        
        Dense->>Quant: grouped_quantize(w, quantizer_i)
        Quant->>CUDA: GroupedQuantizeFFI (batched)
        CUDA-->>Quant: quantized tensors + scales
        
        Dense->>GEMM: gemm(x_q, w_q, scales)
        GEMM->>CUDA: nvte_grouped_gemm (if batched)
        Note over CUDA: GPU-side setup kernel<br/>No D2H memcpy
        CUDA->>CUDA: cublasLtMatmul (grouped)
        CUDA-->>GEMM: output
        GEMM-->>Dense: result
    end
    
    Dense-->>Einsum: vmapped outputs
    Einsum-->>User: final result
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

24 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines +405 to +406
cudaMemsetAsync(outputs->untyped_data() + used_output_size, 0,
outputs->size_bytes() - used_output_size, stream);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: potential pointer arithmetic issue with untyped data

the pointer arithmetic outputs->untyped_data() + used_output_size treats the pointer as char* (byte-addressed), which should be correct. verify that used_output_size is calculated in bytes, not elements.

Suggested change
cudaMemsetAsync(outputs->untyped_data() + used_output_size, 0,
outputs->size_bytes() - used_output_size, stream);
size_t used_output_size = (sum_group_sizes*non_group_m) * n * output_dtype_bytes;
char* output_base = static_cast<char*>(outputs->untyped_data());
cudaMemsetAsync(output_base + used_output_size, 0, outputs->size_bytes() - used_output_size, stream);

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants