-
Notifications
You must be signed in to change notification settings - Fork 607
[JAX] Support for batched einsum and grouped GEMM without D2H memcpy #2604
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
- Add FP8 scale_inv pointer handling in nvte_grouped_gemm for proper FP8 GEMM - Fix random padding in tests to ensure 16-byte alignment for all dtypes - Reorder GroupedGemmSetupWorkspace members for natural alignment - Remove debug prints Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
- Remove unused alignment parameter from GroupedGemmSetupWorkspace::from_buffers - Simplify select_grouped_operand by removing dead code branches - Add GroupedOperandSelection.tensor field to avoid passing tensor separately - Extract set_fp8_scale_pointers and init_matrix_layouts helpers - Add safety check for FP8 on Hopper column-wise fallback - Support NULL C tensor when beta=0 (uses D as placeholder) - Remove unused get_scale_inv() from test - Add use_null_c test parameter and test case - Fix documentation: alpha/beta are single element tensors only Signed-off-by: Piotr Gadzinski <pgadzinski@nvidia.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
for more information, see https://pre-commit.ci
- Change alpha/beta from single values to per-matrix arrays - Validate alpha/beta have exactly num_tensors elements - Update kernel to index alpha_ptr[idx] and beta_ptr[idx] - Move alpha/beta validation to validate_grouped_gemm_inputs - Update tests to use per-matrix alpha/beta arrays - Update documentation Signed-off-by: Piotr Gadzinski <pgadzinski@nvidia.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Piotr Gadzinski <pgadzinski@nvidia.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
This reverts commit bc6cf66.
… single-stream for multi tensor quantize)
for more information, see https://pre-commit.ci
Greptile SummaryThis PR adds support for batched einsum operations and grouped GEMM without device-to-host memory copies, enabling efficient Mixture-of-Experts (MoE) implementations with per-expert FP8 quantization in JAX. Key Changes
TestingComprehensive test coverage includes:
Temporary Workarounds
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User as JAX User Code
participant Einsum as einsum()
participant Dense as dense()
participant GEMM as GemmPrimitive
participant Quant as GroupedQuantize
participant CUDA as cuBLAS/CUDA
User->>Einsum: einsum("EBCM,EMH->EBCH", x, w, quantizer_sets)
Einsum->>Einsum: Parse equation & validate NN layout
Einsum->>Einsum: Stack quantizer_sets into pytree
Einsum->>Dense: vmap(dense_with_quantizer) over batch dim E
loop For each expert (vmapped)
Dense->>Quant: grouped_quantize(x, quantizer_i)
Quant->>CUDA: GroupedQuantizeFFI (batched)
CUDA-->>Quant: quantized tensors + scales
Dense->>Quant: grouped_quantize(w, quantizer_i)
Quant->>CUDA: GroupedQuantizeFFI (batched)
CUDA-->>Quant: quantized tensors + scales
Dense->>GEMM: gemm(x_q, w_q, scales)
GEMM->>CUDA: nvte_grouped_gemm (if batched)
Note over CUDA: GPU-side setup kernel<br/>No D2H memcpy
CUDA->>CUDA: cublasLtMatmul (grouped)
CUDA-->>GEMM: output
GEMM-->>Dense: result
end
Dense-->>Einsum: vmapped outputs
Einsum-->>User: final result
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
24 files reviewed, 1 comment
| cudaMemsetAsync(outputs->untyped_data() + used_output_size, 0, | ||
| outputs->size_bytes() - used_output_size, stream); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: potential pointer arithmetic issue with untyped data
the pointer arithmetic outputs->untyped_data() + used_output_size treats the pointer as char* (byte-addressed), which should be correct. verify that used_output_size is calculated in bytes, not elements.
| cudaMemsetAsync(outputs->untyped_data() + used_output_size, 0, | |
| outputs->size_bytes() - used_output_size, stream); | |
| size_t used_output_size = (sum_group_sizes*non_group_m) * n * output_dtype_bytes; | |
| char* output_base = static_cast<char*>(outputs->untyped_data()); | |
| cudaMemsetAsync(output_base + used_output_size, 0, outputs->size_bytes() - used_output_size, stream); |
Description
Depends on #2502
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: