diff --git a/docs/guides/cli.md b/docs/guides/cli.md
index 0c08e9b17..35ea52171 100644
--- a/docs/guides/cli.md
+++ b/docs/guides/cli.md
@@ -118,158 +118,166 @@ optional arguments:
If you specify how many identities there should be in a frame (i.e., the number of animals) with the {code}`--tracking.clean_instance_count` argument, then we will use a heuristic method to connect "breaks" in the track identities where we lose one identity and spawn another. This can be used as part of the inference pipeline (if models are specified), as part of the tracking-only pipeline (if the predictions file is specified and no models are specified), or by itself on predictions with pre-tracked identities (if you specify {code}`--tracking.tracker none`). See {ref}`proofreading` for more details on tracking.
```none
-usage: sleap-track [-h] [-m MODELS] [--frames FRAMES] [--only-labeled-frames]
- [--only-suggested-frames] [-o OUTPUT] [--no-empty-frames]
- [--verbosity {none,rich,json}]
- [--video.dataset VIDEO.DATASET]
- [--video.input_format VIDEO.INPUT_FORMAT]
- [--video.index VIDEO.INDEX]
- [--cpu | --first-gpu | --last-gpu | --gpu GPU]
- [--peak_threshold PEAK_THRESHOLD] [--batch_size BATCH_SIZE]
- [--open-in-gui] [--tracking.tracker TRACKING.TRACKER]
- [--tracking.target_instance_count TRACKING.TARGET_INSTANCE_COUNT]
- [--tracking.pre_cull_to_target TRACKING.PRE_CULL_TO_TARGET]
- [--tracking.pre_cull_iou_threshold TRACKING.PRE_CULL_IOU_THRESHOLD]
+usage: sleap-track [-h] [-m MODELS] [--frames FRAMES] [--only-labeled-frames] [--only-suggested-frames] [-o OUTPUT] [--no-empty-frames]
+ [--verbosity {none,rich,json}] [--video.dataset VIDEO.DATASET] [--video.input_format VIDEO.INPUT_FORMAT]
+ [--video.index VIDEO.INDEX] [--cpu | --first-gpu | --last-gpu | --gpu GPU] [--max_edge_length_ratio MAX_EDGE_LENGTH_RATIO]
+ [--dist_penalty_weight DIST_PENALTY_WEIGHT] [--batch_size BATCH_SIZE] [--open-in-gui] [--peak_threshold PEAK_THRESHOLD]
+ [-n MAX_INSTANCES] [--tracking.tracker TRACKING.TRACKER] [--tracking.max_tracking TRACKING.MAX_TRACKING]
+ [--tracking.max_tracks TRACKING.MAX_TRACKS] [--tracking.target_instance_count TRACKING.TARGET_INSTANCE_COUNT]
+ [--tracking.pre_cull_to_target TRACKING.PRE_CULL_TO_TARGET] [--tracking.pre_cull_iou_threshold TRACKING.PRE_CULL_IOU_THRESHOLD]
[--tracking.post_connect_single_breaks TRACKING.POST_CONNECT_SINGLE_BREAKS]
- [--tracking.clean_instance_count TRACKING.CLEAN_INSTANCE_COUNT]
- [--tracking.clean_iou_threshold TRACKING.CLEAN_IOU_THRESHOLD]
- [--tracking.similarity TRACKING.SIMILARITY]
- [--tracking.match TRACKING.MATCH]
- [--tracking.track_window TRACKING.TRACK_WINDOW]
- [--tracking.save_shifted_instances TRACKING.SAVE_SHIFTED_INSTANCES]
- [--tracking.min_new_track_points TRACKING.MIN_NEW_TRACK_POINTS]
- [--tracking.min_match_points TRACKING.MIN_MATCH_POINTS]
- [--tracking.img_scale TRACKING.IMG_SCALE]
- [--tracking.of_window_size TRACKING.OF_WINDOW_SIZE]
- [--tracking.of_max_levels TRACKING.OF_MAX_LEVELS]
- [--tracking.kf_node_indices TRACKING.KF_NODE_INDICES]
+ [--tracking.clean_instance_count TRACKING.CLEAN_INSTANCE_COUNT] [--tracking.clean_iou_threshold TRACKING.CLEAN_IOU_THRESHOLD]
+ [--tracking.similarity TRACKING.SIMILARITY] [--tracking.match TRACKING.MATCH] [--tracking.robust TRACKING.ROBUST]
+ [--tracking.track_window TRACKING.TRACK_WINDOW] [--tracking.min_new_track_points TRACKING.MIN_NEW_TRACK_POINTS]
+ [--tracking.min_match_points TRACKING.MIN_MATCH_POINTS] [--tracking.img_scale TRACKING.IMG_SCALE]
+ [--tracking.of_window_size TRACKING.OF_WINDOW_SIZE] [--tracking.of_max_levels TRACKING.OF_MAX_LEVELS]
+ [--tracking.save_shifted_instances TRACKING.SAVE_SHIFTED_INSTANCES] [--tracking.kf_node_indices TRACKING.KF_NODE_INDICES]
[--tracking.kf_init_frame_count TRACKING.KF_INIT_FRAME_COUNT]
[data_path]
positional arguments:
- data_path Path to data to predict on. This can be a labels
- (.slp) file or any supported video format.
+ data_path Path to data to predict on. This can be a labels (.slp) file or any supported video format.
optional arguments:
-h, --help show this help message and exit
-m MODELS, --model MODELS
- Path to trained model directory (with
- training_config.json). Multiple models can be
- specified, each preceded by --model.
- --frames FRAMES List of frames to predict when running on a video. Can
- be specified as a comma separated list (e.g. 1,2,3) or
- a range separated by hyphen (e.g., 1-3, for 1,2,3). If
- not provided, defaults to predicting on the entire
- video.
+ Path to trained model directory (with training_config.json). Multiple models can be specified, each preceded by --model.
+ --frames FRAMES List of frames to predict when running on a video. Can be specified as a comma separated list (e.g. 1,2,3) or a range
+ separated by hyphen (e.g., 1-3, for 1,2,3). If not provided, defaults to predicting on the entire video.
--only-labeled-frames
- Only run inference on user labeled frames when running
- on labels dataset. This is useful for generating
- predictions to compare against ground truth.
+ Only run inference on user labeled frames when running on labels dataset. This is useful for generating predictions to compare
+ against ground truth.
--only-suggested-frames
- Only run inference on unlabeled suggested frames when
- running on labels dataset. This is useful for
- generating predictions for initialization during
- labeling.
+ Only run inference on unlabeled suggested frames when running on labels dataset. This is useful for generating predictions for
+ initialization during labeling.
-o OUTPUT, --output OUTPUT
- The output filename to use for the predicted data. If
- not provided, defaults to
- '[data_path].predictions.slp' if generating predictions or
- '[data_path].[tracker].[similarity method].[matching method].slp'
- if retracking predictions.
- --no-empty-frames Clear any empty frames that did not have any detected
- instances before saving to output.
- -n, --max_instances MAX_INSTANCES
- Limit maximum number of instances in multi-instance models.
- Not available for ID models. Defaults to None.
+ The output filename to use for the predicted data. If not provided, defaults to '[data_path].predictions.slp'.
+ --no-empty-frames Clear any empty frames that did not have any detected instances before saving to output.
--verbosity {none,rich,json}
- Verbosity of inference progress reporting. 'none' does
- not output anything during inference, 'rich' displays
- an updating progress bar, and 'json' outputs the
- progress as a JSON encoded response to the console.
+ Verbosity of inference progress reporting. 'none' does not output anything during inference, 'rich' displays an updating
+ progress bar, and 'json' outputs the progress as a JSON encoded response to the console.
--video.dataset VIDEO.DATASET
The dataset for HDF5 videos.
--video.input_format VIDEO.INPUT_FORMAT
The input_format for HDF5 videos.
--video.index VIDEO.INDEX
- The index of the video to run inference on. Only used if
- data_path points to a labels file.
- --cpu Run inference only on CPU. If not specified, will use
- available GPU.
+ Integer index of video in .slp file to predict on. To be used with an .slp path as an alternative to specifying the video
+ path.
+ --cpu Run inference only on CPU. If not specified, will use available GPU.
--first-gpu Run inference on the first GPU, if available.
--last-gpu Run inference on the last GPU, if available.
- --gpu GPU Run training on the i-th GPU on the system. If 'auto', run on
- the GPU with the highest percentage of available memory.
+ --gpu GPU Run training on the i-th GPU on the system. If 'auto', run on the GPU with the highest percentage of available memory.
--max_edge_length_ratio MAX_EDGE_LENGTH_RATIO
- The maximum expected length of a connected pair of points as a
- fraction of the image size. Candidate connections longer than
- this length will be penalized during matching. Only applies to
- bottom-up (PAF) models.
+ The maximum expected length of a connected pair of points as a fraction of the image size. Candidate connections longer than
+ this length will be penalized during matching. Only applies to bottom-up (PAF) models.
--dist_penalty_weight DIST_PENALTY_WEIGHT
- A coefficient to scale weight of the distance penalty. Set to
- values greater than 1.0 to enforce the distance penalty more
+ A coefficient to scale weight of the distance penalty. Set to values greater than 1.0 to enforce the distance penalty more
strictly. Only applies to bottom-up (PAF) models.
- --peak_threshold PEAK_THRESHOLD
- Minimum confidence map value to consider a peak as
- valid.
--batch_size BATCH_SIZE
- Number of frames to predict at a time. Larger values
- result in faster inference speeds, but require more
- memory.
- --open-in-gui Open the resulting predictions in the GUI when
- finished.
+ Number of frames to predict at a time. Larger values result in faster inference speeds, but require more memory.
+ --open-in-gui Open the resulting predictions in the GUI when finished.
+ --peak_threshold PEAK_THRESHOLD
+ Minimum confidence map value to consider a peak as valid.
+ -n MAX_INSTANCES, --max_instances MAX_INSTANCES
+ Limit maximum number of instances in multi-instance models. Not available for ID models. Defaults to None.
--tracking.tracker TRACKING.TRACKER
- Options: simple, flow, None (default: None)
+ Options: simple, flow, simplemaxtracks, flowmaxtracks, None (default: None)
+ --tracking.max_tracking TRACKING.MAX_TRACKING
+ If true then the tracker will cap the max number of tracks. (default: False)
+ --tracking.max_tracks TRACKING.MAX_TRACKS
+ Maximum number of tracks to be tracked by the tracker. (default: None)
--tracking.target_instance_count TRACKING.TARGET_INSTANCE_COUNT
- Target number of instances to track per frame.
- (default: 0)
+ Target number of instances to track per frame. (default: 0)
--tracking.pre_cull_to_target TRACKING.PRE_CULL_TO_TARGET
- If non-zero and target_instance_count is also non-
- zero, then cull instances over target count per frame
- *before* tracking. (default: 0)
+ If non-zero and target_instance_count is also non-zero, then cull instances over target count per frame *before* tracking.
+ (default: 0)
--tracking.pre_cull_iou_threshold TRACKING.PRE_CULL_IOU_THRESHOLD
- If non-zero and pre_cull_to_target also set, then use
- IOU threshold to remove overlapping instances over
- count *before* tracking. (default: 0)
+ If non-zero and pre_cull_to_target also set, then use IOU threshold to remove overlapping instances over count *before*
+ tracking. (default: 0)
--tracking.post_connect_single_breaks TRACKING.POST_CONNECT_SINGLE_BREAKS
- If non-zero and target_instance_count is also non-
- zero, then connect track breaks when exactly one track
- is lost and exactly one track is spawned in frame.
- (default: 0)
+ If non-zero and target_instance_count is also non-zero, then connect track breaks when exactly one track is lost and exactly
+ one track is spawned in frame. (default: 0)
--tracking.clean_instance_count TRACKING.CLEAN_INSTANCE_COUNT
- Target number of instances to clean *after* tracking.
- (default: 0)
+ Target number of instances to clean *after* tracking. (default: 0)
--tracking.clean_iou_threshold TRACKING.CLEAN_IOU_THRESHOLD
- IOU to use when culling instances *after* tracking.
- (default: 0)
+ IOU to use when culling instances *after* tracking. (default: 0)
--tracking.similarity TRACKING.SIMILARITY
Options: instance, centroid, iou (default: instance)
--tracking.match TRACKING.MATCH
Options: hungarian, greedy (default: greedy)
+ --tracking.robust TRACKING.ROBUST
+ Robust quantile of similarity score for instance matching. If equal to 1, keep the max similarity score (non-robust).
+ (default: 1)
--tracking.track_window TRACKING.TRACK_WINDOW
How many frames back to look for matches (default: 5)
- --tracking.save_shifted_instances TRACKING.SAVE_SHIFTED_INSTANCES
- For optical-flow: Save the shifted instances between
- elapsed frames for optimal comparison (default: 0)
--tracking.min_new_track_points TRACKING.MIN_NEW_TRACK_POINTS
- Minimum number of instance points for spawning new
- track (default: 0)
+ Minimum number of instance points for spawning new track (default: 0)
--tracking.min_match_points TRACKING.MIN_MATCH_POINTS
Minimum points for match candidates (default: 0)
--tracking.img_scale TRACKING.IMG_SCALE
For optical-flow: Image scale (default: 1.0)
--tracking.of_window_size TRACKING.OF_WINDOW_SIZE
- For optical-flow: Optical flow window size to consider
- at each pyramid (default: 21)
+ For optical-flow: Optical flow window size to consider at each pyramid (default: 21)
--tracking.of_max_levels TRACKING.OF_MAX_LEVELS
- For optical-flow: Number of pyramid scale levels to
- consider (default: 3)
+ For optical-flow: Number of pyramid scale levels to consider (default: 3)
+ --tracking.save_shifted_instances TRACKING.SAVE_SHIFTED_INSTANCES
+ If non-zero and tracking.tracker is set to flow, save the shifted instances between elapsed frames (default: 0)
--tracking.kf_node_indices TRACKING.KF_NODE_INDICES
- For Kalman filter: Indices of nodes to track.
- (default: )
+ For Kalman filter: Indices of nodes to track. (default: )
--tracking.kf_init_frame_count TRACKING.KF_INIT_FRAME_COUNT
- For Kalman filter: Number of frames to track with
- other tracker. 0 means no Kalman filters will be used.
- (default: 0)
+ For Kalman filter: Number of frames to track with other tracker. 0 means no Kalman filters will be used. (default: 0)
+```
+
+#### Examples:
+
+**1. Simple inference without tracking:**
+
+```none
+sleap-track -m "models/my_model" -o "output_predictions.slp" "input_video.mp4"
+```
+
+**2. Inference with multi-model pipelines (e.g., top-down):**
+
+```none
+sleap-track -m "models/centroid" -m "models/centered_instance" -o "output_predictions.slp" "input_video.mp4"
+```
+
+**3. Inference on suggested frames of a labeling project:**
+
+```none
+sleap-track -m "models/my_model" --only-suggested-frames -o "labels_with_predictions.slp" "labels.v005.slp"
+```
+
+The resulting `labels_with_predictions.slp` can then merged into the base labels project from the SLEAP GUI via **File** --> **Merge into project...**.
+
+**4. Inference with simple tracking:**
+
+```none
+sleap-track -m "models/my_model" --tracking.tracker simple -o "output_predictions.slp" "input_video.mp4"
+```
+
+**5. Inference with max tracks limit:**
+
+```none
+sleap-track -m "models/my_model" --tracking.tracker simplemaxtracks --tracking.max_tracking 1 --tracking.max_tracks 4 -o "output_predictions.slp" "input_video.mp4"
+```
+
+**6. Re-tracking without pose inference:**
+
+```none
+sleap-track --tracking.tracker simplemaxtracks --tracking.max_tracking 1 --tracking.max_tracks 4 -o "retracked.slp" "input_predictions.slp"
+```
+
+**7. Select GPU for pose inference:**
+
+```none
+sleap-track --gpu 1 ...
+```
+
+**8. Select subset of frames to predict on:**
+
+```none
+sleap-track -m "models/my_model" --frames 1000-2000 "input_video.mp4"
```
## Dataset files
diff --git a/sleap/config/pipeline_form.yaml b/sleap/config/pipeline_form.yaml
index 77722f0d4..cbcea2be5 100644
--- a/sleap/config/pipeline_form.yaml
+++ b/sleap/config/pipeline_form.yaml
@@ -376,28 +376,39 @@ inference:
none:
flow:
- - type: text
- text: 'Pre-tracker data cleaning:'
- - name: tracking.target_instance_count
- label: Target Number of Instances Per Frame
- type: optional_int
- none_label: No target
- default_disabled: true
- range: 1,100
- default: 1
- - name: tracking.pre_cull_to_target
- label: Cull to Target Instance Count
- type: bool
- default: false
- - name: tracking.pre_cull_iou_threshold
- label: Cull using IoU Threshold
- type: double
- default: 0.8
+ # - type: text
+ # text: 'Pre-tracker data cleaning:'
+ # - name: tracking.target_instance_count
+ # label: Target Number of Instances Per Frame
+ # type: optional_int
+ # none_label: No target
+ # default_disabled: true
+ # range: 1,100
+ # default: 1
+ # - name: tracking.pre_cull_to_target
+ # label: Cull to Target Instance Count
+ # type: bool
+ # default: false
+ # - name: tracking.pre_cull_iou_threshold
+ # label: Cull using IoU Threshold
+ # type: double
+ # default: 0.8
- type: text
text: 'Tracking with optical flow:
This tracker "shifts" instances from previous frames using optical flow
before matching instances in each frame to the shifted instances from
prior frames.'
+ # - name: tracking.max_tracking
+ # label: Limit max number of tracks
+ # type: bool
+ default: false
+ - name: tracking.max_tracks
+ label: Max number of tracks
+ type: optional_int
+ none_label: No limit
+ default_disabled: true
+ range: 1,100
+ default: 1
- name: tracking.similarity
label: Similarity Method
type: list
@@ -422,10 +433,10 @@ inference:
none_label: Use max (non-robust)
range: 0,1
default: 0.95
- - name: tracking.save_shifted_instances
- label: Save shifted instances
- type: bool
- default: false
+ # - name: tracking.save_shifted_instances
+ # label: Save shifted instances
+ # type: bool
+ # default: false
- type: text
text: 'Kalman filter-based tracking:
Uses the above tracking options to track instances for an initial
@@ -449,27 +460,38 @@ inference:
default: false
simple:
+ # - type: text
+ # text: 'Pre-tracker data cleaning:'
+ # - name: tracking.target_instance_count
+ # label: Target Number of Instances Per Frame
+ # type: optional_int
+ # none_label: No target
+ # default_disabled: true
+ # range: 1,100
+ # default: 1
+ # - name: tracking.pre_cull_to_target
+ # label: Cull to Target Instance Count
+ # type: bool
+ # default: false
+ # - name: tracking.pre_cull_iou_threshold
+ # label: Cull using IoU Threshold
+ # type: double
+ # default: 0.8
- type: text
- text: 'Pre-tracker data cleaning:'
- - name: tracking.target_instance_count
- label: Target Number of Instances Per Frame
+ text: 'Tracking:
+ This tracker assigns track identities by matching instances from prior
+ frames to instances on subsequent frames.'
+ # - name: tracking.max_tracking
+ # label: Limit max number of tracks
+ # type: bool
+ # default: false
+ - name: tracking.max_tracks
+ label: Max number of tracks
type: optional_int
- none_label: No target
+ none_label: No limit
default_disabled: true
range: 1,100
default: 1
- - name: tracking.pre_cull_to_target
- label: Cull to Target Instance Count
- type: bool
- default: false
- - name: tracking.pre_cull_iou_threshold
- label: Cull using IoU Threshold
- type: double
- default: 0.8
- - type: text
- text: 'Tracking:
- This tracker assigns track identities by matching instances from prior
- frames to instances on subsequent frames.'
- name: tracking.similarity
label: Similarity Method
type: list
diff --git a/sleap/gui/learning/runners.py b/sleap/gui/learning/runners.py
index 3909f1019..ca60c4127 100644
--- a/sleap/gui/learning/runners.py
+++ b/sleap/gui/learning/runners.py
@@ -224,6 +224,7 @@ def make_predict_cli_call(
optional_items_as_nones = (
"tracking.target_instance_count",
+ "tracking.max_tracks",
"tracking.kf_init_frame_count",
"tracking.robust",
"max_instances",
@@ -233,6 +234,16 @@ def make_predict_cli_call(
if key in self.inference_params and self.inference_params[key] is None:
del self.inference_params[key]
+ # Setting max_tracks to True means we want to use the max_tracking mode.
+ if "tracking.max_tracks" in self.inference_params:
+ self.inference_params["tracking.max_tracking"] = True
+
+ # Hacky: Update the tracker name to include "maxtracks" suffix.
+ if self.inference_params["tracking.tracker"] in ("simple", "flow"):
+ self.inference_params["tracking.tracker"] = (
+ self.inference_params["tracking.tracker"] + "maxtracks"
+ )
+
# --tracking.kf_init_frame_count enables the kalman filter tracking
# so if not set, then remove other (unused) args
if "tracking.kf_init_frame_count" not in self.inference_params:
@@ -241,6 +252,7 @@ def make_predict_cli_call(
bool_items_as_ints = (
"tracking.pre_cull_to_target",
+ "tracking.max_tracking",
"tracking.post_connect_single_breaks",
"tracking.save_shifted_instances",
)
diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py
index 222a80bda..6d7d24f8c 100644
--- a/sleap/nn/inference.py
+++ b/sleap/nn/inference.py
@@ -68,7 +68,7 @@
)
from sleap.nn.utils import reset_input_layer
from sleap.io.dataset import Labels
-from sleap.util import frame_list
+from sleap.util import frame_list, make_scoped_dictionary
from sleap.instance import PredictedInstance, LabeledFrame
from tensorflow.python.framework.convert_to_constants import (
@@ -4773,8 +4773,7 @@ def load_model(
be performed.
tracker_window: Number of frames of history to use when tracking. No effect when
`tracker` is `None`.
- tracker_max_instances: If not `None`, discard instances beyond this count when
- tracking. No effect when `tracker` is `None`.
+ tracker_max_instances: If not `None`, create at most this many tracks.
disable_gpu_preallocation: If `True` (the default), initialize the GPU and
disable preallocation of memory. This is necessary to prevent freezing on
some systems with low GPU memory and has negligible impact on performance.
@@ -4863,11 +4862,18 @@ def unpack_sleap_model(model_path):
)
predictor.verbosity = progress_reporting
if tracker is not None:
+ use_max_tracker = tracker_max_instances is not None
+ if use_max_tracker and not tracker.endswith("maxtracks"):
+ # Append maxtracks to the tracker name to use the right tracker variants.
+ tracker += "maxtracks"
+
predictor.tracker = Tracker.make_tracker_by_name(
tracker=tracker,
track_window=tracker_window,
post_connect_single_breaks=True,
- clean_instance_count=tracker_max_instances,
+ max_tracking=use_max_tracker,
+ max_tracks=tracker_max_instances,
+ # clean_instance_count=tracker_max_instances,
)
# Remove temp dirs.
@@ -5335,7 +5341,7 @@ def _make_tracker_from_cli(args: argparse.Namespace) -> Optional[Tracker]:
Returns:
An instance of `Tracker` or `None` if tracking method was not specified.
"""
- policy_args = sleap.util.make_scoped_dictionary(vars(args), exclude_nones=True)
+ policy_args = make_scoped_dictionary(vars(args), exclude_nones=True)
if "tracking" in policy_args:
tracker = Tracker.make_tracker_by_name(**policy_args["tracking"])
return tracker
diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py
index b861c359f..9865b7db5 100644
--- a/sleap/nn/tracking.py
+++ b/sleap/nn/tracking.py
@@ -88,6 +88,13 @@ class MatchedFrameInstances:
img_t: Optional[np.ndarray] = None
+@attr.s(auto_attribs=True, slots=True)
+class MatchedFrameInstance:
+ t: int
+ instance_t: InstanceType
+ img_t: Optional[np.ndarray] = None
+
+
@attr.s(auto_attribs=True, slots=True)
class MatchedShiftedFrameInstances:
ref_t: int
@@ -132,6 +139,66 @@ class FlowCandidateMaker:
def uses_image(self):
return True
+ def get_shifted_instances_from_earlier_time(
+ self, ref_t: int, ref_img: np.ndarray, ref_instances: List[InstanceType], t: int
+ ) -> (np.ndarray, List[InstanceType]):
+ """Generate shifted instances and corresponding image from earlier time.
+
+ Args:
+ ref_instances: Reference instances in the previous frame.
+ ref_img: Previous frame image as a numpy array.
+ ref_t: Previous frame time instance.
+ t: Current time instance.
+ """
+ for ti in reversed(range(ref_t, t)):
+ if (ref_t, ti) in self.shifted_instances:
+ ref_shifted_instances = self.shifted_instances[(ref_t, ti)]
+ # Use shifted instance as a reference
+ if len(ref_shifted_instances.instances_t) > 0:
+ ref_img = ref_shifted_instances.img_t
+ ref_instances = ref_shifted_instances.instances_t
+ break
+ return [ref_img, ref_instances]
+
+ def get_shifted_instances(
+ self,
+ ref_instances: List[InstanceType],
+ ref_img: np.ndarray,
+ ref_t: int,
+ img: np.ndarray,
+ t: int,
+ ) -> List[ShiftedInstance]:
+ """Returns a list of shifted instances and save shifted instances if needed.
+
+ Args:
+ ref_instances: Reference instances in the previous frame.
+ ref_img: Previous frame image as a numpy array.
+ ref_t: Previous frame time instance.
+ img: Current frame image as a numpy array.
+ t: Current time instance.
+ """
+ # Flow shift reference instances to current frame.
+ shifted_instances = self.flow_shift_instances(
+ ref_instances,
+ ref_img,
+ img,
+ min_shifted_points=self.min_points,
+ scale=self.img_scale,
+ window_size=self.of_window_size,
+ max_levels=self.of_max_levels,
+ )
+
+ # Save shifted instances.
+ if self.save_shifted_instances:
+ self.shifted_instances[(ref_t, t)] = MatchedShiftedFrameInstances(
+ ref_t,
+ t,
+ shifted_instances,
+ img,
+ )
+
+ return shifted_instances
+
def get_candidates(
self,
track_matching_queue: Deque[MatchedFrameInstances],
@@ -152,39 +219,15 @@ def get_candidates(
# Check if shifted instance was computed at earlier time
if self.save_shifted_instances:
- for ti in reversed(range(ref_t, t)):
- if (ref_t, ti) in self.shifted_instances:
- ref_shifted_instances = self.shifted_instances[(ref_t, ti)]
- # Use shifted instance as a reference
- if len(ref_shifted_instances.instances_t) > 0:
- ref_img = ref_shifted_instances.img_t
- ref_instances = ref_shifted_instances.instances_t
- break
+ ref_img, ref_instances = self.get_shifted_instances_from_earlier_time(
+ ref_t, ref_img, ref_instances, t
+ )
if len(ref_instances) > 0:
- # Flow shift reference instances to current frame.
- shifted_instances = self.flow_shift_instances(
- ref_instances,
- ref_img,
- img,
- min_shifted_points=self.min_points,
- scale=self.img_scale,
- window_size=self.of_window_size,
- max_levels=self.of_max_levels,
+ candidate_instances.extend(
+ self.get_shifted_instances(ref_instances, ref_img, ref_t, img, t)
)
- # Add to candidate pool.
- candidate_instances.extend(shifted_instances)
-
- # Save shifted instances.
- if self.save_shifted_instances:
- self.shifted_instances[(ref_t, t)] = MatchedShiftedFrameInstances(
- ref_t,
- t,
- shifted_instances,
- img,
- )
-
return candidate_instances
def prune_shifted_instances(self, t: int):
@@ -311,6 +354,86 @@ def flow_shift_instances(
return shifted_instances
+@attr.s(auto_attribs=True)
+class FlowMaxTracksCandidateMaker(FlowCandidateMaker):
+ """Class for producing optical flow shift matching candidates with maximum tracks.
+
+ Attributes:
+ max_tracks: The maximum number of tracks to avoid redundant tracks.
+
+ """
+
+ max_tracks: int = None
+
+ @staticmethod
+ def get_ref_instances(
+ ref_t: int,
+ ref_img: np.ndarray,
+ track_matching_queue_dict: Dict[Track, Deque[MatchedFrameInstance]],
+ ) -> List[InstanceType]:
+ """Generates a list of instances based on the reference time and image.
+
+ Args:
+ ref_t: Previous frame time instance.
+ ref_img: Previous frame image as a numpy array.
+ track_matching_queue_dict: A dictionary of mapping between the tracks
+ and the corresponding instances associated with the track.
+ """
+ instances = []
+ for track, matched_items in track_matching_queue_dict.items():
+ instances += [
+ item.instance_t
+ for item in matched_items
+ if item.t == ref_t and np.all(item.img_t == ref_img)
+ ]
+ return instances
+
+ def get_candidates(
+ self,
+ track_matching_queue_dict: Dict[Track, Deque[MatchedFrameInstance]],
+ t: int,
+ img: np.ndarray,
+ *args,
+ **kwargs,
+ ) -> List[ShiftedInstance]:
+ candidate_instances = []
+
+ # Prune old shifted instances to save time and memory
+ self.prune_shifted_instances(t)
+ # Storing the tracks from the dictionary for counting purpose.
+ tracks = []
+
+ for track, matched_items in track_matching_queue_dict.items():
+ if len(tracks) <= self.max_tracks:
+ tracks.append(track)
+ for matched_item in matched_items:
+ ref_t, ref_img = (
+ matched_item.t,
+ matched_item.img_t,
+ )
+ ref_instances = self.get_ref_instances(
+ ref_t, ref_img, track_matching_queue_dict
+ )
+
+ # Check if shifted instance was computed at earlier time
+ if self.save_shifted_instances:
+ (
+ ref_img,
+ ref_instances,
+ ) = self.get_shifted_instances_from_earlier_time(
+ ref_t, ref_img, ref_instances, t
+ )
+
+ if len(ref_instances) > 0:
+ candidate_instances.extend(
+ self.get_shifted_instances(
+ ref_instances, ref_img, ref_t, img, t
+ )
+ )
+
+ return candidate_instances
+
+
@attr.s(auto_attribs=True)
class SimpleCandidateMaker:
"""Class for producing list of matching candidates from prior frames."""
@@ -334,9 +457,35 @@ def get_candidates(
return candidate_instances
+@attr.s(auto_attribs=True)
+class SimpleMaxTracksCandidateMaker(SimpleCandidateMaker):
+ """Class to generate instances with maximum number of tracks from prior frames."""
+
+ max_tracks: int = None
+
+ def get_candidates(
+ self,
+ track_matching_queue_dict: Dict,
+ *args,
+ **kwargs,
+ ) -> List[InstanceType]:
+ # Create set of matchable candidate instances from each track.
+ candidate_instances = []
+ tracks = []
+ for track, matched_instances in track_matching_queue_dict.items():
+ if len(tracks) <= self.max_tracks:
+ tracks.append(track)
+ for ref_instance in matched_instances:
+ if ref_instance.instance_t.n_visible_points >= self.min_points:
+ candidate_instances.append(ref_instance.instance_t)
+ return candidate_instances
+
+
tracker_policies = dict(
simple=SimpleCandidateMaker,
flow=FlowCandidateMaker,
+ simplemaxtracks=SimpleMaxTracksCandidateMaker,
+ flowmaxtracks=FlowMaxTracksCandidateMaker,
)
similarity_policies = dict(
@@ -407,14 +556,17 @@ class Tracker(BaseTracker):
use a robust quantile similarity score for the track. If the value is 1,
use the max similarity (non-robust). For selecting a robust score,
0.95 is a good value.
+ max_tracking: Max tracking is incorporated when this is set to true.
"""
+ max_tracks: int = None
track_window: int = 5
similarity_function: Optional[Callable] = instance_similarity
matching_function: Callable = greedy_matching
candidate_maker: object = attr.ib(factory=FlowCandidateMaker)
+ max_tracking: bool = False # To enable maximum tracking.
- cleaner: Optional[Callable] = None # todo: deprecate
+ cleaner: Optional[Callable] = None # TODO: deprecate
target_instance_count: int = 0
pre_cull_function: Optional[Callable] = None
post_connect_single_breaks: bool = False
@@ -424,6 +576,10 @@ class Tracker(BaseTracker):
track_matching_queue: Deque[MatchedFrameInstances] = attr.ib()
+ # Hold track, instances with instances as a deque with length as track_window.
+ track_matching_queue_dict: Dict[Track, Deque[MatchedFrameInstance]] = attr.ib(
+ factory=dict
+ )
spawned_tracks: List[Track] = attr.ib(factory=list)
save_tracked_instances: bool = False
@@ -443,7 +599,11 @@ def _init_matching_queue(self):
return deque(maxlen=self.track_window)
def reset_candidates(self):
- self.track_matching_queue = deque(maxlen=self.track_window)
+ if self.max_tracking:
+ for track in self.track_matching_queue_dict:
+ self.track_matching_queue_dict[track] = deque(maxlen=self.track_window)
+ else:
+ self.track_matching_queue = deque(maxlen=self.track_window)
@property
def unique_tracks_in_queue(self) -> List[Track]:
@@ -454,6 +614,10 @@ def unique_tracks_in_queue(self) -> List[Track]:
for instance in match_item.instances_t:
unique_tracks.add(instance.track)
+ if self.max_tracking:
+ for track in self.track_matching_queue_dict.keys():
+ unique_tracks.add(track)
+
return list(unique_tracks)
@property
@@ -482,13 +646,30 @@ def track(
# Infer timestep if not provided.
if t is None:
- if len(self.track_matching_queue) > 0:
-
- # Default to last timestep + 1 if available.
- t = self.track_matching_queue[-1].t + 1
+ if self.max_tracking:
+ if len(self.track_matching_queue_dict) > 0:
+
+ # Default to last timestep + 1 if available.
+ # Here we find the track that has the most instances.
+ track_with_max_instances = max(
+ self.track_matching_queue_dict,
+ key=lambda track: len(self.track_matching_queue_dict[track]),
+ )
+ t = (
+ self.track_matching_queue_dict[track_with_max_instances][-1].t
+ + 1
+ )
+ else:
+ t = 0
else:
- t = 0
+ if len(self.track_matching_queue) > 0:
+
+ # Default to last timestep + 1 if available.
+ t = self.track_matching_queue[-1].t + 1
+
+ else:
+ t = 0
# Initialize containers for tracked instances at the current timestep.
tracked_instances = []
@@ -503,11 +684,19 @@ def track(
self.pre_cull_function(untracked_instances)
# Build a pool of matchable candidate instances.
- candidate_instances = self.candidate_maker.get_candidates(
- track_matching_queue=self.track_matching_queue,
- t=t,
- img=img,
- )
+ if self.max_tracking:
+ candidate_instances = self.candidate_maker.get_candidates(
+ track_matching_queue_dict=self.track_matching_queue_dict,
+ max_tracks=self.max_tracks,
+ t=t,
+ img=img,
+ )
+ else:
+ candidate_instances = self.candidate_maker.get_candidates(
+ track_matching_queue=self.track_matching_queue,
+ t=t,
+ img=img,
+ )
# Determine matches for untracked instances in current frame.
frame_matches = FrameMatches.from_candidate_instances(
@@ -531,10 +720,26 @@ def track(
self.spawn_for_untracked_instances(frame_matches.unmatched_instances, t)
)
- # Add the tracked instances to the matching buffer.
- self.track_matching_queue.append(
- MatchedFrameInstances(t, tracked_instances, img)
- )
+ # Add the tracked instances to the dictionary of matched instances.
+ if self.max_tracking:
+ for tracked_instance in tracked_instances:
+ if tracked_instance.track in self.track_matching_queue_dict:
+ self.track_matching_queue_dict[tracked_instance.track].append(
+ MatchedFrameInstance(t, tracked_instance, img)
+ )
+ elif len(self.track_matching_queue_dict) < self.max_tracks:
+ self.track_matching_queue_dict[tracked_instance.track] = deque(
+ maxlen=self.track_window
+ )
+ self.track_matching_queue_dict[tracked_instance.track].append(
+ MatchedFrameInstance(t, tracked_instance, img)
+ )
+
+ else:
+ # Add the tracked instances to the matching buffer.
+ self.track_matching_queue.append(
+ MatchedFrameInstances(t, tracked_instances, img)
+ )
# Save tracked instances internally.
if self.save_tracked_instances:
@@ -566,6 +771,13 @@ def spawn_for_untracked_instances(
if inst.n_visible_points < self.min_new_track_points:
continue
+ # Skip if we've reached the maximum number of tracks.
+ if (
+ self.max_tracking
+ and len(self.track_matching_queue_dict) >= self.max_tracks
+ ):
+ break
+
# Spawn new track.
new_track = Track(spawned_on=t, name=f"track_{len(self.spawned_tracks)}")
self.spawned_tracks.append(new_track)
@@ -598,6 +810,7 @@ def get_name(self):
@classmethod
def make_tracker_by_name(
cls,
+ # Tracker options
tracker: str = "flow",
similarity: str = "instance",
match: str = "greedy",
@@ -622,6 +835,9 @@ def make_tracker_by_name(
# Kalman filter options
kf_init_frame_count: int = 0,
kf_node_indices: Optional[list] = None,
+ # Max tracking options
+ max_tracks: Optional[int] = None,
+ max_tracking: bool = False,
**kwargs,
) -> BaseTracker:
@@ -652,6 +868,9 @@ def make_tracker_by_name(
candidate_maker.save_shifted_instances = save_shifted_instances
candidate_maker.track_window = track_window
+ if tracker == "simplemaxtracks" or tracker == "flowmaxtracks":
+ candidate_maker.max_tracks = max_tracks
+
cleaner = None
if clean_instance_count:
cleaner = TrackCleaner(
@@ -677,6 +896,8 @@ def pre_cull_function(inst_list):
candidate_maker=candidate_maker,
cleaner=cleaner,
pre_cull_function=pre_cull_function,
+ max_tracking=max_tracking,
+ max_tracks=max_tracks,
target_instance_count=target_instance_count,
post_connect_single_breaks=post_connect_single_breaks,
)
@@ -708,6 +929,16 @@ def get_by_name_factory_options(cls):
]
options.append(option)
+ option = dict(name="max_tracking", default=False)
+ option["type"] = bool
+ option["help"] = "If true then the tracker will cap the max number of tracks."
+ options.append(option)
+
+ option = dict(name="max_tracks", default=None)
+ option["type"] = int
+ option["help"] = "Maximum number of tracks to be tracked by the tracker."
+ options.append(option)
+
option = dict(name="target_instance_count", default=0)
option["type"] = int
option["help"] = "Target number of instances to track per frame."
@@ -854,6 +1085,19 @@ class FlowTracker(Tracker):
candidate_maker: object = attr.ib(factory=FlowCandidateMaker)
+attr.s(auto_attribs=True)
+
+
+class FlowMaxTracker(Tracker):
+ """Pre-configured tracker to use optical flow shifted candidates with max tracks."""
+
+ max_tracks: int = attr.ib(kw_only=True)
+ similarity_function: Callable = instance_similarity
+ matching_function: Callable = greedy_matching
+ candidate_maker: object = attr.ib(factory=FlowMaxTracksCandidateMaker)
+ max_tracking: bool = True
+
+
@attr.s(auto_attribs=True)
class SimpleTracker(Tracker):
"""A Tracker pre-configured to use simple, non-image-based candidates."""
@@ -863,6 +1107,17 @@ class SimpleTracker(Tracker):
candidate_maker: object = attr.ib(factory=SimpleCandidateMaker)
+@attr.s(auto_attribs=True)
+class SimpleMaxTracker(Tracker):
+ """Pre-configured tracker to use simple, non-image-based candidates with max tracks."""
+
+ max_tracks: int = attr.ib(kw_only=True)
+ similarity_function: Callable = instance_iou
+ matching_function: Callable = hungarian_matching
+ candidate_maker: object = attr.ib(factory=SimpleMaxTracksCandidateMaker)
+ max_tracking: bool = True
+
+
@attr.s(auto_attribs=True)
class KalmanInitSet:
init_frame_count: int
diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py
index cc65ac3fe..fe848bb1c 100644
--- a/tests/nn/test_inference.py
+++ b/tests/nn/test_inference.py
@@ -51,7 +51,13 @@
main as sleap_track,
export_cli as sleap_export,
)
-from sleap.nn.tracking import FlowCandidateMaker, Tracker
+from sleap.nn.tracking import (
+ MatchedFrameInstance,
+ FlowCandidateMaker,
+ FlowMaxTracksCandidateMaker,
+ Tracker,
+)
+from sleap.instance import Track
sleap.nn.system.use_cpu_only()
@@ -1335,7 +1341,13 @@ def test_topdown_id_predictor_save(
@pytest.mark.parametrize(
- "output_path,tracker_method", [("not_default", "flow"), (None, "simple")]
+ "output_path,tracker_method",
+ [
+ ("not_default", "flow"),
+ ("not_default", "flowmaxtracks"),
+ (None, "simple"),
+ (None, "simplemaxtracks"),
+ ],
)
def test_retracking(
centered_pair_predictions: Labels, tmpdir, output_path, tracker_method
@@ -1350,6 +1362,9 @@ def test_retracking(
)
if tracker_method == "flow":
cmd += " --tracking.save_shifted_instances 1"
+ elif tracker_method == "simplemaxtracks" or tracker_method == "flowmaxtracks":
+ cmd += " --tracking.max_tracking 1"
+ cmd += " --tracking.max_tracks 2"
if output_path == "not_default":
output_path = Path(tmpdir, "tracked_slp.slp")
cmd += f" --output {output_path}"
@@ -1477,6 +1492,58 @@ def test_flow_tracker(centered_pair_predictions: Labels, tmpdir):
assert abs(key[0] - key[1]) <= track_window # References within window
+@pytest.mark.parametrize(
+ "max_tracks, trackername",
+ [
+ (2, "flowmaxtracks"),
+ (2, "simplemaxtracks"),
+ ],
+)
+def test_max_tracks_matching_queue(
+ centered_pair_predictions: Labels, max_tracks, trackername
+):
+ """Test flow max tracks instance generation."""
+ labels: Labels = centered_pair_predictions
+ max_tracking = True
+ track_window = 5
+
+ # Setup flow max tracker
+ tracker: Tracker = Tracker.make_tracker_by_name(
+ tracker=trackername,
+ track_window=track_window,
+ save_shifted_instances=True,
+ max_tracking=max_tracking,
+ max_tracks=max_tracks,
+ )
+
+ tracker.candidate_maker = cast(FlowMaxTracksCandidateMaker, tracker.candidate_maker)
+
+ # Run tracking
+ frames = sorted(labels.labeled_frames, key=lambda lf: lf.frame_idx)
+
+ for lf in frames[:20]:
+
+ # Clear the tracks
+ for inst in lf.instances:
+ inst.track = None
+
+ track_args = dict(untracked_instances=lf.instances, img=lf.video[lf.frame_idx])
+ tracker.track(**track_args)
+
+ if trackername == "flowmaxtracks":
+ # Check that saved instances are pruned to track window
+ for key in tracker.candidate_maker.shifted_instances.keys():
+ assert lf.frame_idx - key[0] <= track_window # Keys are pruned
+ assert abs(key[0] - key[1]) <= track_window
+
+ # Check if the length of each of the tracks is not more than the track window
+ for track in tracker.track_matching_queue_dict.keys():
+ assert len(tracker.track_matching_queue_dict[track]) <= track_window
+
+ # Check if number of tracks that are generated are not more than the maximum tracks
+ assert len(tracker.track_matching_queue_dict) <= max_tracks
+
+
def test_movenet_inference(movenet_video):
inference_layer = MoveNetInferenceLayer(model_name="lightning")
inference_model = MoveNetInferenceModel(inference_layer)
diff --git a/tests/nn/test_tracker_components.py b/tests/nn/test_tracker_components.py
index 869ebc85c..f861241ee 100644
--- a/tests/nn/test_tracker_components.py
+++ b/tests/nn/test_tracker_components.py
@@ -14,7 +14,9 @@
from sleap.skeleton import Skeleton
-@pytest.mark.parametrize("tracker", ["simple", "flow"])
+@pytest.mark.parametrize(
+ "tracker", ["simple", "flow", "simplemaxtracks", "flowmaxtracks"]
+)
@pytest.mark.parametrize("similarity", ["instance", "iou", "centroid"])
@pytest.mark.parametrize("match", ["greedy", "hungarian"])
@pytest.mark.parametrize("count", [0, 2])
@@ -166,3 +168,222 @@ def test_frame_match_object():
assert matches[1].track == "track b"
assert matches[1].instance == "instance b"
+
+
+def make_insts(trx):
+ skel = Skeleton.from_names_and_edge_inds(
+ ["A", "B", "C"], edge_inds=[[0, 1], [1, 2]]
+ )
+
+ def make_inst(x, y):
+ pts = np.array([[-0.1, -0.1], [0.0, 0.0], [0.1, 0.1]]) + np.array([[x, y]])
+ return PredictedInstance.from_numpy(pts, [1, 1, 1], 1, skel)
+
+ insts = []
+ for frame in trx:
+ insts_frame = []
+ for x, y in frame:
+ insts_frame.append(make_inst(x, y))
+ insts.append(insts_frame)
+ return insts
+
+
+def test_max_tracking_large_gap_single_track():
+ # Track 2 instances with gap > window size
+ preds = make_insts(
+ [
+ [
+ (0, 0),
+ (0, 1),
+ ],
+ [
+ (0.1, 0),
+ (0.1, 1),
+ ],
+ [
+ (0.2, 0),
+ (0.2, 1),
+ ],
+ [
+ (0.3, 0),
+ ],
+ [
+ (0.4, 0),
+ ],
+ [
+ (0.5, 0),
+ (0.5, 1),
+ ],
+ [
+ (0.6, 0),
+ (0.6, 1),
+ ],
+ ]
+ )
+
+ tracker = Tracker.make_tracker_by_name(
+ tracker="simple",
+ # tracker="simplemaxtracks",
+ match="hungarian",
+ track_window=2,
+ # max_tracks=2,
+ # max_tracking=True,
+ )
+
+ tracked = []
+ for insts in preds:
+ tracked_insts = tracker.track(insts)
+ tracked.append(tracked_insts)
+ all_tracks = list(set([inst.track for frame in tracked for inst in frame]))
+
+ assert len(all_tracks) == 3
+
+ tracker = Tracker.make_tracker_by_name(
+ # tracker="simple",
+ tracker="simplemaxtracks",
+ match="hungarian",
+ track_window=2,
+ max_tracks=2,
+ max_tracking=True,
+ )
+
+ tracked = []
+ for insts in preds:
+ tracked_insts = tracker.track(insts)
+ tracked.append(tracked_insts)
+ all_tracks = list(set([inst.track for frame in tracked for inst in frame]))
+
+ assert len(all_tracks) == 2
+
+
+def test_max_tracking_small_gap_on_both_tracks():
+ # Test 2 instances with both tracks with gap > window size
+ preds = make_insts(
+ [
+ [
+ (0, 0),
+ (0, 1),
+ ],
+ [
+ (0.1, 0),
+ (0.1, 1),
+ ],
+ [
+ (0.2, 0),
+ (0.2, 1),
+ ],
+ [],
+ [],
+ [
+ (0.5, 0),
+ (0.5, 1),
+ ],
+ [
+ (0.6, 0),
+ (0.6, 1),
+ ],
+ ]
+ )
+
+ tracker = Tracker.make_tracker_by_name(
+ tracker="simple",
+ # tracker="simplemaxtracks",
+ match="hungarian",
+ track_window=2,
+ # max_tracks=2,
+ # max_tracking=True,
+ )
+
+ tracked = []
+ for insts in preds:
+ tracked_insts = tracker.track(insts)
+ tracked.append(tracked_insts)
+ all_tracks = list(set([inst.track for frame in tracked for inst in frame]))
+
+ assert len(all_tracks) == 4
+
+ tracker = Tracker.make_tracker_by_name(
+ # tracker="simple",
+ tracker="simplemaxtracks",
+ match="hungarian",
+ track_window=2,
+ max_tracks=2,
+ max_tracking=True,
+ )
+
+ tracked = []
+ for insts in preds:
+ tracked_insts = tracker.track(insts)
+ tracked.append(tracked_insts)
+ all_tracks = list(set([inst.track for frame in tracked for inst in frame]))
+
+ assert len(all_tracks) == 2
+
+
+def test_max_tracking_extra_detections():
+ # Test having more than 2 detected instances in a frame
+ preds = make_insts(
+ [
+ [
+ (0, 0),
+ (0, 1),
+ ],
+ [
+ (0.1, 0),
+ (0.1, 1),
+ ],
+ [
+ (0.2, 0),
+ (0.2, 1),
+ ],
+ [
+ (0.3, 0),
+ ],
+ [
+ (0.4, 0),
+ ],
+ [
+ (0.5, 0),
+ (0.5, 1),
+ ],
+ [
+ (0.6, 0),
+ (0.6, 1),
+ (0.6, 0.5),
+ ],
+ ]
+ )
+
+ tracker = Tracker.make_tracker_by_name(
+ tracker="simple",
+ # tracker="simplemaxtracks",
+ match="hungarian",
+ track_window=2,
+ # max_tracks=2,
+ # max_tracking=True,
+ )
+
+ tracked = []
+ for insts in preds:
+ tracked_insts = tracker.track(insts)
+ tracked.append(tracked_insts)
+ all_tracks = list(set([inst.track for frame in tracked for inst in frame]))
+
+ assert len(all_tracks) == 4
+
+ tracker = Tracker.make_tracker_by_name(
+ # tracker="simple",
+ tracker="simplemaxtracks",
+ match="hungarian",
+ track_window=2,
+ max_tracks=2,
+ max_tracking=True,
+ )
+
+ tracked = []
+ for insts in preds:
+ tracked_insts = tracker.track(insts)
+ tracked.append(tracked_insts)
+ all_tracks = list(set([inst.track for frame in tracked for inst in frame]))
+
+ assert len(all_tracks) == 2
diff --git a/tests/nn/test_tracking_integration.py b/tests/nn/test_tracking_integration.py
index 829b7c3cb..a6592dc4d 100644
--- a/tests/nn/test_tracking_integration.py
+++ b/tests/nn/test_tracking_integration.py
@@ -3,10 +3,42 @@
import os
import time
+import sleap
+from sleap.nn.inference import main as inference_cli
import sleap.nn.tracker.components
from sleap.io.dataset import Labels, LabeledFrame
+def test_simple_tracker(tmpdir, centered_pair_predictions_slp_path):
+ cli = (
+ "--tracking.tracker simple "
+ "--frames 200-300 "
+ f"-o {tmpdir}/simpletracks.slp "
+ f"{centered_pair_predictions_slp_path}"
+ )
+ inference_cli(cli.split(" "))
+
+ labels = sleap.load_file(f"{tmpdir}/simpletracks.slp")
+ assert len(labels.tracks) == 27
+
+
+def test_simplemax_tracker(tmpdir, centered_pair_predictions_slp_path):
+ cli = (
+ "--tracking.tracker simplemaxtracks "
+ "--tracking.max_tracking 1 --tracking.max_tracks 2 "
+ "--frames 200-300 "
+ f"-o {tmpdir}/simplemaxtracks.slp "
+ f"{centered_pair_predictions_slp_path}"
+ )
+ inference_cli(cli.split(" "))
+
+ labels = sleap.load_file(f"{tmpdir}/simplemaxtracks.slp")
+ assert len(labels.tracks) == 2
+
+
+# TODO: Refactor the below things into a real test suite.
+
+
def make_ground_truth(frames, tracker, gt_filename):
t0 = time.time()
new_labels = run_tracker(frames, tracker)
@@ -95,6 +127,8 @@ def main(f, dir):
trackers = dict(
simple=sleap.nn.tracker.simple.SimpleTracker,
flow=sleap.nn.tracker.flow.FlowTracker,
+ simplemaxtracks=sleap.nn.tracker.SimpleMaxTracker,
+ flowmaxtracks=sleap.nn.tracker.FlowMaxTracker,
)
matchers = dict(
hungarian=sleap.nn.tracker.components.hungarian_matching,
@@ -110,11 +144,21 @@ def main(f, dir):
0.25,
)
- def make_tracker(tracker_name, matcher_name, sim_name, scale=0):
- tracker = trackers[tracker_name](
- matching_function=matchers[matcher_name],
- similarity_function=similarities[sim_name],
- )
+ def make_tracker(
+ tracker_name, matcher_name, sim_name, max_tracks, max_tracking=False, scale=0
+ ):
+ if tracker_name == "simplemaxtracks" or tracker_name == "flowmaxtracks":
+ tracker = trackers[tracker_name](
+ matching_function=matchers[matcher_name],
+ similarity_function=similarities[sim_name],
+ max_tracks=max_tracks,
+ max_tracking=max_tracking,
+ )
+ else:
+ tracker = trackers[tracker_name](
+ matching_function=matchers[matcher_name],
+ similarity_function=similarities[sim_name],
+ )
if scale:
tracker.candidate_maker.img_scale = scale
return tracker
@@ -145,6 +189,28 @@ def make_tracker_and_filename(*args, **kwargs):
scale=scale,
)
f(frames, tracker, gt_filename)
+ elif tracker_name == "flowmaxtracks":
+ # If this tracker supports scale, try multiple scales
+ for scale in scales:
+ tracker, gt_filename = make_tracker_and_filename(
+ tracker_name=tracker_name,
+ matcher_name=matcher_name,
+ sim_name=sim_name,
+ max_tracks=2,
+ max_tracking=True,
+ scale=scale,
+ )
+ f(frames, tracker, gt_filename)
+ elif tracker_name == "simplemaxtracks":
+ tracker, gt_filename = make_tracker_and_filename(
+ tracker_name=tracker_name,
+ matcher_name=matcher_name,
+ sim_name=sim_name,
+ max_tracks=2,
+ max_tracking=True,
+ scale=0,
+ )
+ f(frames, tracker, gt_filename)
else:
tracker, gt_filename = make_tracker_and_filename(
tracker_name=tracker_name,