diff --git a/Code/GenerateResults.py b/Code/GenerateResults.py new file mode 100644 index 0000000..aaf0e23 --- /dev/null +++ b/Code/GenerateResults.py @@ -0,0 +1,148 @@ +''' +Generate Results from Trained Models +''' + +# Import Necessary Libraries +import platform +import os +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.multiprocessing import Process +from PIL import Image +from torchvision import transforms +import glob +import shutil + +# Import Model Definations +from autoencoder_model import Grey2RGBAutoEncoder +from lstm_model import ConvLSTM + +# Define Universal Variables +image_width = 1280 +image_height = 720 + +# Define Backend for Distributed Computing +def get_backend(): + system_type = platform.system() + if system_type == "Linux": + return "nccl" + else: + return "gloo" + +# Function to initialize the process group for distributed computing +def setup(rank, world_size): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12355' + dist.init_process_group(backend=get_backend(), rank=rank, world_size=world_size) + torch.cuda.set_device(rank) + +# Function to clean up the process group after computation +def cleanup(): + dist.destroy_process_group() + +# The function to load your models +def load_model(model, model_path, device): + map_location = lambda storage, loc: storage.cuda(device) + state_dict = torch.load(model_path, map_location=map_location) + new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} + model.load_state_dict(new_state_dict) + # Move the model to the device and wrap the model with DDP after its state_dict has been loaded + model = model.to(device) + model = DDP(model, device_ids=[device]) + return model + +# Define the function to save images +def save_images(img_seq, img_dir, global_start_idx): + to_pil = transforms.ToPILImage() + for i, image_tensor in enumerate(img_seq): + global_idx = global_start_idx + i # Calculate the global index + image = to_pil(image_tensor.cpu()) + image.save(f'{img_dir}/image_{global_idx:04d}.tif') + +def reorder_and_save_images(img_exp_dir, output_dir): + image_paths = glob.glob(os.path.join(img_exp_dir, 'image_*.tif')) + sorted_image_paths = sorted(image_paths, key=lambda x: int(os.path.basename(x).split('_')[1].split('.')[0])) + for i, img_path in enumerate(sorted_image_paths): + img = Image.open(img_path) + img.save(os.path.join(output_dir, f'enhanced_sequence_{i:04d}.tif')) + +# Define the Transformation +transform = transforms.Compose([ + transforms.Resize((image_height, image_width)), + transforms.Grayscale(), # Convert the images to grayscale + transforms.ToTensor(), +]) + +# The main function that will be executed by each process +def enhance(rank, world_size, img_inp_dir, img_exp_dir, lstm_path, autoencoder_path): + setup(rank, world_size) + lstm_model = ConvLSTM(input_dim=1, hidden_dims=[1, 1, 1], kernel_size=(3, 3), num_layers=3, alpha=0.6) + lstm = load_model(lstm_model, lstm_path, rank) + lstm.eval() + autoencoder_model = Grey2RGBAutoEncoder() + autoencoder = load_model(autoencoder_model, autoencoder_path, rank) + autoencoder.eval() + image_files = os.listdir(img_inp_dir) + per_gpu = (len(image_files) + world_size - 1) // world_size + start_idx = rank * per_gpu + end_idx = min(start_idx + per_gpu, len(image_files)) + global_start_idx = start_idx + local_images = [Image.open(os.path.join(img_inp_dir, image_files[i])) for i in range(start_idx, end_idx)] + local_tensors = torch.stack([transform(image) for image in local_images]).unsqueeze(0).to(rank) + with torch.no_grad(): + local_output_sequence, _ = lstm(local_tensors) + local_output_sequence = local_output_sequence.squeeze(0) + # Interleave the input and output images + interleaved_sequence = torch.stack([t for pair in zip(local_tensors.squeeze(0), local_output_sequence) for t in pair]) + with torch.no_grad(): + local_output_enhanced = torch.stack([autoencoder(t.unsqueeze(0)) for t in interleaved_sequence]).squeeze(1) + save_images(local_output_enhanced, img_exp_dir, global_start_idx) + cleanup() + + +if __name__ == "__main__": + world_size = torch.cuda.device_count() + # Input Sequence Directory (All Methods) + img_sequence_inp_dir = r'../Dataset/Inference/InputSequence' + # Intermediate Results will be Stored in this Directory which later wll be re-ordered (All Methods) + temp_dir = r'../Dataset/Inference/OutputSequence/Temp' + os.makedirs(temp_dir, exist_ok=True) + + '''Working Directories for (Method-1)''' + autoencoder_path = r'../Models/Method1/model_autoencoder_m1.pth' + lstm_path = r'../Models/Method1/model_lstm_m1.pth' + img_sequence_out_dir = r'../Dataset/Inference/OutputSequence/Method1/' + os.makedirs(img_sequence_out_dir, exist_ok=True) + + '''Working Directories for (Method-2)''' + # autoencoder_path = r'../Models/Method2/model_autoencoder_m2.pth' + # lstm_path = r'../Models/Method1/model_lstm_m1.pth' + # img_sequence_out_dir = r'../Dataset/Inference/OutputSequence/Method2/' + # os.makedirs(img_sequence_out_dir, exist_ok=True) + + '''Working Directories for (Method-3)''' + # autoencoder_path = r'../Models/Method1/model_autoencoder_m1.pth' + # lstm_path = r'../Models/Method3/model_lstm_m3.pth' + # img_sequence_out_dir = r'../Dataset/Inference/OutputSequence/Method3/' + # os.makedirs(img_sequence_out_dir, exist_ok=True) + + '''Working Directories for (Method-4)''' + # autoencoder_path = r'../Models/Method2/model_autoencoder_m2.pth' + # lstm_path = r'../Models/Method3/model_lstm_m3.pth' + # img_sequence_out_dir = r'../Dataset/Inference/OutputSequence/Method4/' + # os.makedirs(img_sequence_out_dir, exist_ok=True) + + processes = [] + for rank in range(world_size): + p = Process(target=enhance, args=(rank, world_size, img_sequence_inp_dir, temp_dir, lstm_path, autoencoder_path)) + p.start() + processes.append(p) + for p in processes: + p.join() + + # Reorder images once processing by all GPUs is complete + reorder_and_save_images(temp_dir, img_sequence_out_dir) + # Delete all Intermediate Results + shutil.rmtree(temp_dir) + diff --git a/Code/autoencoder_model.py b/Code/autoencoder_model.py index 0164c4e..e8a9245 100644 --- a/Code/autoencoder_model.py +++ b/Code/autoencoder_model.py @@ -16,9 +16,9 @@ class Grey2RGBAutoEncoder(nn.Module): def __init__(self): super(Grey2RGBAutoEncoder, self).__init__() # Define the Encoder - self.encoder = self._make_layers([1, 3, 6, 12, 24]) + self.encoder = self._make_layers([1, 4, 8, 16, 32]) # Define the Decoder - self.decoder = self._make_layers([24, 12, 6, 3], decoder=True) + self.decoder = self._make_layers([32, 16, 8, 4, 3], decoder=True) # Helper function to create the encoder or decoder layers. def _make_layers(self, channels, decoder=False): @@ -26,12 +26,10 @@ def _make_layers(self, channels, decoder=False): for i in range(len(channels) - 1): if decoder: layers += [nn.ConvTranspose2d(channels[i], channels[i+1], kernel_size=3, stride=1, padding=1), - nn.BatchNorm2d(channels[i+1]), - nn.LeakyReLU(inplace=True)] + nn.ReLU(inplace=True)] else: layers += [nn.Conv2d(channels[i], channels[i+1], kernel_size=3, stride=1, padding=1), - nn.BatchNorm2d(channels[i+1]), - nn.LeakyReLU(inplace=True)] + nn.ReLU(inplace=True)] if decoder: layers[-1] = nn.Sigmoid() return nn.Sequential(*layers) diff --git a/Code/data.py b/Code/data.py index 3de0291..4548244 100644 --- a/Code/data.py +++ b/Code/data.py @@ -11,7 +11,7 @@ from torch.utils.data import DataLoader, Dataset, random_split import torchvision.transforms as transforms import torch -import os +from torch.utils.data.distributed import DistributedSampler # Allow loading of truncated images ImageFile.LOAD_TRUNCATED_IMAGES = True @@ -66,8 +66,10 @@ def get_autoencoder_batches(self, val_split, batch_size): # Split the dataset into training and validation sets train_dataset, val_dataset = random_split(self, [train_size, val_size]) # Create dataloaders for the training and validation sets - train_loader = DataLoader(train_dataset, batch_size, shuffle=True) - val_loader = DataLoader(val_dataset, batch_size, shuffle=True) + train_sampler = DistributedSampler(train_dataset) + val_sampler = DistributedSampler(val_dataset) + train_loader = DataLoader(train_dataset, batch_size=batch_size, pin_memory=True, sampler=train_sampler) + val_loader = DataLoader(val_dataset, batch_size=batch_size, pin_memory=True, sampler=val_sampler) # Return the training and validation dataloaders return train_loader, val_loader @@ -92,8 +94,10 @@ def get_lstm_batches(self, val_split, sequence_length, batch_size): train_dataset.append((sequence_input_train, sequence_target_train)) val_dataset.append((sequence_input_val, sequence_target_val)) # Create the data loaders for training and validation datasets - train_loader = DataLoader(train_dataset, batch_size, shuffle=False) - val_loader = DataLoader(val_dataset, batch_size, shuffle=False) + train_sampler = DistributedSampler(train_dataset) + val_sampler = DistributedSampler(val_dataset) + train_loader = DataLoader(train_dataset, batch_size=batch_size, pin_memory=True, sampler=train_sampler) + val_loader = DataLoader(val_dataset, batch_size=batch_size, pin_memory=True, sampler=val_sampler) return train_loader, val_loader def transform_sequence(self, filenames, lstm=False): diff --git a/Code/losses.py b/Code/losses.py index 2839305..f3b1ac0 100644 --- a/Code/losses.py +++ b/Code/losses.py @@ -15,7 +15,7 @@ Class for Composite Loss with Maximum Entropy Principle Regularization Term ''' class LossMEP(nn.Module): - def __init__(self, alpha=0.5): + def __init__(self, alpha=0.1): super(LossMEP, self).__init__() self.alpha = alpha # Weighting factor for total variation loss @@ -28,9 +28,9 @@ def forward(self, output, target): torch.sum(torch.abs(output[:, :, :, :-1] - output[:, :, :, 1:])) tv_loss /= batch_size * height * width # Normalize by total size # Composite loss - loss = mse_loss + self.alpha * tv_loss + combined_loss = (1 - self.alpha) * mse_loss + self.alpha * tv_loss # Return the composite loss - return loss + return combined_loss ''' Class for Mean Squared Error (MSE) Loss @@ -45,7 +45,7 @@ def forward(self, output, target): - In PyTorch, loss is minimized, by doing 1 - SSIM, minimizing the loss function will lead to maximization of SSIM ''' class SSIMLoss(nn.Module): - def __init__(self, alpha=0.5): + def __init__(self, alpha=0.1): super(SSIMLoss, self).__init__() self.alpha = alpha self.ssim_module = SSIM(data_range=1, size_average=True, channel=1) @@ -53,16 +53,13 @@ def __init__(self, alpha=0.5): def forward(self, seq1, seq2): N, T = seq1.shape[:2] ssim_values = [] - mse_values = [] for i in range(N): for t in range(T): seq1_slice = seq1[i, t:t+1, ...] seq2_slice = seq2[i, t:t+1, ...] ssim_val = self.ssim_module(seq1_slice, seq2_slice) - mse_val = F.mse_loss(seq1_slice, seq2_slice) ssim_values.append(ssim_val) # Compute SSIM for each frame in the sequence - mse_values.append(mse_val) # Compute MSE for each frame in the sequence avg_ssim = torch.stack(ssim_values).mean() # Average SSIM across all frames - avg_mse = torch.stack(mse_values).mean() # Average MSE across all frames - combined_loss = (1 - self.alpha) * avg_mse + self.alpha * (1 - avg_ssim) # SSIM is maximized, so we subtract from 1 + mse_loss = F.mse_loss(seq1, seq2) + combined_loss = (1 - self.alpha) * mse_loss + self.alpha * (1 - avg_ssim) # SSIM is maximized, so we subtract from 1 return combined_loss diff --git a/Code/main.py b/Code/main.py index 0e21014..cee2ed2 100644 --- a/Code/main.py +++ b/Code/main.py @@ -17,6 +17,7 @@ import torch.multiprocessing as mp import torch.distributed as dist import platform +import time # Define Working Directories autoencoder_grayscale_dir = '../Dataset/AutoEncoder/Grayscale' @@ -59,8 +60,8 @@ def main(rank): # Import Loss Functions try: loss_mse = LossMSE() # Mean Squared Error Loss - loss_mep = LossMEP(alpha=0.2) # Maximum Entropy Loss - loss_ssim = SSIMLoss(alpha=0.2) # Structural Similarity Index Measure Loss + loss_mep = LossMEP(alpha=0.1) # Maximum Entropy Loss + loss_ssim = SSIMLoss(alpha=0.1) # Structural Similarity Index Measure Loss if rank == 0: print('Importing Loss Functions Complete.') except Exception as e: @@ -70,7 +71,7 @@ def main(rank): print('-'*20) # Makes Output Readable # Initialize AutoEncoder Model and Import Dataloader (Training, Validation) - data_autoencoder_train, data_autoencoder_val = dataset.get_autoencoder_batches(val_split=0.25, batch_size=16) + data_autoencoder_train, data_autoencoder_val = dataset.get_autoencoder_batches(val_split=0.25, batch_size=32) if rank == 0: print('AutoEncoder Model Data Imported.') model_autoencoder = Grey2RGBAutoEncoder() @@ -79,7 +80,7 @@ def main(rank): print('-'*20) # Makes Output Readable # Initialize LSTM Model and Import Dataloader (Training, Validation) - data_lstm_train, data_lstm_val = dataset.get_lstm_batches(val_split=0.2, sequence_length=30, batch_size=6) + data_lstm_train, data_lstm_val = dataset.get_lstm_batches(val_split=0.2, sequence_length=30, batch_size=12) if rank == 0: print('LSTM Model Data Imported.') model_lstm = ConvLSTM(input_dim=1, hidden_dims=[1,1,1], kernel_size=(3, 3), num_layers=3, alpha=0.5) @@ -93,17 +94,23 @@ def main(rank): # Method 1 : Baseline : Mean Squared Error Loss for AutoEncoder and LSTM os.makedirs('../Models/Method1', exist_ok=True) # Creating Directory for Model Saving model_save_path_ae = '../Models/Method1/model_autoencoder_m1.pth' + optimizer = torch.optim.Adam(model_autoencoder.parameters(), lr=0.01) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.4) trainer_autoencoder_baseline = Trainer(model=model_autoencoder, loss_function=loss_mse, - optimizer=torch.optim.Adam(model_autoencoder.parameters(), lr=0.001), + optimizer=optimizer, + lr_scheduler=lr_scheduler, model_save_path=model_save_path_ae, rank=rank) if rank == 0: print('Method-1 AutoEncoder Trainer Initialized.') model_save_path_lstm = '../Models/Method1/model_lstm_m1.pth' + optimizer = torch.optim.SGD(model_lstm.parameters(), lr=0.01, momentum=0.9) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.8) trainer_lstm_baseline = Trainer(model=model_lstm, loss_function=loss_mse, - optimizer=torch.optim.Adam(model_lstm.parameters(), lr=0.001), + optimizer=optimizer, + lr_scheduler=lr_scheduler, model_save_path=model_save_path_lstm, rank=rank) if rank == 0: @@ -113,9 +120,12 @@ def main(rank): # Method 2 : Composite Loss (MSE + MaxEnt) for AutoEncoder and Mean Squared Error Loss for LSTM os.makedirs('../Models/Method2', exist_ok=True) # Creating Directory for Model Saving model_save_path_ae = '../Models/Method2/model_autoencoder_m2.pth' + optimizer = torch.optim.Adam(model_autoencoder.parameters(), lr=0.01) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.4) trainer_autoencoder_m2 = Trainer(model=model_autoencoder, loss_function=loss_mep, - optimizer=torch.optim.Adam(model_autoencoder.parameters(), lr=0.001), + optimizer=optimizer, + lr_scheduler=lr_scheduler, model_save_path=model_save_path_ae, rank=rank) if rank == 0: @@ -128,9 +138,12 @@ def main(rank): if rank == 0: print('Method-3 AutoEncoder == Method-1 AutoEncoder') model_save_path_lstm = '../Models/Method3/model_lstm_m3.pth' + optimizer = torch.optim.SGD(model_lstm.parameters(), lr=0.01, momentum=0.9) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.8) trainer_lstm_m3 = Trainer(model=model_lstm, loss_function=loss_ssim, - optimizer=torch.optim.Adam(model_lstm.parameters(), lr=0.001), + optimizer=optimizer, + lr_scheduler=lr_scheduler, model_save_path=model_save_path_lstm, rank=rank) if rank == 0: @@ -149,10 +162,11 @@ def main(rank): ''' # Method-1 try: - epochs = 5 + epochs = 100 if rank == 0: print('Method-1 AutoEncoder Training Start') - model_autoencoder_m1 = trainer_autoencoder_baseline.train_autoencoder(epochs, data_autoencoder_train, data_autoencoder_val) + start_time = time.time() + model_autoencoder_m1, stats_autoencoder_m1 = trainer_autoencoder_baseline.train_autoencoder(epochs, data_autoencoder_train, data_autoencoder_val) if rank == 0: print('Method-1 AutoEncoder Training Complete.') except Exception as e: @@ -161,14 +175,17 @@ def main(rank): traceback.print_exc() finally: if rank == 0: + end_time = time.time() + print(f"Execution time: {end_time - start_time} seconds") trainer_autoencoder_baseline.cleanup_ddp() if rank == 0: print('-'*10) # Makes Output Readable try: - epochs = 5 + epochs = 100 if rank == 0: print('Method-1 LSTM Training Start') - model_lstm_m1 = trainer_lstm_baseline.train_lstm(epochs, data_lstm_train, data_lstm_val) + start_time = time.time() + model_lstm_m1, stats_lstm_m1 = trainer_lstm_baseline.train_lstm(epochs, data_lstm_train, data_lstm_val) if rank == 0: print('Method-1 LSTM Training Complete.') except Exception as e: @@ -177,16 +194,19 @@ def main(rank): traceback.print_exc() finally: if rank == 0: + end_time = time.time() + print(f"Execution time: {end_time - start_time} seconds") trainer_lstm_baseline.cleanup_ddp() if rank == 0: print('-'*20) # Makes Output Readable # Method-2 try: - epochs = 5 + epochs = 100 if rank == 0: print('Method-2 AutoEncoder Training Start') - model_autoencoder_m2 = trainer_autoencoder_m2.train_autoencoder(epochs, data_autoencoder_train, data_autoencoder_val) + start_time = time.time() + model_autoencoder_m2, stats_autoencoder_m2 = trainer_autoencoder_m2.train_autoencoder(epochs, data_autoencoder_train, data_autoencoder_val) if rank == 0: print('Method-2 AutoEncoder Training Complete.') except Exception as e: @@ -194,7 +214,10 @@ def main(rank): print(f"Method-2 AutoEncoder Training Error : \n{e}") traceback.print_exc() finally: - trainer_autoencoder_m2.cleanup_ddp() + if rank == 0: + end_time = time.time() + print(f"Execution time: {end_time - start_time} seconds") + trainer_autoencoder_m2.cleanup_ddp() if rank == 0: print('-'*10) # Makes Output Readable print("Method-2 LSTM == Method-1 LSTM, No Need To Train Again.") @@ -205,10 +228,11 @@ def main(rank): print("Method-3 AutoEncoder == Method-1 AutoEncoder, No Need To Train Again.") print('-'*10) # Makes Output Readable try: - epochs = 5 + epochs = 100 if rank == 0: print('Method-3 LSTM Training Start.') - model_lstm_m3 = trainer_lstm_m3.train_lstm(epochs, data_lstm_train, data_lstm_val) + start_time = time.time() + model_lstm_m3, stats_lstm_m3 = trainer_lstm_m3.train_lstm(epochs, data_lstm_train, data_lstm_val) if rank == 0: print('Method-3 LSTM Training Complete.') except Exception as e: @@ -216,7 +240,10 @@ def main(rank): print(f"Method-3 LSTM Training Error : \n{e}") traceback.print_exc() finally: - trainer_lstm_m3.cleanup_ddp() + if rank == 0: + end_time = time.time() + print(f"Execution time: {end_time - start_time} seconds") + trainer_lstm_m3.cleanup_ddp() if rank == 0: print('-'*20) # Makes Output Readable @@ -227,6 +254,33 @@ def main(rank): print("Method-4 LSTM == Method-3 LSTM, No Need To Train Again.") print('-'*20) # Makes Output Readable + # Print Stats of Each Model + if rank == 0: + print('Best Stats for Method-1 AutoEncoder :') + epoch_num, train_loss, val_loss = stats_autoencoder_m1 + print(f'\tEpoch: {epoch_num} --- Training Loss: {train_loss} --- Validation Loss: {val_loss}') + print('-'*10) # Makes Output Readable + print('Best Stats for Method-1 LSTM :') + epoch_num, train_loss, val_loss = stats_lstm_m1 + print(f'\tEpoch: {epoch_num} --- Training Loss: {train_loss} --- Validation Loss: {val_loss}') + print('-'*20) # Makes Output Readable + print('Best Stats for Method-2 AutoEncoder :') + epoch_num, train_loss, val_loss = stats_autoencoder_m2 + print(f'\tEpoch: {epoch_num} --- Training Loss: {train_loss} --- Validation Loss: {val_loss}') + print('-'*10) # Makes Output Readable + print('Best Stats for Method-2 LSTM == Best Stats for Method-1 LSTM:') + print('-'*20) # Makes Output Readable + print('Best Stats for Method-3 AutoEncoder == Best Stats for Method-1 AutoEncoder:') + print('-'*10) # Makes Output Readable + print('Best Stats for Method-3 LSTM :') + epoch_num, train_loss, val_loss = stats_lstm_m3 + print(f'\tEpoch: {epoch_num} --- Training Loss: {train_loss} --- Validation Loss: {val_loss}') + print('-'*20) # Makes Output Readable + print('Best Stats for Method-4 AutoEncoder == Best Stats for Method-2 AutoEncoder') + print('-'*10) # Makes Output Readable + print('Best Stats for Method-4 LSTM == Best Stats for Method-3 LSTM') + print('-'*20) # Makes Output Readable + if __name__ == '__main__': world_size = torch.cuda.device_count() # Number of available GPUs diff --git a/Code/training.py b/Code/training.py index 373eb6e..4742453 100644 --- a/Code/training.py +++ b/Code/training.py @@ -15,7 +15,7 @@ # Define Training Class class Trainer(): - def __init__(self, model, loss_function, optimizer=None, model_save_path=None, rank=None): + def __init__(self, model, loss_function, optimizer=None, model_save_path=None, rank=None, lr_scheduler=None): self.rank = rank # Rank of the current process self.device = torch.device(f'cuda:{rank}' if torch.cuda.is_available() else 'cpu') self.model = model.to(self.device) @@ -23,6 +23,7 @@ def __init__(self, model, loss_function, optimizer=None, model_save_path=None, r self.loss_function = loss_function # Define the optimizer self.optimizer = optimizer if optimizer is not None else torch.optim.Adam(self.model.parameters(), lr=0.001) + self.scheduler = lr_scheduler if lr_scheduler is not None else torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=10, gamma=0.1) # Wrap model with DDP if torch.cuda.device_count() > 1 and rank is not None: self.model = DDP(self.model, device_ids=[rank], find_unused_parameters=True) @@ -43,7 +44,9 @@ def train_autoencoder(self, epochs, train_loader, val_loader): if torch.cuda.device_count() > 0 and self.rank == 0: gpu_names = ', '.join([torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]) print("\tGPUs being used for Training : ",gpu_names) - best_val_loss = float('inf') + best_val_loss = float('inf') + best_epoch = -1 + best_train_loss = float('inf') for epoch in range(epochs): self.model.train() # Set the Model to Training Mode # Training Loop @@ -54,6 +57,7 @@ def train_autoencoder(self, epochs, train_loader, val_loader): self.optimizer.zero_grad() # Zero gradients to prepare for Backward Pass loss.backward() # Backward Pass self.optimizer.step() # Update Model Parameters + self.scheduler.step() # Update Learning Rate # Validation Loss Calculation self.model.eval() # Set the Model to Evaluation Mode with torch.no_grad(): # Disable gradient computation @@ -62,13 +66,16 @@ def train_autoencoder(self, epochs, train_loader, val_loader): val_loss /= len(val_loader) # Compute Average Validation Loss # Print epochs and losses if self.rank == 0: - print(f'\tAutoEncoder Epoch {epoch+1}/{epochs} --- Training Loss: {loss.item()} --- Validation Loss: {val_loss}') + lr = self.optimizer.param_groups[0]['lr'] + print(f'\tEpoch {epoch+1}/{epochs} --- Training Loss: {round(loss.item(),10)} --- Validation Loss: {round(val_loss,10)} --- Learning Rate: {round(lr,8)}') # If the current validation loss is lower than the best validation loss, save the model if val_loss < best_val_loss: best_val_loss = val_loss # Update the best validation loss + best_epoch = epoch+1 + best_train_loss = loss.item() self.save_model() # Save the model - # Return the Trained Model - return self.model + # Return the Trained Model and the best epoch's details + return self.model, (best_epoch, best_train_loss, best_val_loss) def train_lstm(self, epochs, train_loader, val_loader): # Print Names of All Available GPUs (if any) to Train the Model @@ -76,6 +83,8 @@ def train_lstm(self, epochs, train_loader, val_loader): gpu_names = ', '.join([torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]) print("\tGPUs being used for Training : ",gpu_names) best_val_loss = float('inf') + best_epoch = -1 + best_train_loss = float('inf') for epoch in range(epochs): self.model.train() # Set the model to training mode # Training loop @@ -86,6 +95,7 @@ def train_lstm(self, epochs, train_loader, val_loader): loss = self.loss_function(output_sequence, target_sequence) # Compute loss loss.backward() # Backward pass self.optimizer.step() # Update parameters + self.scheduler.step() # Update Learning Rate # Validation loop self.model.eval() # Set the model to evaluation mode with torch.no_grad(): # Disable gradient computation @@ -97,10 +107,13 @@ def train_lstm(self, epochs, train_loader, val_loader): val_loss /= len(val_loader) # Average validation loss # Print epochs and losses if self.rank == 0: - print(f'\tLSTM Epoch {epoch+1}/{epochs} --- Training Loss: {loss.item()} --- Validation Loss: {val_loss}') + lr = self.optimizer.param_groups[0]['lr'] + print(f'\tEpoch {epoch+1}/{epochs} --- Training Loss: {round(loss.item(),10)} --- Validation Loss: {round(val_loss,10)} --- Learning Rate: {round(lr,8)}') # Model saving based on validation loss if val_loss < best_val_loss: best_val_loss = val_loss + best_epoch = epoch+1 + best_train_loss = loss.item() self.save_model() - # Return the trained model - return self.model + # Return the trained model and the best epoch's details + return self.model, (best_epoch, best_train_loss, best_val_loss)