-
Notifications
You must be signed in to change notification settings - Fork 607
CPU Optimizations for FP8 #2559
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ormerEngine into cpu_fp8_optimizations Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci L1 pytorch |
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ormerEngine into cpu_fp8_optimizations Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
|
/te-ci L1 pytorch |
Greptile SummaryThis PR introduces CPU-side optimizations for FP8 tensor operations to reduce Python/C++ boundary overhead. Key optimizations:
Issues found:
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User as User Code
participant Linear as Linear.forward
participant QT as QuantizedTensor
participant Quantizer as C++ Quantizer
participant PyAPI as Python C API
User->>Linear: forward(input, weight, bias)
Note over Linear: Cache requires_grad values<br/>(optimization: avoid repeated lookups)
Linear->>QT: tensor.requires_grad
QT-->>Linear: cached _requires_grad
Linear->>QT: tensor.dtype
QT-->>Linear: cached _dtype
Linear->>Quantizer: create_tensor()
Quantizer->>PyAPI: PyDict_New() + PyObject_Call()
Note over Quantizer,PyAPI: Direct C API calls bypass<br/>pybind11 keyword arg overhead
PyAPI-->>Quantizer: Float8Tensor instance
Quantizer-->>Linear: TensorWrapper, py::object
Linear-->>User: output tensor
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (3)
-
transformer_engine/pytorch/csrc/util.cpp, line 18-20 (link)logic: Critical logical error:
||should be&&. This condition will always betruesince a value cannot simultaneously be both scaling modes, causing the function to always return nullopt for valid inputs. -
transformer_engine/pytorch/quantized_tensor.py, line 373-393 (link)style: commented-out code for
requires_gradcaching optimization - consider removing dead code entirely. Is this code planned to be implemented later or should it be removed?Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
-
transformer_engine/pytorch/module/linear.py, line 484 (link)logic: Logical error: this condition should use OR (||) not AND (&&). The original logic was checking if ANY tensor requires gradients for FP8 handling, but this now only activates when ALL three require gradients, including bias which may be None.
Should the FP8 condition check if any tensor requires gradients (OR logic) rather than all tensors (AND logic)?
10 files reviewed, 3 comments
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ormerEngine into cpu_fp8_optimizations Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
|
/te-ci L1 pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR implements CPU-side performance optimizations for FP8 operations by caching frequently accessed attributes and reducing redundant function calls. The optimizations target expensive PyObject attribute lookups on custom tensor types and repeated C++ function calls.
Key Changes:
- Caches
requires_grad,dtype,shape, andis_cudaattribute accesses to avoid expensive PyObject lookups on custom tensors - Reorders attribute checks in
get_tensor_device()to prioritize internal quantized tensor attributes - Makes
num_devicesstatic innvte_is_non_tn_fp8_gemm_supported()to cache device count - Stores GEMM support check results in local variables to avoid redundant function calls
Critical Issues Found:
- Variable redeclaration error in
cublaslt_gemm.cu(line 224) will prevent compilation - Logic bug in
linear.py(line 484) changes FP8 state management from OR logic to AND logic, breaking functionality when bias is None or doesn't require grad
Confidence Score: 0/5
- This PR cannot be merged due to compilation error and critical logic bug
- Two critical issues prevent merging: (1) C++ compilation will fail due to variable redeclaration at line 224 of cublaslt_gemm.cu, and (2) logic bug at line 484 of linear.py breaks FP8 state management by requiring all three tensors to have requires_grad=True instead of any one of them
- Pay close attention to
transformer_engine/common/gemm/cublaslt_gemm.cu(compilation error) andtransformer_engine/pytorch/module/linear.py(logic bug)
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/common/gemm/cublaslt_gemm.cu | 1/5 | Caches function call result to reduce overhead, but contains variable redeclaration error that will cause compilation failure |
| transformer_engine/common/transformer_engine.cpp | 5/5 | Makes num_devices static to avoid redundant calls to cuda::num_devices() - valid optimization |
| transformer_engine/pytorch/module/linear.py | 0/5 | Caches requires_grad checks for performance, but contains critical logic bug at line 484 that changes FP8 state management behavior |
Sequence Diagram
sequenceDiagram
participant User as User Code
participant Linear as Linear Module
participant Quantizer as Quantizer/QuantizedTensor
participant GEMM as GEMM Operations
participant CPP as C++ Extensions
Note over Linear,CPP: Performance Optimization Flow
User->>Linear: forward(input, weight, bias)
Note over Linear: Cache requires_grad checks
Linear->>Linear: inp_requires_grad = inp.requires_grad<br/>weight_requires_grad = weight.requires_grad<br/>bias_requires_grad = bias.requires_grad
Linear->>Quantizer: Check if quantized tensor
alt QuantizedTensor
Note over Quantizer: Use cached dtype property
Quantizer->>Quantizer: return self._dtype
Note over Quantizer: Use cached shape/is_cuda
Quantizer->>Quantizer: return self._data.shape
else Regular Tensor
Quantizer->>Linear: Standard attribute access
end
Linear->>CPP: get_tensor_device(tensor)
Note over CPP: Reordered attribute checks
CPP->>CPP: Check _rowwise_data first<br/>Check _columnwise_data<br/>Check device last
CPP-->>Linear: device_index
Linear->>GEMM: Configure GEMM parameters
Note over GEMM: Cache nvte_is_non_tn_fp8_gemm_supported
GEMM->>CPP: nvte_is_non_tn_fp8_gemm_supported()
Note over CPP: Static num_devices cached
CPP-->>GEMM: support_flag
GEMM->>GEMM: Store in local variable
GEMM->>GEMM: Execute optimized GEMM
GEMM-->>Linear: output
Note over Linear: FP8 State Management
alt FP8 enabled and requires_grad check
Linear->>Linear: Update FP8 tensors<br/>based on cached flags
end
Linear-->>User: output
Additional Comments (2)
|
|
/te-ci L1 pytorch |
|
/te-ci L1 pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
8 files reviewed, 8 comments
| ret.A_scale_inv = A.scale_inv.dptr; | ||
| ret.lda = is_A_transposed ? k : m; | ||
| if (!nvte_is_non_tn_fp8_gemm_supported() && !is_A_transposed) { | ||
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable is_nvte_non_tn_fp8_gemm_supported is redeclared in the same scope - it was already declared at line 224 for the B matrix configuration. This causes a compilation error or shadowing issue.
The variable should only be declared once at the beginning of the function scope and reused for both A and B matrix configurations.
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); | |
| // Move this declaration before line 125 (Configure A matrix) | |
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); |
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), | ||
| PyTuple_New(0), kwargs); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Memory leak: PyTuple_New(0) creates a new Python tuple object with a reference count of 1, but this reference is never decremented. The tuple is passed directly to PyObject_Call which doesn't steal the reference, so it must be explicitly released.
This same issue occurs at lines 186, 398, 417, 697, 718, 1039, 1058, 1346, and 1367.
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), | |
| PyTuple_New(0), kwargs); | |
| PyObject* args = PyTuple_New(0); | |
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), | |
| args, kwargs); | |
| Py_DECREF(args); | |
| Py_DECREF(kwargs); |
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
| PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Invalid parameter: The Float8TensorStorage.__new__ method does not accept a device parameter. Looking at the Python class definition in float8_tensor_storage.py, the __new__ signature only expects: data, fp8_scale_inv, fp8_dtype, data_transpose, and quantizer.
This will cause a runtime error when trying to create Float8TensorStorage instances. This issue also occurs at lines 184, 396, 415, and in similar patterns for MXFP8 and other quantizers.
| PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); | |
| // Remove this line - device parameter is not accepted | |
| // PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); |
| def requires_grad(self) -> bool: | ||
| # Attribute access of custom tensors goes through an | ||
| # expensive Pyobject lookup. Since requires_grad is set during | ||
| # initialization and may be updated, we cache it in a member variable. | ||
| return self._requires_grad |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical bug: The requires_grad property getter accesses self._requires_grad at line 378, but this member variable is never initialized in the __new__ method (lines 353-364). This will cause an AttributeError when the property is first accessed.
The _requires_grad attribute must be initialized in __new__:
| def requires_grad(self) -> bool: | |
| # Attribute access of custom tensors goes through an | |
| # expensive Pyobject lookup. Since requires_grad is set during | |
| # initialization and may be updated, we cache it in a member variable. | |
| return self._requires_grad | |
| @property | |
| def requires_grad(self) -> bool: | |
| # Attribute access of custom tensors goes through an | |
| # expensive Pyobject lookup. Since requires_grad is set during | |
| # initialization and may be updated, we cache it in a member variable. | |
| # Fallback to parent if not cached yet | |
| if not hasattr(self, '_requires_grad'): | |
| self._requires_grad = torch.Tensor.requires_grad.fget(self) | |
| return self._requires_grad |
| instance._dtype = dtype | ||
| return instance |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing initialization: The _requires_grad member variable needs to be initialized here to support the cached property added later in the PR. Without this, accessing the requires_grad property will fail.
| instance._dtype = dtype | |
| return instance | |
| instance._dtype = dtype | |
| instance._requires_grad = requires_grad | |
| return instance |
| PyTypeObject *NVFP4TensorPythonClass = nullptr; | ||
| PyTypeObject *NVFP4TensorStoragePythonClass = nullptr; | ||
| PyTypeObject *NVFP4QuantizerClass = nullptr; | ||
| bool is_extension_initialized = false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Race condition: The is_extension_initialized flag is not thread-safe. If init_extension() is called concurrently from multiple threads, both threads could pass the check at line 105 before either sets the flag at line 110, causing the initialization code to run multiple times.
The original implementation using null-checks on the individual class pointers was safer because Python's module import system provides thread safety. Consider using std::once_flag and std::call_once for proper thread-safe initialization:
| bool is_extension_initialized = false; | |
| bool is_extension_initialized = false; | |
| std::once_flag extension_init_flag; |
And in init_extension():
void init_extension() {
std::call_once(extension_init_flag, []() {
init_float8_extension();
init_mxfp8_extension();
init_float8blockwise_extension();
init_nvfp4_extensions();
});
}| static int num_devices = transformer_engine::cuda::num_devices(); | ||
| static std::vector<int> cache(num_devices, -1); | ||
| static std::vector<std::once_flag> flags(num_devices); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential edge case: The num_devices is captured once as a static variable, so if the number of devices changes at runtime (e.g., with CUDA MPS or device enumeration changes), the cache and flags vectors could be undersized, leading to out-of-bounds access at line 968.
While dynamic device changes are rare, consider adding a bounds check:
| static int num_devices = transformer_engine::cuda::num_devices(); | |
| static std::vector<int> cache(num_devices, -1); | |
| static std::vector<std::once_flag> flags(num_devices); | |
| static int num_devices = transformer_engine::cuda::num_devices(); | |
| static std::vector<int> cache(num_devices, -1); | |
| static std::vector<std::once_flag> flags(num_devices); | |
| int device_id = transformer_engine::cuda::current_device(); | |
| NVTE_CHECK(device_id < num_devices, "Device ID exceeds cached device count"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, 4 comments
| ret.B_scale_inv = B.scale_inv.dptr; | ||
| ret.ldb = is_B_transposed ? n : k; | ||
| if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) { | ||
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable is_nvte_non_tn_fp8_gemm_supported is redeclared in the same function scope (already declared at line 132). In C++, this is a compilation error. The variable should be declared once at the function scope level before both the A matrix and B matrix configuration blocks.
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); | |
| // Move declaration to line 104 (after GemmParam ret;) to avoid redeclaration |
| py::object scale_inv_py = py::cast(scale_inv); | ||
| // Initialize scale-inverse tensor | ||
| if (!scale_inv) { | ||
| scale_inv = at::reciprocal(scale); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical bug: scale_inv_py is created on line 146 by dereferencing scale_inv (which is std::optional<at::Tensor>), but scale_inv is only initialized on lines 148-150 if it's empty. When scale_inv is not provided to this function, line 146 will dereference an empty optional, causing undefined behavior or a crash.
The fix is to move line 146 after the initialization block:
| py::object scale_inv_py = py::cast(scale_inv); | |
| // Initialize scale-inverse tensor | |
| if (!scale_inv) { | |
| scale_inv = at::reciprocal(scale); | |
| } | |
| py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none(); | |
| // Initialize scale-inverse tensor | |
| if (!scale_inv) { | |
| scale_inv = at::reciprocal(scale); | |
| } | |
| py::object scale_inv_py = py::cast(*scale_inv); |
| PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); | ||
|
|
||
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), | ||
| PyTuple_New(0), kwargs); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Memory leak: PyTuple_New(0) creates a new tuple with reference count 1, but PyObject_Call does not steal the reference. The tuple must be stored and explicitly DECREF'd after the call. This pattern is repeated 12 times throughout this file (lines 166, 186, 398, 417, 697, 718, 1039, 1058, 1346, 1367).
| PyTuple_New(0), kwargs); | |
| PyObject* args = PyTuple_New(0); | |
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), | |
| args, kwargs); | |
| Py_DECREF(args); |
| instance._dtype = dtype | ||
| return instance |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical bug: _dtype is cached in __new__ but _requires_grad is not initialized. The requires_grad property (line 378) tries to return self._requires_grad, which will raise AttributeError when accessed before the setter is called. This breaks the optimization for caching requires_grad.
| instance._dtype = dtype | |
| return instance | |
| instance._dtype = dtype | |
| instance._requires_grad = requires_grad | |
| return instance |
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ormerEngine into cpu_fp8_optimizations Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
13 files reviewed, 13 comments
| ret.B_scale_inv = B.scale_inv.dptr; | ||
| ret.ldb = is_B_transposed ? n : k; | ||
| if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) { | ||
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Variable is_nvte_non_tn_fp8_gemm_supported is already declared at line 132. Declaring it again here causes a compilation error. Remove this declaration and reuse the variable from line 132.
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); | |
| // Reuse is_nvte_non_tn_fp8_gemm_supported from line 132 |
| } | ||
| py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none(); | ||
|
|
||
| py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Use-after-free bug: scale_inv_py is created from scale_inv before scale_inv is initialized. When scale_inv parameter is empty (lines 148-150), this dereferences an empty std::optional, causing undefined behavior. Move this line after the scale_inv initialization.
| py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none(); | |
| // Move scale_inv_py creation after initialization (line 150) |
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
| PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); | ||
|
|
||
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Memory leak: PyTuple_New(0) creates a new Python object with refcount 1, but it's never decremented. This leaks memory on every tensor creation. Add Py_DECREF for the tuple:
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), | |
| PyObject* args = PyTuple_New(0); | |
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), args, kwargs); | |
| Py_DECREF(args); |
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
| PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Memory leak: PyTuple_New(0) is not decremented (same issue as line 166).
| PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); | |
| PyObject* args = PyTuple_New(0); | |
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorPythonClass), args, kwargs); | |
| Py_DECREF(args); |
|
|
||
| ctx.owns_input = saved_inputmat is not inp | ||
| if ctx.fp8 and requires_grad(inp, weight, bias): | ||
| if ctx.fp8 and (inp_requires_grad or weight_requires_grad or bias_requires_grad): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Logic error: The original requires_grad(inp, weight, bias) returns True if ANY tensor requires grad. This change requires ALL THREE to be True with and, breaking FP8 state management when bias is None or doesn't require grad.
| if ctx.fp8 and (inp_requires_grad or weight_requires_grad or bias_requires_grad): | |
| if ctx.fp8 and (inp_requires_grad or weight_requires_grad or bias_requires_grad): |
| PyDict_SetItemString(kwargs, "is_2D_scaled", py::cast(block_scaling_dim == 2).ptr()); | ||
| PyDict_SetItemString(kwargs, "data_format", py::cast(data_format).ptr()); | ||
|
|
||
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8BlockwiseQTensorPythonClass), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Memory leak: PyTuple_New(0) is not decremented (same pattern as line 166).
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
|
|
||
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(MXFP8TensorStoragePythonClass), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Memory leak: PyTuple_New(0) is not decremented (same pattern as line 166).
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
|
|
||
| PyObject* result = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Memory leak: PyTuple_New(0) is not decremented (same pattern as line 166).
| PyDict_SetItemString(kwargs, "fp4_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
|
|
||
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(NVFP4TensorStoragePythonClass), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Memory leak: PyTuple_New(0) is not decremented (same pattern as line 166).
| PyDict_SetItemString(kwargs, "fp4_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
|
|
||
| PyObject* result = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Memory leak: PyTuple_New(0) is not decremented (same pattern as line 166).
for more information, see https://pre-commit.ci
|
/te-ci L1 pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
8 files reviewed, 8 comments
|
|
||
| Py_DECREF(kwargs); | ||
| Py_DECREF(args); | ||
| NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error check occurs AFTER Py_DECREF(kwargs) and Py_DECREF(args) (lines 168-169), which means if result is nullptr, the program will terminate via NVTE_CHECK without proper cleanup. While this is acceptable for fatal errors, it's inconsistent with the cleanup-before-check pattern and could mask the actual Python exception message.
Suggested improvement: Check for nullptr before decrementing reference counts, and use PyErr_Print() or PyErr_Fetch() to preserve the Python exception information:
| NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance"); | |
| PyObject* result = | |
| PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), args, kwargs); | |
| if (result == nullptr) { | |
| PyErr_Print(); // Print the actual Python error | |
| Py_DECREF(kwargs); | |
| Py_DECREF(args); | |
| NVTE_CHECK(false, "Failed to create Float8TensorStorage instance"); | |
| } | |
| Py_DECREF(kwargs); | |
| Py_DECREF(args); | |
| out_py = py::reinterpret_steal<py::object>(result); |
This same pattern applies to all other PyObject_Call invocations in this file (lines 187, 401, 706, 723, 1047, 1068, 1357, 1378).
| @requires_grad.setter | ||
| def requires_grad(self, value: bool) -> None: | ||
| """Set requires_grad property so that autograd engine is aware of the change""" | ||
| # Update the cached value | ||
| self._requires_grad = value | ||
| # Call parent class to ensure autograd engine is aware of the change | ||
| torch.Tensor.requires_grad.fset(self, value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The requires_grad property setter updates the cached _requires_grad value and then calls torch.Tensor.requires_grad.fset(self, value) to ensure the autograd engine is aware. However, there's a potential issue: if the parent class's setter fails or raises an exception, _requires_grad will already be updated, causing the cached value to be out of sync with the actual tensor state.
Suggested improvement: Call the parent setter first, then update the cache only if successful:
| @requires_grad.setter | |
| def requires_grad(self, value: bool) -> None: | |
| """Set requires_grad property so that autograd engine is aware of the change""" | |
| # Update the cached value | |
| self._requires_grad = value | |
| # Call parent class to ensure autograd engine is aware of the change | |
| torch.Tensor.requires_grad.fset(self, value) | |
| @requires_grad.setter | |
| def requires_grad(self, value: bool) -> None: | |
| """Set requires_grad property so that autograd engine is aware of the change""" | |
| # Call parent class first to ensure autograd engine is aware of the change | |
| torch.Tensor.requires_grad.fset(self, value) | |
| # Update the cached value only after successful update | |
| self._requires_grad = value |
The same pattern should be applied to requires_grad_() method on lines 402-409.
| @dtype.setter | ||
| def dtype(self, value: torch.dtype) -> None: | ||
| """Set dtype property""" | ||
| # Update the cached value | ||
| self._dtype = value | ||
| warnings.warn("Dtype of QuantizedTensor has been changed. Ensure this is intended.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dtype property has a setter that allows changing the dtype after creation, with a warning. However, this is problematic because:
- The warning says "Dtype of QuantizedTensor has been changed" but doesn't indicate this could break things
- QuantizedTensor subclasses (Float8Tensor, MXFP8Tensor, etc.) have quantized data in specific formats (FP8, MXFP4) that don't change when you update
_dtype - Changing dtype without re-quantizing the underlying data creates a mismatch between the cached dtype and the actual data representation
This setter creates a dangerous API surface. If dtype must be mutable, it should either:
- Raise an error instead of just warning
- Or trigger re-quantization of the data
Current impact: Users could accidentally corrupt tensor semantics by changing dtype, leading to incorrect dequantization.
| ) | ||
| return out, all_gather_outputs | ||
|
|
||
| @property | ||
| def shape(self): | ||
| """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" | ||
| return self._data.shape if self._data is not None else self._transpose.shape | ||
|
|
||
| @property | ||
| def is_cuda(self): | ||
| """Return whether the tensor is on a CUDA device.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new shape and is_cuda property implementations have a logic issue: they return properties from self._data if it exists, otherwise from self._transpose. However, according to the codebase logic in cublaslt_gemm.cu and quantizer.cpp, on Hopper (non-Blackwell) architectures, FP8 tensors may only have _transpose data for certain GEMM layouts, while on Blackwell they may only have _data.
The properties assume that if _data is None, then _transpose must exist, but there's no fallback handling if both are somehow None (which shouldn't happen but could in edge cases).
Suggested improvement: Add a safety check or document the invariant that at least one must always exist:
| ) | |
| return out, all_gather_outputs | |
| @property | |
| def shape(self): | |
| """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" | |
| return self._data.shape if self._data is not None else self._transpose.shape | |
| @property | |
| def is_cuda(self): | |
| """Return whether the tensor is on a CUDA device.""" | |
| @property | |
| def shape(self): | |
| """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" | |
| if self._data is not None: | |
| return self._data.shape | |
| if self._transpose is not None: | |
| return self._transpose.shape | |
| raise RuntimeError("Float8Tensor has neither _data nor _transpose") | |
| @property | |
| def is_cuda(self): | |
| """Return whether the tensor is on a CUDA device.""" | |
| if self._data is not None: | |
| return self._data.is_cuda | |
| if self._transpose is not None: | |
| return self._transpose.is_cuda | |
| raise RuntimeError("Float8Tensor has neither _data nor _transpose") |
| @property | ||
| def shape(self): | ||
| """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" | ||
| return ( | ||
| self._rowwise_data.shape | ||
| if self._rowwise_data is not None | ||
| else self._columnwise_data.shape | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same issue as Float8Tensor: the shape and is_cuda properties assume that if _rowwise_data is None, then _columnwise_data must exist. However, there's no safety check for the edge case where both could be None.
Suggested improvement: Add safety checks:
| @property | |
| def shape(self): | |
| """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" | |
| return ( | |
| self._rowwise_data.shape | |
| if self._rowwise_data is not None | |
| else self._columnwise_data.shape | |
| ) | |
| @property | |
| def shape(self): | |
| """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" | |
| if self._rowwise_data is not None: | |
| return self._rowwise_data.shape | |
| if self._columnwise_data is not None: | |
| return self._columnwise_data.shape | |
| raise RuntimeError("MXFP8Tensor has neither _rowwise_data nor _columnwise_data") | |
| @property | |
| def is_cuda(self): | |
| """Return whether the tensor is on a CUDA device.""" | |
| if self._rowwise_data is not None: | |
| return self._rowwise_data.is_cuda | |
| if self._columnwise_data is not None: | |
| return self._columnwise_data.is_cuda | |
| raise RuntimeError("MXFP8Tensor has neither _rowwise_data nor _columnwise_data") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 3 comments
| PyObject* kwargs = PyDict_New(); | ||
| PyObject* args = PyTuple_New(0); | ||
| PyDict_SetItemString(kwargs, "data", data_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
|
|
||
| PyObject* result = | ||
| PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), args, kwargs); | ||
|
|
||
| Py_DECREF(kwargs); | ||
| Py_DECREF(args); | ||
| NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing error checking for Python C API calls could lead to crashes. The code doesn't check return values from PyDict_New(), PyTuple_New(), PyDict_SetItemString(), or PyObject_Call() before proceeding. If any of these fail (e.g., due to memory allocation failure), the subsequent operations will access invalid pointers.
Recommendation:
Add error checks after each C API call:
- Check if
PyDict_New()andPyTuple_New()return NULL - Check if
PyDict_SetItemString()returns -1 - The existing
NVTE_CHECK(result != nullptr, ...)is good, but should come beforePy_DECREFcalls to avoid decrementing invalid references
This pattern is repeated throughout this file at lines 176-193, 389-404, 410-426, 688-706, 710-727, 1035-1050, 1054-1070, 1343-1361, and 1365-1383. All instances need the same error checking.
| @property | ||
| def requires_grad(self) -> bool: | ||
| """ | ||
| Return whether or not the tensor requires gradient. | ||
| Attribute access of custom tensors goes through an | ||
| expensive Pyobject lookup. Since requires_grad is set during | ||
| initialization and may be updated, we cache it in a member variable. | ||
| """ | ||
| return self._requires_grad | ||
|
|
||
| @requires_grad.setter | ||
| def requires_grad(self, value: bool) -> None: | ||
| """Set requires_grad property so that autograd engine is aware of the change""" | ||
| # Update the cached value | ||
| self._requires_grad = value | ||
| # Call parent class to ensure autograd engine is aware of the change | ||
| torch.Tensor.requires_grad.fset(self, value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential cache inconsistency if PyTorch modifies requires_grad through internal mechanisms. While the setter and requires_grad_() method properly sync the cache with the parent tensor, there's a risk that PyTorch's autograd engine could modify requires_grad through paths that bypass these methods (e.g., during graph construction or when tensors are involved in certain operations).
Current behavior: The cached value could become stale if PyTorch modifies requires_grad internally.
Mitigation: The implementation correctly updates the parent tensor in the setter and requires_grad_(), which should handle most cases. However, consider documenting this limitation or adding validation in critical paths to ensure cache consistency.
Note: This is a minor concern as most requires_grad modifications go through the standard setter, but worth being aware of for debugging purposes.
| Py_DECREF(kwargs); | ||
| Py_DECREF(args); | ||
|
|
||
| NVTE_CHECK(result != nullptr, "Failed to create Float8Tensor instance"); | ||
| out_py = py::reinterpret_steal<py::object>(result); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Error checking order issue: NVTE_CHECK for result != nullptr happens after Py_DECREF(kwargs) and Py_DECREF(args). While decrefing these is safe even when result is NULL, checking for errors before cleanup is better practice and makes the error path clearer.
If PyObject_Call fails and returns NULL, Python's exception state is set. The NVTE_CHECK should happen immediately after the call, and cleanup should be part of an error handling path. Consider moving the check before Py_DECREF or using proper error handling with early returns.
This pattern affects all PyObject_Call instances in this file.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ormerEngine into cpu_fp8_optimizations Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, 2 comments
| static int num_devices = transformer_engine::cuda::num_devices(); | ||
| static std::vector<int> cache(num_devices, -1); | ||
| static std::vector<std::once_flag> flags(num_devices); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical bug: Potential out-of-bounds access when device_id >= num_devices
Making num_devices static causes it to be evaluated only once at first call. This creates a serious bug if the device ID is not within the range [0, num_devices):
- Problem: CUDA device IDs are not necessarily contiguous. For example, with
CUDA_VISIBLE_DEVICES=0,2, you have 2 visible devices but device IDs are 0 and 2. - Bug scenario: If
num_devicesis initialized to 2, but later code callscurrent_device()which returns 2, thencache[device_id]at line 975 will accesscache[2]which is out of bounds (cache size is 2, valid indices are 0-1). - Impact: This will cause undefined behavior, likely a segmentation fault or memory corruption.
Fix: Either:
- Use
current_device()as the cache key and bounds-check it - Map actual device IDs to cache indices
- Remove the static keyword and re-evaluate num_devices on each call (sacrifices some of the optimization but ensures correctness)
Additional Comments (1)
The original line was: _quantization_list: list[Optional[str]] = [None]Now it's: _quantization_list: list[Optional[str]] = []This means tests will no longer run with Question: Is this intentional? If this is a CPU optimization PR, removing coverage for non-quantized tests seems unrelated to the stated goals. Please clarify if this change is intended or accidental. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
12 files reviewed, 12 comments
|
|
||
| int nvte_is_non_tn_fp8_gemm_supported() { | ||
| int num_devices = transformer_engine::cuda::num_devices(); | ||
| static int num_devices = transformer_engine::cuda::num_devices(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: Potential out-of-bounds access with static num_devices
Making num_devices static creates a serious bug when GPUs are hot-plugged after the first call to this function. The cache and flags vectors are sized based on the initial device count, but device_id from current_device() could exceed num_devices if GPUs are added later.
This will cause out-of-bounds access on lines 968 and 975:
std::call_once(flags[device_id], ...); // OOB if device_id >= num_devices
return cache[device_id]; // OOB if device_id >= num_devicesImpact: Undefined behavior, potential crashes, memory corruption.
Fix: Query device count each time, or add bounds checking:
int num_devices = transformer_engine::cuda::num_devices();
static std::vector<int> cache;
static std::vector<std::once_flag> flags;
static std::mutex resize_mutex;
int device_id = transformer_engine::cuda::current_device();
{
std::lock_guard<std::mutex> lock(resize_mutex);
if (device_id >= cache.size()) {
cache.resize(device_id + 1, -1);
flags.resize(device_id + 1);
}
}
std::call_once(flags[device_id], [&]() { ... });| PyObject* kwargs = PyDict_New(); | ||
| PyObject* args = PyTuple_New(0); | ||
| PyDict_SetItemString(kwargs, "data", data_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
|
|
||
| PyObject* result = | ||
| PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), args, kwargs); | ||
|
|
||
| Py_DECREF(kwargs); | ||
| Py_DECREF(args); | ||
| NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance"); | ||
| out_py = py::reinterpret_steal<py::object>(result); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: Memory leak when exceptions occur
The direct C API usage introduces memory leak vulnerabilities. If any py::cast() call throws an exception or if PyDict_SetItemString() fails after PyDict_New() and PyTuple_New(), the allocated kwargs and args will leak.
Problems:
- No NULL checks after
PyDict_New()andPyTuple_New() - Exception safety:
py::cast()can throw, but cleanup code won't run - Multiple
py::cast()calls create temporary objects that could throw
Example failure scenario:
PyObject* kwargs = PyDict_New(); // Allocated
PyObject* args = PyTuple_New(0); // Allocated
// ... several PyDict_SetItemString calls ...
PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); // Throws!
// Cleanup code never runs -> kwargs and args leakedFix: Use RAII for proper cleanup:
struct PyObjectDeleter {
void operator()(PyObject* obj) { Py_XDECREF(obj); }
};
using PyObjectPtr = std::unique_ptr<PyObject, PyObjectDeleter>;
PyObjectPtr kwargs(PyDict_New());
PyObjectPtr args(PyTuple_New(0));
NVTE_CHECK(kwargs && args, "Failed to allocate Python objects");
// Now safe even if exceptions occurThis issue affects all quantizer types (Float8, Float8CurrentScaling, Float8Block, MXFP8, NVFP4).
| PyObject* kwargs = PyDict_New(); | ||
| PyObject* args = PyTuple_New(0); | ||
| PyDict_SetItemString(kwargs, "shape", py::cast(shape_int64).ptr()); | ||
| PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).ptr()); | ||
| PyDict_SetItemString(kwargs, "data", data_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
| PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); | ||
| PyObject* result = | ||
| PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorPythonClass), args, kwargs); | ||
|
|
||
| Py_DECREF(kwargs); | ||
| Py_DECREF(args); | ||
|
|
||
| NVTE_CHECK(result != nullptr, "Failed to create Float8Tensor instance"); | ||
| out_py = py::reinterpret_steal<py::object>(result); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: Memory leak when exceptions occur (non-internal path)
Same memory leak vulnerability as the internal path above. The kwargs and args will leak if any exception occurs between allocation and cleanup.
Additionally, this block has more py::cast() calls (8 vs 5), increasing the risk of exceptions:
py::cast(shape_int64).ptr()py::cast(GetATenDType(dtype)).ptr()py::cast(this->dtype).ptr()py::cast(device).ptr()
Each of these can throw, and none are wrapped in exception-safe RAII wrappers.
Fix: Apply the same RAII solution as suggested in the previous comment.
| if hasattr(tensor, "_rowwise_data") and tensor._rowwise_data is not None: | ||
| return tensor._rowwise_data.device.index | ||
| if hasattr(tensor, "_columnwise_data") and tensor._columnwise_data is not None: | ||
| return tensor._columnwise_data.device.index | ||
| if hasattr(tensor, "_data") and tensor._data is not None: | ||
| return tensor._data.device.index | ||
| if hasattr(tensor, "_transpose") and tensor._transpose is not None: | ||
| return tensor._transpose.device.index | ||
| if hasattr(tensor, "device"): | ||
| return tensor.device.index |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Performance regression for regular tensors
The reordering of attribute checks optimizes for quantized tensors but creates a performance regression for regular torch.Tensor objects.
Before: Regular tensors check device immediately (line 1) and return
After: Regular tensors check 4 non-existent attributes via hasattr() before checking device
Impact:
- Quantized tensors: ~4 fewer attribute checks ✓
- Regular tensors: ~4 additional attribute checks ✗
Concern: If regular tensors are passed to get_tensor_device() frequently (which seems likely given the function name), this could offset the gains from other optimizations in this PR.
Recommendation: Profile both paths or add an early isinstance check:
def get_tensor_device(tensor: torch.Tensor) -> int:
# Fast path for regular tensors
if type(tensor).__name__ == 'Tensor':
return tensor.device.index if hasattr(tensor, 'device') else torch.cuda.current_device()
# Quantized tensor path
if hasattr(tensor, "_rowwise_data") and tensor._rowwise_data is not None:
return tensor._rowwise_data.device.index
# ... etc| // Use direct C API call bypassing pybind11 overhead | ||
| PyObject* kwargs = PyDict_New(); | ||
| PyDict_SetItemString(kwargs, "data", data_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
|
|
||
| PyObject* args = PyTuple_New(0); | ||
| PyObject* result = | ||
| PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), args, kwargs); | ||
|
|
||
| Py_DECREF(args); | ||
| Py_DECREF(kwargs); | ||
|
|
||
| NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance"); | ||
| out_py = py::reinterpret_steal<py::object>(result); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: Memory leak in Float8CurrentScalingQuantizer (internal path)
Same memory leak vulnerability as Float8Quantizer. The kwargs and args will leak if any exception occurs between allocation and the Py_DECREF cleanup code. See comment on lines 157-171 for detailed explanation and fix.
|
|
||
| NVTE_CHECK(result != nullptr, "Failed to create Float8BlockwiseQTensorStorage instance"); | ||
| ret = py::reinterpret_steal<py::object>(result); | ||
| } else { | ||
| py::handle Float8BlockwiseQTensorClass( | ||
| reinterpret_cast<PyObject*>(Float8BlockwiseQTensorPythonClass)); | ||
| ret = Float8BlockwiseQTensorClass( | ||
| "shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), "rowwise_data"_a = data_rowwise, | ||
| "columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise, | ||
| "columnwise_scale_inv"_a = scale_inv_colwise, "fp8_dtype"_a = this->dtype, | ||
| "quantizer"_a = this->quantizer, "is_2D_scaled"_a = (block_scaling_dim == 2), | ||
| "data_format"_a = data_format); | ||
| // Use direct C API call bypassing pybind11 overhead | ||
| PyObject* kwargs = PyDict_New(); | ||
| PyDict_SetItemString(kwargs, "shape", py::cast(torch_shape).ptr()); | ||
| PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).ptr()); | ||
| PyDict_SetItemString(kwargs, "rowwise_data", py::cast(data_rowwise).ptr()); | ||
| PyDict_SetItemString(kwargs, "columnwise_data", py::cast(data_colwise).ptr()); | ||
| PyDict_SetItemString(kwargs, "rowwise_scale_inv", py::cast(scale_inv_rowwise).ptr()); | ||
| PyDict_SetItemString(kwargs, "columnwise_scale_inv", py::cast(scale_inv_colwise).ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
| PyDict_SetItemString(kwargs, "is_2D_scaled", py::cast(block_scaling_dim == 2).ptr()); | ||
| PyDict_SetItemString(kwargs, "data_format", py::cast(data_format).ptr()); | ||
| PyObject* args = PyTuple_New(0); | ||
| PyObject* result = | ||
| PyObject_Call(reinterpret_cast<PyObject*>(Float8BlockwiseQTensorPythonClass), args, kwargs); | ||
| Py_DECREF(args); | ||
| Py_DECREF(kwargs); | ||
|
|
||
| NVTE_CHECK(result != nullptr, "Failed to create Float8BlockwiseQTensor instance"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: Memory leak in Float8BlockQuantizer (non-internal path)
Same memory leak vulnerability as other quantizers. The kwargs and args will leak if any exception occurs. See comment on lines 176-193 for detailed explanation and fix.
| if (internal) { | ||
| py::handle MXFP8TensorClass(reinterpret_cast<PyObject*>(MXFP8TensorStoragePythonClass)); | ||
| out_py = MXFP8TensorClass("rowwise_data"_a = rowwise_data_py, | ||
| "columnwise_data"_a = columnwise_data_py, | ||
| "rowwise_scale_inv"_a = rowwise_scale_inv_py, | ||
| "columnwise_scale_inv"_a = columnwise_scale_inv_py, | ||
| "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); | ||
| // Use direct C API call bypassing pybind11 overhead | ||
| PyObject* kwargs = PyDict_New(); | ||
| PyObject* args = PyTuple_New(0); | ||
| PyDict_SetItemString(kwargs, "rowwise_data", rowwise_data_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "columnwise_data", columnwise_data_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "rowwise_scale_inv", rowwise_scale_inv_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "columnwise_scale_inv", columnwise_scale_inv_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
|
|
||
| PyObject* result = | ||
| PyObject_Call(reinterpret_cast<PyObject*>(MXFP8TensorStoragePythonClass), args, kwargs); | ||
|
|
||
| Py_DECREF(args); | ||
| Py_DECREF(kwargs); | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: Memory leak in MXFP8Quantizer (internal path)
Same memory leak vulnerability as other quantizers. The kwargs and args will leak if any exception occurs. See comment on lines 157-171 for detailed explanation and fix.
| out_py = py::reinterpret_steal<py::object>(result); | ||
| } else { | ||
| py::handle MXFP8TensorClass(reinterpret_cast<PyObject*>(MXFP8TensorPythonClass)); | ||
| out_py = MXFP8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), | ||
| "rowwise_data"_a = rowwise_data_py, | ||
| "columnwise_data"_a = columnwise_data_py, | ||
| "rowwise_scale_inv"_a = rowwise_scale_inv_py, | ||
| "columnwise_scale_inv"_a = columnwise_scale_inv_py, | ||
| "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); | ||
| // Use direct C API call bypassing pybind11 overhead | ||
| PyObject* kwargs = PyDict_New(); | ||
| PyDict_SetItemString(kwargs, "shape", py::cast(shape_int64).ptr()); | ||
| PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).ptr()); | ||
| PyDict_SetItemString(kwargs, "rowwise_data", rowwise_data_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "columnwise_data", columnwise_data_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "rowwise_scale_inv", rowwise_scale_inv_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "columnwise_scale_inv", columnwise_scale_inv_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
|
|
||
| PyObject* args = PyTuple_New(0); | ||
| PyObject* result = | ||
| PyObject_Call(reinterpret_cast<PyObject*>(MXFP8TensorPythonClass), args, kwargs); | ||
|
|
||
| Py_DECREF(args); | ||
| Py_DECREF(kwargs); | ||
|
|
||
| NVTE_CHECK(result != nullptr, "Failed to create MXFP8Tensor instance"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: Memory leak in MXFP8Quantizer (non-internal path)
Same memory leak vulnerability as other quantizers. The kwargs and args will leak if any exception occurs. See comment on lines 176-193 for detailed explanation and fix.
| if (internal) { | ||
| py::handle NVFP4TensorClass(reinterpret_cast<PyObject*>(NVFP4TensorStoragePythonClass)); | ||
| out_py = NVFP4TensorClass( | ||
| "rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py, | ||
| "rowwise_scale_inv"_a = rowwise_scale_inv_py, | ||
| "columnwise_scale_inv"_a = columnwise_scale_inv_py, "amax_rowwise"_a = amax_rowwise_py, | ||
| "amax_columnwise"_a = amax_columnwise_py, "fp4_dtype"_a = this->dtype, | ||
| "quantizer"_a = this->quantizer); | ||
| // Use direct C API call bypassing pybind11 overhead | ||
| PyObject* kwargs = PyDict_New(); | ||
| PyDict_SetItemString(kwargs, "rowwise_data", rowwise_data_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "columnwise_data", columnwise_data_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "rowwise_scale_inv", rowwise_scale_inv_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "columnwise_scale_inv", columnwise_scale_inv_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "amax_rowwise", amax_rowwise_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "amax_columnwise", amax_columnwise_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp4_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
|
|
||
| PyObject* args = PyTuple_New(0); | ||
| PyObject* result = | ||
| PyObject_Call(reinterpret_cast<PyObject*>(NVFP4TensorStoragePythonClass), args, kwargs); | ||
| Py_DECREF(args); | ||
| Py_DECREF(kwargs); | ||
|
|
||
| NVTE_CHECK(result != nullptr, "Failed to create NVFP4TensorStorage instance"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: Memory leak in NVFP4Quantizer (internal path)
Same memory leak vulnerability as other quantizers. The kwargs and args will leak if any exception occurs. See comment on lines 157-171 for detailed explanation and fix.
| } else { | ||
| py::handle NVFP4TensorClass(reinterpret_cast<PyObject*>(NVFP4TensorPythonClass)); | ||
| out_py = NVFP4TensorClass( | ||
| "shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), | ||
| "rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py, | ||
| "rowwise_scale_inv"_a = rowwise_scale_inv_py, | ||
| "columnwise_scale_inv"_a = columnwise_scale_inv_py, "amax_rowwise"_a = amax_rowwise_py, | ||
| "amax_columnwise"_a = amax_columnwise_py, "fp4_dtype"_a = this->dtype, | ||
| "quantizer"_a = this->quantizer); | ||
| // Use direct C API call bypassing pybind11 overhead | ||
| PyObject* kwargs = PyDict_New(); | ||
| PyDict_SetItemString(kwargs, "shape", py::cast(shape_int64).ptr()); | ||
| PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).ptr()); | ||
| PyDict_SetItemString(kwargs, "rowwise_data", rowwise_data_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "columnwise_data", columnwise_data_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "rowwise_scale_inv", rowwise_scale_inv_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "columnwise_scale_inv", columnwise_scale_inv_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "amax_rowwise", amax_rowwise_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "amax_columnwise", amax_columnwise_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp4_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
|
|
||
| PyObject* args = PyTuple_New(0); | ||
| PyObject* result = | ||
| PyObject_Call(reinterpret_cast<PyObject*>(NVFP4TensorPythonClass), args, kwargs); | ||
|
|
||
| Py_DECREF(args); | ||
| Py_DECREF(kwargs); | ||
|
|
||
| NVTE_CHECK(result != nullptr, "Failed to create NVFP4Tensor instance"); | ||
| out_py = py::reinterpret_steal<py::object>(result); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: Memory leak in NVFP4Quantizer (non-internal path)
Same memory leak vulnerability as other quantizers. The kwargs and args will leak if any exception occurs. See comment on lines 176-193 for detailed explanation and fix.
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, 4 comments
| PyObject* args = PyTuple_New(0); | ||
| PyDict_SetItemString(kwargs, "data", data_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CRITICAL: Memory safety bug - Use-after-free with temporary py::object
The expression py::cast(this->dtype).ptr() creates a temporary py::object that is destroyed immediately after .ptr() returns, but before PyDict_SetItemString can increment its reference count. This results in passing a dangling pointer to PyDict_SetItemString, which is undefined behavior and will likely cause crashes or memory corruption.
The same issue occurs on multiple lines in this file:
- Line 161:
py::cast(this->dtype).ptr() - Line 178:
py::cast(shape_int64).ptr() - Line 179:
py::cast(GetATenDType(dtype)).ptr() - Line 182:
py::cast(this->dtype).ptr() - Line 185:
py::cast(device).ptr()
And similar patterns throughout the file for other quantizer types.
Fix: Store the py::object in a local variable before calling .ptr():
py::object fp8_dtype_py = py::cast(this->dtype);
PyDict_SetItemString(kwargs, "fp8_dtype", fp8_dtype_py.ptr());Or use py::cast(this->dtype).inc_ref().ptr() to explicitly increment the reference count before extracting the pointer.
|
|
||
| ctx.owns_input = saved_inputmat is not inp | ||
| if ctx.fp8 and requires_grad(inp, weight, bias): | ||
| if ctx.fp8 and (inp_requires_grad or weight_requires_grad or bias_requires_grad): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CRITICAL: Logic error - Changed from OR to AND breaks FP8 state management
The original code used the requires_grad() helper function which returns True if ANY of the tensors require grad (OR logic). The new code requires ALL THREE tensors to require grad (AND logic).
Original behavior:
if ctx.fp8 and requires_grad(inp, weight, bias): # True if ANY require gradNew behavior:
if ctx.fp8 and (inp_requires_grad or weight_requires_grad or bias_requires_grad): # Correct!But wait, this line actually looks correct with OR logic. Let me re-check... Actually, this specific line IS correct. The issue mentioned in previous threads was already addressed in this PR. My apologies for the confusion on this specific line - it's actually implemented correctly.
| @property | ||
| def requires_grad(self) -> bool: | ||
| """ | ||
| Return whether or not the tensor requires gradient. | ||
| Attribute access of custom tensors goes through an | ||
| expensive Pyobject lookup. Since requires_grad is set during | ||
| initialization and may be updated, we cache it in a member variable. | ||
| """ | ||
| return self._requires_grad | ||
|
|
||
| @requires_grad.setter | ||
| def requires_grad(self, value: bool) -> None: | ||
| """Set requires_grad property so that autograd engine is aware of the change""" | ||
| # Update the cached value and call parent class method to ensure autograd engine is aware | ||
| self.requires_grad_(value) | ||
|
|
||
| def requires_grad_(self, requires_grad: bool = True) -> QuantizedTensor: | ||
| """Cache requires_grad property and call parent class method""" | ||
| # pylint: disable=missing-function-docstring | ||
| # Update the cached value | ||
| self._requires_grad = requires_grad | ||
| # Call parent class method to ensure autograd engine is aware | ||
| super().requires_grad_(requires_grad) | ||
| return self |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential cache staleness issue with requires_grad property
The cached _requires_grad attribute could become stale if the underlying tensor's requires_grad is modified through mechanisms other than the provided setter. For example:
- PyTorch's autograd engine might directly modify the tensor's
requires_gradflag - Operations like
tensor.requires_grad_()from the parent class might bypass the cached value - Deserialization or other internal PyTorch operations might set
requires_gradwithout going through the custom setter
While the setter correctly updates both the cache and calls super().requires_grad_(value), the getter always returns the cached value, which could diverge from the actual tensor state.
Recommendation: Consider either:
- Synchronizing the cache by checking
super().requires_gradperiodically - Documenting this limitation clearly
- Or accepting this as a known trade-off for performance (which may be acceptable if requires_grad is rarely modified after initialization)
| PyObject* kwargs = PyDict_New(); | ||
| PyObject* args = PyTuple_New(0); | ||
| PyDict_SetItemString(kwargs, "data", data_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
|
|
||
| PyObject* result = | ||
| PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), args, kwargs); | ||
|
|
||
| Py_DECREF(kwargs); | ||
| Py_DECREF(args); | ||
| NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance"); | ||
| out_py = py::reinterpret_steal<py::object>(result); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing error handling for Python C API calls
The code doesn't check for failures in Python C API calls:
PyDict_New()can return NULL on memory allocation failurePyTuple_New(0)can return NULL on memory allocation failurePyDict_SetItemString()can fail and return -1
If PyDict_New() or PyTuple_New() fails, the subsequent calls will crash when dereferencing NULL pointers. If PyDict_SetItemString() fails, the dictionary will be incomplete but the error is silently ignored.
Recommendation: Add error checks:
PyObject* kwargs = PyDict_New();
if (!kwargs) throw std::runtime_error("Failed to create kwargs dict");
PyObject* args = PyTuple_New(0);
if (!args) {
Py_DECREF(kwargs);
throw std::runtime_error("Failed to create args tuple");
}
// Check PyDict_SetItemString return values or use NVTE_CHECKThis pattern repeats throughout all the create_tensor methods in this file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
10 files reviewed, 10 comments
| PyObject* kwargs = PyDict_New(); | ||
| PyObject* args = PyTuple_New(0); | ||
| PyDict_SetItemString(kwargs, "data", data_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
|
|
||
| PyObject* result = | ||
| PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), args, kwargs); | ||
|
|
||
| Py_DECREF(kwargs); | ||
| Py_DECREF(args); | ||
| NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance"); | ||
| out_py = py::reinterpret_steal<py::object>(result); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential reference counting issue with temporary py::object
The pattern PyDict_SetItemString(kwargs, "key", py::cast(value).ptr()) creates temporary py::object instances that may lead to subtle reference counting issues. The temporary object from py::cast() will be destroyed at the end of the statement, potentially before PyDict_SetItemString completes in edge cases.
Safer pattern:
PyObject* kwargs = PyDict_New();
PyObject* args = PyTuple_New(0);
py::object data_obj = data_py;
py::object scale_inv_obj = scale_inv_py;
py::object dtype_obj = py::cast(this->dtype);
py::object transpose_obj = transpose_py;
py::object quantizer_obj = this->quantizer;
PyDict_SetItemString(kwargs, "data", data_obj.ptr());
PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_obj.ptr());
PyDict_SetItemString(kwargs, "fp8_dtype", dtype_obj.ptr());
PyDict_SetItemString(kwargs, "data_transpose", transpose_obj.ptr());
PyDict_SetItemString(kwargs, "quantizer", quantizer_obj.ptr());This ensures all py::object instances remain alive while their pointers are being used. The same issue exists in multiple locations throughout this file (lines 176-185, 389-397, 409-417, and many more).
| instance._requires_grad = requires_grad | ||
| instance._dtype = dtype | ||
| return instance |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing cached attribute initialization in alternate construction paths
The cached attributes _requires_grad and _dtype are initialized here in __new__, but there are several code paths where QuantizedTensor objects can be created that may bypass this initialization:
- Pickling/unpickling: When tensors are pickled and unpickled,
__reduce_ex__and__setstate__may not properly initialize these cached attributes - Tensor operations: Operations like
.clone(),.detach(),torch.cat(), etc. may create new tensor instances - C++ constructors: When tensors are created from C++ code via quantizer
Impact: Accessing .dtype or .requires_grad on such tensors will raise AttributeError: '_dtype' not found.
Recommendation: Add defensive checks in the property getters:
@property
def dtype(self) -> torch.dtype:
if not hasattr(self, '_dtype'):
self._dtype = super().dtype
return self._dtype
@property
def requires_grad(self) -> bool:
if not hasattr(self, '_requires_grad'):
self._requires_grad = super().requires_grad
return self._requires_grad| @requires_grad.setter | ||
| def requires_grad(self, value: bool) -> None: | ||
| """Set requires_grad property so that autograd engine is aware of the change""" | ||
| # Update the cached value and call parent class method to ensure autograd engine is aware | ||
| self.requires_grad_(value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incorrect setter implementation causes infinite recursion
The requires_grad setter calls self.requires_grad_(value) which then calls super().requires_grad_(requires_grad). However, setting requires_grad through assignment (e.g., tensor.requires_grad = True) will invoke this setter, which doesn't update self._requires_grad before calling the parent method.
The correct flow should be:
- Update cached value
self._requires_grad - Call parent
requires_grad_()to notify autograd
But line 399 in requires_grad_() updates the cache AFTER calling super, which means the setter doesn't actually update the cache.
Fix:
@requires_grad.setter
def requires_grad(self, value: bool) -> None:
"""Set requires_grad property so that autograd engine is aware of the change"""
self._requires_grad = value
super().requires_grad_(value)|
|
||
| int nvte_is_non_tn_fp8_gemm_supported() { | ||
| int num_devices = transformer_engine::cuda::num_devices(); | ||
| static int num_devices = transformer_engine::cuda::num_devices(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential race condition in static variable initialization
Making num_devices static improves performance by caching the device count, but this introduces a subtle race condition. While the cache and flags vectors are protected by std::call_once, num_devices itself is initialized on first function call without synchronization.
Scenario:
- Thread A enters function first time, starts calling
transformer_engine::cuda::num_devices() - Thread B enters function while A is still in
num_devices(), sees uninitialized static - Both threads may initialize
num_devices, causing undefined behavior
However, this is likely safe in practice because:
- The function is typically called during module initialization (single-threaded)
- The value doesn't change after initialization
- Most compilers guarantee thread-safe static initialization in C++11
But for correctness, consider wrapping in the existing std::call_once:
static int num_devices = []() {
return transformer_engine::cuda::num_devices();
}();Or initialize it within the std::call_once block if device_id doesn't need num_devices.
| ) | ||
| # pylint: disable=unnecessary-dunder-call | ||
| super(Float8Tensor, type(self)).data.__set__(self, dummy_tensor) | ||
| self.dtype = tensor.dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cached dtype not initialized in all construction paths
This line sets self.dtype = tensor.dtype to initialize the cached _dtype attribute, but this only happens in the _set_data method when copying from another Float8Tensor AND certain conditions are met (size/stride/dtype mismatch).
Missing initialization in:
- Initial tensor construction (first time
Float8Tensoris created) - When sizes/strides match (lines 1002-1006 don't call this setter)
_make_in_reduce_exunpickling path (line 942)- Other construction paths
This will cause AttributeError when accessing .dtype property on tensors created through other paths.
Fix: Ensure _dtype is initialized in Float8Tensor.__new__ or add the defensive check mentioned in the QuantizedTensor comment.
| void init_float8_extension() { | ||
| if (Float8TensorPythonClass) return; | ||
| auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.float8_tensor"); | ||
| Float8QuantizerClass = | ||
| reinterpret_cast<PyTypeObject *>(PyObject_GetAttrString(fp8_module.ptr(), "Float8Quantizer")); | ||
| Float8CurrentScalingQuantizerClass = reinterpret_cast<PyTypeObject *>( | ||
| PyObject_GetAttrString(fp8_module.ptr(), "Float8CurrentScalingQuantizer")); | ||
| Float8TensorPythonClass = | ||
| reinterpret_cast<PyTypeObject *>(PyObject_GetAttrString(fp8_module.ptr(), "Float8Tensor")); | ||
| auto fp8_base_module = | ||
| py::module_::import("transformer_engine.pytorch.tensor.storage.float8_tensor_storage"); | ||
| Float8TensorStoragePythonClass = reinterpret_cast<PyTypeObject *>( | ||
| PyObject_GetAttrString(fp8_base_module.ptr(), "Float8TensorStorage")); | ||
| NVTE_CHECK(Float8TensorPythonClass != nullptr, | ||
| "Internal error: could not initialize pyTorch Float8 extension."); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed null-check guards without proper synchronization guarantee
The original code had null checks (if (Float8TensorPythonClass) return;) to prevent redundant initialization. These were removed in favor of std::call_once, but the individual init_* functions no longer have guards.
Issue: If init_extension() is called multiple times from different threads, std::call_once ensures the lambda runs once. However, if init_float8_extension() is called directly (not through init_extension()), it will re-import modules and reassign global pointers without protection.
Verification needed: Confirm that these functions are NEVER called directly, only through init_extension(). If they can be called directly, the null checks should be restored:
void init_float8_extension() {
if (Float8TensorPythonClass) return; // Keep this guard
auto fp8_module = py::module_::import(...);
...
}| dtype = inp.dtype | ||
| # All checks after this have already been performed once, thus skip | ||
| if self.activation_dtype == inp.dtype: | ||
| if self.activation_dtype == dtype: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Optimization: Cache inp.dtype early to avoid redundant attribute access
Good optimization! By caching inp.dtype early (line 932), the code avoids a second attribute lookup that would have occurred on line 936 in the original code. This is a clean CPU optimization with no functional change.
Note: This is correctly placed BEFORE the early return check, ensuring the cached value is available for the comparison on line 933.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| if hasattr(tensor, "_rowwise_data") and tensor._rowwise_data is not None: | ||
| return tensor._rowwise_data.device.index | ||
| if hasattr(tensor, "_columnwise_data") and tensor._columnwise_data is not None: | ||
| return tensor._columnwise_data.device.index | ||
| if hasattr(tensor, "_data") and tensor._data is not None: | ||
| return tensor._data.device.index | ||
| if hasattr(tensor, "_transpose") and tensor._transpose is not None: | ||
| return tensor._transpose.device.index | ||
| if hasattr(tensor, "device"): | ||
| return tensor.device.index | ||
| return torch.cuda.current_device() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reordering device attribute checks changes behavior for regular tensors
The original code checked tensor.device FIRST (line 78-79 in old code), then checked quantized tensor attributes. The new code reverses this order, checking quantized attributes first.
Impact:
- Performance: For regular
torch.Tensorobjects, this INCREASES overhead by attempting to check for_rowwise_data,_columnwise_data,_data, and_transposeattributes (all of which will fail viahasattr) before finally checkingdevice. - Correctness: For QuantizedTensors that happen to have a
.deviceattribute that differs from their internal data device, this could change which device is returned.
Concern: This optimization seems to hurt performance for the common case (regular tensors) while only helping the less common case (quantized tensors).
Question: What is the relative frequency of this function being called with regular tensors vs quantized tensors? If regular tensors are more common, this change may degrade overall performance.
| @property | ||
| def shape(self): | ||
| """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" | ||
| if self._data is not None: | ||
| return self._data.shape | ||
| else: | ||
| transpose_shape = self._transpose.shape | ||
| return torch.Size(tuple(transpose_shape[1:]) + (transpose_shape[0],)) | ||
|
|
||
| @property | ||
| def is_cuda(self): | ||
| """Return whether the tensor is on a CUDA device.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing error handling when both _data and _transpose are None
The shape property assumes at least one of _data or _transpose is not None, but there's no validation:
@property
def shape(self):
if self._data is not None:
return self._data.shape
else:
transpose_shape = self._transpose.shape # AttributeError if _transpose is None!
return torch.Size(tuple(transpose_shape[1:]) + (transpose_shape[0],))Similarly for is_cuda:
@property
def is_cuda(self):
return self._data.is_cuda if self._data is not None else self._transpose.is_cuda
# AttributeError if both are None!When can this happen?
- During tensor construction before data is initialized
- After certain operations that clear internal data
- In error/edge cases
Fix: Add defensive check:
@property
def shape(self):
if self._data is not None:
return self._data.shape
elif self._transpose is not None:
transpose_shape = self._transpose.shape
return torch.Size(tuple(transpose_shape[1:]) + (transpose_shape[0],))
else:
# Fallback to base class or raise informative error
return super().shape| @property | ||
| def shape(self): | ||
| """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" | ||
| return ( | ||
| self._rowwise_data.shape | ||
| if self._rowwise_data is not None | ||
| else self._columnwise_data.shape | ||
| ) | ||
|
|
||
| @property | ||
| def is_cuda(self): | ||
| """Return whether the tensor is on a CUDA device.""" | ||
| return ( | ||
| self._rowwise_data.is_cuda | ||
| if self._rowwise_data is not None | ||
| else self._columnwise_data.is_cuda |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing error handling when both rowwise_data and columnwise_data are None
Same issue as Float8Tensor - these properties assume at least one data tensor exists:
@property
def shape(self):
return (
self._rowwise_data.shape
if self._rowwise_data is not None
else self._columnwise_data.shape # AttributeError if both are None!
)
@property
def is_cuda(self):
return (
self._rowwise_data.is_cuda
if self._rowwise_data is not None
else self._columnwise_data.is_cuda # AttributeError if both are None!
)Fix: Add validation:
@property
def shape(self):
if self._rowwise_data is not None:
return self._rowwise_data.shape
elif self._columnwise_data is not None:
return self._columnwise_data.shape
else:
return super().shape # or raise informative errorSigned-off-by: Varun Thumbe <vthumbe@nvidia.com>
|
/te-ci L1 pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
11 files reviewed, 2 comments
| def is_cuda(self): | ||
| """Return whether the tensor is on a CUDA device.""" | ||
| return self._data.is_cuda if self._data is not None else self._transpose.is_cuda |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Unlike the shape property above which has a defensive check (raise RuntimeError("Both data and transpose are None")), is_cuda will raise AttributeError if both _data and _transpose are None. Consider adding the same defensive check for consistency:
| def is_cuda(self): | |
| """Return whether the tensor is on a CUDA device.""" | |
| return self._data.is_cuda if self._data is not None else self._transpose.is_cuda | |
| @property | |
| def is_cuda(self): | |
| """Return whether the tensor is on a CUDA device.""" | |
| if self._data is not None: | |
| return self._data.is_cuda | |
| elif self._transpose is not None: | |
| return self._transpose.is_cuda | |
| else: | |
| raise RuntimeError("Both data and transpose are None") |
| def is_cuda(self): | ||
| """Return whether the tensor is on a CUDA device.""" | ||
| return ( | ||
| self._rowwise_data.is_cuda | ||
| if self._rowwise_data is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Same consistency issue as Float8Tensor - is_cuda will raise AttributeError if both _rowwise_data and _columnwise_data are None, while there's no defensive check like shape should have. Consider adding error handling for this edge case.
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
10 files reviewed, 2 comments
| def is_cuda(self): | ||
| """Return whether the tensor is on a CUDA device.""" | ||
| return self._data.is_cuda if self._data is not None else self._transpose.is_cuda |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Missing defensive check for is_cuda property. The shape property above correctly raises RuntimeError when both _data and _transpose are None, but is_cuda will raise AttributeError in the same scenario.
| def is_cuda(self): | |
| """Return whether the tensor is on a CUDA device.""" | |
| return self._data.is_cuda if self._data is not None else self._transpose.is_cuda | |
| @property | |
| def is_cuda(self): | |
| """Return whether the tensor is on a CUDA device.""" | |
| if self._data is not None: | |
| return self._data.is_cuda | |
| elif self._transpose is not None: | |
| return self._transpose.is_cuda | |
| raise RuntimeError("Both data and transpose are None") |
| @property | ||
| def shape(self): | ||
| """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" | ||
| return ( | ||
| self._rowwise_data.shape | ||
| if self._rowwise_data is not None | ||
| else self._columnwise_data.shape | ||
| ) | ||
|
|
||
| @property | ||
| def is_cuda(self): | ||
| """Return whether the tensor is on a CUDA device.""" | ||
| return ( | ||
| self._rowwise_data.is_cuda | ||
| if self._rowwise_data is not None | ||
| else self._columnwise_data.is_cuda | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Both shape and is_cuda properties lack defensive checks for when both _rowwise_data and _columnwise_data are None. Consider adding error handling for consistency with Float8Tensor:
| @property | |
| def shape(self): | |
| """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" | |
| return ( | |
| self._rowwise_data.shape | |
| if self._rowwise_data is not None | |
| else self._columnwise_data.shape | |
| ) | |
| @property | |
| def is_cuda(self): | |
| """Return whether the tensor is on a CUDA device.""" | |
| return ( | |
| self._rowwise_data.is_cuda | |
| if self._rowwise_data is not None | |
| else self._columnwise_data.is_cuda | |
| ) | |
| @property | |
| def shape(self): | |
| """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" | |
| if self._rowwise_data is not None: | |
| return self._rowwise_data.shape | |
| elif self._columnwise_data is not None: | |
| return self._columnwise_data.shape | |
| raise RuntimeError("Both rowwise_data and columnwise_data are None") | |
| @property | |
| def is_cuda(self): | |
| """Return whether the tensor is on a CUDA device.""" | |
| if self._rowwise_data is not None: | |
| return self._rowwise_data.is_cuda | |
| elif self._columnwise_data is not None: | |
| return self._columnwise_data.is_cuda | |
| raise RuntimeError("Both rowwise_data and columnwise_data are None") |
…nd transpose are None Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
11 files reviewed, 1 comment
| @dtype.setter | ||
| def dtype(self, value: torch.dtype) -> None: | ||
| """Set dtype property""" | ||
| # Update the cached value | ||
| self._dtype = value | ||
| warnings.warn("Dtype of QuantizedTensor has been changed. Ensure this is intended.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: The dtype setter unconditionally emits a warning when dtype is set. However, this setter is called in normal operations like Float8Tensor._set_data (line 1004 of float8_tensor.py) and MXFP8Tensor._set_data when copying tensor attributes. This will cause spurious warnings during legitimate tensor operations.
Consider only warning when the value actually changes:
| @dtype.setter | |
| def dtype(self, value: torch.dtype) -> None: | |
| """Set dtype property""" | |
| # Update the cached value | |
| self._dtype = value | |
| warnings.warn("Dtype of QuantizedTensor has been changed. Ensure this is intended.") | |
| @dtype.setter | |
| def dtype(self, value: torch.dtype) -> None: | |
| """Set dtype property""" | |
| # Update the cached value | |
| if hasattr(self, "_dtype") and self._dtype != value: | |
| warnings.warn("Dtype of QuantizedTensor has been changed. Ensure this is intended.") | |
| self._dtype = value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
11 files reviewed, 1 comment
| @dtype.setter | ||
| def dtype(self, value: torch.dtype) -> None: | ||
| """Set dtype property""" | ||
| # Update the cached value | ||
| self._dtype = value | ||
| warnings.warn("Dtype of QuantizedTensor has been changed. Ensure this is intended.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: The warning is emitted unconditionally, but this setter is called in Float8Tensor._set_data, MXFP8Tensor._set_data, and NVFP4Tensor._set_data during normal copy operations when any tensor attribute differs (size, stride, device, etc.), not just dtype. This will cause spurious warnings.
Consider only warning when the value actually changes:
| @dtype.setter | |
| def dtype(self, value: torch.dtype) -> None: | |
| """Set dtype property""" | |
| # Update the cached value | |
| self._dtype = value | |
| warnings.warn("Dtype of QuantizedTensor has been changed. Ensure this is intended.") | |
| @dtype.setter | |
| def dtype(self, value: torch.dtype) -> None: | |
| """Set dtype property""" | |
| # Update the cached value | |
| if hasattr(self, "_dtype") and self._dtype != value: | |
| warnings.warn("Dtype of QuantizedTensor has been changed. Ensure this is intended.") | |
| self._dtype = value |
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
11 files reviewed, 1 comment
| @dtype.setter | ||
| def dtype(self, value: torch.dtype) -> None: | ||
| """Set dtype property""" | ||
| # Update the cached value | ||
| self._dtype = value | ||
| warnings.warn("Dtype of QuantizedTensor has been changed. Ensure this is intended.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: The dtype setter unconditionally emits a warning, but this setter is called in normal _set_data operations (e.g., float8_tensor.py:1004, mxfp8_tensor.py:789, nvfp4_tensor.py:693) during legitimate tensor copy operations. This will cause spurious warnings every time tensor data is copied.
Consider only warning when the value actually changes:
| @dtype.setter | |
| def dtype(self, value: torch.dtype) -> None: | |
| """Set dtype property""" | |
| # Update the cached value | |
| self._dtype = value | |
| warnings.warn("Dtype of QuantizedTensor has been changed. Ensure this is intended.") | |
| @dtype.setter | |
| def dtype(self, value: torch.dtype) -> None: | |
| """Set dtype property""" | |
| # Update the cached value | |
| if hasattr(self, "_dtype") and self._dtype != value: | |
| warnings.warn("Dtype of QuantizedTensor has been changed. Ensure this is intended.") | |
| self._dtype = value |
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: