Skip to content

Commit

Permalink
Added support for Cuda GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
iSiddharth20 committed Dec 24, 2023
1 parent 747e92c commit 4c4428b
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions Code/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
# Define Training Class
class Trainer():
def __init__(self, model, loss_function, model_save_path):
# Define the model
self.model = model
# Define the device
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Define the model and move it to the device
self.model = model.to(self.device)
# Define the loss function
self.loss_function = loss_function
# Define the optimizer
Expand Down

0 comments on commit 4c4428b

Please sign in to comment.