Skip to content

Commit

Permalink
Limit max tracks via track-local queues (#1447)
Browse files Browse the repository at this point in the history
* Initial commit

* format files

* [wip] adding local deque for tracks

* format files

* [wip] adding local deque for tracks

* [wip] Add max tracking for simpletracker

* [wip] Add max tracking for simple tracker

* [wip] add missing argument

* [wip] Add and modify test functions

* [wip] Add and modify test functions

* Bug fix and refactoring code

* [wip] Add max tracking for flow tracker.

* [wip] Including suggested changes

* [wip] refactor code

* Add test function to check max tracks

* Added suggestions and feedback

* Prevent the creation of more than max tracks when we have unmatched detections

* Add tests

* Use maximum tracking by default when loading model via high level API

* Lint

* Fix integration test

* Refactor max tracker tests

* Add integration test for CLI

* typo

* Add max tracks to the tracking GUI

* Update CLI docs and add examples

---------

Co-authored-by: Talmo Pereira <talmo@princeton.edu>
Co-authored-by: Talmo Pereira <talmo@salk.edu>
  • Loading branch information
3 people authored Sep 8, 2023
1 parent 64655d6 commit 93ef288
Show file tree
Hide file tree
Showing 8 changed files with 858 additions and 201 deletions.
220 changes: 114 additions & 106 deletions docs/guides/cli.md

Large diffs are not rendered by default.

96 changes: 59 additions & 37 deletions sleap/config/pipeline_form.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -376,28 +376,39 @@ inference:
none:

flow:
- type: text
text: '<b>Pre-tracker data cleaning</b>:'
- name: tracking.target_instance_count
label: Target Number of Instances Per Frame
type: optional_int
none_label: No target
default_disabled: true
range: 1,100
default: 1
- name: tracking.pre_cull_to_target
label: Cull to Target Instance Count
type: bool
default: false
- name: tracking.pre_cull_iou_threshold
label: Cull using IoU Threshold
type: double
default: 0.8
# - type: text
# text: '<b>Pre-tracker data cleaning</b>:'
# - name: tracking.target_instance_count
# label: Target Number of Instances Per Frame
# type: optional_int
# none_label: No target
# default_disabled: true
# range: 1,100
# default: 1
# - name: tracking.pre_cull_to_target
# label: Cull to Target Instance Count
# type: bool
# default: false
# - name: tracking.pre_cull_iou_threshold
# label: Cull using IoU Threshold
# type: double
# default: 0.8
- type: text
text: '<b>Tracking with optical flow</b>:<br />
This tracker "shifts" instances from previous frames using optical flow
before matching instances in each frame to the <i>shifted</i> instances from
prior frames.'
# - name: tracking.max_tracking
# label: Limit max number of tracks
# type: bool
default: false
- name: tracking.max_tracks
label: Max number of tracks
type: optional_int
none_label: No limit
default_disabled: true
range: 1,100
default: 1
- name: tracking.similarity
label: Similarity Method
type: list
Expand All @@ -422,10 +433,10 @@ inference:
none_label: Use max (non-robust)
range: 0,1
default: 0.95
- name: tracking.save_shifted_instances
label: Save shifted instances
type: bool
default: false
# - name: tracking.save_shifted_instances
# label: Save shifted instances
# type: bool
# default: false
- type: text
text: '<b>Kalman filter-based tracking</b>:<br />
Uses the above tracking options to track instances for an initial
Expand All @@ -449,27 +460,38 @@ inference:
default: false

simple:
# - type: text
# text: '<b>Pre-tracker data cleaning</b>:'
# - name: tracking.target_instance_count
# label: Target Number of Instances Per Frame
# type: optional_int
# none_label: No target
# default_disabled: true
# range: 1,100
# default: 1
# - name: tracking.pre_cull_to_target
# label: Cull to Target Instance Count
# type: bool
# default: false
# - name: tracking.pre_cull_iou_threshold
# label: Cull using IoU Threshold
# type: double
# default: 0.8
- type: text
text: '<b>Pre-tracker data cleaning</b>:'
- name: tracking.target_instance_count
label: Target Number of Instances Per Frame
text: '<b>Tracking</b>:<br />
This tracker assigns track identities by matching instances from prior
frames to instances on subsequent frames.'
# - name: tracking.max_tracking
# label: Limit max number of tracks
# type: bool
# default: false
- name: tracking.max_tracks
label: Max number of tracks
type: optional_int
none_label: No target
none_label: No limit
default_disabled: true
range: 1,100
default: 1
- name: tracking.pre_cull_to_target
label: Cull to Target Instance Count
type: bool
default: false
- name: tracking.pre_cull_iou_threshold
label: Cull using IoU Threshold
type: double
default: 0.8
- type: text
text: '<b>Tracking</b>:<br />
This tracker assigns track identities by matching instances from prior
frames to instances on subsequent frames.'
- name: tracking.similarity
label: Similarity Method
type: list
Expand Down
12 changes: 12 additions & 0 deletions sleap/gui/learning/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def make_predict_cli_call(

optional_items_as_nones = (
"tracking.target_instance_count",
"tracking.max_tracks",
"tracking.kf_init_frame_count",
"tracking.robust",
"max_instances",
Expand All @@ -233,6 +234,16 @@ def make_predict_cli_call(
if key in self.inference_params and self.inference_params[key] is None:
del self.inference_params[key]

# Setting max_tracks to True means we want to use the max_tracking mode.
if "tracking.max_tracks" in self.inference_params:
self.inference_params["tracking.max_tracking"] = True

# Hacky: Update the tracker name to include "maxtracks" suffix.
if self.inference_params["tracking.tracker"] in ("simple", "flow"):
self.inference_params["tracking.tracker"] = (
self.inference_params["tracking.tracker"] + "maxtracks"
)

# --tracking.kf_init_frame_count enables the kalman filter tracking
# so if not set, then remove other (unused) args
if "tracking.kf_init_frame_count" not in self.inference_params:
Expand All @@ -241,6 +252,7 @@ def make_predict_cli_call(

bool_items_as_ints = (
"tracking.pre_cull_to_target",
"tracking.max_tracking",
"tracking.post_connect_single_breaks",
"tracking.save_shifted_instances",
)
Expand Down
16 changes: 11 additions & 5 deletions sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
)
from sleap.nn.utils import reset_input_layer
from sleap.io.dataset import Labels
from sleap.util import frame_list
from sleap.util import frame_list, make_scoped_dictionary
from sleap.instance import PredictedInstance, LabeledFrame

from tensorflow.python.framework.convert_to_constants import (
Expand Down Expand Up @@ -4773,8 +4773,7 @@ def load_model(
be performed.
tracker_window: Number of frames of history to use when tracking. No effect when
`tracker` is `None`.
tracker_max_instances: If not `None`, discard instances beyond this count when
tracking. No effect when `tracker` is `None`.
tracker_max_instances: If not `None`, create at most this many tracks.
disable_gpu_preallocation: If `True` (the default), initialize the GPU and
disable preallocation of memory. This is necessary to prevent freezing on
some systems with low GPU memory and has negligible impact on performance.
Expand Down Expand Up @@ -4863,11 +4862,18 @@ def unpack_sleap_model(model_path):
)
predictor.verbosity = progress_reporting
if tracker is not None:
use_max_tracker = tracker_max_instances is not None
if use_max_tracker and not tracker.endswith("maxtracks"):
# Append maxtracks to the tracker name to use the right tracker variants.
tracker += "maxtracks"

predictor.tracker = Tracker.make_tracker_by_name(
tracker=tracker,
track_window=tracker_window,
post_connect_single_breaks=True,
clean_instance_count=tracker_max_instances,
max_tracking=use_max_tracker,
max_tracks=tracker_max_instances,
# clean_instance_count=tracker_max_instances,
)

# Remove temp dirs.
Expand Down Expand Up @@ -5335,7 +5341,7 @@ def _make_tracker_from_cli(args: argparse.Namespace) -> Optional[Tracker]:
Returns:
An instance of `Tracker` or `None` if tracking method was not specified.
"""
policy_args = sleap.util.make_scoped_dictionary(vars(args), exclude_nones=True)
policy_args = make_scoped_dictionary(vars(args), exclude_nones=True)
if "tracking" in policy_args:
tracker = Tracker.make_tracker_by_name(**policy_args["tracking"])
return tracker
Expand Down
Loading

0 comments on commit 93ef288

Please sign in to comment.