diff --git a/tests/pytorch/distributed/run_fsdp2_fp8_model.py b/tests/pytorch/distributed/run_fsdp2_fp8_model.py index 5f73e476d..5b25e2f4e 100644 --- a/tests/pytorch/distributed/run_fsdp2_fp8_model.py +++ b/tests/pytorch/distributed/run_fsdp2_fp8_model.py @@ -87,7 +87,10 @@ def _parse_args(argv=None, namespace=None): parser.add_argument("--output-size", type=int, default=2048, help="Output size for the model") parser.add_argument("--batch-size", type=int, default=2048, help="Output size for the model") parser.add_argument( - "--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8." + "--quantized-init", action="store_true", default=False, help="Initialize primary weights in FP8 via quantized_model_init." + ) + parser.add_argument( + "--autocast", action="store_true", default=False, help="Enable te.autocast for FP8 compute." ) parser.add_argument( "--iter", type=int, default=10, help="Number of iterations for forward pass" @@ -169,14 +172,32 @@ def _train(args): if args.memory_profile: torch.cuda.memory._record_memory_history(enabled='all', context='all', stacks='all') - if args.fp8_init: - # Build the model with the specified context - with quantized_model_init(enabled=True): - model = SimpleNet(args.input_size, args.hidden_size, args.output_size, use_fsdp2=args.use_fsdp2) - else: + + prof = None + if ( + args.profile + and torch.distributed.get_rank() in args.profile_ranks + ): + prof = torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], + schedule=torch.profiler.schedule( + wait=max(args.profile_step_start - 1, 0), + warmup=1 if args.profile_step_start > 0 else 0, + active=args.profile_step_end - args.profile_step_start, + repeat=1, + ), + on_trace_ready=torch.profiler.tensorboard_trace_handler(args.tensorboard_dir), + record_shapes=True, + profile_memory=True, + with_stack=True, + ) + prof.start() + + # Build the model with the specified context + with quantized_model_init(enabled=args.quantized_init, recipe=fp8_recipe): model = SimpleNet(args.input_size, args.hidden_size, args.output_size, use_fsdp2=args.use_fsdp2) # Move the model to the correct device - if not args.memory_profile: + if not args.memory_profile and not args.profile: model.load_state_dict(torch.load('fsdp_model.pth')) model.to(device) @@ -215,7 +236,10 @@ def _train(args): else: model = DDP(model, device_ids=[LOCAL_RANK]) - optimizer = te.optimizers.FusedAdam(model.parameters(), lr=1e-3) + if args.quantized_init: + optimizer = te.optimizers.FusedAdam(model.parameters(), lr=1e-3, master_weights=True, use_decoupled_grad=True) + else: + optimizer = te.optimizers.FusedAdam(model.parameters(), lr=1e-3) input_path = Path("shared_input.pt") if input_path.exists(): @@ -226,25 +250,6 @@ def _train(args): print("Generated and saved shared input tensor.") out_tensors = [] - prof = None - if ( - args.profile - and torch.distributed.get_rank() in args.profile_ranks - ): - prof = torch.profiler.profile( - activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], - schedule=torch.profiler.schedule( - wait=max(args.profile_step_start - 1, 0), - warmup=1 if args.profile_step_start > 0 else 0, - active=args.profile_step_end - args.profile_step_start, - repeat=1, - ), - on_trace_ready=torch.profiler.tensorboard_trace_handler(args.tensorboard_dir), - record_shapes=True, - profile_memory=True, - with_stack=True, - ) - prof.start() for iteration in range(args.iter): if LOCAL_RANK == 0: print(f"Starting iteration...{iteration}") @@ -253,7 +258,7 @@ def _train(args): # Zero the parameter gradients optimizer.zero_grad() - with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + with te.autocast(enabled=args.autocast, recipe=fp8_recipe): output = model(input_data) target = torch.randn(args.batch_size, args.output_size).to(device) loss = F.mse_loss(output, target) @@ -286,6 +291,9 @@ def _train(args): torch.save(out_tensors, args.gradients_save_file) if args.memory_profile: + with open('memory_summary.txt', 'w') as f: + f.write(torch.cuda.memory_summary(device=None, abbreviated=False)) + snapshot = torch.cuda.memory._snapshot() import pickle with open('memory_snapshot.pickle', 'wb') as f: diff --git a/tests/pytorch/distributed/test_torch_fsdp2_fp8.py b/tests/pytorch/distributed/test_torch_fsdp2_fp8.py index 3eaa449c2..dc01be362 100644 --- a/tests/pytorch/distributed/test_torch_fsdp2_fp8.py +++ b/tests/pytorch/distributed/test_torch_fsdp2_fp8.py @@ -17,37 +17,53 @@ NUM_PROCS: int = torch.cuda.device_count() -def assertEqual( - l1: List[torch.Tensor], l2: List[torch.Tensor]) -> bool: - """Ensures two lists are exactly equal.""" +def assert_allclose( + l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float, rtol: float = None +) -> bool: + """Ensures two lists are equal.""" assert len(l1) == len(l2), "Unequal number of outputs." + tols = dict(atol=atol) + tols["rtol"] = rtol if rtol is not None else 0 + tol = tols["atol"] + (tols["rtol"] * torch.abs(l2)) for i, (t1, t2) in enumerate(zip(l1, l2)): - result = torch.allclose(t1, t2, atol=0, rtol=0) + result = torch.allclose(t1, t2, **tols) if not result: diff = torch.abs(t1 - t2) - exceed_mask = diff > 0 - if exceed_mask.any(): - indices = torch.nonzero(exceed_mask, as_tuple=True) - max_diff = diff[exceed_mask].max() - max_idx = (diff[exceed_mask] == max_diff).nonzero(as_tuple=True)[0][0] - max_location = [idx[max_idx].item() for idx in indices] + if diff.dim() == 0: + max_diff = diff + max_location = [] msg = ( - f"Outputs not close enough in tensor at idx={i}. " - f"Maximum difference at location {max_location} " - f"with {t1[exceed_mask][max_idx].item()} vs {t2[exceed_mask][max_idx].item()} " - f"(diff {max_diff.item()})." + f"Outputs not close enough in scalar tensor at idx={i}. " + f"Difference: {max_diff.item()}." ) + else: + exceed_mask = diff > tol + + if exceed_mask.any(): + indices = torch.nonzero(exceed_mask, as_tuple=True) + max_diff = diff[exceed_mask].max() + max_idx = (diff[exceed_mask] == max_diff).nonzero(as_tuple=True)[0][0] + max_location = [idx[max_idx].item() for idx in indices] + msg = ( + f"Outputs not close enough in tensor at idx={i}. " + f"Maximum difference at location {max_location} " + f"with {t1[exceed_mask][max_idx].item()} vs {t2[exceed_mask][max_idx].item()} " + f"(diff {max_diff.item()})." + ) raise AssertionError(msg) -def _run_test(fp_init, recipe): +def _run_test(quantized_init, autocast, recipe): test_dir = Path(__file__).parent.resolve() fsdp_script = test_dir / "run_fsdp2_fp8_model.py" test_cmd = ["torchrun", f"--nproc_per_node={NUM_PROCS}", "--master-port=29501", str(fsdp_script)] - if fp_init: - test_cmd += ["--fp8-init"] - test_cmd += ["--recipe", recipe] + if quantized_init: + test_cmd += ["--quantized-init"] + if autocast: + test_cmd += ["--autocast"] + if autocast or quantized_init: + test_cmd += ["--recipe", recipe] subprocess.run(test_cmd + ['--use-fsdp2','--gradients-save-file', 'all_iters_fsdp2.pt'], env=os.environ, check=True) subprocess.run(test_cmd + ['--gradients-save-file', 'all_iters_dp.pt'], env=os.environ, check=True) @@ -55,11 +71,28 @@ def _run_test(fp_init, recipe): # Load outputs output_fsdp = torch.load("all_iters_fsdp2.pt", map_location="cpu") output_dp = torch.load("all_iters_dp.pt", map_location="cpu") + atol = 0 + rtol = 0 + # Use relaxed tolerance when FSDP2 and DDP are not guaranteed to be bit-identical: + # + # - quantized_init=True: After each optimizer step, FP8 weights are re-quantized + # from FP32 master weights. Hence we use a relaxed tolerance. + # + # - No FP8 (quantized_init=False, autocast=False): gradient reduction order differs + # (all-reduce vs reduce-scatter), so float non-associativity produces last-bit + # differences in the reduced gradients and updated weights. Hence we use a relaxed tolerance. + # + # When autocast=True and quantized_init=False, FP8 quantization happens after the + # FSDP2 AllGather reconstructs the full weight, so both paths compute identical + # scales and produce bit-identical FP8 GEMMs — strict tolerance (0) is used. + if quantized_init or (not quantized_init and not autocast): + atol = 1e-6 + rtol = 5e-5 for idx, (te_output_no_cache, te_output_cache) in enumerate(zip(output_fsdp, output_dp)): print(f"Comparing FSDP {te_output_no_cache[0]}, DDP {te_output_cache[0]} at index {idx}...") - assertEqual(te_output_no_cache[1], te_output_cache[1]) # expects exact match + assert_allclose(te_output_no_cache[1], te_output_cache[1], atol=atol, rtol=rtol) print(f"Tensor at index {idx} passed comparison.") @@ -70,13 +103,24 @@ def cleanup_artifacts(): if os.path.exists(fname): os.remove(fname) +# Define test cases explicitly +test_cases = [] +# All FP8 enabled cases (all recipes) +for quantized_init in [True, False]: + for autocast in [True, False]: + if quantized_init or autocast: + for recipe in ["delayed", "current", "mxfp8"]: + test_cases.append((quantized_init, autocast, recipe)) +# FP8 disabled case (only once) +test_cases.append((False, False, "delayed")) + + @pytest.mark.skipif(NUM_PROCS < 4, reason="Requires 4+ GPUs") @pytest.mark.skipif(NUM_PROCS % 2 != 0, reason="Requires even number of GPUs") @pytest.mark.skipif(not torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+") -@pytest.mark.parametrize("fp8_init", ([False])) -@pytest.mark.parametrize("recipe", (["delayed", "current", "mxfp8"])) +@pytest.mark.parametrize("quantized_init, autocast, recipe", test_cases) @pytest.mark.usefixtures("cleanup_artifacts") -def test_distributed(fp8_init, recipe): +def test_distributed(quantized_init, autocast, recipe): batch_size = 2048 input_size = 2048 @@ -96,12 +140,12 @@ def test_distributed(fp8_init, recipe): if torch.cuda.device_count() < 4: pytest.skip("FSDP2 test requires at least 4 GPUs") - if fp8_init and not fp8_available: + if quantized_init and not fp8_available: pytest.skip(reason_for_no_fp8) if recipe == "mxfp8" and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - _run_test(fp8_init, recipe) + _run_test(quantized_init, autocast, recipe) def test_dummy() -> None: diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index e74fd9d17..1ed16a02c 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -21,6 +21,7 @@ import torch.nn.functional as F from torch.distributed.tensor import DTensor from torch.utils.cpp_extension import IS_HIP_EXTENSION +from torch.distributed.tensor import DTensor import transformer_engine_torch as tex from transformer_engine.common.recipe import Recipe @@ -1335,9 +1336,7 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: self.keep_fp8_weight_transpose_cache = False param = FSDPAGTensor( param, - module=self, - fp8_meta_index=fp8_meta_index, - keep_fp8_weight_transpose_cache=self.keep_fp8_weight_transpose_cache + fp8_meta_index=fp8_meta_index, ) # Redo parameter wrap in case we broke it above diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index 3ac2be28f..3cc2814de 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -13,8 +13,11 @@ import warnings import torch +from torch.distributed._tensor import DTensor +from transformer_engine.pytorch.tensor.fsdp2_allgather_tensor import FSDPAGTensor import transformer_engine_torch as tex from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer +from transformer_engine.pytorch.quantized_tensor import QuantizedTensor from .multi_tensor_apply import multi_tensor_applier from torch.utils.cpp_extension import IS_HIP_EXTENSION from transformer_engine.pytorch.utils import is_fp8_fnuz @@ -375,10 +378,21 @@ def _initialize_state( store_param_remainders (bool): Store only trailing remainder bits. """ dtype = self.name_to_dtype_map[state_name] + # (Upstream fix https://github.com/NVIDIA/TransformerEngine/commit/139c863f92420271bbae2cbce49d9b170b7d03f9) + # Extract local tensor from DTensor (e.g. from FSDP2) to avoid + # QuantizedTensor.__torch_dispatch__ ignoring the dtype kwarg in + # torch.empty_like, and to ensure optimizer states are plain tensors. + local_param = param._local_tensor if isinstance(param, DTensor) else param + # ROCm fix: FSDPAGTensor is a wrapper around a plain tensor, so we need to extract the underlying tensor. + local_param = local_param._data if isinstance(local_param, FSDPAGTensor) else local_param + # Handle QuantizedTensor by dequantizing first + param_for_empty = ( + local_param.dequantize() if isinstance(local_param, QuantizedTensor) else local_param + ) if store_param_remainders: - data = torch.zeros(param.shape, dtype=torch.int16, device=param.device) + data = torch.zeros_like(param_for_empty, dtype=torch.int16) else: - data = torch.empty(param.shape, dtype=dtype, device=param.device) + data = torch.empty_like(param_for_empty, dtype=dtype) if zero_buffer: data.zero_() @@ -419,7 +433,17 @@ def initialize_state(self, param, store_param_remainders): store_param_remainders=store_param_remainders, ) if not store_param_remainders: - self.set_scaled_state(param, "master_param", param.clone().detach().float()) + # (Upstream fix https://github.com/NVIDIA/TransformerEngine/commit/139c863f92420271bbae2cbce49d9b170b7d03f9) + #Extract local tensor from DTensor and dequantize QuantizedTensor + # to get a plain float32 copy for the master weight. + local_param = param._local_tensor if isinstance(param, DTensor) else param + # ROCm fix: FSDPAGTensor is a wrapper around a plain tensor, so we need to extract the underlying tensor. + local_param = local_param._data if isinstance(local_param, FSDPAGTensor) else local_param + if isinstance(local_param, QuantizedTensor): + master = local_param.dequantize(dtype=torch.float32).clone().detach() + else: + master = local_param.clone().detach().float() + self.set_scaled_state(param, "master_param", master) def state_dict(self): """Override the state_dict() of pytorch. Before returning the state_dict, cast all diff --git a/transformer_engine/pytorch/tensor/fsdp2_allgather_tensor.py b/transformer_engine/pytorch/tensor/fsdp2_allgather_tensor.py index 763fd4419..f9e63a8fe 100644 --- a/transformer_engine/pytorch/tensor/fsdp2_allgather_tensor.py +++ b/transformer_engine/pytorch/tensor/fsdp2_allgather_tensor.py @@ -2,6 +2,7 @@ # Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # See LICENSE for license information. +from __future__ import annotations from typing import Any, Optional, Tuple import torch import torch.nn as nn @@ -44,18 +45,12 @@ def __init__( self, tensor: torch.Tensor, *, - module: nn.Module, fp8_meta_index: str, - keep_fp8_weight_transpose_cache: bool, ): #The underlying tensor self._data = tensor - # Where quantizers are present - self._module = module # Which quantizer to use within module.quantizers["scaling_fwd"][idx] self._fp8_meta_index = fp8_meta_index - # Disable or enable transpose cache for fp8 weights - self._keep_fp8_weight_transpose_cache = keep_fp8_weight_transpose_cache @property def data(self) -> torch.Tensor: @@ -65,7 +60,6 @@ def __repr__(self): return ( f"FSDPAGTensor(" f"elem={self._data}, " - f"module={self._module.__class__.__name__}, " f"fp8_meta_index={self._fp8_meta_index})" ) @@ -75,18 +69,16 @@ def __tensor_flatten__(self): Return (names_of_inner_tensors, flatten_spec_metadata). """ # We only carry the one inner tensor. - # We store (module, fp8_meta_index, keep_fp8_weight_transpose_cache) as metadata to reconstruct. - return ["_data"], (self._module, self._fp8_meta_index, self._keep_fp8_weight_transpose_cache) + # We store fp8_meta_index as metadata to reconstruct. + return ["_data"], (self._fp8_meta_index) @staticmethod def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): - module, fp8_meta_index, keep_fp8_weight_transpose_cache = flatten_spec + fp8_meta_index = flatten_spec return FSDPAGTensor( inner_tensors["_data"], - module=module, fp8_meta_index=fp8_meta_index, - keep_fp8_weight_transpose_cache=keep_fp8_weight_transpose_cache ) @classmethod @@ -99,16 +91,16 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): t = args[0] assert isinstance(t, cls), f"Unexpected detach input type: {type(t)}" detached = t._data.detach() - return cls(detached, module=t._module, fp8_meta_index=t._fp8_meta_index, keep_fp8_weight_transpose_cache=t._keep_fp8_weight_transpose_cache) + return cls(detached, fp8_meta_index=t._fp8_meta_index) # Unwrap only our subclass; capture shared metadata for rewrapping - meta: Optional[tuple[nn.Module, str, bool]] = None + meta: Optional[str] = None def unwrap(x): nonlocal meta if isinstance(x, cls): if meta is None: - meta = (x._module, x._fp8_meta_index, x._keep_fp8_weight_transpose_cache) + meta = x._fp8_meta_index return x._data return x @@ -123,31 +115,48 @@ def unwrap(x): def rewrap(x): if isinstance(x, torch.Tensor): - mod, idx, keep_transpose = meta - return cls(x, module=mod, fp8_meta_index=idx, keep_fp8_weight_transpose_cache=keep_transpose) + return cls(x, fp8_meta_index=meta) return x out = pytree.tree_map_only(torch.Tensor, rewrap, out) return out # Must return (list_of_tensors_to_all_gather, user_metadata) - def fsdp_pre_all_gather(self, mesh): + def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, mp_policy): + """ + Functions FSDP2 calls before all-gather of the weights for both forward and backward passes. + Args: + mesh (torch.distributed.DeviceMesh): DeviceMesh used by FSDP2 + to shard the weights. + orig_size (torch.Size): Original size of the weight tensor. + contiguous_orig_stride (Tuple[int]): Original stride of the weight tensor. + module (FSDPModule): FSDP module. FSDP wrapped module wrapped using fully_shard + that contains this tensor. + mp_policy (MixedPrecisionPolicy): Mixed precision policy used by FSDP2. + + Returns: + sharded_tensors: Tuple[torch.Tensor, ...]: Tuple of tensors + that need to be all-gathered. + metadata: Tuple[Any]: Metadata needed for reconstructing the + tensor after all-gather. + """ + # pylint: disable=unused-argument # If metadata isn't initialized yet, we can't access the quantizers - if not self._module.fp8: - module_class_name = self._module.__class__.__name__ + if not module.fp8: + module_class_name = module.__class__.__name__ if "LayerNormMLP" in module_class_name: num_gemms = 2 else: # Linear, LayerNormLinear, etc. num_gemms = 1 - self._module.init_fp8_metadata(num_gemms=num_gemms) - if not self._module.fp8: - return (self._data,), (self._data.requires_grad,) + module.init_fp8_metadata(num_gemms=num_gemms) + if not module.fp8: + return (self._data,), (self._data.requires_grad, module) # Use the actual data base = self._data # Access the quantizer using fp8_meta_index - quantizer = self._module.quantizers["scaling_fwd"][self._fp8_meta_index] - if not isinstance(quantizer, MXFP8Quantizer) and not self._keep_fp8_weight_transpose_cache: + quantizer = module.quantizers["scaling_fwd"][self._fp8_meta_index] + if not isinstance(quantizer, MXFP8Quantizer): quantizer.set_usage(columnwise=False) if isinstance(quantizer, Float8CurrentScalingQuantizer): quantizer.with_amax_reduction = True @@ -157,8 +166,8 @@ def fsdp_pre_all_gather(self, mesh): rowwise_scale_inv = sharded_fp8_tensor._rowwise_scale_inv if quantizer.rowwise_usage else torch.empty(0, dtype=torch.uint8, device=base.device) columnwise_data = sharded_fp8_tensor._columnwise_data if quantizer.columnwise_usage else torch.empty(0, dtype=torch.uint8, device=base.device) columnwise_scale_inv = sharded_fp8_tensor._columnwise_scale_inv if quantizer.columnwise_usage else torch.empty(0, dtype=torch.uint8, device=base.device) - return (rowwise_data, rowwise_scale_inv, columnwise_data, columnwise_scale_inv, ), (base.requires_grad,) - return (sharded_fp8_tensor._data,), (base.requires_grad,) + return (rowwise_data, rowwise_scale_inv, columnwise_data, columnwise_scale_inv, ), (base.requires_grad, module) + return (sharded_fp8_tensor._data,), (base.requires_grad, module) def fsdp_post_all_gather( self, @@ -168,14 +177,26 @@ def fsdp_post_all_gather( *, out: Optional[torch.Tensor] = None, ): - (requires_grad, ) = metadata - if not self._module.fp8: + """ + Functions FSDP2 calls after all-gather of the weights for both forward and backward passes. + Args: + all_gather_outputs (Tuple[torch.Tensor, ...]): sharded_tensors sent out in fsdp_pre_all_gather from each rank + are all-gathered and received here as a tuple. + metadata (Any): metadata sent out in fsdp_pre_all_gather used for reconstructing the tensor. + param_dtype (torch.dtype): high precision dtype of the tensor. + out (Optional[torch.Tensor], optional): Preallocated output tensor. Defaults to None. + + Returns: + Tuple[Tensor, Tuple[torch.Tensor, ...]]: Allgathered tensor and tuple of internal tensors. + """ + (requires_grad, module) = metadata + if not module.fp8: (data,) = all_gather_outputs return data, all_gather_outputs # Retrieve the same quantizer you used in pre_all_gather - quantizer = self._module.quantizers["scaling_fwd"][self._fp8_meta_index] + quantizer = module.quantizers["scaling_fwd"][self._fp8_meta_index] shape = None - if not isinstance(quantizer, MXFP8Quantizer) and not self._keep_fp8_weight_transpose_cache: + if not isinstance(quantizer, MXFP8Quantizer): quantizer.set_usage(columnwise=False) if isinstance(quantizer, MXFP8Quantizer): (rowwise_data, rowwise_scale_inv, columnwise_data, columnwise_scale_inv,) = all_gather_outputs @@ -197,3 +218,14 @@ def fsdp_post_all_gather( out._scale_inv = 1 / quantizer.scale out._data = data return out, all_gather_outputs + + def __reduce_ex__(self, protocol: int) -> tuple: + """ + Custom pickling - unwrap to inner tensor + + During checkpointing, save just the underlying high-precision tensor. + FSDPAGTensor is a transient wrapper for FSDP2 communication - when the + model is loaded and FSDP2 is re-initialized, parameters get wrapped again. + """ + # Delegate to the inner tensor's serialization + return self._data.__reduce_ex__(protocol)