Skip to content

Commit

Permalink
#49 - ✨ Added GAN model to generate new data for class balancing and …
Browse files Browse the repository at this point in the history
…improved model training and generalisation.
  • Loading branch information
Thomasbehan committed Aug 9, 2024
1 parent 2f8c3f2 commit fba6907
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 0 deletions.
44 changes: 44 additions & 0 deletions commands/run_gan_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import os

import numpy as np
from tensorflow.keras.preprocessing.image import img_to_array, load_img

from skinvestigatorai.config.data import DataConfig
from skinvestigatorai.config.model import ModelConfig
from skinvestigatorai.models.gan_model import GAN


def load_images_from_folder(folder, img_size):
images = []
for subdir, dirs, files in os.walk(folder):
for file in files:
if file.lower().endswith(('.JPG', '.jpg', '.jpeg', '.png')):
img_path = os.path.join(subdir, file)
img = load_img(img_path, target_size=img_size)
img_array = img_to_array(img)
img_array = (img_array - 127.5) / 127.5 # Normalize to [-1, 1]
images.append(img_array)
return np.array(images)


def train_gan_model(root_folder, img_size=(160, 160), latent_dim=100, epochs=10000, batch_size=128, sample_interval=200):
gan = GAN(img_shape=img_size + (3,), latent_dim=latent_dim)

# Load and preprocess images
dataset = load_images_from_folder(root_folder, img_size=img_size)
print(f"Loaded {dataset.shape[0]} images from {root_folder}.")

gan.train(X_train=dataset, epochs=epochs, batch_size=batch_size, sample_interval=sample_interval)


if __name__ == "__main__":
# Set up the paths and parameters
root_folder = os.path.join(DataConfig.OUTPUT_DIR, "train")
img_size = ModelConfig.IMG_SIZE
latent_dim = 100
epochs = 10000
batch_size = 128
sample_interval = 200

# Start training the GAN model
train_gan_model(root_folder, img_size=img_size, latent_dim=latent_dim, epochs=epochs, batch_size=batch_size, sample_interval=sample_interval)
67 changes: 67 additions & 0 deletions skinvestigatorai/models/gan_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers


class GAN:
def __init__(self, img_shape, latent_dim):
self.img_shape = img_shape
self.latent_dim = latent_dim

self.generator = self.build_generator()
self.discriminator = self.build_discriminator()

self.gan = self.build_gan()

def build_generator(self):
model = tf.keras.Sequential()
model.add(layers.Dense(256, input_dim=self.latent_dim))
model.add(layers.LeakyReLU(alpha=0.2))
model.add(layers.BatchNormalization(momentum=0.8))
model.add(layers.Dense(512))
model.add(layers.LeakyReLU(alpha=0.2))
model.add(layers.BatchNormalization(momentum=0.8))
model.add(layers.Dense(1024))
model.add(layers.LeakyReLU(alpha=0.2))
model.add(layers.BatchNormalization(momentum=0.8))
model.add(layers.Dense(np.prod(self.img_shape), activation='tanh'))
model.add(layers.Reshape(self.img_shape))
return model

def build_discriminator(self):
model = tf.keras.Sequential()
model.add(layers.Flatten(input_shape=self.img_shape))
model.add(layers.Dense(512))
model.add(layers.LeakyReLU(alpha=0.2))
model.add(layers.Dense(256))
model.add(layers.LeakyReLU(alpha=0.2))
model.add(layers.Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
return model

def build_gan(self):
self.discriminator.trainable = False
model = tf.keras.Sequential([self.generator, self.discriminator])
model.compile(loss='binary_crossentropy', optimizer='adam')
return model

def train(self, X_train, epochs, batch_size=128, sample_interval=200):
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))

for epoch in range(epochs):
idx = np.random.randint(0, X_train.shape[0], batch_size)
imgs = X_train[idx]

noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
gen_imgs = self.generator.predict(noise)

d_loss_real = self.discriminator.train_on_batch(imgs, valid)
d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
g_loss = self.gan.train_on_batch(noise, valid)

if epoch % sample_interval == 0:
print(f"{epoch} [D loss: {d_loss[0]}, acc.: {100 * d_loss[1]}] [G loss: {g_loss}]")
21 changes: 21 additions & 0 deletions skinvestigatorai/services/gan_augmentor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import numpy as np

from skinvestigatorai.models.gan_model import GAN


class GANAugmentor:
def __init__(self):
self.gan = GAN(img_shape=(160, 160, 3), latent_dim=100)

def generate_samples(self, num_samples):
noise = np.random.normal(0, 1, (num_samples, self.gan.latent_dim))
generated_images = self.gan.generator.predict(noise)
return generated_images

def augment_data(self, original_data, augmentation_factor=2):
augmented_data = []
num_generated = original_data.shape[0] * augmentation_factor

generated_images = self.generate_samples(num_generated)
augmented_data.extend(generated_images)
return np.array(augmented_data)

0 comments on commit fba6907

Please sign in to comment.