Skip to content

Commit

Permalink
Update forward_kinematics.py
Browse files Browse the repository at this point in the history
  • Loading branch information
nmvrs committed Jun 20, 2024
1 parent 9ddcffe commit 89dac3c
Showing 1 changed file with 19 additions and 19 deletions.
38 changes: 19 additions & 19 deletions moai/monads/human/pose/forward_kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,43 +27,43 @@ def forward(
self, # TODO: add parents tensor input?
rotation: torch.Tensor, # [B, (T), J, 3, 3]
position: torch.Tensor, # [B, (T), 3]
offset: typing.Optional[torch.Tensor] = None, # [B, (T), J, 3]
offsets: typing.Optional[torch.Tensor] = None, # [B, (T), J, 3]
parents: typing.Optional[torch.Tensor] = None, # [B, J]
) -> typing.Dict[str, torch.Tensor]: # { [B, (T), J, 3], [B, (T), J, 3, 3] }
joints = torch.empty(rotation.shape[:-1], device=rotation.device)
joints[..., 0, :] = position.clone() # first joint according to global position
offset = (
offset[:, np.newaxis, ..., np.newaxis]
if offset is not None
offsets = (
offsets[:, np.newaxis, ..., np.newaxis]
if offsets is not None
else self.offsets[:, np.newaxis, ..., np.newaxis]
) # NOTE: careful, col vs row major order
# offset = offset[np.newaxis, :, np.newaxis, :] #NOTE: careful, col vs row major order
global_rotation = rotation.clone()
# global_rotation = torch.empty(rotation.shape, device=rotation.device)
# global_rotation[..., 0, :3, :3] = rotation[..., 0, :3, :3].clone()
transforms = torch.empty(*rotation.shape[:-2], 4, 4, device=rotation.device)
transforms[..., :3, :3] = rotation.clone()
transforms[..., :3, 3] = offsets[..., 0].clone()
transforms[..., 0, :3, 3] = position.clone()
transforms[..., 3, 3] = 1.0
# NOTE: currently the op does not support per batch item parents
parent_indices = (
parents[0].detach().cpu()
if parents is not None
else (self.parents[0].detach().cpu())
)
if (
parent_indices.shape[-1] == offset.shape[-3]
parent_indices.shape[-1] == offsets.shape[-3]
): # NOTE: to support using the same parents everywhere
parent_indices = parent_indices[1:]
composed = [transforms[..., 0, :, :]]
for current_idx, parent_idx in enumerate(
parent_indices, start=1
): # NOTE: assumes parents exclude root
joints[..., current_idx, :] = torch.matmul(
global_rotation[..., parent_idx, :, :], offset[..., current_idx, :, :]
).squeeze(-1)
global_rotation[..., current_idx, :, :] = torch.matmul(
global_rotation[..., parent_idx, :, :].clone(),
rotation[..., current_idx, :, :].clone(),
composed.append(
torch.matmul(
composed[parent_idx],
transforms[..., current_idx, :, :],
)
)
joints[..., current_idx, :] += joints[..., parent_idx, :]
composed = torch.stack(composed, dim=-3)
joints = composed[..., :3, 3]

return {
"positions": joints,
"rotations": global_rotation,
"bone_transforms": transforms,
}

0 comments on commit 89dac3c

Please sign in to comment.