diff --git a/elf/tracking/__init__.py b/elf/tracking/__init__.py index bb7a15c..e69de29 100644 --- a/elf/tracking/__init__.py +++ b/elf/tracking/__init__.py @@ -1,2 +0,0 @@ -from .naive_tracking import naive_tracking -from .visualization import visualize_tracks diff --git a/elf/tracking/motile_tracking.py b/elf/tracking/motile_tracking.py new file mode 100644 index 0000000..3f362d9 --- /dev/null +++ b/elf/tracking/motile_tracking.py @@ -0,0 +1,67 @@ +import numpy as np +from . import motile_utils as utils + + +def track_with_motile( + segmentation, + relabel_segmentation=True, + node_cost_function=None, + edge_cost_function=None, + node_selection_cost=0.95, + **problem_kwargs, +): + """Yadda yadda + + Note: this will relabel the segmentation unless `relabel_segmentation=False` + """ + # relabel sthe segmentation so that the ids are unique across time. + # if `relabel_segmentation is False` the segmentation has to be in the correct format already + if relabel_segmentation: + segmentation = utils.relabel_segmentation_across_time(segmentation) + + # compute the node selection costs. + # if `node_cost_function` is passed it is used to compute the costs. + # otherwise we set a fixed node selection cost. + if node_cost_function is None: + n_nodes = int(segmentation.max()) + node_costs = np.full(n_nodes, node_selection_cost) + else: + node_costs = node_cost_function(segmentation) + + # compute the edges and edge selection cost. + # if `edge_cost_function` is not given we use the default approach. + # (currently based on overlap of adjacent slices) + if edge_cost_function is None: + edge_cost_function = utils.compute_edges_from_overlap + edges_and_costs = edge_cost_function(segmentation) + + # construct the problem + solver, graph = utils.construct_problem(segmentation, node_costs, edges_and_costs, **problem_kwargs) + + # solver the problem + solver.solve() + + # parse solution + lineage_graph, lineages = utils.parse_result(solver, graph) + track_graph, tracks = utils.lineage_graph_to_track_graph(lineage_graph, lineages) + + return segmentation, lineage_graph, lineages, track_graph, tracks + + +def get_representation_for_napari(segmentation, lineage_graph, lineages, tracks, color_by_lineage=True): + + node_ids = np.unique(segmentation)[1:] + node_to_track = utils.get_node_assignment(node_ids, tracks) + node_to_lineage = utils.get_node_assignment(node_ids, lineages) + + # create label layer and track data for visualization in napari + tracking_result = utils.recolor_segmentation( + segmentation, node_to_lineage if color_by_lineage else node_to_track + ) + + # create the track data and corresponding parent graph + track_data, parent_graph = utils.create_data_for_track_layer( + segmentation, lineage_graph, node_to_track + ) + + return tracking_result, track_data, parent_graph diff --git a/elf/tracking/motile_utils.py b/elf/tracking/motile_utils.py new file mode 100644 index 0000000..08db836 --- /dev/null +++ b/elf/tracking/motile_utils.py @@ -0,0 +1,254 @@ +"""Utility functionality for tracking with [motile](https://github.com/funkelab/motile). +""" +from copy import deepcopy + +import motile +import networkx as nx +import nifty.ground_truth as ngt +import numpy as np + +from motile import costs, constraints +from nifty.tools import takeDict +from scipy.spatial.distance import cdist +from skimage.measure import regionprops +from skimage.segmentation import relabel_sequential +from tqdm import trange + + +# +# Simple functionality for computing edges and costs +# + +def compute_edges_from_overlap(segmentation, verbose=True): + def compute_overlap_between_frames(frame_a, frame_b): + overlap_function = ngt.overlap(frame_a, frame_b) + + node_ids = np.unique(frame_a)[1:] + overlaps = [overlap_function.overlapArraysNormalized(node_id) for node_id in node_ids] + + source_ids = [src for node_id, ovlp in zip(node_ids, overlaps) for src in [node_id] * len(ovlp[0])] + target_ids = [ov for ovlp in overlaps for ov in ovlp[0]] + overlap_values = [ov for ovlp in overlaps for ov in ovlp[1]] + assert len(source_ids) == len(target_ids) == len(overlap_values),\ + f"{len(source_ids)}, {len(target_ids)}, {len(overlap_values)}" + + edges = [ + {"source": source_id, "target": target_id, "score": ovlp} + for source_id, target_id, ovlp in zip(source_ids, target_ids, overlap_values) + ] + + # filter out zeros + edges = [edge for edge in edges if edge["target"] != 0] + return edges + + edges = [] + for t in trange(segmentation.shape[0] - 1, disable=not verbose, desc="Compute edges via overlap"): + this_frame = segmentation[t] + next_frame = segmentation[t + 1] + frame_edges = compute_overlap_between_frames(this_frame, next_frame) + edges.extend(frame_edges) + return edges + + +def compute_edges_from_centroid_distance(segmentation, max_distance, normalize_distances=True, verbose=True): + nt = segmentation.shape[0] + props = regionprops(segmentation) + centroids_and_labels = [[prop.centroid[0], prop.centroid[1:], prop.label] for prop in props] + + centroids, labels = {}, {} + for t, centroid, label in centroids_and_labels: + centroids[t] = centroids.get(t, []) + [centroid] + labels[t] = labels.get(t, []) + [label] + centroids = {t: np.stack(np.array(val)) for t, val in centroids.items()} + labels = {t: np.array(val) for t, val in labels.items()} + + def compute_dist_between_frames(t): + centers_a, centers_b = centroids[t], centroids[t + 1] + labels_a, labels_b = labels[t], labels[t + 1] + assert len(centers_a) == len(labels_a) + assert len(centers_b) == len(labels_b) + + distances = cdist(centers_a, centers_b) + edge_mask = distances <= max_distance + distance_values = distances[edge_mask] + + idx_a, idx_b = np.where(edge_mask) + source_ids, target_ids = labels_a[idx_a], labels_b[idx_b] + assert len(distance_values) == len(source_ids) == len(target_ids) + + return source_ids, target_ids, distance_values + # return edges + + source_ids, target_ids, distances = [], [], [] + for t in trange(nt - 1, disable=not verbose, desc="Compute edges via centroid distance"): + this_src, this_tgt, this_dist = compute_dist_between_frames(t) + source_ids.extend(this_src), target_ids.extend(this_tgt), distances.extend(this_dist) + + if normalize_distances: + distances = np.array(distances) + max_dist = distances.max() + distances = 1.0 - distances / max_dist + + edges = [ + {"source": source_id, "target": target_id, "score": distance} + for source_id, target_id, distance in zip(source_ids, target_ids, distances) + ] + return edges + + +# TODO does this work for 4d data (time + 3d)? if no we need to iterate over the time axis +def compute_node_costs_from_foreground_probabilities(segmentation, probabilities, cost_attribute="mean_intensity"): + props = regionprops(segmentation, probabilities) + costs = [getattr(prop, cost_attribute) for prop in props] + return costs + +# +# Utility functions for constructing tracking problems +# + + +def relabel_segmentation_across_time(segmentation): + offset = 0 + relabeled = [] + for frame in segmentation: + frame, _, _ = relabel_sequential(frame) + frame[frame != 0] += offset + offset = frame.max() + relabeled.append(frame) + return np.stack(relabeled) + + +# TODO exppose the relevant weights and constants! +def construct_problem( + segmentation, + node_costs, + edges_and_costs, + max_parents=1, + max_children=2, +): + node_ids, indexes = np.unique(segmentation, return_index=True) + indexes = np.unravel_index(indexes, shape=segmentation.shape) + timeframes = indexes[0] + + # get rid of 0 + if node_ids[0] == 0: + node_ids, timeframes = node_ids[1:], timeframes[1:] + assert len(node_ids) == len(timeframes) + + graph = nx.DiGraph() + # if the node function is not passed then we assume that all nodes should be selected + assert len(node_costs) == len(node_ids) + nodes = [ + {"id": node_id, "score": score, "t": t} for node_id, score, t in zip(node_ids, node_costs, timeframes) + ] + + graph.add_nodes_from([(node["id"], node) for node in nodes]) + graph.add_edges_from([(edge["source"], edge["target"], edge) for edge in edges_and_costs]) + + # construct da graph + graph = motile.TrackGraph(graph) + solver = motile.Solver(graph) + + # we can do linear reweighting of the costs: a * x + b + # where: a=weight, b=constant + solver.add_costs(costs.NodeSelection(weight=-1.0, attribute="score", constant=0)) + solver.add_costs(costs.EdgeSelection(weight=-1.0, attribute="score", constant=0)) + + # add the constraints: we allow for divisions (max childeren = 2) + solver.add_constraints(constraints.MaxParents(max_parents)) + solver.add_constraints(constraints.MaxChildren(max_children)) + + # add costs for appearance and divisions + solver.add_costs(costs.Appear(constant=1.0)) + solver.add_costs(costs.Split(constant=1.0)) + + return solver, graph + + +# +# Solution parsing and visualization +# + + +def parse_result(solver, graph): + lineage_graph = nx.DiGraph() + + node_indicators = solver.get_variables(motile.variables.NodeSelected) + edge_indicators = solver.get_variables(motile.variables.EdgeSelected) + + # build new graphs that contain the selected nodes and tracking / lineage results + for node, index in node_indicators.items(): + if solver.solution[index] > 0.5: + lineage_graph.add_node(node, **graph.nodes[node]) + + for edge, index in edge_indicators.items(): + if solver.solution[index] > 0.5: + lineage_graph.add_edge(*edge, **graph.edges[edge]) + + # use connected components to find the lineages + lineages = nx.weakly_connected_components(lineage_graph) + lineages = {lineage_id: list(lineage) for lineage_id, lineage in enumerate(lineages, 1)} + return lineage_graph, lineages + + +def lineage_graph_to_track_graph(lineage_graph, lineages): + # create a new graph that only contains the tracks by not connecting nodes with a degree of 2 + track_graph = nx.DiGraph() + track_graph.add_nodes_from(lineage_graph.nodes) + + # iterate over the edges to find splits and end tracks there + for (u, v), features in lineage_graph.edges.items(): + out_edges = lineage_graph.out_edges(u) + # normal track continuation + if len(out_edges) == 1: + track_graph.add_edge(u, v) + # otherwise track ends at division and we don't continue + + # use connected components to find the tracks + tracks = nx.weakly_connected_components(track_graph) + tracks = {track_id: list(track) for track_id, track in enumerate(tracks, 1)} + + return track_graph, tracks + + +def get_node_assignment(node_ids, assignments): + # generate a dictionary that maps each node id (= segment id) to its assignment + node_assignment = { + node_id: assignment_id for assignment_id, nodes in assignments.items() for node_id in nodes + } + + # everything that was not selected gets mapped to 0 + not_selected = list(set(node_ids) - set(node_assignment.keys())) + node_assignment.update({not_select: 0 for not_select in not_selected}) + + return node_assignment + + +def recolor_segmentation(segmentation, node_to_assignment): + # we need to add a value for mapping 0, otherwise the function fails + node_to_assignment_ = deepcopy(node_to_assignment) + node_to_assignment_[0] = 0 + recolored_segmentation = takeDict(node_to_assignment_, segmentation) + return recolored_segmentation + + +def create_data_for_track_layer(segmentation, lineage_graph, node_to_track): + # compute regionpros and extract centroids + props = regionprops(segmentation) + centroids = {prop.label: prop.centroid for prop in props} + + # create the track data representation for napari + track_data = [ + [node_to_track[node_id]] + list(centroid) for node_id, centroid in centroids.items() + if node_id in node_to_track + ] + + # create the parent graph for tracks + parent_graph = {} + for (u, v), features in lineage_graph.edges.items(): + out_edges = lineage_graph.out_edges(u) + if len(out_edges) == 2: + track_u, track_v = node_to_track[u], node_to_track[v] + parent_graph[track_v] = parent_graph.get(track_v, []) + [track_u] + + return track_data, parent_graph diff --git a/elf/tracking/naive_tracking.py b/elf/tracking/naive_tracking.py deleted file mode 100644 index baaea3a..0000000 --- a/elf/tracking/naive_tracking.py +++ /dev/null @@ -1,118 +0,0 @@ -import multiprocessing as mp -from concurrent import futures - -import numpy as np -import nifty.ground_truth as ngt -import tqdm -from scipy.ndimage.morphology import distance_transform_edt - -# -# TODO naive tracking with divisions -# - - -def _compute_distance_matches( - next_t, current_t, unmatched, max_assignment_distance, n_threads, t -): - mask = np.isin(next_t, unmatched) - distances, indices = distance_transform_edt(np.logical_not(mask), return_indices=True) - current_ids = np.unique(current_t) - - def _find_distance_match(current_id): - mask = current_t == current_id - masked_distances = distances[mask] - min_dist_point = np.argmin(masked_distances) - min_dist = masked_distances[min_dist_point] - if min_dist > max_assignment_distance: - return None, None - index = indices[:, mask] - index = tuple(ind[min_dist_point] for ind in index) - next_id = next_t[index] - return next_id, min_dist - - with futures.ThreadPoolExecutor(n_threads) as tp: - matched_next = list(tp.map(_find_distance_match, current_ids[1:])) - matched_next = {curr_id: match for curr_id, match in zip(current_ids[1:], matched_next)} - - distance_matches = {} - match_distances = {} - for curr_id, (next_id, dist) in matched_next.items(): - if next_id is None: - continue - if next_id in distance_matches: - prev_dist = match_distances[next_id] - if dist < prev_dist: - distance_matches[next_id] = curr_id - match_distances[next_id] = dist - else: - distance_matches[next_id] = curr_id - match_distances[next_id] = dist - return distance_matches - - -def naive_tracking(time_series, max_assignment_distance, n_threads=-1, verbose=False): - """Naive tracking without divisions. - - Arguments: - time_series [np.ndarray] - - max_assignment_distance [float] - - allow_divisions [bool] - - n_threads [int] - - verbose [bool] - - """ - track_ids = {} - nt = len(time_series) - - # initialize the track ids with the first time point - current_t = time_series[0] - current_ids = np.setdiff1d(np.unique(current_t), [0]) - track_ids = {0: {curr_id: track_id for track_id, curr_id in enumerate(current_ids, 1)}} - next_track_id = len(track_ids[0]) + 1 - - if n_threads == -1: - n_threads = mp.cpu_count() - - range_ = tqdm.trange(1, nt, desc="Naive tracking") if verbose else range(1, nt) - for t in range_: - next_t = time_series[t] - next_ids = np.setdiff1d(np.unique(next_t), [0]) - - # compute the area overlaps beween current and next time point - ovlp_comp = ngt.overlap(next_t, current_t) - ovlps = {next_id: ovlp_comp.overlapArrays(next_id, sorted=True) for next_id in next_ids} - - ovlp_ids, ovlp_counts = {}, {} - for next_id, (labels, counts) in ovlps.items(): - if labels[0] == 0: - labels, counts = labels[1:], counts[1:] - if len(labels) == 0: - continue - ovlp_ids[next_id] = labels[0] - ovlp_counts[next_id] = counts[0] - - # assign track ids based on maximum overlap - prev_track_assignments = track_ids[t - 1] - track_matches = {next_id: prev_track_assignments[matched] for next_id, matched in ovlp_ids.items()} - unmatched = list(set(next_ids) - set(track_matches.keys())) - if unmatched and max_assignment_distance > 0: - distance_matches = _compute_distance_matches( - next_t, current_t, unmatched, max_assignment_distance, n_threads, t - ) - distance_matches = { - next_id: prev_track_assignments[matched] for next_id, matched in distance_matches.items() - } - # don't distance match to previous overlap matches - ovlp_tracks = set(track_matches.values()) - distance_matches = {k: v for k, v in distance_matches.items() if v not in ovlp_tracks} - track_matches = {**track_matches, **distance_matches} - - unmatched = list(set(next_ids) - set(track_matches.keys())) - new_tracks = {next_id: track_id for track_id, next_id in enumerate(unmatched, next_track_id)} - next_track_id += len(new_tracks) - - track_ids_t = {**track_matches, **new_tracks} - track_ids[t] = track_ids_t - current_t = next_t - current_ids = next_ids - - return track_ids diff --git a/elf/tracking/visualization.py b/elf/tracking/visualization.py deleted file mode 100644 index 2ce4a37..0000000 --- a/elf/tracking/visualization.py +++ /dev/null @@ -1,90 +0,0 @@ -import numpy as np -import vigra - -from nifty.tools import takeDict -from skimage.measure import regionprops - - -def compute_centers(labels, use_eccentricity=False): - # this is more accurate, but extremely expensive - if use_eccentricity: - centers = vigra.filters.eccentricityCenters(labels.astype("uint32")) - # TODO need to process this further - else: - props = regionprops(labels) - centers = {prop.label: prop.centroid for prop in props} - return centers - - -def color_by_tracking(segmentation, track_assignments, size_filter=0): - tracking = np.zeros_like(segmentation) - track_ids = np.unique([val for assignments in track_assignments.values() for val in assignments.values()]) - tracks_to_times = {track_id: [] for track_id in track_ids} - for t in range(tracking.shape[0]): - assignments = track_assignments[t] - assignments[0] = 0 - track_t = takeDict(assignments, segmentation[t]) - if size_filter > 0: - ids, sizes = np.unique(track_t, return_counts=True) - too_small = ids[sizes < size_filter] - track_t[np.isin(track_t, too_small)] = 0 - tracking[t] = track_t - ids_t = np.unique(track_t)[1:] - for track_id in ids_t: - tracks_to_times[track_id].append(t) - assert tracking.shape == segmentation.shape - return tracking, tracks_to_times - - -def visualize_tracks(viewer, segmentation, track_assignments, - edge_width=4, size_filter=0, show_full_tracks=False, - selected_tracks=None): - tracking, tracks_to_times = color_by_tracking(segmentation, track_assignments) - track_ids = np.unique(tracking)[1:] - - color_map = { - track_id: np.array(np.random.rand(3).tolist() + [1]) for track_id in track_ids - } - color_map[0] = np.array([0, 0, 0, 0]) - - track_start = {track_id: np.min(tracks_to_times[track_id]) for track_id in track_ids} - track_stop = {track_id: np.max(tracks_to_times[track_id]) + 1 for track_id in track_ids} - - current_centers = compute_centers(tracking[0]) - lines, line_colors = [], [] - for t in range(1, len(tracking)): - next_centers = compute_centers(tracking[t]) - line_tracks = [track_id for track_id in current_centers if track_id in next_centers] - if selected_tracks: - line_tracks = list(set(line_tracks).intersection(set(selected_tracks))) - - if show_full_tracks: - lines_t, line_colors_t = [], [] - for track_id in line_tracks: - t0, t1 = track_start[track_id], track_stop[track_id] - lines_t.extend([ - np.array([(t_track,) + current_centers[track_id], (t_track,) + next_centers[track_id]]) - for t_track in range(t0, t1) - ]) - line_colors_t.extend([color_map[track_id]] * (t1 - t0)) - - else: - lines_t = [ - np.array([(t - 1,) + current_centers[track_id], (t - 1,) + next_centers[track_id]]) - for track_id in line_tracks - ] - line_colors_t = [color_map[track_id] for track_id in line_tracks] - - assert len(lines_t) == len(line_colors_t) - lines.extend(lines_t) - line_colors.extend(line_colors_t) - current_centers = next_centers - - assert len(lines) == len(line_colors) - viewer.add_labels(tracking, color=color_map) - viewer.add_shapes( - lines, - shape_type="line", - edge_width=edge_width, - edge_color=line_colors - ) diff --git a/example/tracking/ctc_hela.py b/example/tracking/ctc_hela.py new file mode 100644 index 0000000..c871c59 --- /dev/null +++ b/example/tracking/ctc_hela.py @@ -0,0 +1,69 @@ +import os +from glob import glob + +import imageio.v3 as imageio +import napari +import numpy as np + +from elf.tracking.motile_tracking import track_with_motile, get_representation_for_napari + + +def get_ctc_hela_data(): + # load the data. you can download it from TODO + image_folder = "/home/pape/Work/my_projects/micro-sam/examples/data/DIC-C2DH-HeLa.zip.unzip/DIC-C2DH-HeLa/01" + images = np.stack([imageio.imread(path) for path in sorted(glob(os.path.join(image_folder, "*.tif")))]) + + seg_folder = "/home/pape/Work/my_projects/micro-sam/examples/finetuning/data/hela-ctc-01-gt.zip.unzip/masks" + segmentation = np.stack([imageio.imread(path) for path in sorted(glob(os.path.join(seg_folder, "*.tif")))]) + assert images.shape == segmentation.shape + + return images, segmentation + + +def default_tracking(): + images, segmentation = get_ctc_hela_data() + + # run the tracking and get visualization data for napari + segmentation, lineage_graph, lineages, track_graph, tracks = track_with_motile(segmentation) + tracking_result, track_data, parent_graph = get_representation_for_napari( + segmentation, lineage_graph, lineages, tracks + ) + + # visualize with napari + v = napari.Viewer() + v.add_image(images) + v.add_labels(tracking_result) + v.add_tracks(track_data, name="tracks", graph=parent_graph) + napari.run() + + +def tracking_with_custom_edge_function(): + from functools import partial + from elf.tracking.motile_utils import compute_edges_from_centroid_distance + + images, segmentation = get_ctc_hela_data() + + # run the tracking and get visualization data for napari + edge_cost_function = partial(compute_edges_from_centroid_distance, max_distance=50) + segmentation, lineage_graph, lineages, track_graph, tracks = track_with_motile( + segmentation, edge_cost_function=edge_cost_function + ) + tracking_result, track_data, parent_graph = get_representation_for_napari( + segmentation, lineage_graph, lineages, tracks + ) + + # visualize with napari + v = napari.Viewer() + v.add_image(images) + v.add_labels(tracking_result) + v.add_tracks(track_data, name="tracks", graph=parent_graph) + napari.run() + + +def main(): + # default_tracking() + tracking_with_custom_edge_function() + + +if __name__ == "__main__": + main() diff --git a/test/io_tests/test_intern_wrapper.py b/test/io_tests/test_intern_wrapper.py index 3db40fb..f426a16 100644 --- a/test/io_tests/test_intern_wrapper.py +++ b/test/io_tests/test_intern_wrapper.py @@ -1,6 +1,4 @@ -import os import unittest -from shutil import rmtree import numpy as np @@ -11,7 +9,10 @@ @unittest.skipIf(array is None, "Needs intern (pip install intern)") +@unittest.expectedFailure class TestInternWrapper(unittest.TestCase): + + # the address is currently not available def test_can_access_dataset(self): from elf.io.intern_wrapper import InternDataset