Skip to content

Commit

Permalink
Merge pull request from Dev Branch
Browse files Browse the repository at this point in the history
Updated Model Architecture, Loss Functions, Trainer Initializations
  • Loading branch information
iSiddharth20 authored Dec 23, 2023
2 parents ea4c35b + ae32f3d commit 3a60c41
Show file tree
Hide file tree
Showing 45 changed files with 1,223 additions and 221 deletions.
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.tif filter=lfs diff=lfs merge=lfs -text
56 changes: 56 additions & 0 deletions Code/RequiredResults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Import Necessary Libraries
import os
from PIL import Image
import torch

# Function that generates RGB Image Sequence with Interpolated Frames from a Grayscale Image Sequence
def generate_rgb_sequence(model_lstm, model_autoencoder, grey_sequence, n_interpolate_frames,
model_save_path_lstm, model_save_path_ae, generated_sequence_dir):

if os.path.exists(model_save_path_lstm):
model_lstm.load_state_dict(torch.load(model_save_path_lstm))
model_lstm.eval()

if os.path.exists(model_save_path_ae):
model_autoencoder.load_state_dict(torch.load(model_save_path_ae))
model_autoencoder.eval()

full_sequence_gray = model_lstm(grey_sequence, n_interpolate_frames)

full_sequence_rgb = []
with torch.no_grad():
for i in range(full_sequence_gray.size(1)):
gray_frame = full_sequence_gray[:, i, :, :]
rgb_frame = model_autoencoder(gray_frame.unsqueeze(dim=0))
full_sequence_rgb.append(rgb_frame)

os.makedirs(generated_sequence_dir, exist_ok=True)
for idx, rgb_tensor in enumerate(full_sequence_rgb):

image_data = rgb_tensor.squeeze().cpu().numpy()
image_data = np.transpose(image_data, (1, 2, 0))
image_data = (image_data * 255).astype(np.uint8)
image = Image.fromarray(image_data)

image_path = os.path.join(generated_sequence_dir, f'generated_frame_{idx:04d}.tif')
image.save(image_path)

print('The generated sequence of RGB images has been saved.')


'''
Pass Output of LSTM Model to AutoEncoder Model to Obtain Final Output
'''
# Maximize Likelihood Principle
model_save_path_ae = '../Models/model_autoencoder_mlp.pth'
model_save_path_lstm = '../Models/model_lstm_mlp.pth'
generated_sequence_dir = '../Dataset/GeneratedSequence/MLP'
generate_rgb_sequence(model_lstm_mlp, model_autoencoder_mlp, grey_sequence, n_interpolate_frames,
model_save_path_lstm, model_save_path_ae, generated_sequence_dir)

# Maximize Entropy Principle
model_save_path_ae = '../Models/model_autoencoder_mep.pth'
model_save_path_lstm = '../Models/model_lstm_mep.pth'
generated_sequence_dir = '../Dataset/GeneratedSequence/MEP'
generate_rgb_sequence(model_lstm_mep, model_autoencoder_mep, grey_sequence, n_interpolate_frames,
model_save_path_lstm, model_save_path_ae, generated_sequence_dir)
7 changes: 4 additions & 3 deletions Code/autoencoder_model.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
'''
Module that specifies AutoEncoder Architecture using PyTorch
Module for AutoEncoder
Generates 3-Chanel RGB Image from 1-Chanel Grayscale Image
--------------------------------------------------------------------------------
'''

# Import Necessary Libraries
import torch.nn as nn

# Define AutoEncoder Architecture
class AutoEncoder(nn.Module):
class Grey2RGBAutoEncoder(nn.Module):
def __init__(self):
super(AutoEncoder, self).__init__()
super(Grey2RGBAutoEncoder, self).__init__()

'''
# Define the Encoder
Expand Down
41 changes: 28 additions & 13 deletions Code/data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
'''
Module that specifies Data Pre-Processing
Importing Dataset, Converting it to PyTorch Tensors, Splitting it into Training and Validation Sets
--------------------------------------------------------------------------------
'''

Expand All @@ -10,8 +11,8 @@
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import torch
from torch.utils.data import DataLoader, TensorDataset

# Define a class for Importing dataset and Storing it as NumPy Array
class Dataset:
def __init__(self, grayscale_dir, rgb_dir, image_size, batch_size):
self.grayscale_dir = grayscale_dir # Directory for grayscale images
Expand Down Expand Up @@ -49,17 +50,31 @@ def load_images_to_tensor(self, directory):
return images

# Function to get batches of input-target pairs from data (This Functionality is for AutoEncoder Component of the Program)
def get_autoencoder_batches(self):
def get_autoencoder_batches(self,val_split):
# Create a Dataset from the Tensors
dataset = torch.utils.data.TensorDataset(self.grayscale_images, self.rgb_images)
# Create a dataloader for the Dataset
dataloader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
# Return the dataloader
return dataloader

# Function to get image-sequence of imported data (This Functionality is for LSTM Component of the Program)
dataset = TensorDataset(self.grayscale_images, self.rgb_images)
# Calculate the number of samples to include in the validation set
val_size = int(val_split * len(dataset))
train_size = len(dataset) - val_size
# Split the dataset into training and validation sets
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
# Create dataloaders for the training and validation sets
train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=True)
# Return the training and validation dataloaders
return train_loader, val_loader

# Function to get batches of original_sequence-interpolated_sequence from data (This Functionality is for LSTM Component of the Program)
def get_lstm_batches(self):
# Add an extra dimension at the beginning of the Tensor so that it has shape [1, m, C, H, W]
grayscale_image_sequence = self.grayscale_images.unsqueeze(0)
return grayscale_image_sequence

# Add an extra dimension to the grayscale images tensor
greyscale_image_sequence = self.grayscale_images.unsqueeze(0)
# Split the sequence into training and validation sets
greyscale_image_sequence_train = greyscale_image_sequence[:, 1::2] # All odd-indexed images for Training
greyscale_image_sequence_val = greyscale_image_sequence # All images for Validation of Interpolated Frames
# Create TensorDatasets
train_data = TensorDataset(greyscale_image_sequence_train)
val_data = TensorDataset(greyscale_image_sequence_val)
# Create DataLoaders
train_loader = DataLoader(train_data, batch_size=self.batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=self.batch_size, shuffle=True)
return train_loader, val_loader
10 changes: 3 additions & 7 deletions Code/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@ def __init__(self, model, loss_fn):
self.loss_fn = loss_fn

def convert_to_image(self, tensor):
# Assumes tensor is 4D batch of images, and undoes normalization to [0, 255]
tensor = tensor.clone() # Avoid changes to the original tensor
tensor = tensor.clone()
tensor = tensor * 255.0
tensor = tensor.cpu().numpy().astype(np.uint8)
if tensor.ndim == 4 and tensor.shape[1] == 1: # single-channel images
return tensor[:, 0] # Remove channel dimension for SSIM
if tensor.ndim == 4 and tensor.shape[1] == 1:
return tensor[:, 0]
return tensor

def evaluate(self, test_loader):
Expand All @@ -34,11 +33,8 @@ def evaluate(self, test_loader):
output = self.convert_to_image(output)
mse = ((batch_data - output) ** 2).mean(axis=None)
mse_total += mse

# Compute SSIM over each image in batch and average
batch_ssim = np.mean([ssim(x, y, data_range=255) for x, y in zip(batch_data, output)])
ssim_total += batch_ssim

mse_avg = mse_total / len(test_loader)
ssim_avg = ssim_total / len(test_loader)
print('Test MSE: {:.4f}'.format(mse_avg))
Expand Down
72 changes: 42 additions & 30 deletions Code/losses.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,57 @@
'''
Module that specifies Loss Functions
--------------------------------------------------------------------------------
Module for Loss Functions :
- Maximum Entropy Principle (MEP)
- Maximum Likelihood Principle (MLP)
- Structural Similarity Index Measure (SSIM)
'''

# Import Necessary Libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_msssim import SSIM

# Define a class for the Maximum Entropy Principle (MEP) Loss
'''
Class for Composite Loss with MaxEnt Regularization Term
- Maximum Entropy Principle
'''
class LossMEP(nn.Module):
def __init__(self, alpha=0.5):
super(LossMEP, self).__init__()
# Regularization Parameter Weight
self.alpha = alpha
# Base Loss Function (MSE)
self.mse = nn.MSELoss()
self.alpha = alpha # Weighting factor for the loss
self.mse = nn.MSELoss() # Mean Squared Error loss

def forward(self, output, target):
# Compute the MSE loss
mse_loss = self.mse(output, target)
# Compute Entropy of the Target Distribution
entropy = -torch.sum(target * torch.log(output + 1e-8), dim=-1).mean()
# Compute Composite Loss Function with MaxEnt Regularization Term
regularized_loss = self.alpha * mse_loss + (1 - self.alpha) * entropy
# Return Composite Loss
return regularized_loss
mse_loss = self.mse(output, target) # Compute MSE Loss
entropy = -torch.sum(target * torch.log(output + 1e-8), dim=-1).mean() # Compute Entropy
composite_loss = self.alpha * mse_loss + (1 - self.alpha) * entropy # Compute Composite Loss
return composite_loss

# Define a class for the Maximum Likelihood Principle (MLP) Loss
class LossMLP(nn.Module):
def __init__(self, alpha=0.5):
super(LossMLP, self).__init__()
# Regularization Parameter Weight
self.alpha = alpha
# Mean Squared Error Loss
self.mse = nn.MSELoss()
'''
Class for Mean Squared Error (MSE) Loss
- Maximum Likelihood Principle
'''
class LossMSE(nn.Module):
def __init__(self):
super(LossMSE, self).__init__()
self.mse = nn.MSELoss() # Mean Squared Error loss

def forward(self, output, target):
# Compute the MSE loss
likelihood_loss = self.mse(output, target)
# Compute Loss Function with Maximum Likelihood Principle
regularized_loss = self.alpha * likelihood_loss
# Return Loss
return regularized_loss
likelihood_loss = self.mse(output, target) # Compute MSE loss
return likelihood_loss

'''
Class for Structural Similarity Index Measure (SSIM) Loss
- Maximum Likelihood Principle
- In PyTorch, loss is minimized, by doing 1 - SSIM, minimizing the loss function will lead to maximization of SSIM
'''
class SSIMLoss(nn.Module):
def __init__(self, data_range=1, size_average=True):
super(SSIMLoss, self).__init__()
self.data_range = data_range # The range of the input image (usually 1.0 or 255)
self.size_average = size_average # If True, the SSIM of all windows are averaged
# Initialize SSIM module
self.ssim_module = SSIM(data_range=self.data_range, size_average=self.size_average)

def forward(self, img1, img2):
ssim_value = self.ssim_module(img1, img2) # Compute SSIM
return 1 - ssim_value # Return loss
Loading

0 comments on commit 3a60c41

Please sign in to comment.