-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
46 lines (40 loc) · 1.36 KB
/
utils.py
File metadata and controls
46 lines (40 loc) · 1.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
import random
import numpy as np
import torch.distributed as dist
import os
# ----------------------------------------------------------------------
# DDP utils
# ----------------------------------------------------------------------
def setup_ddp():
"""Initialize torch.distributed using env vars set by torchrun."""
if dist.is_available() and not dist.is_initialized():
dist.init_process_group(backend="nccl", init_method="env://")
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
device = torch.device(f"cuda:{local_rank}")
return device, local_rank
def is_main_process() -> bool:
return not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0
def barrier():
if dist.is_available() and dist.is_initialized():
dist.barrier()
def set_seed(seed = 0):
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
def num_to_groups(num: int, divisor: int) -> list:
"""
Splits a number into groups of a given divisor.
Args:
num (int): The number to split.
divisor (int): The size of each group.
Returns:
list: A list containing the sizes of the groups.
"""
groups = num // divisor
remainder = num % divisor
arr = [divisor] * groups
if remainder > 0:
arr.append(remainder)
return arr