Skip to content

Commit

Permalink
Optimized Version of Base Code
Browse files Browse the repository at this point in the history
  • Loading branch information
iSiddharth20 authored Dec 29, 2023
2 parents 57aadd4 + 394a7e4 commit 8220501
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 88 deletions.
30 changes: 10 additions & 20 deletions Code/autoencoder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
Module for AutoEncoder
Generates 3-Chanel RGB Image from 1-Chanel Grayscale Image
--------------------------------------------------------------------------------
For each pair of consecutive values in the channels list, a Convolutional or Transposed Convolutional layer is created.
The number of input channels is the first value, and the number of output channels is the second value.
A Batch Normalization layer and a LeakyReLU activation function are added after each Convolutional or Transposed Convolutional layer.
In the case of the decoder, the final layer uses a Sigmoid activation function instead of LeakyReLU.
'''

# Import Necessary Libraries
Expand All @@ -11,36 +15,23 @@
class Grey2RGBAutoEncoder(nn.Module):
def __init__(self):
super(Grey2RGBAutoEncoder, self).__init__()
'''
# Define the Encoder
The Encoder consists of 4 Convolutional layers with ReLU activation function
Encoder takes 1-Chanel Grayscale image (1 channel) as input and outputs High-Dimentional-Representation
'''
self.encoder = self._make_layers([1, 64, 128, 256, 512])

'''
self.encoder = self._make_layers([1, 64, 128, 256])
# Define the Decoder
The Decoder consists of 4 Transpose Convolutional layers with ReLU activation function
Decoder takes High-Dimentional-Representation as input and outputs 3-Chanel RGB image
The last layer uses a Sigmoid activation function instead of ReLU
'''
self.decoder = self._make_layers([512, 256, 128, 64, 3], decoder=True)
self.decoder = self._make_layers([256, 128, 64, 3], decoder=True)

# Helper function to create the encoder or decoder layers.
def _make_layers(self, channels, decoder=False):
layers = []
for i in range(len(channels) - 1):
'''
For each pair of consecutive values in the channels list, a Convolutional or Transposed Convolutional layer is created.
The number of input channels is the first value, and the number of output channels is the second value.
A ReLU activation function is added after each Convolutional layer.
'''
if decoder:
layers += [nn.ConvTranspose2d(channels[i], channels[i+1], kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True)]
nn.BatchNorm2d(channels[i+1]),
nn.LeakyReLU(inplace=True)]
else:
layers += [nn.Conv2d(channels[i], channels[i+1], kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True)]
nn.BatchNorm2d(channels[i+1]),
nn.LeakyReLU(inplace=True)]
if decoder:
layers[-1] = nn.Sigmoid()
return nn.Sequential(*layers)
Expand All @@ -50,4 +41,3 @@ def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x

21 changes: 11 additions & 10 deletions Code/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torch.utils.data import DataLoader, Dataset, random_split
import torchvision.transforms as transforms
import torch
import os

# Allow loading of truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True
Expand All @@ -26,14 +27,15 @@ def __init__(self, grayscale_dir, rgb_dir, image_size, batch_size, valid_exts=['
self.valid_exts = valid_exts # Valid file extensions
# Get list of valid image filenames
self.filenames = [f for f in self.grayscale_dir.iterdir() if f.suffix in self.valid_exts]
self.length = len(self.filenames)
# Define transformations: resize and convert to tensor
self.transform = transforms.Compose([
transforms.Resize(self.image_size),
transforms.ToTensor()])

# Return the total number of images
def __len__(self):
return len(self.filenames)
return self.length

# Get a single item or a slice from the dataset
def __getitem__(self, idx):
Expand All @@ -43,8 +45,12 @@ def __getitem__(self, idx):
grayscale_path = self.filenames[idx]
rgb_path = self.rgb_dir / grayscale_path.name
# Open images
grayscale_img = Image.open(grayscale_path)
rgb_img = Image.open(rgb_path)
try:
grayscale_img = Image.open(grayscale_path)
rgb_img = Image.open(rgb_path)
except IOError:
print(f"Error opening images {grayscale_path} or {rgb_path}")
return None
# Apply transformations
grayscale_img = self.transform(grayscale_img)
rgb_img = self.transform(rgb_img)
Expand Down Expand Up @@ -72,32 +78,27 @@ def transform_sequence(self, filenames):
# Get batches for LSTM training
def get_lstm_batches(self, val_split, sequence_length, sequence_stride=2):
assert sequence_length % 2 == 0, "The sequence length must be even."

# Compute the total number of sequences that can be formed, given the stride and length
sequence_indices = range(0, len(self.filenames) - sequence_length + 1, sequence_stride)
sequence_indices = range(0, self.length - sequence_length + 1, sequence_stride)
total_sequences = len(sequence_indices)

# Divide the sequences into training and validation
train_size = int((1.0 - val_split) * total_sequences)
train_indices = sequence_indices[:train_size]
val_indices = sequence_indices[train_size:]

# Create dataset with valid sequences only
train_dataset = self.create_sequence_pairs(train_indices, sequence_length)
val_dataset = self.create_sequence_pairs(val_indices, sequence_length)

# Create the data loaders for training and validation datasets
train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False)

return train_loader, val_loader

def create_sequence_pairs(self, indices, sequence_length):
sequence_pairs = []
for start in indices:
end = start + sequence_length
# Make sure we don't go out of bounds
if end < len(self.filenames):
if end < self.length:
sequence_input = self.transform_sequence(self.filenames[start:end])
sequence_target = self.transform_sequence(self.filenames[start + 1:end + 1])
sequence_pairs.append((sequence_input, sequence_target))
Expand Down
4 changes: 1 addition & 3 deletions Code/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@ def forward(self, output, target):
mse_loss = F.mse_loss(output, target)
# Assume output to be raw logits: calculate log_probs and use it to compute entropy
log_probs = F.log_softmax(output, dim=1) # dim 1 is the channel dimension
probs = torch.exp(log_probs)
entropy_loss = -torch.sum(probs * log_probs, dim=1).mean()

entropy_loss = -torch.sum(torch.exp(log_probs) * log_probs, dim=1).mean()
# Combine MSE with entropy loss scaled by alpha factor
composite_loss = (1 - self.alpha) * mse_loss + self.alpha * entropy_loss
return composite_loss
Expand Down
70 changes: 42 additions & 28 deletions Code/lstm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@
Generate Intermediate Images and Return the Complete Image Sequence with Interpolated Images
--------------------------------------------------------------------------------
'''
# Import Necessary Libraries
# Importing Necessary Libraries
import torch
from torch import nn
from torch.nn import functional as F

# Define ConvLSTMCell class
class ConvLSTMCell(nn.Module):
def __init__(self, input_dim, hidden_dim, kernel_size, num_features):
super(ConvLSTMCell, self).__init__()
# Define the convolutional layer
self.hidden_dim = hidden_dim
padding = kernel_size[0] // 2, kernel_size[1] // 2
self.conv = nn.Conv2d(in_channels=input_dim + hidden_dim,
Expand All @@ -19,77 +21,89 @@ def __init__(self, input_dim, hidden_dim, kernel_size, num_features):
padding=padding)

def forward(self, input_tensor, cur_state):
# Unpack the current state into hidden state (h_cur) and cell state (c_cur)
h_cur, c_cur = cur_state
# Concatenate the input tensor and the current hidden state along the channel dimension
combined = torch.cat([input_tensor, h_cur], dim=1)
# Apply the convolution to the combined tensor
combined_conv = self.conv(combined)
# Split the convolution output into four parts for input gate, forget gate, output gate, and cell gate
cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
# Apply sigmoid activation to the input, forget, and output gates
i = torch.sigmoid(cc_i)
f = torch.sigmoid(cc_f)
o = torch.sigmoid(cc_o)
# Apply tanh activation to the cell gate
g = torch.tanh(cc_g)

# Compute the next cell state as a combination of the forget gate, current cell state, input gate, and cell gate
c_next = f * c_cur + i * g
# Compute the next hidden state as the output gate times the tanh of the next cell state
h_next = o * torch.tanh(c_next)

# Return the next hidden state and cell state
return h_next, c_next

# Define the ConvLSTM class
class ConvLSTM(nn.Module):
def __init__(self, input_dim, hidden_dims, kernel_size, num_layers, alpha=0.5):
super(ConvLSTM, self).__init__()
# Set the number of layers, alpha parameter, and hidden dimensions
self.num_layers = num_layers
self.alpha = alpha
self.hidden_dims = hidden_dims
# Initialize a ModuleList to hold the ConvLSTM cells
self.cells = nn.ModuleList()

# Loop over the number of layers and create a ConvLSTM cell for each layer
for i in range(num_layers):
# The input dimension for the first layer is input_dim, for other layers it is the hidden dimension of the previous layer
cur_input_dim = input_dim if i == 0 else hidden_dims[i - 1]
# Append a new ConvLSTM cell to the cells list
self.cells.append(ConvLSTMCell(input_dim=cur_input_dim,
hidden_dim=hidden_dims[i],
kernel_size=kernel_size,
num_features=4)) # LSTM has 4 gates (features)

def init_hidden(self, batch_size, image_height, image_width):
# Initialize a list to hold the initial hidden and cell states
init_states = []
# Loop over the number of layers
for i in range(self.num_layers):
# Note the change from self.hidden_dim to self.hidden_dims
'''
For each layer, create a zero tensor for the hidden state and the cell state
The size of the tensor is (batch_size, hidden_dim, image_height, image_width)
The tensor is moved to the same device as the weights of the convolutional layer of the corresponding ConvLSTM cell
'''
init_states.append([torch.zeros(batch_size, self.hidden_dims[i], image_height, image_width, device=self.cells[i].conv.weight.device),
torch.zeros(batch_size, self.hidden_dims[i], image_height, image_width, device=self.cells[i].conv.weight.device)])
# Return the initial states
return init_states


def forward(self, input_tensor, cur_state=None):
# Extract the batch size, sequence length, height, and width from the input tensor
b, seq_len, _, h, w = input_tensor.size()

# If no current state is provided, initialize it using the init_hidden method
if cur_state is None:
cur_state = self.init_hidden(b, h, w)

# Initialize output tensors for each sequence element
# Initialize the output sequence tensor with zeros
output_sequence = torch.zeros((b, seq_len - 1, self.hidden_dims[-1], h, w), device=input_tensor.device)

# Loop over each ConvLSTM cell (layer) in the model
for layer_idx, cell in enumerate(self.cells):

# Fix: Unpack hidden and cell states for the current layer
# Extract the hidden state and cell state for the current layer
h, c = cur_state[layer_idx]

# For handling the sequence of images
# Loop over each time step in the input sequence
for t in range(seq_len - 1):
# Perform forward pass through the cell
h, c = cell(input_tensor[:, t, :, :, :], (h, c)) # Updated to pass tuple `(h, c)`

if layer_idx == self.num_layers - 1: # Only store output from the last layer
# Pass the input and current state through the cell to get the next state
h, c = cell(input_tensor[:, t, :, :, :], (h, c))
# If this is the last layer, add the hidden state to the output sequence
if layer_idx == self.num_layers - 1:
output_sequence[:, t, :, :, :] = h

# Generate the next input from alpha-blending
# If this is not the last time step, generate the next input by alpha-blending the current and next input
if t != seq_len - 2:
next_input = (1 - self.alpha) * input_tensor[:, t, :, :, :] + self.alpha * input_tensor[:, t + 1, :, :, :]
h, c = cell(next_input, (h, c)) # Updated to pass tuple `(h, c)`

h, c = cell(next_input, (h, c))
# Update the current state for this layer
cur_state[layer_idx] = (h, c)

# No need to stack since we're assigning the results in the output tensor

# Predict an extra frame beyond the last input frame
# After processing all time steps, predict an extra frame beyond the last input frame
h, c = cell(input_tensor[:, -1, :, :, :], (h, c))
output_sequence = torch.cat([output_sequence, h.unsqueeze(1)], dim=1)
return output_sequence, cur_state
# Return the output sequence and the final state
return output_sequence, cur_state
Loading

0 comments on commit 8220501

Please sign in to comment.