Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions distributed/tensor_parallelism/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,25 @@ PyTorch native Tensor Parallel APIs, which include:
More details about the PyTorch native Tensor Parallel APIs, please see PyTorch docs:
https://pytorch.org/docs/stable/distributed.tensor.parallel.html

```
## Installation

```bash
pip install -r requirements.txt
python example.py
```

## Running Examples

You can run the examples using `torchrun` to launch distributed training:

```bash
# Simple Tensor Parallel example
torchrun --nnodes=1 --nproc_per_node=4 tensor_parallel_mlp.py

# Tensor Parallel with Sequence Parallel
torchrun --nnodes=1 --nproc_per_node=4 tensor_parallel_with_sequence_parallel.py

# FSDP + Tensor Parallel with Llama2 model
torchrun --nnodes=1 --nproc_per_node=4 fsdp_with_tp_example.py
```

For more details, check the `run_examples.sh` script.
92 changes: 46 additions & 46 deletions distributed/tensor_parallelism/fsdp_tp_example.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,3 @@
import sys
import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F

from log_utils import rank_log, get_logger, verify_min_gpu_count

# ---- GPU check ------------
_min_gpu_count = 4

if not verify_min_gpu_count(min_gpus=_min_gpu_count):
print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.")
sys.exit()
# ---------------------------

from llama2_model import Transformer, ModelArgs

from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed._tensor import Shard, Replicate
from torch.distributed.tensor.parallel import (
parallelize_module,
ColwiseParallel,
RowwiseParallel,
PrepareModuleInput,
SequenceParallel
)


"""
This is the script to test 2D Parallel which combines Tensor/Sequence
parallel with Fully Sharded Data Parallel (TP/SP + FSDP) on a example
Expand Down Expand Up @@ -60,29 +29,60 @@
https://pytorch.org/tutorials/intermediate/TP_tutorial.html
"""

import sys
import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F

from log_utils import rank_log, get_logger, verify_min_gpu_count

# ---- GPU check ------------
_min_gpu_count = 4

if not verify_min_gpu_count(min_gpus=_min_gpu_count):
print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.")
sys.exit()
# ---------------------------

from llama2_model import Transformer, ModelArgs

from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed._tensor import Shard, Replicate
from torch.distributed.tensor.parallel import (
parallelize_module,
ColwiseParallel,
RowwiseParallel,
PrepareModuleInput,
SequenceParallel
)

tp_size = 2
logger = get_logger()

# understand world topology
_rank = int(os.environ["RANK"])
_world_size = int(os.environ["WORLD_SIZE"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])


print(f"Starting PyTorch 2D (FSDP + TP) example on rank {_rank}.")
print(f"Starting PyTorch 2D (FSDP + TP) example on rank {rank}.")
assert (
_world_size % tp_size == 0
), f"World size {_world_size} needs to be divisible by TP size {tp_size}"
world_size % tp_size == 0
), f"World size {world_size} needs to be divisible by TP size {tp_size}"


# create a sharding plan based on the given world_size.
dp_size = _world_size // tp_size
dp_size = world_size // tp_size

# Create a device mesh with 2 dimensions.
# First dim is the data parallel dimension
# Second dim is the tensor parallel dimension.
device_mesh = init_device_mesh("cuda", (dp_size, tp_size), mesh_dim_names=("dp", "tp"))
acc = torch.accelerator.current_accelerator()
device_mesh = init_device_mesh(acc.type, (dp_size, tp_size), mesh_dim_names=("dp", "tp"))

rank_log(_rank, logger, f"Device Mesh created: {device_mesh=}")
rank_log(rank, logger, f"Device Mesh created: {device_mesh=}")
tp_mesh = device_mesh["tp"]
dp_mesh = device_mesh["dp"]

Expand All @@ -95,7 +95,7 @@
# create model and move it to GPU - init"cuda"_mesh has already mapped GPU ids.
simple_llama2_config = ModelArgs(dim=256, n_layers=2, n_heads=16, vocab_size=32000)

model = Transformer.from_model_args(simple_llama2_config).to("cuda")
model = Transformer.from_model_args(simple_llama2_config).to(rank)

# init model weights
model.init_weights()
Expand Down Expand Up @@ -153,28 +153,28 @@
# Init FSDP using the dp device mesh
sharded_model = FSDP(model, device_mesh=dp_mesh, use_orig_params=True)

rank_log(_rank, logger, f"Model after parallelization {sharded_model=}\n")
rank_log(rank, logger, f"Model after parallelization {sharded_model=}\n")

# Create an optimizer for the parallelized and sharded model.
lr = 3e-3
rank_log(_rank, logger, f"Creating AdamW optimizer with learning rate {lr}")
rank_log(rank, logger, f"Creating AdamW optimizer with learning rate {lr}")
optimizer = torch.optim.AdamW(sharded_model.parameters(), lr=lr, foreach=True)

# Training loop:
# Perform a num of iterations of forward/backward
# and optimizations for the sharded module.
rank_log(_rank, logger, "\nStarting 2D training...")
rank_log(rank, logger, "\nStarting 2D training...")
num_iterations = 10
batch_size = 2

for i in range(num_iterations):
# seeding with dp_rank to ensure identical inputs for TP groups
torch.manual_seed(i + dp_rank)
inp = torch.randint(32000, (8, 256), device="cuda")
inp = torch.randint(32000, (8, 256), device=acc.type)

output = sharded_model(inp)
output.sum().backward()
optimizer.step()
rank_log(_rank, logger, f"2D iter {i} complete")
rank_log(rank, logger, f"2D iter {i} complete")

rank_log(_rank, logger, "2D training successfully completed!")
rank_log(rank, logger, "2D training successfully completed!")
6 changes: 3 additions & 3 deletions distributed/tensor_parallelism/log_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@ def rank_log(_rank, logger, msg):

def verify_min_gpu_count(min_gpus: int = 2) -> bool:
""" verification that we have at least 2 gpus to run dist examples """
has_cuda = torch.cuda.is_available()
gpu_count = torch.cuda.device_count()
return has_cuda and gpu_count >= min_gpus
has_acc = torch.accelerator.is_available()
gpu_count = torch.accelerator.device_count()
return has_acc and gpu_count >= min_gpus
7 changes: 1 addition & 6 deletions distributed/tensor_parallelism/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1 @@
# Python dependencies required for running the example

--pre
--extra-index-url https://download.pytorch.org/whl/nightly/cu118
--extra-index-url https://download.pytorch.org/whl/nightly/cu121
torch >= 2.3.0.dev0; sys_platform == "linux"
torch>=2.8
60 changes: 29 additions & 31 deletions distributed/tensor_parallelism/sequence_parallel_example.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
"""
This is the script to test Sequence Parallel(SP) on a toy model in a
Megetron-LM SPMD style. We show an E2E working flow from forward,
backward and optimization.

We use the example of two `nn.Linear` layers with an element-wise `nn.RELU`
in between to show an example of sequence parallel, which was proposed in paper:

https://arxiv.org/pdf/2205.05198.pdf.

Like tensor parallel, we parallelize the first linear layer by column
and also parallelize the second linear layer by row. But the input in each rank
now is different so that we need one all-gather for input and one reduce-scatter
in the end of the second linear layer.
"""

import os
import sys
import torch
Expand All @@ -22,28 +38,8 @@
sys.exit()
# ---------------------------


from torch.distributed._tensor.device_mesh import init_device_mesh



"""
This is the script to test Sequence Parallel(SP) on a toy model in a
Megetron-LM SPMD style. We show an E2E working flow from forward,
backward and optimization.

We use the example of two `nn.Linear` layers with an element-wise `nn.RELU`
in between to show an example of sequence parallel, which was proposed in paper:

https://arxiv.org/pdf/2205.05198.pdf.

Like tensor parallel, we parallelize the first linear layer by column
and also parallelize the second linear layer by row. But the input in each rank
now is different so that we need one all-gather for input and one reduce-scatter
in the end of the second linear layer.
"""


class ToyModel(nn.Module):
"""MLP based model"""

Expand All @@ -64,18 +60,20 @@ def forward(self, x):
logger = get_logger()

# create a device mesh based on the given world_size.
device_mesh = init_device_mesh(
device_type="cuda", mesh_shape=(int(os.environ["WORLD_SIZE"]),)
)
world_size = int(os.environ["WORLD_SIZE"])

model = ToyModel()

_rank = device_mesh.get_rank()
device = torch.accelerator.current_accelerator()
device_mesh = init_device_mesh(device_type=device.type, mesh_shape=(world_size,))
rank = device_mesh.get_rank()

print(f"Starting PyTorch Sequence Parallel example on rank {_rank}.")
print(f"Starting PyTorch Sequence Parallel example on rank {rank}.")

rank_log(_rank, logger, f"Device Mesh created: {device_mesh=}")
rank_log(rank, logger, f"Device Mesh created: {device_mesh=}")

# create model and move it to GPU. Init_device_mesh has already assigned gpu ids...
model = ToyModel().to("cuda")
model = ToyModel().to(rank)

# Custom parallelization plan for the model
sp_model = parallelize_module(
Expand All @@ -96,14 +94,14 @@ def forward(self, x):
# Perform a num of iterations of forward/backward
# and optimizations for the sharded module.
num_iters = 10
rank_log(_rank, logger, "Sequence Parallel training starting...")
rank_log(rank, logger, "Sequence Parallel training starting...")

for i in range(num_iters):
# For SP, input can be different across all ranks.
inp = torch.rand(20, 10, device="cuda")
inp = torch.rand(20, 10, device=f'{device}:{rank}')
output = sp_model(inp)
output.sum().backward()
optimizer.step()
rank_log(_rank, logger, f"Sequence Parallel iter {i} completed")
rank_log(rank, logger, f"Sequence Parallel iter {i} completed")

rank_log(_rank, logger, "Sequence Parallel training completed!")
rank_log(rank, logger, "Sequence Parallel training completed!")
Loading
Loading