Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/moverseai/moai
Browse files Browse the repository at this point in the history
  • Loading branch information
moversekostas committed Jul 31, 2024
2 parents 52f56ff + b77409b commit ebb68ae
Show file tree
Hide file tree
Showing 11 changed files with 60 additions and 17 deletions.
1 change: 1 addition & 0 deletions hydra_plugins/moai_dsl_plugin/moai_dsl_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@
| "view" "(" name "," SIGNED_INT ("," SIGNED_INT)* ")" -> reshape
| "transpose" "(" name "," SIGNED_INT ("," SIGNED_INT)* ")" -> transpose
| "flatten" "(" name "," SIGNED_INT ["," SIGNED_INT] ")" -> flatten
| "repeat_interleave" "(" name "," SIGNED_INT "," SIGNED_INT ")" -> repeat
| "zeros" "(" name ")" -> zeros_like
| "ones" "(" name ")" -> ones_like
| "rand" "(" name ")" -> rand_like
Expand Down
5 changes: 3 additions & 2 deletions moai/conf/model/monads/human/hand/mano.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# @package model.monads._name_
# @package model.monads.mano

_target_: moai.monads.human.hand.MANO
model_path: ???
Expand All @@ -11,4 +11,5 @@ use_betas: true
use_pca: true
pca_components: 12
num_betas: 10
flat_hand_mean: false
flat_hand_mean: false
batch_size: 1
2 changes: 1 addition & 1 deletion moai/conf/model/monads/math/rad2deg.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# @package model.monads._name_
# @package model.monads.rad2deg

_target_: moai.monads.math.Rad2Deg
5 changes: 5 additions & 0 deletions moai/core/execution/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,11 @@ def flatten(self, key, *dims):
# self.results.append(f'result{self.index}')
# self.index += 1

def repeat(self, key, *dims):
key = self.extract(key)
dims = list(map(int, dims))
self._transform_operation("repeat_interleave", key, dims)

def unsqueeze(self, key, *dims):
if not isinstance(key, str) or isinstance(key, Token): # NOTE: is lark.Tree
key = self.extract(key)
Expand Down
22 changes: 16 additions & 6 deletions moai/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,13 @@ def closure(tensors, index, steps, stage, optimizer, objective):
):
frequency = toolz.get(C._FREQUENCY_, iter_monitor_stage, 1)
should_monitor = iter % frequency == 0
if (
iter_tensor_metrics := iter_monitor_stage.get(C._METRICS_, None)
) and should_monitor:
for metric in (
toolz.get(C._METRICS_, iter_monitor_stage, None) or []
):
self.named_metrics[metric](batch)
if (
iter_tensor_monitor := iter_monitor_stage.get(
C._MONITORS_, None
Expand All @@ -369,10 +376,6 @@ def closure(tensors, index, steps, stage, optimizer, objective):
toolz.get(C._FLOWS_, iter_monitor_stage, None) or []
):
self.named_flows[step](batch)
for metric in (
toolz.get(C._METRICS_, iter_monitor_stage, None) or []
):
self.named_metrics[metric](batch)
extras = { # TODO: step => 'lightning_step'
"lightning_step": self.global_step,
"epoch": self.current_epoch,
Expand Down Expand Up @@ -462,7 +465,8 @@ def test_step(
def validation_step(
self,
batch: typing.Dict[str, torch.Tensor],
batch_nb: int,
# batch_nb: int,
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
batch = benedict.benedict(batch, keyattr_enabled=False)
Expand All @@ -481,12 +485,18 @@ def validation_step(
self.monitor, f"{C._VAL_}.{C._DATASETS_}.{dataset_name}"
)
):
extras = {
"lightning_step": self.trainer.test_loop.batch_progress.current.completed, # NOTE: self.global_step does not increment correctly
"epoch": self.current_epoch,
"batch_idx": batch_idx,
}
with torch.no_grad():
for step in get_list(proc, C._FLOWS_):
batch = self.named_flows[step](batch)
for metric in get_list(monitor, C._METRICS_):
self.named_metrics[metric](batch)
# TODO add monitors/visualization
for tensor_monitor in get_list(monitor, C._MONITORS_):
self.named_monitors[tensor_monitor](batch, extras)
return batch

def configure_optimizers(
Expand Down
4 changes: 3 additions & 1 deletion moai/data/datasets/human/amass.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@ def __init__(
downsample_factor: int = 1,
reconstruct: bool = False,
model_type: str = "smplx",
file_type: str = None,
) -> None:
super().__init__()
assert_path(log, "AMASS data root path", data_root)
assert_path(log, "SMPLX data path", smplx_root)
file_type = "**_poses.npz" if model_type == "smpl" else "**_stageii.npz"
if file_type is None:
file_type = "**_poses.npz" if model_type == "smpl" else "**_stageii.npz"
gendered_shape_fn = "**shape.npz" if model_type == "smpl" else "**_stagei.npz"
is_all_parts = isinstance(parts, str) and ("all" == parts or "**" == parts)
parts = os.listdir(data_root) if is_all_parts else parts
Expand Down
3 changes: 2 additions & 1 deletion moai/monads/human/hand/mano.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
use_pca: bool = True,
flat_hand_mean: bool = False,
with_tips: bool = False,
batch_size: int = 1,
):
self.with_tips = with_tips
super(MANO, self).__init__(
Expand All @@ -47,7 +48,7 @@ def __init__(
create_transl=use_translation,
use_pca=use_pca,
dtype=torch.float32,
batch_size=1,
batch_size=batch_size,
num_pca_comps=pca_components,
num_betas=num_betas,
flat_hand_mean=flat_hand_mean,
Expand Down
2 changes: 1 addition & 1 deletion moai/serve/streaming_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def handle(self, data: typing.Mapping[str, typing.Any], context: typing.Any):
# get the dataloader for returned dict
dataloader = td["dataloader"]
# iterate over the dataloader
for batch_idx, batch in enumerate(dataloader):
for batch_idx, batch in enumerate(torch.utils.data.DataLoader(dataloader)):
self.model.optimization_step = 0
# batch should be send to correct device
batch = toolz.valmap(self.__to_device__, batch)
Expand Down
3 changes: 0 additions & 3 deletions moai/validation/metrics/generation/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ def forward(
self,
gt: torch.Tensor,
pred: torch.Tensor,
# weights: torch.Tensor=None,
) -> torch.Tensor:
return {"gt": gt, "pred": pred}

Expand All @@ -26,6 +25,4 @@ def compute(self, pred: np.ndarray, gt: np.ndarray) -> np.ndarray:
c1, c2 = np.cov(gt[:num_samples]), np.cov(pred[:num_samples])
diff_sq = np.linalg.norm(mu1 - mu2, ord=2, axis=-1, keepdims=False)
fid = diff_sq + np.trace(c1 + c2 - 2 * np.sqrt(c1 * c2))
# if weights is not None:
# fid = fid * weights
return fid
13 changes: 11 additions & 2 deletions moai/visualization/rerun/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def multiframe_multiview_posed_image(
optimization_step: typing.Optional[int] = None,
lightning_step: typing.Optional[int] = None,
iteration: typing.Optional[int] = None,
jpeg_quality: int = 40,
) -> None:
if optimization_step is not None:
rr.set_time_sequence("optimization_step", optimization_step)
Expand Down Expand Up @@ -58,9 +59,12 @@ def multiframe_multiview_posed_image(
),
)
# log image
image = np.ascontiguousarray(
images[fr][i].transpose(-2, -1, 0) * 255
).astype(np.uint8)
rr.log(
path + f"/frame_{fr}/cam_{i}",
rr.Image(images[fr][i].transpose(-2, -1, 0)),
rr.Image(image).compress(jpeg_quality=jpeg_quality),
)


Expand All @@ -72,6 +76,7 @@ def multicam_posed_image(
optimization_step: typing.Optional[int] = None,
lightning_step: typing.Optional[int] = None,
iteration: typing.Optional[int] = None,
jpeg_quality: int = 40,
) -> None:
if optimization_step is not None:
rr.set_time_sequence("optimization_step", optimization_step)
Expand All @@ -88,6 +93,7 @@ def multicam_posed_image(
rr.Transform3D(
translation=poses[i][:3, 3],
mat3x3=poses[i][:3, :3],
from_parent=True,
),
)
rr.log(
Expand All @@ -97,9 +103,12 @@ def multicam_posed_image(
),
)
# log image
image = np.ascontiguousarray(images[i].transpose(-2, -1, 0) * 255).astype(
np.uint8
)
rr.log(
path + f"/cam_{i}",
rr.Image(images[i].transpose(-2, -1, 0)),
rr.Image(image).compress(jpeg_quality=jpeg_quality),
)


Expand Down
17 changes: 17 additions & 0 deletions tests/dsl/test_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,23 @@ def test_flatten(self, parser, highdim_tensors):
x = self._parse_and_run(parser, expression, highdim_tensors)
assert x.sum() == 300.0

def test_repeat_interleave(self, parser, highdim_tensors):
expression = "repeat_interleave(fourdim, 2, 1)"
x = self._parse_and_run(parser, expression, highdim_tensors)
assert torch.equal(x, highdim_tensors["fourdim"].repeat_interleave(2, 1))
expression = "repeat_interleave(fourdim, 2, 0)"
x = self._parse_and_run(parser, expression, highdim_tensors)
assert torch.equal(x, highdim_tensors["fourdim"].repeat_interleave(2, 0))
expression = "repeat_interleave(single, 2, 1)"
x = self._parse_and_run(parser, expression, highdim_tensors)
assert x.sum() == 3600 * 2 # 3600 is the sum of single and 2 is the repeat
expression = "repeat_interleave(fourdim, 2, 0)"
x = self._parse_and_run(parser, expression, highdim_tensors)
assert (x[5:] != highdim_tensors["fourdim"]).sum() == 0
expression = "repeat_interleave(fourdim, 2, 0) + ones(10, 3, 2, 6)"
x = self._parse_and_run(parser, expression, highdim_tensors)
assert x.sum() == 120 * 2 + 10 * 3 * 2 * 6

def test_trig(self, parser, trig_tensors):
expression = "sin(pi2)"
x = self._parse_and_run(parser, expression, trig_tensors)
Expand Down

0 comments on commit ebb68ae

Please sign in to comment.