diff --git a/src/skelcast/primitives/visualize.py b/src/skelcast/primitives/visualize.py index 5abe4da..73656dd 100644 --- a/src/skelcast/primitives/visualize.py +++ b/src/skelcast/primitives/visualize.py @@ -1,7 +1,7 @@ import time from enum import Enum -from typing import Union +from typing import Union, Optional import numpy as np import torch @@ -26,6 +26,7 @@ class Colors(Enum): def visualize_skeleton(skeleton: Union[np.ndarray, torch.Tensor], + trajectory: Optional[Union[np.ndarray, torch.Tensor]] = None, framerate: int = 30, skeleton_type: str = 'kinect'): assert isinstance(skeleton, (np.ndarray, torch.Tensor)), f'Expected a numpy array or a PyTorch tensor, got {type(skeleton)} instead.' @@ -48,6 +49,21 @@ def visualize_skeleton(skeleton: Union[np.ndarray, torch.Tensor], vis = o3d.visualization.Visualizer() vis.create_window() + if trajectory is not None: + trajectory_line_set = o3d.geometry.LineSet() + print(f'trajectory shape: {trajectory.shape}') + for timestep in range(trajectory.shape[0] - 1): + if trajectory is not None: + for joint in range(n_joints): + # Create line segment for each joint connecting its position at this timestep to the next + start_point = trajectory[timestep, joint] + end_point = trajectory[timestep + 1, joint] + trajectory_line_set.points.append(start_point) + trajectory_line_set.points.append(end_point) + index = len(trajectory_line_set.points) - 2 + trajectory_line_set.lines.append([index, index + 1]) + trajectory_line_set.colors.append([0, 1, 0]) # Set color for trajectory lines, e.g., green + for timestep in range(seq_len): # Update point cloud for the current timestep point_cloud.points = o3d.utility.Vector3dVector(skeleton[timestep]) @@ -58,6 +74,9 @@ def visualize_skeleton(skeleton: Union[np.ndarray, torch.Tensor], line_set.points = o3d.utility.Vector3dVector(skeleton[timestep]) line_set.colors = o3d.utility.Vector3dVector([Colors.BLUE.value for _ in connections]) # Blue color for connections + if trajectory is not None: + vis.add_geometry(trajectory_line_set) + if timestep == 0: vis.add_geometry(point_cloud) vis.add_geometry(line_set) diff --git a/tools/visualize_skel_movement.py b/tools/visualize_skel_movement.py index 2ec9101..8c7709a 100644 --- a/tools/visualize_skel_movement.py +++ b/tools/visualize_skel_movement.py @@ -5,8 +5,9 @@ import torch.nn as nn from skelcast.data.dataset import NTURGBDDataset +from skelcast.data.transforms import MinMaxScaleTransform from skelcast.primitives.visualize import visualize_skeleton -from skelcast.models.rnn.pvred import PositionalVelocityRecurrentEncoderDecoder +from skelcast.models.transformers.sttf import SpatioTemporalTransformer argparser = argparse.ArgumentParser(description='Visualize skeleton movement.') argparser.add_argument('--dataset', type=str, required=True, help='Path to the dataset.') @@ -20,33 +21,23 @@ log_format = '[%(asctime)s] %(levelname)s: %(message)s' date_format = '%Y-%m-%d %H:%M:%S' logging.basicConfig(level=logging.INFO, format=log_format, datefmt=date_format) - + tf = MinMaxScaleTransform(feature_scale=[0.0, 1.0]) dataset = NTURGBDDataset(args.dataset, missing_files_dir='data/missing/', label_file='data/labels.txt', cache_file=args.cache_file, - max_number_of_bodies=1) - model = PositionalVelocityRecurrentEncoderDecoder(input_dim=75, - enc_hidden_dim=64, - dec_hidden_dim=64, - enc_type='lstm', - dec_type='lstm', - include_velocity=False, - pos_enc=None, - batch_first=True, - use_padded_len_mask=False, - observe_until=20, - use_std_mask=False, - loss_fn=nn.MSELoss()) + max_number_of_bodies=1, transforms=tf) + model = SpatioTemporalTransformer(n_joints=25, d_model=256, n_blocks=3, n_heads=8, d_head=16, mlp_dim=512, loss_fn=nn.SmoothL1Loss(), dropout=0.5) # TODO: Remove the hard coding of the checkpoint path - checkpoint = torch.load('/home/kaseris/Documents/mount/checkpoints_forecasting/heather-head/checkpoint_epoch_99_2023-12-13_115017.pt') + checkpoint = torch.load('/home/kaseris/Documents/mount/checkpoints_forecasting/presto-class/checkpoint_epoch_16_2024-01-05_092620.pt') model_state_dict = checkpoint['model_state_dict'] model.load_state_dict(model_state_dict) model = model.to('cpu') skeleton, label = dataset[args.sample] seq_len, n_bodies, n_joints, n_dims = skeleton.shape - input_to_model = skeleton.unsqueeze(0) - preds, _ = model(input_to_model.to(torch.float32), y=None, masks=None) + # input_to_model = skeleton.unsqueeze(0) + # preds, _ = model(input_to_model.to(torch.float32), y=None, masks=None) + preds = model.predict(skeleton.to(torch.float32), n_steps=30, observe_from_to=[1, 11]) logging.info(f'preds shape: {preds.shape}') - visualize_skeleton(skeleton.squeeze(1)) - preds = preds.view(preds.shape[1], 1, 25, 3) + visualize_skeleton(skeleton.squeeze(1), trajectory=preds.squeeze(0), framerate=5) + # preds = preds.view(preds.shape[1], 1, 25, 3) # TODO: Visualize the prediction superimposed on the skeleton - visualize_skeleton(preds.detach().squeeze(1), framerate=5) \ No newline at end of file + # visualize_skeleton(preds.detach().squeeze(0)) \ No newline at end of file