Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/moverseai/moai
Browse files Browse the repository at this point in the history
  • Loading branch information
tzole1155 committed Sep 6, 2024
2 parents 57bd6ed + 7d687cc commit 4517691
Showing 1 changed file with 21 additions and 32 deletions.
53 changes: 21 additions & 32 deletions moai/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,18 +229,18 @@ def predict_step(
batch_idx: int,
dataset_idx: int = 0,
) -> typing.Dict[str, typing.Union[torch.Tensor, typing.Dict[str, torch.Tensor]]]:
log.info(f"Predicting batch {batch_idx} ...")
log.debug(f"Predicting batch {batch_idx} ...")
batch = benedict.benedict(batch, keyattr_enabled=False)
extras = {
"stage": "predict",
"lightning_step": self.global_step,
"batch_idx": batch_idx,
"optimization_step": 0, # TODO: add this for fitting case
}
if proc := get_dict(self.process, C._PREDICT_):
with torch.no_grad(): # TODO: probably this is not needed
for step in get_list(proc, C._FLOWS_):
batch = self.named_flows[step](batch)
extras = {
"stage": "predict",
"lightning_step": self.global_step,
"batch_idx": batch_idx,
"optimization_step": 0, # TODO: add this for fitting case
}
if monitor := get_dict(self.monitor, f"{C._PREDICT_}.{C._BATCH_}"):
for metric in get_list(monitor, C._METRICS_):
self.named_metrics[metric](batch)
Expand Down Expand Up @@ -427,39 +427,28 @@ def test_step(
batch = benedict.benedict(batch, keyattr_enabled=False)
batch[C._MOAI_METRICS_] = {}
datasets = list(self.data.test.iterator.datasets.keys())
# if dataset is zipped we should follow a differet approach
dataset_name = (
dataset_name = ( # if dataset is zipped we should follow a differet approach
datasets[dataloader_idx]
if "Zipped" not in self.data.test.iterator._target_
else "zipped"
)
monitor = (
toolz.get_in([C._TEST_, C._DATASETS_, dataset_name], self.monitor) or []
)
if (
proc := get_dict(self.process, f"{C._TEST_}.{C._DATASETS_}.{dataset_name}")
) and (
monitor := get_dict(
self.monitor, f"{C._TEST_}.{C._DATASETS_}.{dataset_name}"
)
):
extras = {
"lightning_step": self.trainer.test_loop.batch_progress.current.completed, # NOTE: self.global_step does not increment correctly
"epoch": self.current_epoch,
"batch_idx": batch_idx,
}
extras = {
"lightning_step": self.trainer.test_loop.batch_progress.current.completed, # NOTE: self.global_step does not increment correctly
"epoch": self.current_epoch,
"batch_idx": batch_idx,
}
if proc := get_dict(self.process, f"{C._TEST_}.{C._DATASETS_}.{dataset_name}"):
with torch.no_grad(): # TODO: probably this is not needed
# for iter in range(iters): #NOTE: is this necessary?
for step in get_list(proc, C._FLOWS_):
batch = self.named_flows[step](batch)
for metric in get_list(monitor, C._METRICS_): # Metrics monitoring
self.named_metrics[metric](batch)
# Tensor monitoring for visualization
# tensor_monitors = toolz.get(C._MONITORS_, monitor, None) or []
# for tensor_monitor in tensor_monitors:
# self.named_monitors[tensor_monitor](batch)
for tensor_monitor in get_list(monitor, C._MONITORS_):
self.named_monitors[tensor_monitor](batch, extras)
if monitor := get_dict(
self.monitor, f"{C._TEST_}.{C._DATASETS_}.{dataset_name}"
):
for metric in get_list(monitor, C._METRICS_): # Metrics monitoring
self.named_metrics[metric](batch)
for tensor_monitor in get_list(monitor, C._MONITORS_):
self.named_monitors[tensor_monitor](batch, extras)

@torch.no_grad
def validation_step(
Expand Down

0 comments on commit 4517691

Please sign in to comment.