-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
85 lines (68 loc) · 2.94 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
from utils.evaluation_visualization import EvaluationVisualization, evaluate_and_visualize
from factories.vision_transformer_factory import VisionTransformerFactory
def train(model, train_loader, val_loader, criterion, optimizer, device, num_epochs):
train_losses = []
val_losses = []
accuracies = []
for epoch in range(num_epochs):
model.train()
total_train_loss = 0.0
for inputs, labels in train_loader:
inputs, labels, = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_train_loss += loss.item()
avg_train_loss = total_train_loss / len(train_loader)
train_losses.append(avg_train_loss)
val_loss, accuracy, conf_matrix = evaluate_and_visualize(model, val_loader, criterion, device, classes, epoch)
val_losses.append(val_loss)
accuracies.append(accuracy)
EvaluationVisualization.plot_loss_accuracy(train_losses, val_losses, accuracies, epoch)
print(f"Epoch [{epoch + 1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Val Loss: {val_loss:.4f}")
def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Hyperparameters
input_dim = 28*28
num_classes = 10
num_blocks = 4
hidden_dim = 256
ffn_hidden_dim = 512
dropout_prob = 0.1
batch_size = 128
learning_rate = 0.0001
num_epochs = 10
global classes
classes = [str(i) for i in range(10)] # MNIST Class Labels
# Create model
model = VisionTransformerFactory.create_model(
input_dim, num_classes, num_blocks, hidden_dim, ffn_hidden_dim, dropout_prob
)
model.to(device)
# Create data loaders
train_loader, val_loader = VisionTransformerFactory.create_dataloader(batch_size)
# Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
# Train Loop
train(model, train_loader, val_loader, criterion, optimizer, device, num_epochs)
# Save the trained model
torch.save(model.state_dict(), '/output/path/to/here/vision_transformer_model.pth')
# Load the trained model for inference
model.load_state_dict(torch.load('/output/path/to/here/vision_transformer_model.pth'))
model.eval()
# Example: Inference on a single image
# example_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
# example_image = example_transform(Image.open('/input/image/path/to/here/mnist5.jpg')).unsqueeze(0).to(device)
# with torch.no_grad():
# output = model(example_image.view(example_image.size(0), -1)) # Flatten the input tensor for inference
# predicted_class = torch.argmax(output, dim=1).item()
# print(f'Predicted Class: {predicted_class}')
if __name__ == "__main__":
main()