Skip to content

Commit

Permalink
Merge branch 'master' into fix-gpu-tests-and-failing-metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 authored Nov 7, 2024
2 parents 65a2fce + 0d8b8a1 commit 0b72cf5
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
7 changes: 6 additions & 1 deletion tests/ignite/engine/test_deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest
import torch
import torch.nn as nn
from packaging.version import Version
from torch.optim import SGD
from torch.utils.data import BatchSampler, DataLoader, RandomSampler

Expand Down Expand Up @@ -737,7 +738,11 @@ def write_data_grads_weights(e):
grad_norms.append([i, total[1]] + out2)

if sd is not None:
sd = torch.load(sd)
if Version(torch.__version__) >= Version("1.13.0"):
kwargs = {"weights_only": False}
else:
kwargs = {}
sd = torch.load(sd, **kwargs)
model.load_state_dict(sd[0])
opt.load_state_dict(sd[1])
from ignite.engine.deterministic import _repr_rng_state
Expand Down
6 changes: 5 additions & 1 deletion tests/ignite/handlers/test_state_param_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,11 @@ def test_torch_save_load(dirname):

filepath = Path(dirname) / "dummy_lambda_state_parameter_scheduler.pt"
torch.save(lambda_state_parameter_scheduler, filepath)
loaded_lambda_state_parameter_scheduler = torch.load(filepath)
if Version(torch.__version__) >= Version("1.13.0"):
kwargs = {"weights_only": False}
else:
kwargs = {}
loaded_lambda_state_parameter_scheduler = torch.load(filepath, **kwargs)

engine1 = Engine(lambda e, b: None)
lambda_state_parameter_scheduler.attach(engine1, Events.EPOCH_COMPLETED)
Expand Down

0 comments on commit 0b72cf5

Please sign in to comment.