diff --git a/scripts/get_tracking_results.py b/scripts/get_tracking_results.py new file mode 100644 index 0000000..7d64fa0 --- /dev/null +++ b/scripts/get_tracking_results.py @@ -0,0 +1,103 @@ +import os +from pathlib import Path + +import numpy as np +import pandas as pd +import imageio.v3 as imageio + +from deepcell_tracking.utils import load_trks + + +ROOT = "/scratch/projects/nim00007/sam/for_tracking" + + +def load_tracking_segmentation(experiment): + result_dir = os.path.join(ROOT, "results") + + if experiment.startswith("vit"): + if experiment == "vit_l": + seg_path = os.path.join(result_dir, "vit_l.tif") + elif experiment == "vit_l_lm": + seg_path = os.path.join(result_dir, "vit_l_lm.tif") + elif experiment == "vit_l_specialist": + seg_path = os.path.join(result_dir, "vit_l_lm_specialist.tif") + # elif experiment == "trackmate_stardist": + # seg_path = os.path.join(result_dir, "trackmate_stardist", "every_3rd_fr_result.tif") + + else: + raise ValueError(experiment) + + seg = imageio.imread(seg_path) + return seg + + else: # return the result directory for stardist + return os.path.join(result_dir, "trackmate_stardist", "01_RES") + + +def check_tracking_results(raw, labels, curr_lineages, chosen_frames): + seg_default = load_tracking_segmentation("vit_l") + seg_generalist = load_tracking_segmentation("vit_l_lm") + seg_specialist = load_tracking_segmentation("vit_l_specialist") + + # let's get the tracks only for the objects present per frame + for idx in np.unique(labels)[1:]: + lineage = curr_lineages[idx] + lineage["frames"] = [frame for frame in lineage["frames"] if frame in chosen_frames] + + import napari + v = napari.Viewer() + v.add_image(raw) + v.add_labels(labels) + + v.add_labels(seg_default, visible=False) + v.add_labels(seg_generalist, visible=False) + v.add_labels(seg_specialist, visible=False) + + napari.run() + + +def get_tracking_data(): + data_dir = os.path.join(ROOT, "data", "DynamicNuclearNet-tracking-v1_0") + data_source = np.load(os.path.join(data_dir, "data-source.npz"), allow_pickle=True) + + fname = "test.trks" + track_file = os.path.join(data_dir, fname) + split_name = Path(track_file).stem + + data = load_trks(track_file) + + X = data["X"] + y = data["y"] + lineages = data["lineages"] + + meta = pd.DataFrame( + data_source[split_name], + columns=["filename", "experiment", "pixel_size", "screening_passed", "time_step", "specimen"] + ) + # print(meta) + + # let's convert the data to expected shape + X = X.squeeze(-1) + y = y.squeeze(-1) + + # NOTE: chosen slice for the tracking user study. + _slice = 7 + raw, labels = X[_slice, ...], y[_slice, ...] + curr_lineages = lineages[_slice] + + # NOTE: let's get every third frame of data and see how it looks + chosen_frames = list(range(0, raw.shape[0], 3)) + raw = np.stack([raw[frame] for frame in chosen_frames]) + labels = np.stack([labels[frame] for frame in chosen_frames]) + + # let's create a value map + frmaps = {} + for i, frval in enumerate(chosen_frames): + frmaps[frval] = i + + # let's remove frames which are not a part of our chosen frames + for k, v in curr_lineages.items(): + curr_frames = v["frames"] + v["frames"] = [frmaps[frval] for frval in curr_frames if frval in chosen_frames] + + return raw, labels, curr_lineages, chosen_frames diff --git a/scripts/gt_tracks.csv b/scripts/gt_tracks.csv new file mode 100644 index 0000000..1d0065c --- /dev/null +++ b/scripts/gt_tracks.csv @@ -0,0 +1,108 @@ +,Cell_ID,Start,End,Parent_ID +0,1,0,23,0 +1,2,0,23,0 +2,3,0,23,0 +3,4,0,23,0 +4,5,0,20,0 +5,6,0,5,0 +6,7,0,23,0 +7,8,0,23,0 +8,9,0,23,0 +9,10,0,4,0 +10,11,0,22,0 +11,12,0,23,0 +12,13,0,3,0 +13,14,0,23,0 +14,15,0,3,0 +15,16,0,23,0 +16,17,0,23,0 +17,18,0,23,0 +18,19,0,3,0 +19,20,0,23,0 +20,21,0,23,0 +21,22,0,23,0 +22,23,0,22,0 +23,24,0,3,0 +24,25,0,23,0 +25,26,0,23,0 +26,27,0,4,0 +27,28,0,5,0 +28,29,0,23,0 +29,30,0,23,0 +30,31,0,20,0 +31,32,0,11,0 +32,33,0,23,0 +33,34,0,23,0 +34,35,0,7,0 +35,36,0,23,0 +36,37,0,18,0 +37,38,0,23,0 +38,39,0,23,0 +39,40,0,18,0 +40,41,0,1,0 +41,42,0,23,0 +42,43,0,2,0 +43,44,0,23,0 +44,45,0,23,0 +45,46,0,23,0 +46,47,0,23,0 +47,48,0,2,0 +48,49,0,13,0 +49,50,0,13,0 +50,51,0,23,0 +51,52,0,23,0 +52,53,0,3,0 +53,54,0,23,0 +54,55,0,23,0 +55,56,0,1,0 +56,57,0,3,0 +57,58,0,12,0 +58,59,0,23,0 +59,60,0,0,0 +60,61,1,23,60 +61,62,1,23,0 +62,63,2,23,41 +63,64,2,23,41 +64,65,2,23,56 +65,66,2,23,56 +66,67,3,23,48 +67,68,3,23,48 +68,69,3,23,43 +69,70,3,23,43 +70,109,4,23,0 +71,71,4,23,13 +72,72,4,23,13 +73,73,4,23,57 +74,74,4,23,24 +75,75,4,23,24 +76,76,4,23,53 +77,77,4,23,53 +78,78,4,4,0 +79,79,5,23,10 +80,80,5,23,10 +81,81,5,21,27 +82,82,5,9,27 +83,83,6,23,6 +84,84,6,23,6 +85,85,6,23,28 +86,86,6,23,28 +87,87,7,13,0 +88,88,7,21,0 +89,89,8,18,0 +90,90,8,23,35 +91,91,8,23,35 +92,92,12,14,0 +93,93,12,23,32 +94,94,12,23,32 +95,95,13,15,58 +96,96,13,23,58 +97,97,14,14,50 +98,98,14,23,50 +99,99,14,23,49 +100,100,14,23,49 +101,101,18,22,0 +102,102,19,23,0 +103,103,23,23,23 +104,104,23,23,23 +105,105,23,23,11 +106,106,23,23,11 diff --git a/scripts/test_ctc_metric.py b/scripts/test_ctc_metric.py new file mode 100644 index 0000000..f3a32ea --- /dev/null +++ b/scripts/test_ctc_metric.py @@ -0,0 +1,203 @@ +import os +import numpy as np +import pandas as pd + +import h5py + +from traccuracy import run_metrics +from traccuracy.matchers import CTCMatcher +from traccuracy._tracking_graph import TrackingGraph +from traccuracy.metrics import CTCMetrics, DivisionMetrics +from traccuracy.loaders._ctc import _get_node_attributes, ctc_to_graph, _check_ctc, load_ctc_data + +from scipy.ndimage import binary_closing + + +# ROOT = "/scratch/usr/nimanwai/micro-sam/for_tracking/for_traccuracy/" # hlrn +ROOT = "/media/anwai/ANWAI/results/micro-sam/for_traccuracy/" # local + + +def mark_potential_split(frames, last_frame, idx): + if frames.max() == last_frame: # object is tracked until the last frame + split_frame = None # they can't split in this case + prev_parent_id = None + else: # object either goes out of frame or splits + split_frame = frames.max() # let's assume that it splits, we will know if it does or not + prev_parent_id = idx + return split_frame, prev_parent_id + + +def extract_df_from_segmentation(segmentation): + track_ids = np.unique(segmentation)[1:] + last_frame = segmentation.shape[0] - 1 + + all_tracks = [] + prev_parent_id = None + + for idx in track_ids: + frames = np.unique(np.where(segmentation == idx)[0]) + + if frames.min() == 0: # object starts at first frame + pid = 0 + split_frame, prev_parent_id = mark_potential_split(frames, last_frame, idx) + + else: + if split_frame is not None: # takes the parent information from above + # have fam is the end frame of the potential parent, so our frame has to be the next frame + if split_frame + 1 == frames.min(): + pid = prev_parent_id + + # otherwise we just have some track that starts so it's not the child + else: + pid = 0 + split_frame, prev_parent_id = mark_potential_split(frames, last_frame, idx) + + else: + pid = 0 # assumes that it was an object that started at a random frame + split_frame, prev_parent_id = mark_potential_split(frames, last_frame, idx) + + track_dict = { + "Cell_ID": idx, + "Start": frames.min(), + "End": frames.max(), + "Parent_ID": pid, + } + + all_tracks.append(pd.DataFrame.from_dict([track_dict])) + + pred_tracks_df = pd.concat(all_tracks) + return pred_tracks_df + + +def evaluate_tracking(raw, labels, seg, segmentation_method, filter_label_ids): + if os.path.isdir(seg): # for trackmate stardist + seg_T = load_ctc_data( + data_dir=seg, + track_path=os.path.join(seg, 'res_track.txt'), + name=f'DynamicNuclearNet-{segmentation_method}' + ) + + else: # for micro-sam + seg_nodes = _get_node_attributes(seg) + seg_df = extract_df_from_segmentation(seg) + seg_G = ctc_to_graph(seg_df, seg_nodes) + _check_ctc(seg_df, seg_nodes, seg) + seg_T = TrackingGraph(seg_G, segmentation=seg, name=f"DynamicNuclearNet-{segmentation_method}") + + ids, sizes = np.unique(seg_df.Parent_ID.values, return_counts=True) + print(ids, sizes) + # calcuates node attributes for each detectionc + gt_nodes = _get_node_attributes(labels) + + # converts inputs to isbi-tracking format - the version expected as inputs in traccuracy + # it's preconverted using "from deepcell_tracking.isbi_utils import trk_to_isbi" + gt_df = pd.read_csv(os.path.join(ROOT, "gt_tracks.csv")) + mask = np.ones(len(gt_df), dtype="bool") + # breakpoint() + mask[np.isin(gt_df.Cell_ID, filter_label_ids)] = False + gt_df = gt_df[mask] + # breakpoint() + + # creates graphs from ctc-type info (isbi-type? probably means the same thing) + gt_G = ctc_to_graph(gt_df, gt_nodes) + + # OPTIONAL: This tests if inputs (images, dfs and node attributes) to create tracking graphs are as expected + _check_ctc(gt_df, gt_nodes, labels) + + gt_T = TrackingGraph(gt_G, segmentation=labels, name="DynamicNuclearNet-GT") + + ctc_results = run_metrics( + gt_data=gt_T, + pred_data=seg_T, + matcher=CTCMatcher(), + metrics=[CTCMetrics(), DivisionMetrics(max_frame_buffer=0)], + ) + print(ctc_results) + + +def size_filter(segmentation, min_size=100): + ids, sizes = np.unique(segmentation, return_counts=True) + filter_ids = ids[sizes < min_size] + segmentation[np.isin(segmentation, filter_ids)] = 0 + return segmentation + + +def get_tracking_data(segmentation_method, visualize=False): + _path = os.path.join(ROOT, "tracking_micro_sam.h5") + + with h5py.File(_path, "r") as f: + raw = f["raw"][:] + labels = f["labels"][:] + + if segmentation_method.startswith("vit"): + segmentation = f[f"segmentations/{segmentation_method}"][:] + segmentation = size_filter(segmentation) + else: + segmentation = os.path.join(ROOT, "trackmate_stardist", "01_RES") + + # test case + def check_consecutive(instances): + instance_ids = np.unique(instances)[1:] + + id_list = [] + for idx in instance_ids: + frames = np.unique(np.where(instances == idx)[0]) + consistent_instance = (sorted(frames) == list(range(min(frames), max(frames) + 1))) + if not consistent_instance: + id_list.append(idx) + + return id_list + + def rectify_labels(instances): + id_list = check_consecutive(instances) + print("Closing instances", id_list) + for idx in id_list: + object_mask = (instances == idx) + + structuring_element = np.zeros((3, 1, 1)) + structuring_element[:, 0, 0] = 1 + + closed_mask = binary_closing(object_mask.copy(), iterations=1, structure=structuring_element) + # breakpoint() + closed_mask = np.logical_or(object_mask, closed_mask) + # breakpoint() + + instances[closed_mask] = idx + + # import napari + # v = napari.Viewer() + # v.add_image(closed_mask.astype("uint8") * 255, name="After Closing", blending="additive", colormap="blue") + # v.add_image(object_mask.astype("uint8") * 255, name="Original") + # v.add_labels(instances, visible=False) + # napari.run() + + return instances + + filter_ids = check_consecutive(labels) + labels[np.isin(labels, filter_ids)] = 0 + + if not os.path.isdir(segmentation): + segmentation = rectify_labels(segmentation) + + if visualize: + import napari + + v = napari.Viewer() + v.add_image(raw) + if not os.path.isdir(segmentation): + v.add_labels(segmentation, visible=False) + + napari.run() + + return raw, labels, segmentation, filter_ids + + +def main(): + segmentation_method = "vit_l_lm" + + raw, labels, segmentation, filter_ids = get_tracking_data(segmentation_method, visualize=False) + evaluate_tracking(raw, labels, segmentation, segmentation_method, filter_ids) + + +if __name__ == "__main__": + main()