Skip to content

Commit

Permalink
Fix #2127 move to ema device
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Apr 11, 2024
1 parent e25bbfc commit 24f6d4f
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions timm/utils/model_ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,9 @@ def apply_update_(self, model, decay: float):
else:
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
if ema_v.is_floating_point():
ema_v.lerp_(model_v, weight=1. - decay)
ema_v.lerp_(model_v.to(device=self.device), weight=1. - decay)
else:
ema_v.copy_(model_v)
ema_v.copy_(model_v.to(device=self.device))

def apply_update_no_buffers_(self, model, decay: float):
# interpolate parameters, copy buffers
Expand All @@ -246,7 +246,7 @@ def apply_update_no_buffers_(self, model, decay: float):
torch._foreach_add_(ema_params, model_params, alpha=1 - decay)
else:
for ema_p, model_p in zip(ema_params, model_params):
ema_p.lerp_(model_p, weight=1. - decay)
ema_p.lerp_(model_p.to(device=self.device), weight=1. - decay)

for ema_b, model_b in zip(self.module.buffers(), model.buffers()):
ema_b.copy_(model_b.to(device=self.device))
Expand Down

0 comments on commit 24f6d4f

Please sign in to comment.