From e8e63b1606ff73a3398099bea89c1ac48375beaf Mon Sep 17 00:00:00 2001 From: sudhu2k Date: Tue, 17 Mar 2026 01:23:43 +0000 Subject: [PATCH 1/9] Initial commit --- tests/pytorch/distributed/run_fsdp2_fp8_model.py | 16 ++++++++++++++-- tests/pytorch/distributed/run_fsdp2_model.py | 9 +++++++++ transformer_engine/pytorch/module/base.py | 1 + .../pytorch/module/grouped_linear.py | 4 ++++ transformer_engine/pytorch/module/linear.py | 2 ++ 5 files changed, 30 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/distributed/run_fsdp2_fp8_model.py b/tests/pytorch/distributed/run_fsdp2_fp8_model.py index 5f73e476d..d386b0f5f 100644 --- a/tests/pytorch/distributed/run_fsdp2_fp8_model.py +++ b/tests/pytorch/distributed/run_fsdp2_fp8_model.py @@ -89,6 +89,9 @@ def _parse_args(argv=None, namespace=None): parser.add_argument( "--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8." ) + parser.add_argument( + "--fp8-autocast", action="store_true", default=False, help="Enable FP8 autocast." + ) parser.add_argument( "--iter", type=int, default=10, help="Number of iterations for forward pass" ) @@ -215,7 +218,10 @@ def _train(args): else: model = DDP(model, device_ids=[LOCAL_RANK]) - optimizer = te.optimizers.FusedAdam(model.parameters(), lr=1e-3) + if args.fp8_init: + optimizer = te.optimizers.FusedAdam(model.parameters(), lr=1e-3, master_weights=True, use_decoupled_grad=True) + else: + optimizer = te.optimizers.FusedAdam(model.parameters(), lr=1e-3) input_path = Path("shared_input.pt") if input_path.exists(): @@ -253,7 +259,10 @@ def _train(args): # Zero the parameter gradients optimizer.zero_grad() - with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + if args.fp8_autocast: + with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + output = model(input_data) + else: output = model(input_data) target = torch.randn(args.batch_size, args.output_size).to(device) loss = F.mse_loss(output, target) @@ -286,6 +295,9 @@ def _train(args): torch.save(out_tensors, args.gradients_save_file) if args.memory_profile: + with open('memory_summary.txt', 'w') as f: + f.write(torch.cuda.memory_summary(device=None, abbreviated=False)) + snapshot = torch.cuda.memory._snapshot() import pickle with open('memory_snapshot.pickle', 'wb') as f: diff --git a/tests/pytorch/distributed/run_fsdp2_model.py b/tests/pytorch/distributed/run_fsdp2_model.py index 3b9264279..cbd073ee4 100644 --- a/tests/pytorch/distributed/run_fsdp2_model.py +++ b/tests/pytorch/distributed/run_fsdp2_model.py @@ -22,17 +22,26 @@ import torch import torch.distributed as dist from torch.distributed.tensor import DTensor +from torch.distributed.tensor import DTensor import torch.nn.functional as F from torch import nn, optim from torch.distributed import DeviceMesh from torch.distributed._composable.fsdp import fully_shard from torch.distributed.device_mesh import init_device_mesh from transformer_engine.pytorch import QuantizedTensor +from transformer_engine.pytorch import QuantizedTensor from contextlib import nullcontext LOCAL_RANK = None +def dist_print(msg): + if LOCAL_RANK == 0: + print(msg) + +LOCAL_RANK = None + + def dist_print(msg): if LOCAL_RANK == 0: print(msg) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index e74fd9d17..e3ef6c0e3 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -21,6 +21,7 @@ import torch.nn.functional as F from torch.distributed.tensor import DTensor from torch.utils.cpp_extension import IS_HIP_EXTENSION +from torch.distributed.tensor import DTensor import transformer_engine_torch as tex from transformer_engine.common.recipe import Recipe diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index fa2ae5f85..868dbe453 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -119,6 +119,10 @@ def forward( and not in_fp8_activation_recompute_phase() ) # No need to set the quantizer states if weight is already quantized + if weight_quantizers[0] is not None and not isinstance( + weights[0], QuantizedTensorStorage + ): + # No need to set the quantizer states if weight is already quantized if weight_quantizers[0] is not None and not isinstance( weights[0], QuantizedTensorStorage ): diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index be147166a..859e8d5fa 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -254,6 +254,8 @@ def forward( if fp8 or debug: # Configure quantizer # No need to set the quantizer states if weight is already quantized + if weight_quantizer is not None and not isinstance(weight, QuantizedTensor): + # No need to set the quantizer states if weight is already quantized if weight_quantizer is not None and not isinstance(weight, QuantizedTensor): columnwise_usage = is_grad_enabled and inp.requires_grad and keep_fp8_weight_transpose_cache if not columnwise_usage and keep_fp8_weight_transpose_cache: From c3e33e34f20921899de2d9b8368d62a7d0ee09ba Mon Sep 17 00:00:00 2001 From: sugovind Date: Sat, 31 Jan 2026 01:26:02 +0000 Subject: [PATCH 2/9] - Updated test functions to include new parameters for FP8 autocasting and refined test case generation for various configurations. - Cleaned up unused variables and improved code readability in the FSDPAGTensor class by removing unnecessary parameters. --- .../distributed/run_fsdp2_fp8_model.py | 44 ++++----- .../distributed/test_torch_fsdp2_fp8.py | 28 ++++-- transformer_engine/pytorch/module/base.py | 4 +- .../pytorch/tensor/fsdp2_allgather_tensor.py | 94 ++++++++++++------- 4 files changed, 107 insertions(+), 63 deletions(-) diff --git a/tests/pytorch/distributed/run_fsdp2_fp8_model.py b/tests/pytorch/distributed/run_fsdp2_fp8_model.py index d386b0f5f..936107f61 100644 --- a/tests/pytorch/distributed/run_fsdp2_fp8_model.py +++ b/tests/pytorch/distributed/run_fsdp2_fp8_model.py @@ -172,14 +172,35 @@ def _train(args): if args.memory_profile: torch.cuda.memory._record_memory_history(enabled='all', context='all', stacks='all') + + prof = None + if ( + args.profile + and torch.distributed.get_rank() in args.profile_ranks + ): + prof = torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], + schedule=torch.profiler.schedule( + wait=max(args.profile_step_start - 1, 0), + warmup=1 if args.profile_step_start > 0 else 0, + active=args.profile_step_end - args.profile_step_start, + repeat=1, + ), + on_trace_ready=torch.profiler.tensorboard_trace_handler(args.tensorboard_dir), + record_shapes=True, + profile_memory=True, + with_stack=True, + ) + prof.start() + if args.fp8_init: # Build the model with the specified context - with quantized_model_init(enabled=True): + with quantized_model_init(enabled=True, recipe=fp8_recipe): model = SimpleNet(args.input_size, args.hidden_size, args.output_size, use_fsdp2=args.use_fsdp2) else: model = SimpleNet(args.input_size, args.hidden_size, args.output_size, use_fsdp2=args.use_fsdp2) # Move the model to the correct device - if not args.memory_profile: + if not args.memory_profile and not args.profile: model.load_state_dict(torch.load('fsdp_model.pth')) model.to(device) @@ -232,25 +253,6 @@ def _train(args): print("Generated and saved shared input tensor.") out_tensors = [] - prof = None - if ( - args.profile - and torch.distributed.get_rank() in args.profile_ranks - ): - prof = torch.profiler.profile( - activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], - schedule=torch.profiler.schedule( - wait=max(args.profile_step_start - 1, 0), - warmup=1 if args.profile_step_start > 0 else 0, - active=args.profile_step_end - args.profile_step_start, - repeat=1, - ), - on_trace_ready=torch.profiler.tensorboard_trace_handler(args.tensorboard_dir), - record_shapes=True, - profile_memory=True, - with_stack=True, - ) - prof.start() for iteration in range(args.iter): if LOCAL_RANK == 0: print(f"Starting iteration...{iteration}") diff --git a/tests/pytorch/distributed/test_torch_fsdp2_fp8.py b/tests/pytorch/distributed/test_torch_fsdp2_fp8.py index 3eaa449c2..d8de35c7a 100644 --- a/tests/pytorch/distributed/test_torch_fsdp2_fp8.py +++ b/tests/pytorch/distributed/test_torch_fsdp2_fp8.py @@ -39,7 +39,7 @@ def assertEqual( ) raise AssertionError(msg) -def _run_test(fp_init, recipe): +def _run_test(fp_init, fp8_autocast, recipe): test_dir = Path(__file__).parent.resolve() fsdp_script = test_dir / "run_fsdp2_fp8_model.py" @@ -47,7 +47,10 @@ def _run_test(fp_init, recipe): if fp_init: test_cmd += ["--fp8-init"] - test_cmd += ["--recipe", recipe] + if fp8_autocast: + test_cmd += ["--fp8-autocast"] + if fp8_autocast or fp_init: + test_cmd += ["--recipe", recipe] subprocess.run(test_cmd + ['--use-fsdp2','--gradients-save-file', 'all_iters_fsdp2.pt'], env=os.environ, check=True) subprocess.run(test_cmd + ['--gradients-save-file', 'all_iters_dp.pt'], env=os.environ, check=True) @@ -70,13 +73,24 @@ def cleanup_artifacts(): if os.path.exists(fname): os.remove(fname) +# Define test cases explicitly +test_cases = [] +# All FP8 enabled cases (all recipes) +for fp8_init in [True, False]: + for fp8_autocast in [True, False]: + if fp8_init or fp8_autocast: + for recipe in ["delayed", "current", "mxfp8"]: + test_cases.append((fp8_init, fp8_autocast, recipe)) +# FP8 disabled case (only once) +test_cases.append((False, False, "delayed")) + + @pytest.mark.skipif(NUM_PROCS < 4, reason="Requires 4+ GPUs") @pytest.mark.skipif(NUM_PROCS % 2 != 0, reason="Requires even number of GPUs") @pytest.mark.skipif(not torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+") -@pytest.mark.parametrize("fp8_init", ([False])) -@pytest.mark.parametrize("recipe", (["delayed", "current", "mxfp8"])) +@pytest.mark.parametrize("fp8_init,fp8_autocast,recipe", test_cases) @pytest.mark.usefixtures("cleanup_artifacts") -def test_distributed(fp8_init, recipe): +def test_distributed(fp8_init, fp8_autocast, recipe): batch_size = 2048 input_size = 2048 @@ -101,7 +115,7 @@ def test_distributed(fp8_init, recipe): if recipe == "mxfp8" and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - _run_test(fp8_init, recipe) + _run_test(fp8_init, fp8_autocast, recipe) def test_dummy() -> None: @@ -110,4 +124,4 @@ def test_dummy() -> None: pytest returns exit code 5 if all tests are skipped. """ - pass + pass \ No newline at end of file diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index e3ef6c0e3..1ed16a02c 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1336,9 +1336,7 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: self.keep_fp8_weight_transpose_cache = False param = FSDPAGTensor( param, - module=self, - fp8_meta_index=fp8_meta_index, - keep_fp8_weight_transpose_cache=self.keep_fp8_weight_transpose_cache + fp8_meta_index=fp8_meta_index, ) # Redo parameter wrap in case we broke it above diff --git a/transformer_engine/pytorch/tensor/fsdp2_allgather_tensor.py b/transformer_engine/pytorch/tensor/fsdp2_allgather_tensor.py index 763fd4419..492a259ca 100644 --- a/transformer_engine/pytorch/tensor/fsdp2_allgather_tensor.py +++ b/transformer_engine/pytorch/tensor/fsdp2_allgather_tensor.py @@ -1,7 +1,7 @@ #!/usr/bin/python3 # Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # See LICENSE for license information. - +from __future__ import annotations from typing import Any, Optional, Tuple import torch import torch.nn as nn @@ -44,18 +44,12 @@ def __init__( self, tensor: torch.Tensor, *, - module: nn.Module, fp8_meta_index: str, - keep_fp8_weight_transpose_cache: bool, ): #The underlying tensor self._data = tensor - # Where quantizers are present - self._module = module # Which quantizer to use within module.quantizers["scaling_fwd"][idx] self._fp8_meta_index = fp8_meta_index - # Disable or enable transpose cache for fp8 weights - self._keep_fp8_weight_transpose_cache = keep_fp8_weight_transpose_cache @property def data(self) -> torch.Tensor: @@ -65,7 +59,6 @@ def __repr__(self): return ( f"FSDPAGTensor(" f"elem={self._data}, " - f"module={self._module.__class__.__name__}, " f"fp8_meta_index={self._fp8_meta_index})" ) @@ -75,18 +68,16 @@ def __tensor_flatten__(self): Return (names_of_inner_tensors, flatten_spec_metadata). """ # We only carry the one inner tensor. - # We store (module, fp8_meta_index, keep_fp8_weight_transpose_cache) as metadata to reconstruct. - return ["_data"], (self._module, self._fp8_meta_index, self._keep_fp8_weight_transpose_cache) + # We store fp8_meta_index as metadata to reconstruct. + return ["_data"], (self._fp8_meta_index) @staticmethod def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): - module, fp8_meta_index, keep_fp8_weight_transpose_cache = flatten_spec + fp8_meta_index = flatten_spec return FSDPAGTensor( inner_tensors["_data"], - module=module, fp8_meta_index=fp8_meta_index, - keep_fp8_weight_transpose_cache=keep_fp8_weight_transpose_cache ) @classmethod @@ -99,16 +90,16 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): t = args[0] assert isinstance(t, cls), f"Unexpected detach input type: {type(t)}" detached = t._data.detach() - return cls(detached, module=t._module, fp8_meta_index=t._fp8_meta_index, keep_fp8_weight_transpose_cache=t._keep_fp8_weight_transpose_cache) + return cls(detached, fp8_meta_index=t._fp8_meta_index) # Unwrap only our subclass; capture shared metadata for rewrapping - meta: Optional[tuple[nn.Module, str, bool]] = None + meta: Optional[str] = None def unwrap(x): nonlocal meta if isinstance(x, cls): if meta is None: - meta = (x._module, x._fp8_meta_index, x._keep_fp8_weight_transpose_cache) + meta = x._fp8_meta_index return x._data return x @@ -123,31 +114,48 @@ def unwrap(x): def rewrap(x): if isinstance(x, torch.Tensor): - mod, idx, keep_transpose = meta - return cls(x, module=mod, fp8_meta_index=idx, keep_fp8_weight_transpose_cache=keep_transpose) + return cls(x, fp8_meta_index=meta) return x out = pytree.tree_map_only(torch.Tensor, rewrap, out) return out # Must return (list_of_tensors_to_all_gather, user_metadata) - def fsdp_pre_all_gather(self, mesh): + def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, mp_policy): + """Functions FSDP2 calls before all-gather of the + weights for both forward and backward passes. + Args: + mesh (torch.distributed.DeviceMesh): DeviceMesh used by FSDP2 + to shard the weights. + orig_size (torch.Size): Original size of the weight tensor. + contiguous_orig_stride (Tuple[int]): Original stride of the weight tensor. + module (FSDPModule): FSDP module. FSDP wrapped module wrapped using fully_shard + that contains this tensor. + mp_policy (MixedPrecisionPolicy): Mixed precision policy used by FSDP2. + + Returns: + sharded_tensors: Tuple[torch.Tensor, ...]: Tuple of tensors + that need to be all-gathered. + metadata: Tuple[Any]: Metadata needed for reconstructing the + tensor after all-gather. + """ + # pylint: disable=unused-argument # If metadata isn't initialized yet, we can't access the quantizers - if not self._module.fp8: - module_class_name = self._module.__class__.__name__ + if not module.fp8: + module_class_name = module.__class__.__name__ if "LayerNormMLP" in module_class_name: num_gemms = 2 else: # Linear, LayerNormLinear, etc. num_gemms = 1 - self._module.init_fp8_metadata(num_gemms=num_gemms) - if not self._module.fp8: - return (self._data,), (self._data.requires_grad,) + module.init_fp8_metadata(num_gemms=num_gemms) + if not module.fp8: + return (self._data,), (self._data.requires_grad, module) # Use the actual data base = self._data # Access the quantizer using fp8_meta_index - quantizer = self._module.quantizers["scaling_fwd"][self._fp8_meta_index] - if not isinstance(quantizer, MXFP8Quantizer) and not self._keep_fp8_weight_transpose_cache: + quantizer = module.quantizers["scaling_fwd"][self._fp8_meta_index] + if not isinstance(quantizer, MXFP8Quantizer): quantizer.set_usage(columnwise=False) if isinstance(quantizer, Float8CurrentScalingQuantizer): quantizer.with_amax_reduction = True @@ -157,8 +165,8 @@ def fsdp_pre_all_gather(self, mesh): rowwise_scale_inv = sharded_fp8_tensor._rowwise_scale_inv if quantizer.rowwise_usage else torch.empty(0, dtype=torch.uint8, device=base.device) columnwise_data = sharded_fp8_tensor._columnwise_data if quantizer.columnwise_usage else torch.empty(0, dtype=torch.uint8, device=base.device) columnwise_scale_inv = sharded_fp8_tensor._columnwise_scale_inv if quantizer.columnwise_usage else torch.empty(0, dtype=torch.uint8, device=base.device) - return (rowwise_data, rowwise_scale_inv, columnwise_data, columnwise_scale_inv, ), (base.requires_grad,) - return (sharded_fp8_tensor._data,), (base.requires_grad,) + return (rowwise_data, rowwise_scale_inv, columnwise_data, columnwise_scale_inv, ), (base.requires_grad, module) + return (sharded_fp8_tensor._data,), (base.requires_grad, module) def fsdp_post_all_gather( self, @@ -168,14 +176,26 @@ def fsdp_post_all_gather( *, out: Optional[torch.Tensor] = None, ): - (requires_grad, ) = metadata - if not self._module.fp8: + """Functions FSDP2 calls after all-gather of the + weights for both forward and backward passes. + Args: + all_gather_outputs (Tuple[torch.Tensor, ...]): sharded_tensors sent out in fsdp_pre_all_gather from each rank + are all-gathered and received here as a tuple. + metadata (Any): metadata sent out in fsdp_pre_all_gather used for reconstructing the tensor. + param_dtype (torch.dtype): high precision dtype of the tensor. + out (Optional[torch.Tensor], optional): Preallocated output tensor. Defaults to None. + + Returns: + Tuple[Tensor, Tuple[torch.Tensor, ...]]: Allgathered tensor and tuple of internal tensors. + """ + (requires_grad, module) = metadata + if not module.fp8: (data,) = all_gather_outputs return data, all_gather_outputs # Retrieve the same quantizer you used in pre_all_gather - quantizer = self._module.quantizers["scaling_fwd"][self._fp8_meta_index] + quantizer = module.quantizers["scaling_fwd"][self._fp8_meta_index] shape = None - if not isinstance(quantizer, MXFP8Quantizer) and not self._keep_fp8_weight_transpose_cache: + if not isinstance(quantizer, MXFP8Quantizer): quantizer.set_usage(columnwise=False) if isinstance(quantizer, MXFP8Quantizer): (rowwise_data, rowwise_scale_inv, columnwise_data, columnwise_scale_inv,) = all_gather_outputs @@ -197,3 +217,13 @@ def fsdp_post_all_gather( out._scale_inv = 1 / quantizer.scale out._data = data return out, all_gather_outputs + + def __reduce_ex__(self, protocol: int) -> tuple: + """Custom pickling - unwrap to inner tensor + + During checkpointing, save just the underlying high-precision tensor. + FSDPAGTensor is a transient wrapper for FSDP2 communication - when the + model is loaded and FSDP2 is re-initialized, parameters get wrapped again. + """ + # Delegate to the inner tensor's serialization + return self._data.__reduce_ex__(protocol) \ No newline at end of file From db36143bf48b79bd20e351c9cba743e85b1ce433 Mon Sep 17 00:00:00 2001 From: sudhu2k Date: Tue, 17 Mar 2026 04:59:58 +0000 Subject: [PATCH 3/9] Refactor quantizer state checks and optimize tensor initialization in FusedAdam. Added debug print for DTensor in MultiTensorApply. --- transformer_engine/pytorch/module/grouped_linear.py | 4 ---- transformer_engine/pytorch/module/linear.py | 2 -- transformer_engine/pytorch/optimizers/fused_adam.py | 4 ++-- transformer_engine/pytorch/optimizers/multi_tensor_apply.py | 1 + 4 files changed, 3 insertions(+), 8 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 868dbe453..fa2ae5f85 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -119,10 +119,6 @@ def forward( and not in_fp8_activation_recompute_phase() ) # No need to set the quantizer states if weight is already quantized - if weight_quantizers[0] is not None and not isinstance( - weights[0], QuantizedTensorStorage - ): - # No need to set the quantizer states if weight is already quantized if weight_quantizers[0] is not None and not isinstance( weights[0], QuantizedTensorStorage ): diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 859e8d5fa..be147166a 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -254,8 +254,6 @@ def forward( if fp8 or debug: # Configure quantizer # No need to set the quantizer states if weight is already quantized - if weight_quantizer is not None and not isinstance(weight, QuantizedTensor): - # No need to set the quantizer states if weight is already quantized if weight_quantizer is not None and not isinstance(weight, QuantizedTensor): columnwise_usage = is_grad_enabled and inp.requires_grad and keep_fp8_weight_transpose_cache if not columnwise_usage and keep_fp8_weight_transpose_cache: diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index 3ac2be28f..b608e1db5 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -376,9 +376,9 @@ def _initialize_state( """ dtype = self.name_to_dtype_map[state_name] if store_param_remainders: - data = torch.zeros(param.shape, dtype=torch.int16, device=param.device) + data = torch.zeros_like(param, dtype=torch.int16) else: - data = torch.empty(param.shape, dtype=dtype, device=param.device) + data = torch.empty_like(param, dtype=dtype) if zero_buffer: data.zero_() diff --git a/transformer_engine/pytorch/optimizers/multi_tensor_apply.py b/transformer_engine/pytorch/optimizers/multi_tensor_apply.py index c791f0c4e..3dc2bbe00 100644 --- a/transformer_engine/pytorch/optimizers/multi_tensor_apply.py +++ b/transformer_engine/pytorch/optimizers/multi_tensor_apply.py @@ -18,6 +18,7 @@ def __call__(self, op, noop_flag_buffer, tensor_lists, *args): for i, ts in enumerate(tensor_lists): for j, t in enumerate(ts): if isinstance(t, DTensor): + print(f"DTensor found: {t}") tensor_lists[i][j] = t._local_tensor.data if IS_HIP_EXTENSION else t._local_tensor return op(self.chunk_size, noop_flag_buffer, tensor_lists, *args) From 13b40073c6b7846e847739212eef726b3622a6fa Mon Sep 17 00:00:00 2001 From: sudhu2k Date: Tue, 17 Mar 2026 19:54:32 +0000 Subject: [PATCH 4/9] Refactor assertion function in FP8 tests to use relative and absolute tolerances for tensor comparisons. Updated test logic to accommodate new tolerance parameters for improved accuracy in floating-point comparisons. --- .../distributed/test_torch_fsdp2_fp8.py | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/tests/pytorch/distributed/test_torch_fsdp2_fp8.py b/tests/pytorch/distributed/test_torch_fsdp2_fp8.py index d8de35c7a..098a1e9a5 100644 --- a/tests/pytorch/distributed/test_torch_fsdp2_fp8.py +++ b/tests/pytorch/distributed/test_torch_fsdp2_fp8.py @@ -17,15 +17,20 @@ NUM_PROCS: int = torch.cuda.device_count() -def assertEqual( - l1: List[torch.Tensor], l2: List[torch.Tensor]) -> bool: - """Ensures two lists are exactly equal.""" +def assert_allclose( + l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float, rtol: float = None +) -> bool: + """Ensures two lists are equal.""" assert len(l1) == len(l2), "Unequal number of outputs." for i, (t1, t2) in enumerate(zip(l1, l2)): - result = torch.allclose(t1, t2, atol=0, rtol=0) + tols = dict(atol=atol) + if rtol is not None: + tols["rtol"] = rtol + result = torch.allclose(t1, t2, **tols) if not result: diff = torch.abs(t1 - t2) - exceed_mask = diff > 0 + tol = atol + (rtol * torch.abs(t2)) + exceed_mask = diff > tol if exceed_mask.any(): indices = torch.nonzero(exceed_mask, as_tuple=True) max_diff = diff[exceed_mask].max() @@ -58,11 +63,16 @@ def _run_test(fp_init, fp8_autocast, recipe): # Load outputs output_fsdp = torch.load("all_iters_fsdp2.pt", map_location="cpu") output_dp = torch.load("all_iters_dp.pt", map_location="cpu") + atol = 0 + rtol = 0 + if fp_init: + atol = 1e-6 + rtol = 5e-5 for idx, (te_output_no_cache, te_output_cache) in enumerate(zip(output_fsdp, output_dp)): print(f"Comparing FSDP {te_output_no_cache[0]}, DDP {te_output_cache[0]} at index {idx}...") - assertEqual(te_output_no_cache[1], te_output_cache[1]) # expects exact match + assert_allclose(te_output_no_cache[1], te_output_cache[1], atol=atol, rtol=rtol) # expects exact match print(f"Tensor at index {idx} passed comparison.") From d91241f4d144141971bf25ad463ac3eeb15105fb Mon Sep 17 00:00:00 2001 From: sudhu2k Date: Tue, 17 Mar 2026 21:10:09 +0000 Subject: [PATCH 5/9] Update test tolerances for FP8 configurations to account for potential differences in gradient calculations. Clean up unused debug print statements in MultiTensorApply and ensure proper newline at the end of the FSDPAGTensor serialization method. --- tests/pytorch/distributed/run_fsdp2_fp8_model.py | 2 +- tests/pytorch/distributed/run_fsdp2_model.py | 9 --------- tests/pytorch/distributed/test_torch_fsdp2_fp8.py | 8 +++++++- .../pytorch/optimizers/multi_tensor_apply.py | 1 - .../pytorch/tensor/fsdp2_allgather_tensor.py | 2 +- 5 files changed, 9 insertions(+), 13 deletions(-) diff --git a/tests/pytorch/distributed/run_fsdp2_fp8_model.py b/tests/pytorch/distributed/run_fsdp2_fp8_model.py index 936107f61..1fc41dd51 100644 --- a/tests/pytorch/distributed/run_fsdp2_fp8_model.py +++ b/tests/pytorch/distributed/run_fsdp2_fp8_model.py @@ -262,7 +262,7 @@ def _train(args): # Zero the parameter gradients optimizer.zero_grad() if args.fp8_autocast: - with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + with te.autocast(enabled=True, recipe=fp8_recipe): output = model(input_data) else: output = model(input_data) diff --git a/tests/pytorch/distributed/run_fsdp2_model.py b/tests/pytorch/distributed/run_fsdp2_model.py index cbd073ee4..3b9264279 100644 --- a/tests/pytorch/distributed/run_fsdp2_model.py +++ b/tests/pytorch/distributed/run_fsdp2_model.py @@ -22,26 +22,17 @@ import torch import torch.distributed as dist from torch.distributed.tensor import DTensor -from torch.distributed.tensor import DTensor import torch.nn.functional as F from torch import nn, optim from torch.distributed import DeviceMesh from torch.distributed._composable.fsdp import fully_shard from torch.distributed.device_mesh import init_device_mesh from transformer_engine.pytorch import QuantizedTensor -from transformer_engine.pytorch import QuantizedTensor from contextlib import nullcontext LOCAL_RANK = None -def dist_print(msg): - if LOCAL_RANK == 0: - print(msg) - -LOCAL_RANK = None - - def dist_print(msg): if LOCAL_RANK == 0: print(msg) diff --git a/tests/pytorch/distributed/test_torch_fsdp2_fp8.py b/tests/pytorch/distributed/test_torch_fsdp2_fp8.py index 098a1e9a5..f1f419cf3 100644 --- a/tests/pytorch/distributed/test_torch_fsdp2_fp8.py +++ b/tests/pytorch/distributed/test_torch_fsdp2_fp8.py @@ -65,7 +65,13 @@ def _run_test(fp_init, fp8_autocast, recipe): output_dp = torch.load("all_iters_dp.pt", map_location="cpu") atol = 0 rtol = 0 - if fp_init: + # Use relaxed tolerance when FSDP2 and DDP are not guaranteed to be bit-identical: + # - fp8_init=True (FP8 weights, FP32 compute): AllGather(FP8)->Dequantize->GEMM vs Dequantize->GEMM + # differs in dequantization context/order and can yield O(1e-11) differences. + # - fp32 (no FP8): gradient reduction order (all-reduce vs reduce-scatter) differs, so float + # non-associativity produces last-bit differences in the reduced gradient and updated weights. + # When fp8_autocast=True, both paths use the same FP8 GEMM with no dequantization, so 0 tol is used. + if fp_init or (not fp_init and not fp8_autocast): atol = 1e-6 rtol = 5e-5 diff --git a/transformer_engine/pytorch/optimizers/multi_tensor_apply.py b/transformer_engine/pytorch/optimizers/multi_tensor_apply.py index 3dc2bbe00..c791f0c4e 100644 --- a/transformer_engine/pytorch/optimizers/multi_tensor_apply.py +++ b/transformer_engine/pytorch/optimizers/multi_tensor_apply.py @@ -18,7 +18,6 @@ def __call__(self, op, noop_flag_buffer, tensor_lists, *args): for i, ts in enumerate(tensor_lists): for j, t in enumerate(ts): if isinstance(t, DTensor): - print(f"DTensor found: {t}") tensor_lists[i][j] = t._local_tensor.data if IS_HIP_EXTENSION else t._local_tensor return op(self.chunk_size, noop_flag_buffer, tensor_lists, *args) diff --git a/transformer_engine/pytorch/tensor/fsdp2_allgather_tensor.py b/transformer_engine/pytorch/tensor/fsdp2_allgather_tensor.py index 492a259ca..ab1341be6 100644 --- a/transformer_engine/pytorch/tensor/fsdp2_allgather_tensor.py +++ b/transformer_engine/pytorch/tensor/fsdp2_allgather_tensor.py @@ -226,4 +226,4 @@ def __reduce_ex__(self, protocol: int) -> tuple: model is loaded and FSDP2 is re-initialized, parameters get wrapped again. """ # Delegate to the inner tensor's serialization - return self._data.__reduce_ex__(protocol) \ No newline at end of file + return self._data.__reduce_ex__(protocol) From 2b8818d809a0eccc44288949fb8a9c5fabd48b92 Mon Sep 17 00:00:00 2001 From: sudhu2k Date: Tue, 17 Mar 2026 21:35:55 +0000 Subject: [PATCH 6/9] Ensure proper newline at the end of the test_torch_fsdp2_fp8.py file by adding a newline character after the pass statement in the test_dummy function. --- tests/pytorch/distributed/test_torch_fsdp2_fp8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pytorch/distributed/test_torch_fsdp2_fp8.py b/tests/pytorch/distributed/test_torch_fsdp2_fp8.py index f1f419cf3..79556d5fd 100644 --- a/tests/pytorch/distributed/test_torch_fsdp2_fp8.py +++ b/tests/pytorch/distributed/test_torch_fsdp2_fp8.py @@ -140,4 +140,4 @@ def test_dummy() -> None: pytest returns exit code 5 if all tests are skipped. """ - pass \ No newline at end of file + pass From 8964d5615e4a95082b9c8c681b2b231b8470bef8 Mon Sep 17 00:00:00 2001 From: sudhu2k Date: Wed, 18 Mar 2026 15:07:14 +0000 Subject: [PATCH 7/9] Refactor tolerance calculations. --- tests/pytorch/distributed/test_torch_fsdp2_fp8.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/pytorch/distributed/test_torch_fsdp2_fp8.py b/tests/pytorch/distributed/test_torch_fsdp2_fp8.py index 79556d5fd..ecbbb9dcd 100644 --- a/tests/pytorch/distributed/test_torch_fsdp2_fp8.py +++ b/tests/pytorch/distributed/test_torch_fsdp2_fp8.py @@ -22,14 +22,13 @@ def assert_allclose( ) -> bool: """Ensures two lists are equal.""" assert len(l1) == len(l2), "Unequal number of outputs." + tols = dict(atol=atol) + tols["rtol"] = rtol if rtol is not None else 0 + tol = tols["atol"] + (tols["rtol"] * torch.abs(l2)) for i, (t1, t2) in enumerate(zip(l1, l2)): - tols = dict(atol=atol) - if rtol is not None: - tols["rtol"] = rtol result = torch.allclose(t1, t2, **tols) if not result: diff = torch.abs(t1 - t2) - tol = atol + (rtol * torch.abs(t2)) exceed_mask = diff > tol if exceed_mask.any(): indices = torch.nonzero(exceed_mask, as_tuple=True) From 54938d9334d2a13264936a426465e41655f58aa6 Mon Sep 17 00:00:00 2001 From: sudhu2k Date: Wed, 18 Mar 2026 16:19:09 +0000 Subject: [PATCH 8/9] Refactor model initialization and autocasting logic in FSDP2 FP8 tests for improved clarity and consistency. --- .../distributed/run_fsdp2_fp8_model.py | 12 ++--- .../distributed/test_torch_fsdp2_fp8.py | 45 ++++++++++++------- 2 files changed, 33 insertions(+), 24 deletions(-) diff --git a/tests/pytorch/distributed/run_fsdp2_fp8_model.py b/tests/pytorch/distributed/run_fsdp2_fp8_model.py index 1fc41dd51..740b1f3e2 100644 --- a/tests/pytorch/distributed/run_fsdp2_fp8_model.py +++ b/tests/pytorch/distributed/run_fsdp2_fp8_model.py @@ -193,11 +193,8 @@ def _train(args): ) prof.start() - if args.fp8_init: - # Build the model with the specified context - with quantized_model_init(enabled=True, recipe=fp8_recipe): - model = SimpleNet(args.input_size, args.hidden_size, args.output_size, use_fsdp2=args.use_fsdp2) - else: + # Build the model with the specified context + with quantized_model_init(enabled=args.fp8_init, recipe=fp8_recipe): model = SimpleNet(args.input_size, args.hidden_size, args.output_size, use_fsdp2=args.use_fsdp2) # Move the model to the correct device if not args.memory_profile and not args.profile: @@ -261,10 +258,7 @@ def _train(args): # Zero the parameter gradients optimizer.zero_grad() - if args.fp8_autocast: - with te.autocast(enabled=True, recipe=fp8_recipe): - output = model(input_data) - else: + with te.autocast(enabled=args.fp8_autocast, recipe=fp8_recipe): output = model(input_data) target = torch.randn(args.batch_size, args.output_size).to(device) loss = F.mse_loss(output, target) diff --git a/tests/pytorch/distributed/test_torch_fsdp2_fp8.py b/tests/pytorch/distributed/test_torch_fsdp2_fp8.py index ecbbb9dcd..73c473239 100644 --- a/tests/pytorch/distributed/test_torch_fsdp2_fp8.py +++ b/tests/pytorch/distributed/test_torch_fsdp2_fp8.py @@ -29,18 +29,27 @@ def assert_allclose( result = torch.allclose(t1, t2, **tols) if not result: diff = torch.abs(t1 - t2) - exceed_mask = diff > tol - if exceed_mask.any(): - indices = torch.nonzero(exceed_mask, as_tuple=True) - max_diff = diff[exceed_mask].max() - max_idx = (diff[exceed_mask] == max_diff).nonzero(as_tuple=True)[0][0] - max_location = [idx[max_idx].item() for idx in indices] + if diff.dim() == 0: + max_diff = diff + max_location = [] msg = ( - f"Outputs not close enough in tensor at idx={i}. " - f"Maximum difference at location {max_location} " - f"with {t1[exceed_mask][max_idx].item()} vs {t2[exceed_mask][max_idx].item()} " - f"(diff {max_diff.item()})." + f"Outputs not close enough in scalar tensor at idx={i}. " + f"Difference: {max_diff.item()}." ) + else: + exceed_mask = diff > tol + + if exceed_mask.any(): + indices = torch.nonzero(exceed_mask, as_tuple=True) + max_diff = diff[exceed_mask].max() + max_idx = (diff[exceed_mask] == max_diff).nonzero(as_tuple=True)[0][0] + max_location = [idx[max_idx].item() for idx in indices] + msg = ( + f"Outputs not close enough in tensor at idx={i}. " + f"Maximum difference at location {max_location} " + f"with {t1[exceed_mask][max_idx].item()} vs {t2[exceed_mask][max_idx].item()} " + f"(diff {max_diff.item()})." + ) raise AssertionError(msg) def _run_test(fp_init, fp8_autocast, recipe): @@ -65,11 +74,17 @@ def _run_test(fp_init, fp8_autocast, recipe): atol = 0 rtol = 0 # Use relaxed tolerance when FSDP2 and DDP are not guaranteed to be bit-identical: - # - fp8_init=True (FP8 weights, FP32 compute): AllGather(FP8)->Dequantize->GEMM vs Dequantize->GEMM - # differs in dequantization context/order and can yield O(1e-11) differences. - # - fp32 (no FP8): gradient reduction order (all-reduce vs reduce-scatter) differs, so float - # non-associativity produces last-bit differences in the reduced gradient and updated weights. - # When fp8_autocast=True, both paths use the same FP8 GEMM with no dequantization, so 0 tol is used. + # + # - fp8_init=True: After each optimizer step, FP8 weights are re-quantized from + # FP32 master weights. Hence we use a relaxed tolerance. + # + # - No FP8 (fp8_init=False, fp8_autocast=False): gradient reduction order differs + # (all-reduce vs reduce-scatter), so float non-associativity produces last-bit + # differences in the reduced gradients and updated weights. Hence we use a relaxed tolerance. + # + # When fp8_autocast=True and fp8_init=False, FP8 quantization happens after the + # FSDP2 AllGather reconstructs the full weight, so both paths compute identical + # scales and produce bit-identical FP8 GEMMs — strict tolerance (0) is used. if fp_init or (not fp_init and not fp8_autocast): atol = 1e-6 rtol = 5e-5 From c1949d3c91e2d73cba09a437d8ba0d634e998ced Mon Sep 17 00:00:00 2001 From: sudhu2k Date: Thu, 19 Mar 2026 19:13:04 +0000 Subject: [PATCH 9/9] Fix FusedAdam DTensor state initialization for FSDP2 Manually ported fix from upstream commit 139c863 The full commit was not cherry-picked due to unrelated changes across many files. Addressed PR comments --- .../distributed/run_fsdp2_fp8_model.py | 10 ++--- .../distributed/test_torch_fsdp2_fp8.py | 40 +++++++++---------- .../pytorch/optimizers/fused_adam.py | 30 ++++++++++++-- .../pytorch/tensor/fsdp2_allgather_tensor.py | 12 +++--- 4 files changed, 59 insertions(+), 33 deletions(-) diff --git a/tests/pytorch/distributed/run_fsdp2_fp8_model.py b/tests/pytorch/distributed/run_fsdp2_fp8_model.py index 740b1f3e2..5b25e2f4e 100644 --- a/tests/pytorch/distributed/run_fsdp2_fp8_model.py +++ b/tests/pytorch/distributed/run_fsdp2_fp8_model.py @@ -87,10 +87,10 @@ def _parse_args(argv=None, namespace=None): parser.add_argument("--output-size", type=int, default=2048, help="Output size for the model") parser.add_argument("--batch-size", type=int, default=2048, help="Output size for the model") parser.add_argument( - "--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8." + "--quantized-init", action="store_true", default=False, help="Initialize primary weights in FP8 via quantized_model_init." ) parser.add_argument( - "--fp8-autocast", action="store_true", default=False, help="Enable FP8 autocast." + "--autocast", action="store_true", default=False, help="Enable te.autocast for FP8 compute." ) parser.add_argument( "--iter", type=int, default=10, help="Number of iterations for forward pass" @@ -194,7 +194,7 @@ def _train(args): prof.start() # Build the model with the specified context - with quantized_model_init(enabled=args.fp8_init, recipe=fp8_recipe): + with quantized_model_init(enabled=args.quantized_init, recipe=fp8_recipe): model = SimpleNet(args.input_size, args.hidden_size, args.output_size, use_fsdp2=args.use_fsdp2) # Move the model to the correct device if not args.memory_profile and not args.profile: @@ -236,7 +236,7 @@ def _train(args): else: model = DDP(model, device_ids=[LOCAL_RANK]) - if args.fp8_init: + if args.quantized_init: optimizer = te.optimizers.FusedAdam(model.parameters(), lr=1e-3, master_weights=True, use_decoupled_grad=True) else: optimizer = te.optimizers.FusedAdam(model.parameters(), lr=1e-3) @@ -258,7 +258,7 @@ def _train(args): # Zero the parameter gradients optimizer.zero_grad() - with te.autocast(enabled=args.fp8_autocast, recipe=fp8_recipe): + with te.autocast(enabled=args.autocast, recipe=fp8_recipe): output = model(input_data) target = torch.randn(args.batch_size, args.output_size).to(device) loss = F.mse_loss(output, target) diff --git a/tests/pytorch/distributed/test_torch_fsdp2_fp8.py b/tests/pytorch/distributed/test_torch_fsdp2_fp8.py index 73c473239..dc01be362 100644 --- a/tests/pytorch/distributed/test_torch_fsdp2_fp8.py +++ b/tests/pytorch/distributed/test_torch_fsdp2_fp8.py @@ -52,17 +52,17 @@ def assert_allclose( ) raise AssertionError(msg) -def _run_test(fp_init, fp8_autocast, recipe): +def _run_test(quantized_init, autocast, recipe): test_dir = Path(__file__).parent.resolve() fsdp_script = test_dir / "run_fsdp2_fp8_model.py" test_cmd = ["torchrun", f"--nproc_per_node={NUM_PROCS}", "--master-port=29501", str(fsdp_script)] - if fp_init: - test_cmd += ["--fp8-init"] - if fp8_autocast: - test_cmd += ["--fp8-autocast"] - if fp8_autocast or fp_init: + if quantized_init: + test_cmd += ["--quantized-init"] + if autocast: + test_cmd += ["--autocast"] + if autocast or quantized_init: test_cmd += ["--recipe", recipe] subprocess.run(test_cmd + ['--use-fsdp2','--gradients-save-file', 'all_iters_fsdp2.pt'], env=os.environ, check=True) @@ -75,24 +75,24 @@ def _run_test(fp_init, fp8_autocast, recipe): rtol = 0 # Use relaxed tolerance when FSDP2 and DDP are not guaranteed to be bit-identical: # - # - fp8_init=True: After each optimizer step, FP8 weights are re-quantized from - # FP32 master weights. Hence we use a relaxed tolerance. + # - quantized_init=True: After each optimizer step, FP8 weights are re-quantized + # from FP32 master weights. Hence we use a relaxed tolerance. # - # - No FP8 (fp8_init=False, fp8_autocast=False): gradient reduction order differs + # - No FP8 (quantized_init=False, autocast=False): gradient reduction order differs # (all-reduce vs reduce-scatter), so float non-associativity produces last-bit # differences in the reduced gradients and updated weights. Hence we use a relaxed tolerance. # - # When fp8_autocast=True and fp8_init=False, FP8 quantization happens after the + # When autocast=True and quantized_init=False, FP8 quantization happens after the # FSDP2 AllGather reconstructs the full weight, so both paths compute identical # scales and produce bit-identical FP8 GEMMs — strict tolerance (0) is used. - if fp_init or (not fp_init and not fp8_autocast): + if quantized_init or (not quantized_init and not autocast): atol = 1e-6 rtol = 5e-5 for idx, (te_output_no_cache, te_output_cache) in enumerate(zip(output_fsdp, output_dp)): print(f"Comparing FSDP {te_output_no_cache[0]}, DDP {te_output_cache[0]} at index {idx}...") - assert_allclose(te_output_no_cache[1], te_output_cache[1], atol=atol, rtol=rtol) # expects exact match + assert_allclose(te_output_no_cache[1], te_output_cache[1], atol=atol, rtol=rtol) print(f"Tensor at index {idx} passed comparison.") @@ -106,11 +106,11 @@ def cleanup_artifacts(): # Define test cases explicitly test_cases = [] # All FP8 enabled cases (all recipes) -for fp8_init in [True, False]: - for fp8_autocast in [True, False]: - if fp8_init or fp8_autocast: +for quantized_init in [True, False]: + for autocast in [True, False]: + if quantized_init or autocast: for recipe in ["delayed", "current", "mxfp8"]: - test_cases.append((fp8_init, fp8_autocast, recipe)) + test_cases.append((quantized_init, autocast, recipe)) # FP8 disabled case (only once) test_cases.append((False, False, "delayed")) @@ -118,9 +118,9 @@ def cleanup_artifacts(): @pytest.mark.skipif(NUM_PROCS < 4, reason="Requires 4+ GPUs") @pytest.mark.skipif(NUM_PROCS % 2 != 0, reason="Requires even number of GPUs") @pytest.mark.skipif(not torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+") -@pytest.mark.parametrize("fp8_init,fp8_autocast,recipe", test_cases) +@pytest.mark.parametrize("quantized_init, autocast, recipe", test_cases) @pytest.mark.usefixtures("cleanup_artifacts") -def test_distributed(fp8_init, fp8_autocast, recipe): +def test_distributed(quantized_init, autocast, recipe): batch_size = 2048 input_size = 2048 @@ -140,12 +140,12 @@ def test_distributed(fp8_init, fp8_autocast, recipe): if torch.cuda.device_count() < 4: pytest.skip("FSDP2 test requires at least 4 GPUs") - if fp8_init and not fp8_available: + if quantized_init and not fp8_available: pytest.skip(reason_for_no_fp8) if recipe == "mxfp8" and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - _run_test(fp8_init, fp8_autocast, recipe) + _run_test(quantized_init, autocast, recipe) def test_dummy() -> None: diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index b608e1db5..3cc2814de 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -13,8 +13,11 @@ import warnings import torch +from torch.distributed._tensor import DTensor +from transformer_engine.pytorch.tensor.fsdp2_allgather_tensor import FSDPAGTensor import transformer_engine_torch as tex from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer +from transformer_engine.pytorch.quantized_tensor import QuantizedTensor from .multi_tensor_apply import multi_tensor_applier from torch.utils.cpp_extension import IS_HIP_EXTENSION from transformer_engine.pytorch.utils import is_fp8_fnuz @@ -375,10 +378,21 @@ def _initialize_state( store_param_remainders (bool): Store only trailing remainder bits. """ dtype = self.name_to_dtype_map[state_name] + # (Upstream fix https://github.com/NVIDIA/TransformerEngine/commit/139c863f92420271bbae2cbce49d9b170b7d03f9) + # Extract local tensor from DTensor (e.g. from FSDP2) to avoid + # QuantizedTensor.__torch_dispatch__ ignoring the dtype kwarg in + # torch.empty_like, and to ensure optimizer states are plain tensors. + local_param = param._local_tensor if isinstance(param, DTensor) else param + # ROCm fix: FSDPAGTensor is a wrapper around a plain tensor, so we need to extract the underlying tensor. + local_param = local_param._data if isinstance(local_param, FSDPAGTensor) else local_param + # Handle QuantizedTensor by dequantizing first + param_for_empty = ( + local_param.dequantize() if isinstance(local_param, QuantizedTensor) else local_param + ) if store_param_remainders: - data = torch.zeros_like(param, dtype=torch.int16) + data = torch.zeros_like(param_for_empty, dtype=torch.int16) else: - data = torch.empty_like(param, dtype=dtype) + data = torch.empty_like(param_for_empty, dtype=dtype) if zero_buffer: data.zero_() @@ -419,7 +433,17 @@ def initialize_state(self, param, store_param_remainders): store_param_remainders=store_param_remainders, ) if not store_param_remainders: - self.set_scaled_state(param, "master_param", param.clone().detach().float()) + # (Upstream fix https://github.com/NVIDIA/TransformerEngine/commit/139c863f92420271bbae2cbce49d9b170b7d03f9) + #Extract local tensor from DTensor and dequantize QuantizedTensor + # to get a plain float32 copy for the master weight. + local_param = param._local_tensor if isinstance(param, DTensor) else param + # ROCm fix: FSDPAGTensor is a wrapper around a plain tensor, so we need to extract the underlying tensor. + local_param = local_param._data if isinstance(local_param, FSDPAGTensor) else local_param + if isinstance(local_param, QuantizedTensor): + master = local_param.dequantize(dtype=torch.float32).clone().detach() + else: + master = local_param.clone().detach().float() + self.set_scaled_state(param, "master_param", master) def state_dict(self): """Override the state_dict() of pytorch. Before returning the state_dict, cast all diff --git a/transformer_engine/pytorch/tensor/fsdp2_allgather_tensor.py b/transformer_engine/pytorch/tensor/fsdp2_allgather_tensor.py index ab1341be6..f9e63a8fe 100644 --- a/transformer_engine/pytorch/tensor/fsdp2_allgather_tensor.py +++ b/transformer_engine/pytorch/tensor/fsdp2_allgather_tensor.py @@ -1,6 +1,7 @@ #!/usr/bin/python3 # Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # See LICENSE for license information. + from __future__ import annotations from typing import Any, Optional, Tuple import torch @@ -122,8 +123,8 @@ def rewrap(x): # Must return (list_of_tensors_to_all_gather, user_metadata) def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, mp_policy): - """Functions FSDP2 calls before all-gather of the - weights for both forward and backward passes. + """ + Functions FSDP2 calls before all-gather of the weights for both forward and backward passes. Args: mesh (torch.distributed.DeviceMesh): DeviceMesh used by FSDP2 to shard the weights. @@ -176,8 +177,8 @@ def fsdp_post_all_gather( *, out: Optional[torch.Tensor] = None, ): - """Functions FSDP2 calls after all-gather of the - weights for both forward and backward passes. + """ + Functions FSDP2 calls after all-gather of the weights for both forward and backward passes. Args: all_gather_outputs (Tuple[torch.Tensor, ...]): sharded_tensors sent out in fsdp_pre_all_gather from each rank are all-gathered and received here as a tuple. @@ -219,7 +220,8 @@ def fsdp_post_all_gather( return out, all_gather_outputs def __reduce_ex__(self, protocol: int) -> tuple: - """Custom pickling - unwrap to inner tensor + """ + Custom pickling - unwrap to inner tensor During checkpointing, save just the underlying high-precision tensor. FSDPAGTensor is a transient wrapper for FSDP2 communication - when the