Skip to content

Commit

Permalink
Update motile based tracking implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Sep 6, 2023
1 parent 1a9e8af commit a07bfd7
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 57 deletions.
1 change: 0 additions & 1 deletion elf/tracking/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
from .naive_tracking import naive_tracking
from .visualization import visualize_tracks
63 changes: 61 additions & 2 deletions elf/tracking/motile_tracking.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,67 @@
import numpy as np
from . import motile_utils as utils


# TODO
def track_with_motile(
segmentation,
relabel_segmentation=True,
node_cost_function=None,
edge_cost_function=None,
node_selection_cost=0.95,
**problem_kwargs,
):
pass
"""Yadda yadda
Note: this will relabel the segmentation unless `relabel_segmentation=False`
"""
# relabel sthe segmentation so that the ids are unique across time.
# if `relabel_segmentation is False` the segmentation has to be in the correct format already
if relabel_segmentation:
segmentation = utils.relabel_segmentation_across_time(segmentation)

# compute the node selection costs.
# if `node_cost_function` is passed it is used to compute the costs.
# otherwise we set a fixed node selection cost.
if node_cost_function is None:
n_nodes = int(segmentation.max())
node_costs = np.full(n_nodes, node_selection_cost)
else:
node_costs = node_cost_function(segmentation)

# compute the edges and edge selection cost.
# if `edge_cost_function` is not given we use the default approach.
# (currently based on overlap of adjacent slices)
if edge_cost_function is None:
edge_cost_function = utils.compute_edges_from_overlap
edges_and_costs = edge_cost_function(segmentation)

# construct the problem
solver, graph = utils.construct_problem(segmentation, node_costs, edges_and_costs, **problem_kwargs)

# solver the problem
solver.solve()

# parse solution
lineage_graph, lineages = utils.parse_result(solver, graph)
track_graph, tracks = utils.lineage_graph_to_track_graph(lineage_graph, lineages)

return segmentation, lineage_graph, lineages, track_graph, tracks


def get_representation_for_napari(segmentation, lineage_graph, lineages, tracks, color_by_lineage=True):

node_ids = np.unique(segmentation)[1:]
node_to_track = utils.get_node_assignment(node_ids, tracks)
node_to_lineage = utils.get_node_assignment(node_ids, lineages)

# create label layer and track data for visualization in napari
tracking_result = utils.recolor_segmentation(
segmentation, node_to_lineage if color_by_lineage else node_to_track
)

# create the track data and corresponding parent graph
track_data, parent_graph = utils.create_data_for_track_layer(
segmentation, lineage_graph, node_to_track
)

return tracking_result, track_data, parent_graph
140 changes: 86 additions & 54 deletions elf/tracking/motile_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""Utility functionality for tracking with [motile](ttps://github.com/funkelab/motile).
"""Utility functionality for tracking with [motile](https://github.com/funkelab/motile).
"""
from copy import deepcopy

import motile
import networkx as nx
import nifty.gt as ngt
import nifty.ground_truth as ngt
import numpy as np

from motile import costs, constraints
from nifty.tools import takeDict
from scipy.spatial.distance import cdist
from skimage.measure import regionprops
from skimage.segmentation import relabel_sequential
from tqdm import trange
Expand All @@ -18,8 +20,6 @@
#

def compute_edges_from_overlap(segmentation, verbose=True):
"""
"""
def compute_overlap_between_frames(frame_a, frame_b):
overlap_function = ngt.overlap(frame_a, frame_b)

Expand All @@ -45,20 +45,59 @@ def compute_overlap_between_frames(frame_a, frame_b):
for t in trange(segmentation.shape[0] - 1, disable=not verbose, desc="Compute edges via overlap"):
this_frame = segmentation[t]
next_frame = segmentation[t + 1]
frame_edges = compute_edges_from_overlap(this_frame, next_frame)
frame_edges = compute_overlap_between_frames(this_frame, next_frame)
edges.extend(frame_edges)
return edges


# TODO
def compute_edges_from_centroid_distance(segmentation, max_distance, verbose=True):
pass
def compute_edges_from_centroid_distance(segmentation, max_distance, normalize_distances=True, verbose=True):
nt = segmentation.shape[0]
props = regionprops(segmentation)
centroids_and_labels = [[prop.centroid[0], prop.centroid[1:], prop.label] for prop in props]

centroids, labels = {}, {}
for t, centroid, label in centroids_and_labels:
centroids[t] = centroids.get(t, []) + [centroid]
labels[t] = labels.get(t, []) + [label]
centroids = {t: np.stack(np.array(val)) for t, val in centroids.items()}
labels = {t: np.array(val) for t, val in labels.items()}

def compute_dist_between_frames(t):
centers_a, centers_b = centroids[t], centroids[t + 1]
labels_a, labels_b = labels[t], labels[t + 1]
assert len(centers_a) == len(labels_a)
assert len(centers_b) == len(labels_b)

distances = cdist(centers_a, centers_b)
edge_mask = distances <= max_distance
distance_values = distances[edge_mask]

idx_a, idx_b = np.where(edge_mask)
source_ids, target_ids = labels_a[idx_a], labels_b[idx_b]
assert len(distance_values) == len(source_ids) == len(target_ids)

return source_ids, target_ids, distance_values
# return edges

source_ids, target_ids, distances = [], [], []
for t in trange(nt - 1, disable=not verbose, desc="Compute edges via centroid distance"):
this_src, this_tgt, this_dist = compute_dist_between_frames(t)
source_ids.extend(this_src), target_ids.extend(this_tgt), distances.extend(this_dist)

if normalize_distances:
distances = np.array(distances)
max_dist = distances.max()
distances = 1.0 - distances / max_dist

edges = [
{"source": source_id, "target": target_id, "score": distance}
for source_id, target_id, distance in zip(source_ids, target_ids, distances)
]
return edges


# TODO does this work for 4d data (time + 3d)? if no we need to iterate over the time axis
def compute_node_costs_from_foreground_probabilities(segmentation, probabilities, cost_attribute="mean_intensity"):
"""
"""
props = regionprops(segmentation, probabilities)
costs = [getattr(prop, cost_attribute) for prop in props]
return costs
Expand All @@ -69,8 +108,6 @@ def compute_node_costs_from_foreground_probabilities(segmentation, probabilities


def relabel_segmentation_across_time(segmentation):
"""
"""
offset = 0
relabeled = []
for frame in segmentation:
Expand All @@ -88,8 +125,6 @@ def construct_problem(
max_parents=1,
max_children=2,
):
"""
"""
node_ids, indexes = np.unique(segmentation, return_index=True)
indexes = np.unravel_index(indexes, shape=segmentation.shape)
timeframes = indexes[0]
Expand Down Expand Up @@ -129,7 +164,7 @@ def construct_problem(
solver.add_costs(costs.Appear(constant=1.0))
solver.add_costs(costs.Split(constant=1.0))

return graph, solver
return solver, graph


#
Expand All @@ -138,8 +173,6 @@ def construct_problem(


def parse_result(solver, graph):
"""
"""
lineage_graph = nx.DiGraph()

node_indicators = solver.get_variables(motile.variables.NodeSelected)
Expand All @@ -160,24 +193,7 @@ def parse_result(solver, graph):
return lineage_graph, lineages


def recolor_segmentation(segmentation, lineages):
# generate a dictionary that maps each node id (= segment id) to its lineage
node_to_lineage = {
node_id: lineage_id for lineage_id, nodes in lineages.items() for node_id in nodes
}

# everything that was not selected gets mapped to 0
seg_ids = np.unique(segmentation)
not_selected = list(set(seg_ids) - set(node_to_lineage.keys()))
node_to_lineage.update({not_select: 0 for not_select in not_selected})

# relabel based on the dict
# (this can also be with a numpy function, but I only know my convenience function by heart...)
recolored_segmentation = takeDict(node_to_lineage, segmentation)
return recolored_segmentation, node_to_lineage


def lineage_graph_to_track_graph(lineage_graph, lineages, node_to_lineage):
def lineage_graph_to_track_graph(lineage_graph, lineages):
# create a new graph that only contains the tracks by not connecting nodes with a degree of 2
track_graph = nx.DiGraph()
track_graph.add_nodes_from(lineage_graph.nodes)
Expand All @@ -188,37 +204,53 @@ def lineage_graph_to_track_graph(lineage_graph, lineages, node_to_lineage):
# normal track continuation
if len(out_edges) == 1:
track_graph.add_edge(u, v)
# otherwise track ends at division and we don't contnue
# otherwise track ends at division and we don't continue

# use connected components to find the tracks
tracks = nx.weakly_connected_components(track_graph)
tracks = {track_id: list(track) for track_id, track in enumerate(tracks, 1)}

# find the mapping of nodes to tracks and tracks to lineage
track_to_nodes = {track_id: list(track) for track_id, track in enumerate(tracks, 1)}
node_to_track = {
node_id: track_id for track_id, nodes in track_to_nodes.items() for node_id in nodes
return track_graph, tracks


def get_node_assignment(node_ids, assignments):
# generate a dictionary that maps each node id (= segment id) to its assignment
node_assignment = {
node_id: assignment_id for assignment_id, nodes in assignments.items() for node_id in nodes
}
track_to_lineage = {lineage_id: node_to_track[lineage[0]] for lineage_id, lineage in lineages.items()}

return node_to_track, track_to_lineage
# everything that was not selected gets mapped to 0
not_selected = list(set(node_ids) - set(node_assignment.keys()))
node_assignment.update({not_select: 0 for not_select in not_selected})

return node_assignment


def recolor_segmentation(segmentation, node_to_assignment):
# we need to add a value for mapping 0, otherwise the function fails
node_to_assignment_ = deepcopy(node_to_assignment)
node_to_assignment_[0] = 0
recolored_segmentation = takeDict(node_to_assignment_, segmentation)
return recolored_segmentation

def create_napari_track_layer(segmentation, lineage_graph, lineages, node_to_lineage):
# extract the graph with only tracks (without divisions)
# and the graph of tracks into lineages
node_to_track, track_to_lineage = lineage_graph_to_track_graph(lineage_graph, lineages, node_to_lineage)

# get the regionproperties for centroids
def create_data_for_track_layer(segmentation, lineage_graph, node_to_track):
# compute regionpros and extract centroids
props = regionprops(segmentation)
centroids = {
prop.label: prop.centroid for prop in props
}
centroids = {prop.label: prop.centroid for prop in props}

# create the track data representation for napari
track_data = [
[node_to_track[node_id]] + list(centroid) for node_id, centroid in centroids.items()
if node_id in node_to_track
]

lineage_to_tracks = {}
for track, lineage in track_to_lineage.items():
lineage_to_tracks[lineage] = lineage_to_tracks.get(lineage, []) + [track]
# create the parent graph for tracks
parent_graph = {}
for (u, v), features in lineage_graph.edges.items():
out_edges = lineage_graph.out_edges(u)
if len(out_edges) == 2:
track_u, track_v = node_to_track[u], node_to_track[v]
parent_graph[track_v] = parent_graph.get(track_v, []) + [track_u]

return track_data, lineage_to_tracks
return track_data, parent_graph
69 changes: 69 additions & 0 deletions example/tracking/ctc_hela.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import os
from glob import glob

import imageio.v3 as imageio
import napari
import numpy as np

from elf.tracking.motile_tracking import track_with_motile, get_representation_for_napari


def get_ctc_hela_data():
# load the data. you can download it from TODO
image_folder = "/home/pape/Work/my_projects/micro-sam/examples/data/DIC-C2DH-HeLa.zip.unzip/DIC-C2DH-HeLa/01"
images = np.stack([imageio.imread(path) for path in sorted(glob(os.path.join(image_folder, "*.tif")))])

seg_folder = "/home/pape/Work/my_projects/micro-sam/examples/finetuning/data/hela-ctc-01-gt.zip.unzip/masks"
segmentation = np.stack([imageio.imread(path) for path in sorted(glob(os.path.join(seg_folder, "*.tif")))])
assert images.shape == segmentation.shape

return images, segmentation


def default_tracking():
images, segmentation = get_ctc_hela_data()

# run the tracking and get visualization data for napari
segmentation, lineage_graph, lineages, track_graph, tracks = track_with_motile(segmentation)
tracking_result, track_data, parent_graph = get_representation_for_napari(
segmentation, lineage_graph, lineages, tracks
)

# visualize with napari
v = napari.Viewer()
v.add_image(images)
v.add_labels(tracking_result)
v.add_tracks(track_data, name="tracks", graph=parent_graph)
napari.run()


def tracking_with_custom_edge_function():
from functools import partial
from elf.tracking.motile_utils import compute_edges_from_centroid_distance

images, segmentation = get_ctc_hela_data()

# run the tracking and get visualization data for napari
edge_cost_function = partial(compute_edges_from_centroid_distance, max_distance=50)
segmentation, lineage_graph, lineages, track_graph, tracks = track_with_motile(
segmentation, edge_cost_function=edge_cost_function
)
tracking_result, track_data, parent_graph = get_representation_for_napari(
segmentation, lineage_graph, lineages, tracks
)

# visualize with napari
v = napari.Viewer()
v.add_image(images)
v.add_labels(tracking_result)
v.add_tracks(track_data, name="tracks", graph=parent_graph)
napari.run()


def main():
# default_tracking()
tracking_with_custom_edge_function()


if __name__ == "__main__":
main()

0 comments on commit a07bfd7

Please sign in to comment.