From 290f3a7e93a5bbc50a77fc53f85b255cbdc01a92 Mon Sep 17 00:00:00 2001 From: iSiddharth20 Date: Mon, 1 Jan 2024 18:15:07 -0800 Subject: [PATCH] Added Platform Check for init_process_group backend parameter initialization --- Code/main.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/Code/main.py b/Code/main.py index e09b6aa..883d85b 100644 --- a/Code/main.py +++ b/Code/main.py @@ -16,16 +16,24 @@ import torch import torch.multiprocessing as mp import torch.distributed as dist +import platform # Define Working Directories grayscale_dir = '../Dataset/Greyscale' rgb_dir = '../Dataset/RGB' # Define Universal Parameters -image_height = 3000 -image_width = 4500 +image_height = 4000 +image_width = 6000 batch_size = 2 +def get_backend(): + system_type = platform.system() + if system_type == "Linux": + return "nccl" + else: + return "gloo" + def main_worker(rank, world_size): # Set environment variables os.environ['MASTER_ADDR'] = 'localhost' @@ -34,7 +42,7 @@ def main_worker(rank, world_size): torch.manual_seed(0) torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True - dist.init_process_group(backend="nccl", init_method="env://", world_size=world_size, rank=rank) + dist.init_process_group(backend=get_backend(), init_method="env://", world_size=world_size, rank=rank) main(rank) # Call the existing main function. def main(rank):