From 4c4428b306f17d6ca2036aab3b1a8ac5532ab8c0 Mon Sep 17 00:00:00 2001 From: iSiddharth20 Date: Sat, 23 Dec 2023 20:31:59 -0800 Subject: [PATCH] Added support for Cuda GPU --- Code/training.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/Code/training.py b/Code/training.py index a308373..1ef6880 100644 --- a/Code/training.py +++ b/Code/training.py @@ -14,8 +14,10 @@ # Define Training Class class Trainer(): def __init__(self, model, loss_function, model_save_path): - # Define the model - self.model = model + # Define the device + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + # Define the model and move it to the device + self.model = model.to(self.device) # Define the loss function self.loss_function = loss_function # Define the optimizer