Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 36 additions & 28 deletions tests/pytorch/distributed/run_fsdp2_fp8_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,10 @@ def _parse_args(argv=None, namespace=None):
parser.add_argument("--output-size", type=int, default=2048, help="Output size for the model")
parser.add_argument("--batch-size", type=int, default=2048, help="Output size for the model")
parser.add_argument(
"--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8."
"--quantized-init", action="store_true", default=False, help="Initialize primary weights in FP8 via quantized_model_init."
)
parser.add_argument(
"--autocast", action="store_true", default=False, help="Enable te.autocast for FP8 compute."
)
parser.add_argument(
"--iter", type=int, default=10, help="Number of iterations for forward pass"
Expand Down Expand Up @@ -169,14 +172,32 @@ def _train(args):

if args.memory_profile:
torch.cuda.memory._record_memory_history(enabled='all', context='all', stacks='all')
if args.fp8_init:
# Build the model with the specified context
with quantized_model_init(enabled=True):
model = SimpleNet(args.input_size, args.hidden_size, args.output_size, use_fsdp2=args.use_fsdp2)
else:

prof = None
if (
args.profile
and torch.distributed.get_rank() in args.profile_ranks
):
prof = torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
schedule=torch.profiler.schedule(
wait=max(args.profile_step_start - 1, 0),
warmup=1 if args.profile_step_start > 0 else 0,
active=args.profile_step_end - args.profile_step_start,
repeat=1,
),
on_trace_ready=torch.profiler.tensorboard_trace_handler(args.tensorboard_dir),
record_shapes=True,
profile_memory=True,
with_stack=True,
)
prof.start()

# Build the model with the specified context
with quantized_model_init(enabled=args.quantized_init, recipe=fp8_recipe):
model = SimpleNet(args.input_size, args.hidden_size, args.output_size, use_fsdp2=args.use_fsdp2)
# Move the model to the correct device
if not args.memory_profile:
if not args.memory_profile and not args.profile:
model.load_state_dict(torch.load('fsdp_model.pth'))
model.to(device)

Expand Down Expand Up @@ -215,7 +236,10 @@ def _train(args):
else:
model = DDP(model, device_ids=[LOCAL_RANK])

optimizer = te.optimizers.FusedAdam(model.parameters(), lr=1e-3)
if args.quantized_init:
optimizer = te.optimizers.FusedAdam(model.parameters(), lr=1e-3, master_weights=True, use_decoupled_grad=True)
else:
optimizer = te.optimizers.FusedAdam(model.parameters(), lr=1e-3)

input_path = Path("shared_input.pt")
if input_path.exists():
Expand All @@ -226,25 +250,6 @@ def _train(args):
print("Generated and saved shared input tensor.")

out_tensors = []
prof = None
if (
args.profile
and torch.distributed.get_rank() in args.profile_ranks
):
prof = torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
schedule=torch.profiler.schedule(
wait=max(args.profile_step_start - 1, 0),
warmup=1 if args.profile_step_start > 0 else 0,
active=args.profile_step_end - args.profile_step_start,
repeat=1,
),
on_trace_ready=torch.profiler.tensorboard_trace_handler(args.tensorboard_dir),
record_shapes=True,
profile_memory=True,
with_stack=True,
)
prof.start()
for iteration in range(args.iter):
if LOCAL_RANK == 0:
print(f"Starting iteration...{iteration}")
Expand All @@ -253,7 +258,7 @@ def _train(args):

# Zero the parameter gradients
optimizer.zero_grad()
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does with te.fp8_autocast(enabled=args.fp8_autocast,.. ) do the same?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does do the same but since with TEv2.10, te.fp8_autocast is replaced with te.autocast, I've made the change to be consistent.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So will 'with te.autocast(enabled=args.fp8_autocast, recipe=...)' do the same as if/else?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it should. I'll make the changes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

with te.autocast(enabled=args.autocast, recipe=fp8_recipe):
output = model(input_data)
target = torch.randn(args.batch_size, args.output_size).to(device)
loss = F.mse_loss(output, target)
Expand Down Expand Up @@ -286,6 +291,9 @@ def _train(args):
torch.save(out_tensors, args.gradients_save_file)

if args.memory_profile:
with open('memory_summary.txt', 'w') as f:
f.write(torch.cuda.memory_summary(device=None, abbreviated=False))

snapshot = torch.cuda.memory._snapshot()
import pickle
with open('memory_snapshot.pickle', 'wb') as f:
Expand Down
92 changes: 68 additions & 24 deletions tests/pytorch/distributed/test_torch_fsdp2_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,49 +17,82 @@

NUM_PROCS: int = torch.cuda.device_count()

def assertEqual(
l1: List[torch.Tensor], l2: List[torch.Tensor]) -> bool:
"""Ensures two lists are exactly equal."""
def assert_allclose(
l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float, rtol: float = None
) -> bool:
"""Ensures two lists are equal."""
assert len(l1) == len(l2), "Unequal number of outputs."
tols = dict(atol=atol)
tols["rtol"] = rtol if rtol is not None else 0
tol = tols["atol"] + (tols["rtol"] * torch.abs(l2))
for i, (t1, t2) in enumerate(zip(l1, l2)):
result = torch.allclose(t1, t2, atol=0, rtol=0)
result = torch.allclose(t1, t2, **tols)
if not result:
diff = torch.abs(t1 - t2)
exceed_mask = diff > 0
if exceed_mask.any():
indices = torch.nonzero(exceed_mask, as_tuple=True)
max_diff = diff[exceed_mask].max()
max_idx = (diff[exceed_mask] == max_diff).nonzero(as_tuple=True)[0][0]
max_location = [idx[max_idx].item() for idx in indices]
if diff.dim() == 0:
max_diff = diff
max_location = []
msg = (
f"Outputs not close enough in tensor at idx={i}. "
f"Maximum difference at location {max_location} "
f"with {t1[exceed_mask][max_idx].item()} vs {t2[exceed_mask][max_idx].item()} "
f"(diff {max_diff.item()})."
f"Outputs not close enough in scalar tensor at idx={i}. "
f"Difference: {max_diff.item()}."
)
else:
exceed_mask = diff > tol

if exceed_mask.any():
indices = torch.nonzero(exceed_mask, as_tuple=True)
max_diff = diff[exceed_mask].max()
max_idx = (diff[exceed_mask] == max_diff).nonzero(as_tuple=True)[0][0]
max_location = [idx[max_idx].item() for idx in indices]
msg = (
f"Outputs not close enough in tensor at idx={i}. "
f"Maximum difference at location {max_location} "
f"with {t1[exceed_mask][max_idx].item()} vs {t2[exceed_mask][max_idx].item()} "
f"(diff {max_diff.item()})."
)
raise AssertionError(msg)

def _run_test(fp_init, recipe):
def _run_test(quantized_init, autocast, recipe):
test_dir = Path(__file__).parent.resolve()
fsdp_script = test_dir / "run_fsdp2_fp8_model.py"

test_cmd = ["torchrun", f"--nproc_per_node={NUM_PROCS}", "--master-port=29501", str(fsdp_script)]

if fp_init:
test_cmd += ["--fp8-init"]
test_cmd += ["--recipe", recipe]
if quantized_init:
test_cmd += ["--quantized-init"]
if autocast:
test_cmd += ["--autocast"]
if autocast or quantized_init:
test_cmd += ["--recipe", recipe]

subprocess.run(test_cmd + ['--use-fsdp2','--gradients-save-file', 'all_iters_fsdp2.pt'], env=os.environ, check=True)
subprocess.run(test_cmd + ['--gradients-save-file', 'all_iters_dp.pt'], env=os.environ, check=True)

# Load outputs
output_fsdp = torch.load("all_iters_fsdp2.pt", map_location="cpu")
output_dp = torch.load("all_iters_dp.pt", map_location="cpu")
atol = 0
rtol = 0
# Use relaxed tolerance when FSDP2 and DDP are not guaranteed to be bit-identical:
#
# - quantized_init=True: After each optimizer step, FP8 weights are re-quantized
# from FP32 master weights. Hence we use a relaxed tolerance.
#
# - No FP8 (quantized_init=False, autocast=False): gradient reduction order differs
# (all-reduce vs reduce-scatter), so float non-associativity produces last-bit
# differences in the reduced gradients and updated weights. Hence we use a relaxed tolerance.
#
# When autocast=True and quantized_init=False, FP8 quantization happens after the
# FSDP2 AllGather reconstructs the full weight, so both paths compute identical
# scales and produce bit-identical FP8 GEMMs — strict tolerance (0) is used.
if quantized_init or (not quantized_init and not autocast):
atol = 1e-6
rtol = 5e-5

for idx, (te_output_no_cache, te_output_cache) in enumerate(zip(output_fsdp, output_dp)):

print(f"Comparing FSDP {te_output_no_cache[0]}, DDP {te_output_cache[0]} at index {idx}...")
assertEqual(te_output_no_cache[1], te_output_cache[1]) # expects exact match
assert_allclose(te_output_no_cache[1], te_output_cache[1], atol=atol, rtol=rtol)
print(f"Tensor at index {idx} passed comparison.")


Expand All @@ -70,13 +103,24 @@ def cleanup_artifacts():
if os.path.exists(fname):
os.remove(fname)

# Define test cases explicitly
test_cases = []
# All FP8 enabled cases (all recipes)
for quantized_init in [True, False]:
for autocast in [True, False]:
if quantized_init or autocast:
for recipe in ["delayed", "current", "mxfp8"]:
test_cases.append((quantized_init, autocast, recipe))
# FP8 disabled case (only once)
test_cases.append((False, False, "delayed"))


@pytest.mark.skipif(NUM_PROCS < 4, reason="Requires 4+ GPUs")
@pytest.mark.skipif(NUM_PROCS % 2 != 0, reason="Requires even number of GPUs")
@pytest.mark.skipif(not torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+")
@pytest.mark.parametrize("fp8_init", ([False]))
@pytest.mark.parametrize("recipe", (["delayed", "current", "mxfp8"]))
@pytest.mark.parametrize("quantized_init, autocast, recipe", test_cases)
@pytest.mark.usefixtures("cleanup_artifacts")
def test_distributed(fp8_init, recipe):
def test_distributed(quantized_init, autocast, recipe):

batch_size = 2048
input_size = 2048
Expand All @@ -96,12 +140,12 @@ def test_distributed(fp8_init, recipe):
if torch.cuda.device_count() < 4:
pytest.skip("FSDP2 test requires at least 4 GPUs")

if fp8_init and not fp8_available:
if quantized_init and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)

_run_test(fp8_init, recipe)
_run_test(quantized_init, autocast, recipe)


def test_dummy() -> None:
Expand Down
5 changes: 2 additions & 3 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torch.nn.functional as F
from torch.distributed.tensor import DTensor
from torch.utils.cpp_extension import IS_HIP_EXTENSION
from torch.distributed.tensor import DTensor

import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
Expand Down Expand Up @@ -1335,9 +1336,7 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None:
self.keep_fp8_weight_transpose_cache = False
param = FSDPAGTensor(
param,
module=self,
fp8_meta_index=fp8_meta_index,
keep_fp8_weight_transpose_cache=self.keep_fp8_weight_transpose_cache
fp8_meta_index=fp8_meta_index,
)

# Redo parameter wrap in case we broke it above
Expand Down
30 changes: 27 additions & 3 deletions transformer_engine/pytorch/optimizers/fused_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@
import warnings

import torch
from torch.distributed._tensor import DTensor
from transformer_engine.pytorch.tensor.fsdp2_allgather_tensor import FSDPAGTensor
import transformer_engine_torch as tex
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer
from transformer_engine.pytorch.quantized_tensor import QuantizedTensor
from .multi_tensor_apply import multi_tensor_applier
from torch.utils.cpp_extension import IS_HIP_EXTENSION
from transformer_engine.pytorch.utils import is_fp8_fnuz
Expand Down Expand Up @@ -375,10 +378,21 @@ def _initialize_state(
store_param_remainders (bool): Store only trailing remainder bits.
"""
dtype = self.name_to_dtype_map[state_name]
# (Upstream fix https://github.com/NVIDIA/TransformerEngine/commit/139c863f92420271bbae2cbce49d9b170b7d03f9)
# Extract local tensor from DTensor (e.g. from FSDP2) to avoid
# QuantizedTensor.__torch_dispatch__ ignoring the dtype kwarg in
# torch.empty_like, and to ensure optimizer states are plain tensors.
local_param = param._local_tensor if isinstance(param, DTensor) else param
# ROCm fix: FSDPAGTensor is a wrapper around a plain tensor, so we need to extract the underlying tensor.
local_param = local_param._data if isinstance(local_param, FSDPAGTensor) else local_param
# Handle QuantizedTensor by dequantizing first
param_for_empty = (
local_param.dequantize() if isinstance(local_param, QuantizedTensor) else local_param
)
if store_param_remainders:
data = torch.zeros(param.shape, dtype=torch.int16, device=param.device)
data = torch.zeros_like(param_for_empty, dtype=torch.int16)
else:
data = torch.empty(param.shape, dtype=dtype, device=param.device)
data = torch.empty_like(param_for_empty, dtype=dtype)
if zero_buffer:
data.zero_()

Expand Down Expand Up @@ -419,7 +433,17 @@ def initialize_state(self, param, store_param_remainders):
store_param_remainders=store_param_remainders,
)
if not store_param_remainders:
self.set_scaled_state(param, "master_param", param.clone().detach().float())
# (Upstream fix https://github.com/NVIDIA/TransformerEngine/commit/139c863f92420271bbae2cbce49d9b170b7d03f9)
#Extract local tensor from DTensor and dequantize QuantizedTensor
# to get a plain float32 copy for the master weight.
local_param = param._local_tensor if isinstance(param, DTensor) else param
# ROCm fix: FSDPAGTensor is a wrapper around a plain tensor, so we need to extract the underlying tensor.
local_param = local_param._data if isinstance(local_param, FSDPAGTensor) else local_param
if isinstance(local_param, QuantizedTensor):
master = local_param.dequantize(dtype=torch.float32).clone().detach()
else:
master = local_param.clone().detach().float()
self.set_scaled_state(param, "master_param", master)

def state_dict(self):
"""Override the state_dict() of pytorch. Before returning the state_dict, cast all
Expand Down
Loading