From e925fe4031d08a229a4f17242d7bbb82f419b981 Mon Sep 17 00:00:00 2001 From: Soumik Rakshit <19soumik.rakshit96@gmail.com> Date: Mon, 4 Sep 2023 11:58:24 +0000 Subject: [PATCH] fix: WandbMetricsLogger for torch backend --- wandb_addons/keras/metrics_logger.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/wandb_addons/keras/metrics_logger.py b/wandb_addons/keras/metrics_logger.py index e80cc458..bdc87ac6 100644 --- a/wandb_addons/keras/metrics_logger.py +++ b/wandb_addons/keras/metrics_logger.py @@ -115,7 +115,8 @@ def _get_lr(self) -> Union[float, None]: return None elif torch_backend_available: if isinstance(self.model.optimizer.learning_rate, torch.Tensor): - return float(self.model.optimizer.learning_rate.numpy().item()) + lr = self.model.optimizer.learning_rate.to("cpu") + return float(lr.numpy().item()) else: wandb.termerror("Unable to log learning rate.", repeat=False) return None