From 7d687cc2714a5bd56cb018ec017d59d0aec8e5cd Mon Sep 17 00:00:00 2001 From: Nick Date: Fri, 6 Sep 2024 10:35:16 +0300 Subject: [PATCH] Update model.py --- moai/core/model.py | 53 ++++++++++++++++++---------------------------- 1 file changed, 21 insertions(+), 32 deletions(-) diff --git a/moai/core/model.py b/moai/core/model.py index 52f27a3..f34b4da 100644 --- a/moai/core/model.py +++ b/moai/core/model.py @@ -229,18 +229,18 @@ def predict_step( batch_idx: int, dataset_idx: int = 0, ) -> typing.Dict[str, typing.Union[torch.Tensor, typing.Dict[str, torch.Tensor]]]: - log.info(f"Predicting batch {batch_idx} ...") + log.debug(f"Predicting batch {batch_idx} ...") batch = benedict.benedict(batch, keyattr_enabled=False) + extras = { + "stage": "predict", + "lightning_step": self.global_step, + "batch_idx": batch_idx, + "optimization_step": 0, # TODO: add this for fitting case + } if proc := get_dict(self.process, C._PREDICT_): with torch.no_grad(): # TODO: probably this is not needed for step in get_list(proc, C._FLOWS_): batch = self.named_flows[step](batch) - extras = { - "stage": "predict", - "lightning_step": self.global_step, - "batch_idx": batch_idx, - "optimization_step": 0, # TODO: add this for fitting case - } if monitor := get_dict(self.monitor, f"{C._PREDICT_}.{C._BATCH_}"): for metric in get_list(monitor, C._METRICS_): self.named_metrics[metric](batch) @@ -427,39 +427,28 @@ def test_step( batch = benedict.benedict(batch, keyattr_enabled=False) batch[C._MOAI_METRICS_] = {} datasets = list(self.data.test.iterator.datasets.keys()) - # if dataset is zipped we should follow a differet approach - dataset_name = ( + dataset_name = ( # if dataset is zipped we should follow a differet approach datasets[dataloader_idx] if "Zipped" not in self.data.test.iterator._target_ else "zipped" ) - monitor = ( - toolz.get_in([C._TEST_, C._DATASETS_, dataset_name], self.monitor) or [] - ) - if ( - proc := get_dict(self.process, f"{C._TEST_}.{C._DATASETS_}.{dataset_name}") - ) and ( - monitor := get_dict( - self.monitor, f"{C._TEST_}.{C._DATASETS_}.{dataset_name}" - ) - ): - extras = { - "lightning_step": self.trainer.test_loop.batch_progress.current.completed, # NOTE: self.global_step does not increment correctly - "epoch": self.current_epoch, - "batch_idx": batch_idx, - } + extras = { + "lightning_step": self.trainer.test_loop.batch_progress.current.completed, # NOTE: self.global_step does not increment correctly + "epoch": self.current_epoch, + "batch_idx": batch_idx, + } + if proc := get_dict(self.process, f"{C._TEST_}.{C._DATASETS_}.{dataset_name}"): with torch.no_grad(): # TODO: probably this is not needed # for iter in range(iters): #NOTE: is this necessary? for step in get_list(proc, C._FLOWS_): batch = self.named_flows[step](batch) - for metric in get_list(monitor, C._METRICS_): # Metrics monitoring - self.named_metrics[metric](batch) - # Tensor monitoring for visualization - # tensor_monitors = toolz.get(C._MONITORS_, monitor, None) or [] - # for tensor_monitor in tensor_monitors: - # self.named_monitors[tensor_monitor](batch) - for tensor_monitor in get_list(monitor, C._MONITORS_): - self.named_monitors[tensor_monitor](batch, extras) + if monitor := get_dict( + self.monitor, f"{C._TEST_}.{C._DATASETS_}.{dataset_name}" + ): + for metric in get_list(monitor, C._METRICS_): # Metrics monitoring + self.named_metrics[metric](batch) + for tensor_monitor in get_list(monitor, C._MONITORS_): + self.named_monitors[tensor_monitor](batch, extras) @torch.no_grad def validation_step(