Skip to content

Commit

Permalink
Added Platform Check for init_process_group backend parameter initial…
Browse files Browse the repository at this point in the history
…ization
  • Loading branch information
iSiddharth20 committed Jan 2, 2024
1 parent 70742a1 commit 290f3a7
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions Code/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,24 @@
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
import platform

# Define Working Directories
grayscale_dir = '../Dataset/Greyscale'
rgb_dir = '../Dataset/RGB'

# Define Universal Parameters
image_height = 3000
image_width = 4500
image_height = 4000
image_width = 6000
batch_size = 2

def get_backend():
system_type = platform.system()
if system_type == "Linux":
return "nccl"
else:
return "gloo"

def main_worker(rank, world_size):
# Set environment variables
os.environ['MASTER_ADDR'] = 'localhost'
Expand All @@ -34,7 +42,7 @@ def main_worker(rank, world_size):
torch.manual_seed(0)
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
dist.init_process_group(backend="nccl", init_method="env://", world_size=world_size, rank=rank)
dist.init_process_group(backend=get_backend(), init_method="env://", world_size=world_size, rank=rank)
main(rank) # Call the existing main function.

def main(rank):
Expand Down

0 comments on commit 290f3a7

Please sign in to comment.