-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#49 - ✨ Added GAN model to generate new data for class balancing and …
…improved model training and generalisation.
- Loading branch information
1 parent
2f8c3f2
commit fba6907
Showing
3 changed files
with
132 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import os | ||
|
||
import numpy as np | ||
from tensorflow.keras.preprocessing.image import img_to_array, load_img | ||
|
||
from skinvestigatorai.config.data import DataConfig | ||
from skinvestigatorai.config.model import ModelConfig | ||
from skinvestigatorai.models.gan_model import GAN | ||
|
||
|
||
def load_images_from_folder(folder, img_size): | ||
images = [] | ||
for subdir, dirs, files in os.walk(folder): | ||
for file in files: | ||
if file.lower().endswith(('.JPG', '.jpg', '.jpeg', '.png')): | ||
img_path = os.path.join(subdir, file) | ||
img = load_img(img_path, target_size=img_size) | ||
img_array = img_to_array(img) | ||
img_array = (img_array - 127.5) / 127.5 # Normalize to [-1, 1] | ||
images.append(img_array) | ||
return np.array(images) | ||
|
||
|
||
def train_gan_model(root_folder, img_size=(160, 160), latent_dim=100, epochs=10000, batch_size=128, sample_interval=200): | ||
gan = GAN(img_shape=img_size + (3,), latent_dim=latent_dim) | ||
|
||
# Load and preprocess images | ||
dataset = load_images_from_folder(root_folder, img_size=img_size) | ||
print(f"Loaded {dataset.shape[0]} images from {root_folder}.") | ||
|
||
gan.train(X_train=dataset, epochs=epochs, batch_size=batch_size, sample_interval=sample_interval) | ||
|
||
|
||
if __name__ == "__main__": | ||
# Set up the paths and parameters | ||
root_folder = os.path.join(DataConfig.OUTPUT_DIR, "train") | ||
img_size = ModelConfig.IMG_SIZE | ||
latent_dim = 100 | ||
epochs = 10000 | ||
batch_size = 128 | ||
sample_interval = 200 | ||
|
||
# Start training the GAN model | ||
train_gan_model(root_folder, img_size=img_size, latent_dim=latent_dim, epochs=epochs, batch_size=batch_size, sample_interval=sample_interval) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import numpy as np | ||
import tensorflow as tf | ||
from tensorflow.keras import layers | ||
|
||
|
||
class GAN: | ||
def __init__(self, img_shape, latent_dim): | ||
self.img_shape = img_shape | ||
self.latent_dim = latent_dim | ||
|
||
self.generator = self.build_generator() | ||
self.discriminator = self.build_discriminator() | ||
|
||
self.gan = self.build_gan() | ||
|
||
def build_generator(self): | ||
model = tf.keras.Sequential() | ||
model.add(layers.Dense(256, input_dim=self.latent_dim)) | ||
model.add(layers.LeakyReLU(alpha=0.2)) | ||
model.add(layers.BatchNormalization(momentum=0.8)) | ||
model.add(layers.Dense(512)) | ||
model.add(layers.LeakyReLU(alpha=0.2)) | ||
model.add(layers.BatchNormalization(momentum=0.8)) | ||
model.add(layers.Dense(1024)) | ||
model.add(layers.LeakyReLU(alpha=0.2)) | ||
model.add(layers.BatchNormalization(momentum=0.8)) | ||
model.add(layers.Dense(np.prod(self.img_shape), activation='tanh')) | ||
model.add(layers.Reshape(self.img_shape)) | ||
return model | ||
|
||
def build_discriminator(self): | ||
model = tf.keras.Sequential() | ||
model.add(layers.Flatten(input_shape=self.img_shape)) | ||
model.add(layers.Dense(512)) | ||
model.add(layers.LeakyReLU(alpha=0.2)) | ||
model.add(layers.Dense(256)) | ||
model.add(layers.LeakyReLU(alpha=0.2)) | ||
model.add(layers.Dense(1, activation='sigmoid')) | ||
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy']) | ||
return model | ||
|
||
def build_gan(self): | ||
self.discriminator.trainable = False | ||
model = tf.keras.Sequential([self.generator, self.discriminator]) | ||
model.compile(loss='binary_crossentropy', optimizer='adam') | ||
return model | ||
|
||
def train(self, X_train, epochs, batch_size=128, sample_interval=200): | ||
valid = np.ones((batch_size, 1)) | ||
fake = np.zeros((batch_size, 1)) | ||
|
||
for epoch in range(epochs): | ||
idx = np.random.randint(0, X_train.shape[0], batch_size) | ||
imgs = X_train[idx] | ||
|
||
noise = np.random.normal(0, 1, (batch_size, self.latent_dim)) | ||
gen_imgs = self.generator.predict(noise) | ||
|
||
d_loss_real = self.discriminator.train_on_batch(imgs, valid) | ||
d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake) | ||
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) | ||
|
||
noise = np.random.normal(0, 1, (batch_size, self.latent_dim)) | ||
g_loss = self.gan.train_on_batch(noise, valid) | ||
|
||
if epoch % sample_interval == 0: | ||
print(f"{epoch} [D loss: {d_loss[0]}, acc.: {100 * d_loss[1]}] [G loss: {g_loss}]") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import numpy as np | ||
|
||
from skinvestigatorai.models.gan_model import GAN | ||
|
||
|
||
class GANAugmentor: | ||
def __init__(self): | ||
self.gan = GAN(img_shape=(160, 160, 3), latent_dim=100) | ||
|
||
def generate_samples(self, num_samples): | ||
noise = np.random.normal(0, 1, (num_samples, self.gan.latent_dim)) | ||
generated_images = self.gan.generator.predict(noise) | ||
return generated_images | ||
|
||
def augment_data(self, original_data, augmentation_factor=2): | ||
augmented_data = [] | ||
num_generated = original_data.shape[0] * augmentation_factor | ||
|
||
generated_images = self.generate_samples(num_generated) | ||
augmented_data.extend(generated_images) | ||
return np.array(augmented_data) |