-
Notifications
You must be signed in to change notification settings - Fork 0
/
simple_cnn.py
97 lines (83 loc) · 2.67 KB
/
simple_cnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
import numpy as np
img_width, img_height = 128, 128
batch_size = 64
num_epochs = 40
model_save_path = '/content/drive/MyDrive/tomato/modified_cnn_model.h5'
# Data augmentation to increase validation accuracy and reduce overfitting
train_datagen = ImageDataGenerator(
rescale=1.0 / 255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True
)
# Normalization of validation Data
val_datagen = ImageDataGenerator(rescale=1.0 / 255)
# Load the training data
train_generator = train_datagen.flow_from_directory(
'/content/drive/MyDrive/tomato/train',
target_size=(img_width, img_height),
batch_size=batch_size,
class_mode='categorical'
)
# Load the validation
val_generator = val_datagen.flow_from_directory(
'/content/drive/MyDrive/tomato/val',
target_size=(img_width, img_height),
batch_size=batch_size,
class_mode='categorical'
)
sample_images, sample_labels = next(train_generator)
class_names = list(train_generator.class_indices.keys())
plt.figure(figsize=(10, 10))
for i in range(9):
plt.subplot(3, 3, i + 1)
plt.imshow(sample_images[i])
class_index = np.argmax(sample_labels[i])
class_name = class_names[class_index]
plt.title(f'Class: {class_name}')
plt.axis('off')
plt.show()
model = Sequential([
Conv2D(32, (3, 3), activation='relu', input_shape=(img_width, img_height, 3)),
MaxPooling2D((2, 2)),
Conv2D(64, (3, 3), activation='relu'),
MaxPooling2D((2, 2)),
Conv2D(128, (3, 3), activation='relu'),
MaxPooling2D((2, 2)),
Conv2D(256, (3, 3), activation='relu'),
MaxPooling2D((2, 2)),
Flatten(),
Dense(512, activation='relu'),
Dropout(0.5),
Dense(10, activation='softmax')
])
model.compile(optimizer='adam',loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()
history = model.fit(
train_generator,
steps_per_epoch=train_generator.samples // batch_size,
epochs=num_epochs,
validation_data=val_generator,
validation_steps=val_generator.samples // batch_size
)
model.save(model_save_path)
#plot accuracy
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend(['Train', 'Validation'], loc='upper left')
plt.show()
#plot loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(['Train', 'Validation'], loc='upper right')
plt.show()