From 3ce31ff1731866438d748893893fc0f9baa1142c Mon Sep 17 00:00:00 2001 From: jafraustro Date: Thu, 3 Jul 2025 11:07:50 -0700 Subject: [PATCH 1/2] Enhance README and examples for Tensor Parallelism - Added installation instructions and example running commands to README.md. - Update files to have a better organization Signed-off-by: jafraustro --- distributed/tensor_parallelism/README.md | 22 ++++++- .../tensor_parallelism/fsdp_tp_example.py | 61 +++++++++---------- .../sequence_parallel_example.py | 36 +++++------ .../tensor_parallel_example.py | 48 +++++++-------- 4 files changed, 88 insertions(+), 79 deletions(-) diff --git a/distributed/tensor_parallelism/README.md b/distributed/tensor_parallelism/README.md index b49d1672e8..83efb5dc78 100644 --- a/distributed/tensor_parallelism/README.md +++ b/distributed/tensor_parallelism/README.md @@ -10,7 +10,25 @@ PyTorch native Tensor Parallel APIs, which include: More details about the PyTorch native Tensor Parallel APIs, please see PyTorch docs: https://pytorch.org/docs/stable/distributed.tensor.parallel.html -``` +## Installation + +```bash pip install -r requirements.txt -python example.py ``` + +## Running Examples + +You can run the examples using `torchrun` to launch distributed training: + +```bash +# Simple Tensor Parallel example +torchrun --nnodes=1 --nproc_per_node=4 tensor_parallel_mlp.py + +# Tensor Parallel with Sequence Parallel +torchrun --nnodes=1 --nproc_per_node=4 tensor_parallel_with_sequence_parallel.py + +# FSDP + Tensor Parallel with Llama2 model +torchrun --nnodes=1 --nproc_per_node=4 fsdp_with_tp_example.py +``` + +For more details, check the `run_examples.sh` script. diff --git a/distributed/tensor_parallelism/fsdp_tp_example.py b/distributed/tensor_parallelism/fsdp_tp_example.py index dbab48c1b8..305ee2cce7 100644 --- a/distributed/tensor_parallelism/fsdp_tp_example.py +++ b/distributed/tensor_parallelism/fsdp_tp_example.py @@ -1,34 +1,3 @@ -import sys -import os -import torch -import torch.distributed as dist -import torch.nn as nn -import torch.nn.functional as F - -from log_utils import rank_log, get_logger, verify_min_gpu_count - -# ---- GPU check ------------ -_min_gpu_count = 4 - -if not verify_min_gpu_count(min_gpus=_min_gpu_count): - print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.") - sys.exit() -# --------------------------- - -from llama2_model import Transformer, ModelArgs - -from torch.distributed.device_mesh import init_device_mesh -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed._tensor import Shard, Replicate -from torch.distributed.tensor.parallel import ( - parallelize_module, - ColwiseParallel, - RowwiseParallel, - PrepareModuleInput, - SequenceParallel -) - - """ This is the script to test 2D Parallel which combines Tensor/Sequence parallel with Fully Sharded Data Parallel (TP/SP + FSDP) on a example @@ -60,6 +29,36 @@ https://pytorch.org/tutorials/intermediate/TP_tutorial.html """ +import sys +import os +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F + +from log_utils import rank_log, get_logger, verify_min_gpu_count + +# ---- GPU check ------------ +_min_gpu_count = 4 + +if not verify_min_gpu_count(min_gpus=_min_gpu_count): + print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.") + sys.exit() +# --------------------------- + +from llama2_model import Transformer, ModelArgs + +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed._tensor import Shard, Replicate +from torch.distributed.tensor.parallel import ( + parallelize_module, + ColwiseParallel, + RowwiseParallel, + PrepareModuleInput, + SequenceParallel +) + tp_size = 2 logger = get_logger() diff --git a/distributed/tensor_parallelism/sequence_parallel_example.py b/distributed/tensor_parallelism/sequence_parallel_example.py index 3324d28d4a..d2eb00e517 100644 --- a/distributed/tensor_parallelism/sequence_parallel_example.py +++ b/distributed/tensor_parallelism/sequence_parallel_example.py @@ -1,3 +1,19 @@ +""" +This is the script to test Sequence Parallel(SP) on a toy model in a +Megetron-LM SPMD style. We show an E2E working flow from forward, +backward and optimization. + +We use the example of two `nn.Linear` layers with an element-wise `nn.RELU` +in between to show an example of sequence parallel, which was proposed in paper: + +https://arxiv.org/pdf/2205.05198.pdf. + +Like tensor parallel, we parallelize the first linear layer by column +and also parallelize the second linear layer by row. But the input in each rank +now is different so that we need one all-gather for input and one reduce-scatter +in the end of the second linear layer. +""" + import os import sys import torch @@ -22,28 +38,8 @@ sys.exit() # --------------------------- - from torch.distributed._tensor.device_mesh import init_device_mesh - - -""" -This is the script to test Sequence Parallel(SP) on a toy model in a -Megetron-LM SPMD style. We show an E2E working flow from forward, -backward and optimization. - -We use the example of two `nn.Linear` layers with an element-wise `nn.RELU` -in between to show an example of sequence parallel, which was proposed in paper: - -https://arxiv.org/pdf/2205.05198.pdf. - -Like tensor parallel, we parallelize the first linear layer by column -and also parallelize the second linear layer by row. But the input in each rank -now is different so that we need one all-gather for input and one reduce-scatter -in the end of the second linear layer. -""" - - class ToyModel(nn.Module): """MLP based model""" diff --git a/distributed/tensor_parallelism/tensor_parallel_example.py b/distributed/tensor_parallelism/tensor_parallel_example.py index 0b9c884507..877278beb0 100755 --- a/distributed/tensor_parallelism/tensor_parallel_example.py +++ b/distributed/tensor_parallelism/tensor_parallel_example.py @@ -1,29 +1,3 @@ -import os -import sys -import torch -import torch.nn as nn - -from torch.distributed.tensor.parallel import ( - parallelize_module, - ColwiseParallel, - RowwiseParallel, -) - -from log_utils import rank_log, get_logger, verify_min_gpu_count - -# ---- GPU check ------------ -_min_gpu_count = 2 - -if not verify_min_gpu_count(min_gpus=_min_gpu_count): - print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.") - sys.exit() -# --------------------------- - -from torch.distributed._tensor.device_mesh import init_device_mesh - - - - """ This is the script to test Tensor Parallel(TP) on a toy model in a Megetron-LM SPMD style. We show an E2E working flow from forward, @@ -55,6 +29,28 @@ Parallelism APIs in this example to show users how to use them. """ +import os +import sys +import torch +import torch.nn as nn +import torch.distributed as dist +from torch.distributed.tensor.parallel import ( + parallelize_module, + ColwiseParallel, + RowwiseParallel, +) +from log_utils import rank_log, get_logger, verify_min_gpu_count + +# ---- GPU check ------------ +_min_gpu_count = 2 + +if not verify_min_gpu_count(min_gpus=_min_gpu_count): + print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.") + sys.exit() +# --------------------------- + +from torch.distributed._tensor.device_mesh import init_device_mesh + class ToyModel(nn.Module): """MLP based model""" From cef0b4c69de32edfb876921c49ff976ed25713c1 Mon Sep 17 00:00:00 2001 From: jafraustro Date: Thu, 3 Jul 2025 14:27:43 -0700 Subject: [PATCH 2/2] Refactor tensor parallelism examples to use the accelerator API Signed-off-by: jafraustro --- .../tensor_parallelism/fsdp_tp_example.py | 31 ++++++++++--------- distributed/tensor_parallelism/log_utils.py | 6 ++-- .../tensor_parallelism/requirements.txt | 7 +---- .../sequence_parallel_example.py | 24 +++++++------- .../tensor_parallel_example.py | 27 ++++++++-------- 5 files changed, 47 insertions(+), 48 deletions(-) diff --git a/distributed/tensor_parallelism/fsdp_tp_example.py b/distributed/tensor_parallelism/fsdp_tp_example.py index 305ee2cce7..099db14aca 100644 --- a/distributed/tensor_parallelism/fsdp_tp_example.py +++ b/distributed/tensor_parallelism/fsdp_tp_example.py @@ -63,25 +63,26 @@ logger = get_logger() # understand world topology -_rank = int(os.environ["RANK"]) -_world_size = int(os.environ["WORLD_SIZE"]) +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) -print(f"Starting PyTorch 2D (FSDP + TP) example on rank {_rank}.") +print(f"Starting PyTorch 2D (FSDP + TP) example on rank {rank}.") assert ( - _world_size % tp_size == 0 -), f"World size {_world_size} needs to be divisible by TP size {tp_size}" + world_size % tp_size == 0 +), f"World size {world_size} needs to be divisible by TP size {tp_size}" # create a sharding plan based on the given world_size. -dp_size = _world_size // tp_size +dp_size = world_size // tp_size # Create a device mesh with 2 dimensions. # First dim is the data parallel dimension # Second dim is the tensor parallel dimension. -device_mesh = init_device_mesh("cuda", (dp_size, tp_size), mesh_dim_names=("dp", "tp")) +acc = torch.accelerator.current_accelerator() +device_mesh = init_device_mesh(acc.type, (dp_size, tp_size), mesh_dim_names=("dp", "tp")) -rank_log(_rank, logger, f"Device Mesh created: {device_mesh=}") +rank_log(rank, logger, f"Device Mesh created: {device_mesh=}") tp_mesh = device_mesh["tp"] dp_mesh = device_mesh["dp"] @@ -94,7 +95,7 @@ # create model and move it to GPU - init"cuda"_mesh has already mapped GPU ids. simple_llama2_config = ModelArgs(dim=256, n_layers=2, n_heads=16, vocab_size=32000) -model = Transformer.from_model_args(simple_llama2_config).to("cuda") +model = Transformer.from_model_args(simple_llama2_config).to(rank) # init model weights model.init_weights() @@ -152,28 +153,28 @@ # Init FSDP using the dp device mesh sharded_model = FSDP(model, device_mesh=dp_mesh, use_orig_params=True) -rank_log(_rank, logger, f"Model after parallelization {sharded_model=}\n") +rank_log(rank, logger, f"Model after parallelization {sharded_model=}\n") # Create an optimizer for the parallelized and sharded model. lr = 3e-3 -rank_log(_rank, logger, f"Creating AdamW optimizer with learning rate {lr}") +rank_log(rank, logger, f"Creating AdamW optimizer with learning rate {lr}") optimizer = torch.optim.AdamW(sharded_model.parameters(), lr=lr, foreach=True) # Training loop: # Perform a num of iterations of forward/backward # and optimizations for the sharded module. -rank_log(_rank, logger, "\nStarting 2D training...") +rank_log(rank, logger, "\nStarting 2D training...") num_iterations = 10 batch_size = 2 for i in range(num_iterations): # seeding with dp_rank to ensure identical inputs for TP groups torch.manual_seed(i + dp_rank) - inp = torch.randint(32000, (8, 256), device="cuda") + inp = torch.randint(32000, (8, 256), device=acc.type) output = sharded_model(inp) output.sum().backward() optimizer.step() - rank_log(_rank, logger, f"2D iter {i} complete") + rank_log(rank, logger, f"2D iter {i} complete") -rank_log(_rank, logger, "2D training successfully completed!") +rank_log(rank, logger, "2D training successfully completed!") diff --git a/distributed/tensor_parallelism/log_utils.py b/distributed/tensor_parallelism/log_utils.py index f16d46526d..c3cad1f093 100644 --- a/distributed/tensor_parallelism/log_utils.py +++ b/distributed/tensor_parallelism/log_utils.py @@ -17,6 +17,6 @@ def rank_log(_rank, logger, msg): def verify_min_gpu_count(min_gpus: int = 2) -> bool: """ verification that we have at least 2 gpus to run dist examples """ - has_cuda = torch.cuda.is_available() - gpu_count = torch.cuda.device_count() - return has_cuda and gpu_count >= min_gpus + has_acc = torch.accelerator.is_available() + gpu_count = torch.accelerator.device_count() + return has_acc and gpu_count >= min_gpus diff --git a/distributed/tensor_parallelism/requirements.txt b/distributed/tensor_parallelism/requirements.txt index 80fad36bf2..ccfd41377b 100644 --- a/distributed/tensor_parallelism/requirements.txt +++ b/distributed/tensor_parallelism/requirements.txt @@ -1,6 +1 @@ -# Python dependencies required for running the example - ---pre ---extra-index-url https://download.pytorch.org/whl/nightly/cu118 ---extra-index-url https://download.pytorch.org/whl/nightly/cu121 -torch >= 2.3.0.dev0; sys_platform == "linux" +torch>=2.8 diff --git a/distributed/tensor_parallelism/sequence_parallel_example.py b/distributed/tensor_parallelism/sequence_parallel_example.py index d2eb00e517..dc69dba6f5 100644 --- a/distributed/tensor_parallelism/sequence_parallel_example.py +++ b/distributed/tensor_parallelism/sequence_parallel_example.py @@ -60,18 +60,20 @@ def forward(self, x): logger = get_logger() # create a device mesh based on the given world_size. -device_mesh = init_device_mesh( - device_type="cuda", mesh_shape=(int(os.environ["WORLD_SIZE"]),) -) +world_size = int(os.environ["WORLD_SIZE"]) + +model = ToyModel() -_rank = device_mesh.get_rank() +device = torch.accelerator.current_accelerator() +device_mesh = init_device_mesh(device_type=device.type, mesh_shape=(world_size,)) +rank = device_mesh.get_rank() -print(f"Starting PyTorch Sequence Parallel example on rank {_rank}.") +print(f"Starting PyTorch Sequence Parallel example on rank {rank}.") -rank_log(_rank, logger, f"Device Mesh created: {device_mesh=}") +rank_log(rank, logger, f"Device Mesh created: {device_mesh=}") # create model and move it to GPU. Init_device_mesh has already assigned gpu ids... -model = ToyModel().to("cuda") +model = ToyModel().to(rank) # Custom parallelization plan for the model sp_model = parallelize_module( @@ -92,14 +94,14 @@ def forward(self, x): # Perform a num of iterations of forward/backward # and optimizations for the sharded module. num_iters = 10 -rank_log(_rank, logger, "Sequence Parallel training starting...") +rank_log(rank, logger, "Sequence Parallel training starting...") for i in range(num_iters): # For SP, input can be different across all ranks. - inp = torch.rand(20, 10, device="cuda") + inp = torch.rand(20, 10, device=f'{device}:{rank}') output = sp_model(inp) output.sum().backward() optimizer.step() - rank_log(_rank, logger, f"Sequence Parallel iter {i} completed") + rank_log(rank, logger, f"Sequence Parallel iter {i} completed") -rank_log(_rank, logger, "Sequence Parallel training completed!") +rank_log(rank, logger, "Sequence Parallel training completed!") diff --git a/distributed/tensor_parallelism/tensor_parallel_example.py b/distributed/tensor_parallelism/tensor_parallel_example.py index 877278beb0..601dcd34c4 100755 --- a/distributed/tensor_parallelism/tensor_parallel_example.py +++ b/distributed/tensor_parallelism/tensor_parallel_example.py @@ -71,21 +71,22 @@ def forward(self, x): logger = get_logger() # create a device mesh based on the given world_size. -_world_size = int(os.environ["WORLD_SIZE"]) +world_size = int(os.environ["WORLD_SIZE"]) -device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,)) -_rank = device_mesh.get_rank() +device = torch.accelerator.current_accelerator() +device_mesh = init_device_mesh(device_type=device.type, mesh_shape=(world_size,)) +rank = device_mesh.get_rank() -print(f"Starting PyTorch TP example on rank {_rank}.") +print(f"Starting PyTorch TP example on rank {rank}.") assert ( - _world_size % 2 == 0 -), f"TP examples require even number of GPUs, but got {_world_size} gpus" + world_size % 2 == 0 +), f"TP examples require even number of GPUs, but got {world_size} gpus" -rank_log(_rank, logger, f"Device Mesh created: {device_mesh=}") +rank_log(rank, logger, f"Device Mesh created: {device_mesh=}") -# create model and move it to GPU - init"cuda"_mesh has already mapped GPU ids. -tp_model = ToyModel().to("cuda") +# create model and move it to GPU - init_device_mesh has already mapped GPU ids. +tp_model = ToyModel().to(rank) # Custom parallelization plan for the model @@ -106,16 +107,16 @@ def forward(self, x): # Perform a num of iterations of forward/backward # and optimizations for the sharded module. num_iters = 10 -rank_log(_rank, logger, "Tensor Parallel training starting...") +rank_log(rank, logger, "Tensor Parallel training starting...") for i in range(num_iters): # For TP, input needs to be same across all TP ranks. # Setting the random seed is to mimic the behavior of dataloader. torch.manual_seed(i) - inp = torch.rand(20, 10, device="cuda") + inp = torch.rand(20, 10, device=f'{device}:{rank}') output = tp_model(inp) output.sum().backward() optimizer.step() - rank_log(_rank, logger, f"Tensor Parallel iter {i} completed") + rank_log(rank, logger, f"Tensor Parallel iter {i} completed") -rank_log(_rank, logger, "Tensor Parallel training completed!") +rank_log(rank, logger, "Tensor Parallel training completed!")