diff --git a/distributed/tensor_parallelism/fsdp_tp_example.py b/distributed/tensor_parallelism/fsdp_tp_example.py index a44a58ba0f..fb0d5ba1f5 100644 --- a/distributed/tensor_parallelism/fsdp_tp_example.py +++ b/distributed/tensor_parallelism/fsdp_tp_example.py @@ -173,3 +173,6 @@ rank_log(_rank, logger, f"2D iter {i} complete") rank_log(_rank, logger, "2D training successfully completed!") + +if dist.is_initialized(): + dist.destroy_process_group() diff --git a/distributed/tensor_parallelism/sequence_parallel_example.py b/distributed/tensor_parallelism/sequence_parallel_example.py index 988973af4b..73320f5bcc 100644 --- a/distributed/tensor_parallelism/sequence_parallel_example.py +++ b/distributed/tensor_parallelism/sequence_parallel_example.py @@ -22,6 +22,7 @@ import torch import torch.nn as nn +import torch.distributed as dist from torch.distributed._tensor import Shard from torch.distributed.tensor.parallel import ( @@ -107,3 +108,6 @@ def forward(self, x): rank_log(_rank, logger, f"Sequence Parallel iter {i} completed") rank_log(_rank, logger, "Sequence Parallel training completed!") + +if dist.is_initialized(): + dist.destroy_process_group() diff --git a/distributed/tensor_parallelism/tensor_parallel_example.py b/distributed/tensor_parallelism/tensor_parallel_example.py index c42a952ea8..6a4b4ea531 100755 --- a/distributed/tensor_parallelism/tensor_parallel_example.py +++ b/distributed/tensor_parallelism/tensor_parallel_example.py @@ -122,3 +122,6 @@ def forward(self, x): rank_log(_rank, logger, f"Tensor Parallel iter {i} completed") rank_log(_rank, logger, "Tensor Parallel training completed!") + +if dist.is_initialized(): + dist.destroy_process_group()