-
Notifications
You must be signed in to change notification settings - Fork 24
NVFP4 cast/transpose without TMA #472
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
b8a4024
0519b4b
8f4b04d
e60ff21
8bbb162
f573b40
eaaae94
8f94cf6
bac7993
8ae38e8
05a977a
0385852
46d382d
15416f1
bac5096
da24223
c03b7bb
c453dba
4a843ba
316dffb
8a47bc5
5c747bd
db56b8f
b318bda
62eea94
6d459ec
6eb2707
8cec975
c20e0e9
ccda439
4b0fd34
84934c2
e79134a
586bd09
4896edf
aa18e9a
c918a19
5bd7388
95d0c9f
6cd6038
55a8c84
10d88bf
b4caf6f
511db61
36cf73a
a85f68f
f4f5ec9
a607feb
ca2e444
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -91,7 +91,6 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, | |
| dummy_workspace_tensor, stream); | ||
| break; | ||
| } | ||
| #ifndef __HIP_PLATFORM_AMD__ | ||
| case NVTE_NVFP4_1D_SCALING: { | ||
| NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING"); | ||
|
|
||
|
|
@@ -108,6 +107,7 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, | |
| (cols % 32 == 0) && output_tensor->has_data(); | ||
|
|
||
| // Launch NVFP4 quantize kernel | ||
| #ifndef __HIP_PLATFORM_AMD__ | ||
| if (use_optimized_kernel) { | ||
| if (quant_config_cpp.nvfp4_2d_quantization) { | ||
| nvfp4::quantize_transpose</*use_2d_quantization=*/true>( | ||
|
|
@@ -117,10 +117,22 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, | |
| *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); | ||
| } | ||
| } else { | ||
| #endif | ||
| auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax | ||
| : output_tensor->columnwise_amax; | ||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| // Fix for upstream bug: if amax was not explicitly set, fall back to the | ||
| // scale field which holds the same value when set via set_scale(). | ||
| NVTE_CHECK(global_amax.dptr != nullptr || output_tensor->scale.dptr != nullptr, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it a bug fix for upstream? If not, why do we need this specific treatment for global amax?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I believe this is a bug in upstream.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe put comment then.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. Thanks! Also, check if upstream already had an fix. If not, I think it's okay to drop the rocm specific guard. What do you think @ipanfilo?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this is fixed in upstream yet. I added a comment in a607feb |
||
| "NVFP4 quantization requires global_amax (output_tensor->amax) " | ||
| "or scale to be set. Call output.set_scale(amax_value) before quantizing."); | ||
| const SimpleTensor& effective_amax = | ||
| (global_amax.dptr != nullptr) ? global_amax : output_tensor->scale; | ||
| quantize_transpose_vector_blockwise_fp4( | ||
| /*input=*/input_tensor->data, /*global_amax=*/effective_amax, | ||
| #else | ||
| /*input=*/input_tensor->data, /*global_amax=*/global_amax, | ||
| #endif | ||
| /*scale_inv=*/output_tensor->scale_inv, | ||
| /*scale_inv_t=*/output_tensor->columnwise_scale_inv, | ||
| /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data, | ||
|
|
@@ -131,9 +143,12 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, | |
| /*rng_state=*/quant_config_cpp.rng_state, | ||
| /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, | ||
| /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); | ||
| #ifndef __HIP_PLATFORM_AMD__ | ||
| } | ||
| #endif | ||
| break; | ||
| } | ||
| #ifndef __HIP_PLATFORM_AMD__ | ||
| case NVTE_BLOCK_SCALING_2D: { | ||
| // TODO(kwyss): IS_ACT, ParamOP, OP parameters support. | ||
| NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_2D"); | ||
|
|
@@ -238,7 +253,6 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens | |
| stream); | ||
| break; | ||
| } | ||
| #ifndef __HIP_PLATFORM_AMD__ | ||
| case NVTE_NVFP4_1D_SCALING: { | ||
| NVTE_CHECK((!IS_DBIAS && !IS_DACT), | ||
| "IS_DBIAS and IS_DACT are not supported by BWD NVTE_NVFP4_1D_SCALING"); | ||
|
|
@@ -256,6 +270,7 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens | |
| (cols % 32 == 0) && output_tensor->has_data(); | ||
|
|
||
| // Launch NVFP4 quantize kernel | ||
| #ifndef __HIP_PLATFORM_AMD__ | ||
| if (use_optimized_kernel) { | ||
| if (quant_config_cpp.nvfp4_2d_quantization) { | ||
| nvfp4::quantize_transpose</*use_2d_quantization=*/true>( | ||
|
|
@@ -265,10 +280,22 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens | |
| *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); | ||
| } | ||
| } else { | ||
| #endif | ||
| auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax | ||
| : output_tensor->columnwise_amax; | ||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| // Fix for upstream bug: if amax was not explicitly set, fall back to the | ||
| // scale field which holds the same value when set via set_scale(). | ||
| NVTE_CHECK(global_amax.dptr != nullptr || output_tensor->scale.dptr != nullptr, | ||
| "NVFP4 quantization requires global_amax (output_tensor->amax) " | ||
| "or scale to be set. Call output.set_scale(amax_value) before quantizing."); | ||
| const SimpleTensor& effective_amax = | ||
| (global_amax.dptr != nullptr) ? global_amax : output_tensor->scale; | ||
| quantize_transpose_vector_blockwise_fp4( | ||
| /*input=*/grad_tensor->data, /*global_amax=*/global_amax, | ||
| /*input=*/input_tensor->data, /*global_amax=*/effective_amax, | ||
| #else | ||
| /*input=*/input_tensor->data, /*global_amax=*/global_amax, | ||
| #endif | ||
| /*scale_inv=*/output_tensor->scale_inv, | ||
| /*scale_inv_t=*/output_tensor->columnwise_scale_inv, | ||
| /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data, | ||
|
|
@@ -279,9 +306,12 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens | |
| /*rng_state=*/quant_config_cpp.rng_state, | ||
| /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, | ||
| /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); | ||
| #ifndef __HIP_PLATFORM_AMD__ | ||
| } | ||
| #endif | ||
| break; | ||
| } | ||
| #ifndef __HIP_PLATFORM_AMD__ | ||
| case NVTE_BLOCK_SCALING_2D: { | ||
| // TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support. | ||
| NVTE_CHECK((!IS_DBIAS && !IS_DACT), | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.