-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update motile based tracking implementation
- Loading branch information
1 parent
1a9e8af
commit a07bfd7
Showing
4 changed files
with
216 additions
and
57 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1 @@ | ||
from .naive_tracking import naive_tracking | ||
from .visualization import visualize_tracks |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,67 @@ | ||
import numpy as np | ||
from . import motile_utils as utils | ||
|
||
|
||
# TODO | ||
def track_with_motile( | ||
segmentation, | ||
relabel_segmentation=True, | ||
node_cost_function=None, | ||
edge_cost_function=None, | ||
node_selection_cost=0.95, | ||
**problem_kwargs, | ||
): | ||
pass | ||
"""Yadda yadda | ||
Note: this will relabel the segmentation unless `relabel_segmentation=False` | ||
""" | ||
# relabel sthe segmentation so that the ids are unique across time. | ||
# if `relabel_segmentation is False` the segmentation has to be in the correct format already | ||
if relabel_segmentation: | ||
segmentation = utils.relabel_segmentation_across_time(segmentation) | ||
|
||
# compute the node selection costs. | ||
# if `node_cost_function` is passed it is used to compute the costs. | ||
# otherwise we set a fixed node selection cost. | ||
if node_cost_function is None: | ||
n_nodes = int(segmentation.max()) | ||
node_costs = np.full(n_nodes, node_selection_cost) | ||
else: | ||
node_costs = node_cost_function(segmentation) | ||
|
||
# compute the edges and edge selection cost. | ||
# if `edge_cost_function` is not given we use the default approach. | ||
# (currently based on overlap of adjacent slices) | ||
if edge_cost_function is None: | ||
edge_cost_function = utils.compute_edges_from_overlap | ||
edges_and_costs = edge_cost_function(segmentation) | ||
|
||
# construct the problem | ||
solver, graph = utils.construct_problem(segmentation, node_costs, edges_and_costs, **problem_kwargs) | ||
|
||
# solver the problem | ||
solver.solve() | ||
|
||
# parse solution | ||
lineage_graph, lineages = utils.parse_result(solver, graph) | ||
track_graph, tracks = utils.lineage_graph_to_track_graph(lineage_graph, lineages) | ||
|
||
return segmentation, lineage_graph, lineages, track_graph, tracks | ||
|
||
|
||
def get_representation_for_napari(segmentation, lineage_graph, lineages, tracks, color_by_lineage=True): | ||
|
||
node_ids = np.unique(segmentation)[1:] | ||
node_to_track = utils.get_node_assignment(node_ids, tracks) | ||
node_to_lineage = utils.get_node_assignment(node_ids, lineages) | ||
|
||
# create label layer and track data for visualization in napari | ||
tracking_result = utils.recolor_segmentation( | ||
segmentation, node_to_lineage if color_by_lineage else node_to_track | ||
) | ||
|
||
# create the track data and corresponding parent graph | ||
track_data, parent_graph = utils.create_data_for_track_layer( | ||
segmentation, lineage_graph, node_to_track | ||
) | ||
|
||
return tracking_result, track_data, parent_graph |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import os | ||
from glob import glob | ||
|
||
import imageio.v3 as imageio | ||
import napari | ||
import numpy as np | ||
|
||
from elf.tracking.motile_tracking import track_with_motile, get_representation_for_napari | ||
|
||
|
||
def get_ctc_hela_data(): | ||
# load the data. you can download it from TODO | ||
image_folder = "/home/pape/Work/my_projects/micro-sam/examples/data/DIC-C2DH-HeLa.zip.unzip/DIC-C2DH-HeLa/01" | ||
images = np.stack([imageio.imread(path) for path in sorted(glob(os.path.join(image_folder, "*.tif")))]) | ||
|
||
seg_folder = "/home/pape/Work/my_projects/micro-sam/examples/finetuning/data/hela-ctc-01-gt.zip.unzip/masks" | ||
segmentation = np.stack([imageio.imread(path) for path in sorted(glob(os.path.join(seg_folder, "*.tif")))]) | ||
assert images.shape == segmentation.shape | ||
|
||
return images, segmentation | ||
|
||
|
||
def default_tracking(): | ||
images, segmentation = get_ctc_hela_data() | ||
|
||
# run the tracking and get visualization data for napari | ||
segmentation, lineage_graph, lineages, track_graph, tracks = track_with_motile(segmentation) | ||
tracking_result, track_data, parent_graph = get_representation_for_napari( | ||
segmentation, lineage_graph, lineages, tracks | ||
) | ||
|
||
# visualize with napari | ||
v = napari.Viewer() | ||
v.add_image(images) | ||
v.add_labels(tracking_result) | ||
v.add_tracks(track_data, name="tracks", graph=parent_graph) | ||
napari.run() | ||
|
||
|
||
def tracking_with_custom_edge_function(): | ||
from functools import partial | ||
from elf.tracking.motile_utils import compute_edges_from_centroid_distance | ||
|
||
images, segmentation = get_ctc_hela_data() | ||
|
||
# run the tracking and get visualization data for napari | ||
edge_cost_function = partial(compute_edges_from_centroid_distance, max_distance=50) | ||
segmentation, lineage_graph, lineages, track_graph, tracks = track_with_motile( | ||
segmentation, edge_cost_function=edge_cost_function | ||
) | ||
tracking_result, track_data, parent_graph = get_representation_for_napari( | ||
segmentation, lineage_graph, lineages, tracks | ||
) | ||
|
||
# visualize with napari | ||
v = napari.Viewer() | ||
v.add_image(images) | ||
v.add_labels(tracking_result) | ||
v.add_tracks(track_data, name="tracks", graph=parent_graph) | ||
napari.run() | ||
|
||
|
||
def main(): | ||
# default_tracking() | ||
tracking_with_custom_edge_function() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |