Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 4 additions & 12 deletions distributed/rpc/ddp_rpc/README.md
Original file line number Diff line number Diff line change
@@ -1,18 +1,10 @@
Distributed DataParallel + Distributed RPC Framework Example

The example shows how to combine Distributed DataParallel with the Distributed
RPC Framework. There are two trainer nodes, 1 master node and 1 parameter
server in the example.
This example demonstrates how to combine Distributed DataParallel (DDP) with the Distributed RPC Framework. It requires two trainer nodes (each with a GPU), one master node, and one parameter server.

The master node creates an embedding table on the parameter server and drives
the training loop on the trainers. The model consists of a dense part
(nn.Linear) replicated on the trainers via Distributed DataParallel and a
sparse part (nn.EmbeddingBag) which resides on the parameter server. Each
trainer performs an embedding lookup on the parameter server (using the
Distributed RPC Framework) and then executes its local nn.Linear module.
During the backward pass, the gradients for the dense part are aggregated via
allreduce by DDP and the distributed backward pass updates the parameters for
the embedding table on the parameter server.
The master node initializes an embedding table on the parameter server and orchestrates the training loop across the trainers. The model is composed of a dense component (`nn.Linear`), which is replicated on the trainers using DDP, and a sparse component (`nn.EmbeddingBag`), which resides on the parameter server.

Each trainer performs embedding lookups on the parameter server via RPC, then processes the results through its local `nn.Linear` module. During the backward pass, DDP aggregates gradients for the dense part using allreduce, while the distributed backward pass updates the embedding table parameters on the parameter server.


```
Expand Down
30 changes: 24 additions & 6 deletions distributed/rpc/ddp_rpc/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
NUM_EMBEDDINGS = 100
EMBEDDING_DIM = 16

def verify_min_gpu_count(min_gpus: int = 2) -> bool:
""" verification that we have at least 2 gpus to run dist examples """
has_gpu = torch.accelerator.is_available()
gpu_count = torch.accelerator.device_count()
return has_gpu and gpu_count >= min_gpus

class HybridModel(torch.nn.Module):
r"""
Expand All @@ -24,15 +29,15 @@ class HybridModel(torch.nn.Module):
This remote model can get a Remote Reference to the embedding table on the parameter server.
"""

def __init__(self, remote_emb_module, device):
def __init__(self, remote_emb_module, rank):
super(HybridModel, self).__init__()
self.remote_emb_module = remote_emb_module
self.fc = DDP(torch.nn.Linear(16, 8).cuda(device), device_ids=[device])
self.device = device
self.fc = DDP(torch.nn.Linear(16, 8).to(rank))
self.rank = rank

def forward(self, indices, offsets):
emb_lookup = self.remote_emb_module.forward(indices, offsets)
return self.fc(emb_lookup.cuda(self.device))
return self.fc(emb_lookup.to(self.rank))


def _run_trainer(remote_emb_module, rank):
Expand Down Expand Up @@ -83,7 +88,7 @@ def get_next_batch(rank):
batch_size += 1

offsets_tensor = torch.LongTensor(offsets)
target = torch.LongTensor(batch_size).random_(8).cuda(rank)
target = torch.LongTensor(batch_size).random_(8).to(rank)
yield indices, offsets_tensor, target

# Train for 100 epochs
Expand Down Expand Up @@ -145,9 +150,13 @@ def run_worker(rank, world_size):
for fut in futs:
fut.wait()
elif rank <= 1:
acc = torch.accelerator.current_accelerator()
device = torch.device(acc)
backend = torch.distributed.get_default_backend_for_device(device)
torch.accelerator.device_index(rank)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how setting torch.accelerator.device_index(rank) works for "cpu"? why it's not under torch.accelerator.is_available()?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove CPU execution option since DDP requires 2 GPUs for this example.

# Initialize process group for Distributed DataParallel on trainers.
dist.init_process_group(
backend="gloo", rank=rank, world_size=2, init_method="tcp://localhost:29500"
backend=backend, rank=rank, world_size=2, init_method="tcp://localhost:29500"
)

# Initialize RPC.
Expand All @@ -172,9 +181,18 @@ def run_worker(rank, world_size):

# block until all rpcs finish
rpc.shutdown()

# Clean up process group for trainers to avoid resource leaks
if rank <= 1:
dist.destroy_process_group()


if __name__ == "__main__":
# 2 trainers, 1 parameter server, 1 master.
world_size = 4
_min_gpu_count = 2
if not verify_min_gpu_count(min_gpus=_min_gpu_count):
print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.")
exit()
mp.spawn(run_worker, args=(world_size,), nprocs=world_size, join=True)
print("Distributed RPC example completed successfully.")
3 changes: 2 additions & 1 deletion distributed/rpc/ddp_rpc/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
torch>=1.6.0
torch>=2.7.0
numpy
2 changes: 2 additions & 0 deletions distributed/rpc/parameter_server/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
torch>=2.7.1
numpy
36 changes: 21 additions & 15 deletions distributed/rpc/parameter_server/rpc_parameter_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,19 @@ def __init__(self, num_gpus=0):
super(Net, self).__init__()
print(f"Using {num_gpus} GPUs to train")
self.num_gpus = num_gpus
device = torch.device(
"cuda:0" if torch.cuda.is_available() and self.num_gpus > 0 else "cpu")
if torch.accelerator.is_available() and self.num_gpus > 0:
acc = torch.accelerator.current_accelerator()
device = torch.device(f'{acc}:0')
else:
device = torch.device("cpu")
print(f"Putting first 2 convs on {str(device)}")
# Put conv layers on the first cuda device
# Put conv layers on the first accelerator device
self.conv1 = nn.Conv2d(1, 32, 3, 1).to(device)
self.conv2 = nn.Conv2d(32, 64, 3, 1).to(device)
# Put rest of the network on the 2nd cuda device, if there is one
if "cuda" in str(device) and num_gpus > 1:
device = torch.device("cuda:1")
# Put rest of the network on the 2nd accelerator device, if there is one
if torch.accelerator.is_available() and self.num_gpus > 0:
acc = torch.accelerator.current_accelerator()
device = torch.device(f'{acc}:1')

print(f"Putting rest of layers on {str(device)}")
self.dropout1 = nn.Dropout2d(0.25).to(device)
Expand Down Expand Up @@ -72,21 +76,22 @@ def call_method(method, rref, *args, **kwargs):
# <foo_instance>.bar(arg1, arg2) on the remote node and getting the result
# back.


def remote_method(method, rref, *args, **kwargs):
args = [method, rref] + list(args)
return rpc.rpc_sync(rref.owner(), call_method, args=args, kwargs=kwargs)


# --------- Parameter Server --------------------
class ParameterServer(nn.Module):
def __init__(self, num_gpus=0):
super().__init__()
model = Net(num_gpus=num_gpus)
self.model = model
self.input_device = torch.device(
"cuda:0" if torch.cuda.is_available() and num_gpus > 0 else "cpu")

if torch.accelerator.is_available() and num_gpus > 0:
acc = torch.accelerator.current_accelerator()
self.input_device = torch.device(f'{acc}:0')
else:
self.input_device = torch.device("cpu")

def forward(self, inp):
inp = inp.to(self.input_device)
out = self.model(inp)
Expand All @@ -113,11 +118,9 @@ def get_param_rrefs(self):
param_rrefs = [rpc.RRef(param) for param in self.model.parameters()]
return param_rrefs


param_server = None
global_lock = Lock()


def get_parameter_server(num_gpus=0):
global param_server
# Ensure that we get only one handle to the ParameterServer.
Expand Down Expand Up @@ -197,8 +200,11 @@ def get_accuracy(test_loader, model):
model.eval()
correct_sum = 0
# Use GPU to evaluate if possible
device = torch.device("cuda:0" if model.num_gpus > 0
and torch.cuda.is_available() else "cpu")
if torch.accelerator.is_available() and model.num_gpus > 0:
acc = torch.accelerator.current_accelerator()
device = torch.device(f'{acc}:0')
else:
device = torch.device("cpu")
with torch.no_grad():
for i, (data, target) in enumerate(test_loader):
out = model(data)
Expand Down
3 changes: 2 additions & 1 deletion distributed/rpc/rnn/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
torch
torch>=2.7.1
numpy
10 changes: 6 additions & 4 deletions distributed/rpc/rnn/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,15 @@ def __init__(self, ntoken, ninp, dropout):
super(EmbeddingTable, self).__init__()
self.drop = nn.Dropout(dropout)
self.encoder = nn.Embedding(ntoken, ninp)
if torch.cuda.is_available():
self.encoder = self.encoder.cuda()
if torch.accelerator.is_available():
device = torch.accelerator.current_accelerator()
self.encoder = self.encoder.to(device)
nn.init.uniform_(self.encoder.weight, -0.1, 0.1)

def forward(self, input):
if torch.cuda.is_available():
input = input.cuda()
if torch.accelerator.is_available():
device = torch.accelerator.current_accelerator()
input = input.to(device)
return self.drop(self.encoder(input)).cpu()


Expand Down
10 changes: 10 additions & 0 deletions run_distributed_examples.sh
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,20 @@ function distributed_minGPT-ddp() {
uv run bash run_example.sh mingpt/main.py || error "minGPT example failed"
}

function distributed_rpc_ddp_rpc() {
uv run main.py || error "ddp_rpc example failed"
}

function distributed_rpc_rnn() {
uv run main.py || error "rpc_rnn example failed"
}

function run_all() {
run distributed/tensor_parallelism
run distributed/ddp
run distributed/minGPT-ddp
run distributed/rpc/ddp_rpc
run distributed/rpc/rnn
}

# by default, run all examples
Expand Down
2 changes: 1 addition & 1 deletion utils.sh
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ function run() {
if start $EXAMPLE; then
# drop trailing slash (occurs due to auto completion in bash interactive mode)
# replace slashes with underscores: this allows to call nested examples
EXAMPLE_FN=$(echo $EXAMPLE | sed "s@/\$@@" | sed 's@/@_@')
EXAMPLE_FN=$(echo $EXAMPLE | sed "s@/\$@@" | sed 's@/@_@g')
$EXAMPLE_FN
fi
stop $EXAMPLE
Expand Down