Skip to content

Commit

Permalink
Add the option to visualize the trajectory of the predictions of the …
Browse files Browse the repository at this point in the history
…forecasting model
  • Loading branch information
kaseris committed Jan 5, 2024
1 parent 8142108 commit aa3f88c
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 22 deletions.
21 changes: 20 additions & 1 deletion src/skelcast/primitives/visualize.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import time

from enum import Enum
from typing import Union
from typing import Union, Optional

import numpy as np
import torch
Expand All @@ -26,6 +26,7 @@ class Colors(Enum):


def visualize_skeleton(skeleton: Union[np.ndarray, torch.Tensor],
trajectory: Optional[Union[np.ndarray, torch.Tensor]] = None,
framerate: int = 30,
skeleton_type: str = 'kinect'):
assert isinstance(skeleton, (np.ndarray, torch.Tensor)), f'Expected a numpy array or a PyTorch tensor, got {type(skeleton)} instead.'
Expand All @@ -48,6 +49,21 @@ def visualize_skeleton(skeleton: Union[np.ndarray, torch.Tensor],
vis = o3d.visualization.Visualizer()
vis.create_window()

if trajectory is not None:
trajectory_line_set = o3d.geometry.LineSet()
print(f'trajectory shape: {trajectory.shape}')
for timestep in range(trajectory.shape[0] - 1):
if trajectory is not None:
for joint in range(n_joints):
# Create line segment for each joint connecting its position at this timestep to the next
start_point = trajectory[timestep, joint]
end_point = trajectory[timestep + 1, joint]
trajectory_line_set.points.append(start_point)
trajectory_line_set.points.append(end_point)
index = len(trajectory_line_set.points) - 2
trajectory_line_set.lines.append([index, index + 1])
trajectory_line_set.colors.append([0, 1, 0]) # Set color for trajectory lines, e.g., green

for timestep in range(seq_len):
# Update point cloud for the current timestep
point_cloud.points = o3d.utility.Vector3dVector(skeleton[timestep])
Expand All @@ -58,6 +74,9 @@ def visualize_skeleton(skeleton: Union[np.ndarray, torch.Tensor],
line_set.points = o3d.utility.Vector3dVector(skeleton[timestep])
line_set.colors = o3d.utility.Vector3dVector([Colors.BLUE.value for _ in connections]) # Blue color for connections

if trajectory is not None:
vis.add_geometry(trajectory_line_set)

if timestep == 0:
vis.add_geometry(point_cloud)
vis.add_geometry(line_set)
Expand Down
33 changes: 12 additions & 21 deletions tools/visualize_skel_movement.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import torch.nn as nn

from skelcast.data.dataset import NTURGBDDataset
from skelcast.data.transforms import MinMaxScaleTransform
from skelcast.primitives.visualize import visualize_skeleton
from skelcast.models.rnn.pvred import PositionalVelocityRecurrentEncoderDecoder
from skelcast.models.transformers.sttf import SpatioTemporalTransformer

argparser = argparse.ArgumentParser(description='Visualize skeleton movement.')
argparser.add_argument('--dataset', type=str, required=True, help='Path to the dataset.')
Expand All @@ -20,33 +21,23 @@
log_format = '[%(asctime)s] %(levelname)s: %(message)s'
date_format = '%Y-%m-%d %H:%M:%S'
logging.basicConfig(level=logging.INFO, format=log_format, datefmt=date_format)

tf = MinMaxScaleTransform(feature_scale=[0.0, 1.0])
dataset = NTURGBDDataset(args.dataset, missing_files_dir='data/missing/', label_file='data/labels.txt',
cache_file=args.cache_file,
max_number_of_bodies=1)
model = PositionalVelocityRecurrentEncoderDecoder(input_dim=75,
enc_hidden_dim=64,
dec_hidden_dim=64,
enc_type='lstm',
dec_type='lstm',
include_velocity=False,
pos_enc=None,
batch_first=True,
use_padded_len_mask=False,
observe_until=20,
use_std_mask=False,
loss_fn=nn.MSELoss())
max_number_of_bodies=1, transforms=tf)
model = SpatioTemporalTransformer(n_joints=25, d_model=256, n_blocks=3, n_heads=8, d_head=16, mlp_dim=512, loss_fn=nn.SmoothL1Loss(), dropout=0.5)
# TODO: Remove the hard coding of the checkpoint path
checkpoint = torch.load('/home/kaseris/Documents/mount/checkpoints_forecasting/heather-head/checkpoint_epoch_99_2023-12-13_115017.pt')
checkpoint = torch.load('/home/kaseris/Documents/mount/checkpoints_forecasting/presto-class/checkpoint_epoch_16_2024-01-05_092620.pt')
model_state_dict = checkpoint['model_state_dict']
model.load_state_dict(model_state_dict)
model = model.to('cpu')
skeleton, label = dataset[args.sample]
seq_len, n_bodies, n_joints, n_dims = skeleton.shape
input_to_model = skeleton.unsqueeze(0)
preds, _ = model(input_to_model.to(torch.float32), y=None, masks=None)
# input_to_model = skeleton.unsqueeze(0)
# preds, _ = model(input_to_model.to(torch.float32), y=None, masks=None)
preds = model.predict(skeleton.to(torch.float32), n_steps=30, observe_from_to=[1, 11])
logging.info(f'preds shape: {preds.shape}')
visualize_skeleton(skeleton.squeeze(1))
preds = preds.view(preds.shape[1], 1, 25, 3)
visualize_skeleton(skeleton.squeeze(1), trajectory=preds.squeeze(0), framerate=5)
# preds = preds.view(preds.shape[1], 1, 25, 3)
# TODO: Visualize the prediction superimposed on the skeleton
visualize_skeleton(preds.detach().squeeze(1), framerate=5)
# visualize_skeleton(preds.detach().squeeze(0))

0 comments on commit aa3f88c

Please sign in to comment.