Skip to content

Commit

Permalink
Updated Distributed Training and Added Inference Module (#26)
Browse files Browse the repository at this point in the history
* Displays Best Model Stats (Epoch Num, Train Loss, Val Loss) for all methods at the end of program

* Added DistributedSampler for PyTorch DDP

* Made SSIM Regularization Faster

* Added Learning Rate Scheduler to Model Training

* Display Current Learning Rate along with Epoch Number and Losses

* Rounded Off Loss Values to 10 Decimal Places

* Corrected initialization of optimizers in trainer objects

* Fixed Parameters for LR Scheduler

* Experimenting with SGD Optimizer

* Updated Optimizers, LR Schedulers for Each Method

* Adam for AutoEncoder, SGD with Momentum for LSTM

* Added Inference Module
  • Loading branch information
iSiddharth20 authored Jan 11, 2024
1 parent 9dae2cd commit 797933e
Show file tree
Hide file tree
Showing 6 changed files with 260 additions and 46 deletions.
148 changes: 148 additions & 0 deletions Code/GenerateResults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
'''
Generate Results from Trained Models
'''

# Import Necessary Libraries
import platform
import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.multiprocessing import Process
from PIL import Image
from torchvision import transforms
import glob
import shutil

# Import Model Definations
from autoencoder_model import Grey2RGBAutoEncoder
from lstm_model import ConvLSTM

# Define Universal Variables
image_width = 1280
image_height = 720

# Define Backend for Distributed Computing
def get_backend():
system_type = platform.system()
if system_type == "Linux":
return "nccl"
else:
return "gloo"

# Function to initialize the process group for distributed computing
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group(backend=get_backend(), rank=rank, world_size=world_size)
torch.cuda.set_device(rank)

# Function to clean up the process group after computation
def cleanup():
dist.destroy_process_group()

# The function to load your models
def load_model(model, model_path, device):
map_location = lambda storage, loc: storage.cuda(device)
state_dict = torch.load(model_path, map_location=map_location)
new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
model.load_state_dict(new_state_dict)
# Move the model to the device and wrap the model with DDP after its state_dict has been loaded
model = model.to(device)
model = DDP(model, device_ids=[device])
return model

# Define the function to save images
def save_images(img_seq, img_dir, global_start_idx):
to_pil = transforms.ToPILImage()
for i, image_tensor in enumerate(img_seq):
global_idx = global_start_idx + i # Calculate the global index
image = to_pil(image_tensor.cpu())
image.save(f'{img_dir}/image_{global_idx:04d}.tif')

def reorder_and_save_images(img_exp_dir, output_dir):
image_paths = glob.glob(os.path.join(img_exp_dir, 'image_*.tif'))
sorted_image_paths = sorted(image_paths, key=lambda x: int(os.path.basename(x).split('_')[1].split('.')[0]))
for i, img_path in enumerate(sorted_image_paths):
img = Image.open(img_path)
img.save(os.path.join(output_dir, f'enhanced_sequence_{i:04d}.tif'))

# Define the Transformation
transform = transforms.Compose([
transforms.Resize((image_height, image_width)),
transforms.Grayscale(), # Convert the images to grayscale
transforms.ToTensor(),
])

# The main function that will be executed by each process
def enhance(rank, world_size, img_inp_dir, img_exp_dir, lstm_path, autoencoder_path):
setup(rank, world_size)
lstm_model = ConvLSTM(input_dim=1, hidden_dims=[1, 1, 1], kernel_size=(3, 3), num_layers=3, alpha=0.6)
lstm = load_model(lstm_model, lstm_path, rank)
lstm.eval()
autoencoder_model = Grey2RGBAutoEncoder()
autoencoder = load_model(autoencoder_model, autoencoder_path, rank)
autoencoder.eval()
image_files = os.listdir(img_inp_dir)
per_gpu = (len(image_files) + world_size - 1) // world_size
start_idx = rank * per_gpu
end_idx = min(start_idx + per_gpu, len(image_files))
global_start_idx = start_idx
local_images = [Image.open(os.path.join(img_inp_dir, image_files[i])) for i in range(start_idx, end_idx)]
local_tensors = torch.stack([transform(image) for image in local_images]).unsqueeze(0).to(rank)
with torch.no_grad():
local_output_sequence, _ = lstm(local_tensors)
local_output_sequence = local_output_sequence.squeeze(0)
# Interleave the input and output images
interleaved_sequence = torch.stack([t for pair in zip(local_tensors.squeeze(0), local_output_sequence) for t in pair])
with torch.no_grad():
local_output_enhanced = torch.stack([autoencoder(t.unsqueeze(0)) for t in interleaved_sequence]).squeeze(1)
save_images(local_output_enhanced, img_exp_dir, global_start_idx)
cleanup()


if __name__ == "__main__":
world_size = torch.cuda.device_count()
# Input Sequence Directory (All Methods)
img_sequence_inp_dir = r'../Dataset/Inference/InputSequence'
# Intermediate Results will be Stored in this Directory which later wll be re-ordered (All Methods)
temp_dir = r'../Dataset/Inference/OutputSequence/Temp'
os.makedirs(temp_dir, exist_ok=True)

'''Working Directories for (Method-1)'''
autoencoder_path = r'../Models/Method1/model_autoencoder_m1.pth'
lstm_path = r'../Models/Method1/model_lstm_m1.pth'
img_sequence_out_dir = r'../Dataset/Inference/OutputSequence/Method1/'
os.makedirs(img_sequence_out_dir, exist_ok=True)

'''Working Directories for (Method-2)'''
# autoencoder_path = r'../Models/Method2/model_autoencoder_m2.pth'
# lstm_path = r'../Models/Method1/model_lstm_m1.pth'
# img_sequence_out_dir = r'../Dataset/Inference/OutputSequence/Method2/'
# os.makedirs(img_sequence_out_dir, exist_ok=True)

'''Working Directories for (Method-3)'''
# autoencoder_path = r'../Models/Method1/model_autoencoder_m1.pth'
# lstm_path = r'../Models/Method3/model_lstm_m3.pth'
# img_sequence_out_dir = r'../Dataset/Inference/OutputSequence/Method3/'
# os.makedirs(img_sequence_out_dir, exist_ok=True)

'''Working Directories for (Method-4)'''
# autoencoder_path = r'../Models/Method2/model_autoencoder_m2.pth'
# lstm_path = r'../Models/Method3/model_lstm_m3.pth'
# img_sequence_out_dir = r'../Dataset/Inference/OutputSequence/Method4/'
# os.makedirs(img_sequence_out_dir, exist_ok=True)

processes = []
for rank in range(world_size):
p = Process(target=enhance, args=(rank, world_size, img_sequence_inp_dir, temp_dir, lstm_path, autoencoder_path))
p.start()
processes.append(p)
for p in processes:
p.join()

# Reorder images once processing by all GPUs is complete
reorder_and_save_images(temp_dir, img_sequence_out_dir)
# Delete all Intermediate Results
shutil.rmtree(temp_dir)

10 changes: 4 additions & 6 deletions Code/autoencoder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,20 @@ class Grey2RGBAutoEncoder(nn.Module):
def __init__(self):
super(Grey2RGBAutoEncoder, self).__init__()
# Define the Encoder
self.encoder = self._make_layers([1, 3, 6, 12, 24])
self.encoder = self._make_layers([1, 4, 8, 16, 32])
# Define the Decoder
self.decoder = self._make_layers([24, 12, 6, 3], decoder=True)
self.decoder = self._make_layers([32, 16, 8, 4, 3], decoder=True)

# Helper function to create the encoder or decoder layers.
def _make_layers(self, channels, decoder=False):
layers = []
for i in range(len(channels) - 1):
if decoder:
layers += [nn.ConvTranspose2d(channels[i], channels[i+1], kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(channels[i+1]),
nn.LeakyReLU(inplace=True)]
nn.ReLU(inplace=True)]
else:
layers += [nn.Conv2d(channels[i], channels[i+1], kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(channels[i+1]),
nn.LeakyReLU(inplace=True)]
nn.ReLU(inplace=True)]
if decoder:
layers[-1] = nn.Sigmoid()
return nn.Sequential(*layers)
Expand Down
14 changes: 9 additions & 5 deletions Code/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch.utils.data import DataLoader, Dataset, random_split
import torchvision.transforms as transforms
import torch
import os
from torch.utils.data.distributed import DistributedSampler

# Allow loading of truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True
Expand Down Expand Up @@ -66,8 +66,10 @@ def get_autoencoder_batches(self, val_split, batch_size):
# Split the dataset into training and validation sets
train_dataset, val_dataset = random_split(self, [train_size, val_size])
# Create dataloaders for the training and validation sets
train_loader = DataLoader(train_dataset, batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size, shuffle=True)
train_sampler = DistributedSampler(train_dataset)
val_sampler = DistributedSampler(val_dataset)
train_loader = DataLoader(train_dataset, batch_size=batch_size, pin_memory=True, sampler=train_sampler)
val_loader = DataLoader(val_dataset, batch_size=batch_size, pin_memory=True, sampler=val_sampler)
# Return the training and validation dataloaders
return train_loader, val_loader

Expand All @@ -92,8 +94,10 @@ def get_lstm_batches(self, val_split, sequence_length, batch_size):
train_dataset.append((sequence_input_train, sequence_target_train))
val_dataset.append((sequence_input_val, sequence_target_val))
# Create the data loaders for training and validation datasets
train_loader = DataLoader(train_dataset, batch_size, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size, shuffle=False)
train_sampler = DistributedSampler(train_dataset)
val_sampler = DistributedSampler(val_dataset)
train_loader = DataLoader(train_dataset, batch_size=batch_size, pin_memory=True, sampler=train_sampler)
val_loader = DataLoader(val_dataset, batch_size=batch_size, pin_memory=True, sampler=val_sampler)
return train_loader, val_loader

def transform_sequence(self, filenames, lstm=False):
Expand Down
15 changes: 6 additions & 9 deletions Code/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
Class for Composite Loss with Maximum Entropy Principle Regularization Term
'''
class LossMEP(nn.Module):
def __init__(self, alpha=0.5):
def __init__(self, alpha=0.1):
super(LossMEP, self).__init__()
self.alpha = alpha # Weighting factor for total variation loss

Expand All @@ -28,9 +28,9 @@ def forward(self, output, target):
torch.sum(torch.abs(output[:, :, :, :-1] - output[:, :, :, 1:]))
tv_loss /= batch_size * height * width # Normalize by total size
# Composite loss
loss = mse_loss + self.alpha * tv_loss
combined_loss = (1 - self.alpha) * mse_loss + self.alpha * tv_loss
# Return the composite loss
return loss
return combined_loss

'''
Class for Mean Squared Error (MSE) Loss
Expand All @@ -45,24 +45,21 @@ def forward(self, output, target):
- In PyTorch, loss is minimized, by doing 1 - SSIM, minimizing the loss function will lead to maximization of SSIM
'''
class SSIMLoss(nn.Module):
def __init__(self, alpha=0.5):
def __init__(self, alpha=0.1):
super(SSIMLoss, self).__init__()
self.alpha = alpha
self.ssim_module = SSIM(data_range=1, size_average=True, channel=1)

def forward(self, seq1, seq2):
N, T = seq1.shape[:2]
ssim_values = []
mse_values = []
for i in range(N):
for t in range(T):
seq1_slice = seq1[i, t:t+1, ...]
seq2_slice = seq2[i, t:t+1, ...]
ssim_val = self.ssim_module(seq1_slice, seq2_slice)
mse_val = F.mse_loss(seq1_slice, seq2_slice)
ssim_values.append(ssim_val) # Compute SSIM for each frame in the sequence
mse_values.append(mse_val) # Compute MSE for each frame in the sequence
avg_ssim = torch.stack(ssim_values).mean() # Average SSIM across all frames
avg_mse = torch.stack(mse_values).mean() # Average MSE across all frames
combined_loss = (1 - self.alpha) * avg_mse + self.alpha * (1 - avg_ssim) # SSIM is maximized, so we subtract from 1
mse_loss = F.mse_loss(seq1, seq2)
combined_loss = (1 - self.alpha) * mse_loss + self.alpha * (1 - avg_ssim) # SSIM is maximized, so we subtract from 1
return combined_loss
Loading

0 comments on commit 797933e

Please sign in to comment.