From 0ce3567458e82a083c01e3dc31895af64fa303cc Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Mon, 6 May 2024 14:16:19 +0200 Subject: [PATCH 1/6] Add tracking metrics - initial try-out --- scripts/test_ctc_metric.py | 124 +++++++++++++++++++++++++++++++++++++ 1 file changed, 124 insertions(+) create mode 100644 scripts/test_ctc_metric.py diff --git a/scripts/test_ctc_metric.py b/scripts/test_ctc_metric.py new file mode 100644 index 0000000..fdc4181 --- /dev/null +++ b/scripts/test_ctc_metric.py @@ -0,0 +1,124 @@ +import os +from glob import glob +from pathlib import Path + +import numpy as np +import pandas as pd +import imageio.v3 as imageio + +import napari + +from deepcell_tracking.utils import load_trks +# from deepcell_tracking.isbi_utils import trk_to_isbi + +from traccuracy.loaders._ctc import _get_node_attributes + + +def load_tracking_segmentation(experiment): + ROOT = r"/home/anwai/results/tracking/MicroSAM testing/" + TRACKMATE_ROOT = r"/home/anwai/results/tracking/trackmate_stardist/microSAM revision every 3rd fr" + + if experiment == "vit_l": + seg_path = glob(os.path.join(ROOT, r"round 2 vit_l", "*.tif"))[0] + elif experiment == "vit_l_lm": + seg_path = glob(os.path.join(ROOT, "vit_l_finetuned", "*.tif"))[0] + elif experiment == "vit_l_specialist": + seg_path = glob(os.path.join(ROOT, "vit_l_specialist", "*.tif"))[0] + elif experiment == "trackmate_stardist": + seg_path = glob(os.path.join(TRACKMATE_ROOT, "*.tif"))[0] + else: + raise ValueError(experiment) + + return imageio.imread(seg_path) + + +def check_tracking_results(raw, labels, curr_lineages, chosen_frames): + seg_default = load_tracking_segmentation("vit_l") + seg_generalist = load_tracking_segmentation("vit_l_lm") + seg_specialist = load_tracking_segmentation("vit_l_specialist") + + # let's get the tracks only for the objects present per frame + for idx in np.unique(labels)[1:]: + lineage = curr_lineages[idx] + lineage["frames"] = [frame for frame in lineage["frames"] if frame in chosen_frames] + + v = napari.Viewer() + v.add_image(raw) + v.add_labels(labels) + + v.add_labels(seg_default, visible=False) + v.add_labels(seg_generalist, visible=False) + v.add_labels(seg_specialist, visible=False) + + napari.run() + + +def get_tracking_data(): + data_dir = "/home/anwai/data/dynamicnuclearnet/DynamicNuclearNet-tracking-v1_0/" + data_source = np.load(os.path.join(data_dir, "data-source.npz"), allow_pickle=True) + + fname = "test.trks" + track_file = os.path.join(data_dir, fname) + split_name = Path(track_file).stem + + data = load_trks(track_file) + + X = data["X"] + y = data["y"] + lineages = data["lineages"] + + meta = pd.DataFrame( + data_source[split_name], + columns=["filename", "experiment", "pixel_size", "screening_passed", "time_step", "specimen"] + ) + print(meta) + + # let's convert the data to expected shape + X = X.squeeze(-1) + y = y.squeeze(-1) + + # NOTE: chosen slice for the tracking user study. + _slice = 7 + raw, labels = X[_slice, ...], y[_slice, ...] + curr_lineages = lineages[_slice] + + # NOTE: let's get every third frame of data and see how it looks + chosen_frames = list(range(0, raw.shape[0], 3)) + raw = np.stack([raw[frame] for frame in chosen_frames]) + labels = np.stack([labels[frame] for frame in chosen_frames]) + + # let's create a value map + frmaps = {} + for i, frval in enumerate(chosen_frames): + frmaps[frval] = i + + # let's remove frames which are not a part of our chosen frames + for k, v in curr_lineages.items(): + curr_frames = v["frames"] + v["frames"] = [frmaps[frval] for frval in curr_frames if frval in chosen_frames] + + return raw, labels, curr_lineages, chosen_frames + + +def evaluate_tracking(raw, labels, curr_lineages, chosen_frames, segmentation_method): + seg = load_tracking_segmentation(segmentation_method) + + gt_df = _get_node_attributes(labels) + seg_df = _get_node_attributes(seg) + + output = trk_to_isbi(curr_lineages, path=None) + + breakpoint() + + +def main(): + raw, labels, curr_lineages, chosen_frames = get_tracking_data() + + # check_tracking_results(raw, labels, curr_lineages, chosen_frames) + + segmentation_method = "trackmate_stardist" + evaluate_tracking(raw, labels, curr_lineages, chosen_frames, segmentation_method) + + +if __name__ == "__main__": + main() From 68e57485daa028ac7289a36966aaf79df998c73d Mon Sep 17 00:00:00 2001 From: anwai98 Date: Mon, 6 May 2024 16:32:48 +0200 Subject: [PATCH 2/6] Add draft for testing traccuracy implementation --- scripts/get_tracking_results.py | 96 ++++++++++++++++++++++++ scripts/test_ctc_metric.py | 128 ++++++++++---------------------- 2 files changed, 135 insertions(+), 89 deletions(-) create mode 100644 scripts/get_tracking_results.py diff --git a/scripts/get_tracking_results.py b/scripts/get_tracking_results.py new file mode 100644 index 0000000..4d4b817 --- /dev/null +++ b/scripts/get_tracking_results.py @@ -0,0 +1,96 @@ +import os +from pathlib import Path + +import numpy as np +import pandas as pd +import imageio.v3 as imageio + +from deepcell_tracking.utils import load_trks + + +ROOT = "/scratch/projects/nim00007/sam/for_tracking" + + +def load_tracking_segmentation(experiment): + result_dir = os.path.join(ROOT, "results") + if experiment == "vit_l": + seg_path = os.path.join(result_dir, "vit_l.tif") + elif experiment == "vit_l_lm": + seg_path = os.path.join(result_dir, "vit_l_lm.tif") + elif experiment == "vit_l_specialist": + seg_path = os.path.join(result_dir, "vit_l_lm_specialist.tif") + elif experiment == "trackmate_stardist": + seg_path = os.path.join(result_dir, "trackmate_stardist", "every_3rd_fr_result.tif") + else: + raise ValueError(experiment) + + return imageio.imread(seg_path) + + +def check_tracking_results(raw, labels, curr_lineages, chosen_frames): + seg_default = load_tracking_segmentation("vit_l") + seg_generalist = load_tracking_segmentation("vit_l_lm") + seg_specialist = load_tracking_segmentation("vit_l_specialist") + + # let's get the tracks only for the objects present per frame + for idx in np.unique(labels)[1:]: + lineage = curr_lineages[idx] + lineage["frames"] = [frame for frame in lineage["frames"] if frame in chosen_frames] + + import napari + v = napari.Viewer() + v.add_image(raw) + v.add_labels(labels) + + v.add_labels(seg_default, visible=False) + v.add_labels(seg_generalist, visible=False) + v.add_labels(seg_specialist, visible=False) + + napari.run() + + +def get_tracking_data(): + data_dir = os.path.join(ROOT, "data", "DynamicNuclearNet-tracking-v1_0") + data_source = np.load(os.path.join(data_dir, "data-source.npz"), allow_pickle=True) + + fname = "test.trks" + track_file = os.path.join(data_dir, fname) + split_name = Path(track_file).stem + + data = load_trks(track_file) + + X = data["X"] + y = data["y"] + lineages = data["lineages"] + + meta = pd.DataFrame( + data_source[split_name], + columns=["filename", "experiment", "pixel_size", "screening_passed", "time_step", "specimen"] + ) + # print(meta) + + # let's convert the data to expected shape + X = X.squeeze(-1) + y = y.squeeze(-1) + + # NOTE: chosen slice for the tracking user study. + _slice = 7 + raw, labels = X[_slice, ...], y[_slice, ...] + curr_lineages = lineages[_slice] + + # NOTE: let's get every third frame of data and see how it looks + chosen_frames = list(range(0, raw.shape[0], 3)) + raw = np.stack([raw[frame] for frame in chosen_frames]) + labels = np.stack([labels[frame] for frame in chosen_frames]) + + # let's create a value map + frmaps = {} + for i, frval in enumerate(chosen_frames): + frmaps[frval] = i + + # let's remove frames which are not a part of our chosen frames + for k, v in curr_lineages.items(): + curr_frames = v["frames"] + v["frames"] = [frmaps[frval] for frval in curr_frames if frval in chosen_frames] + + return raw, labels, curr_lineages, chosen_frames diff --git a/scripts/test_ctc_metric.py b/scripts/test_ctc_metric.py index fdc4181..68c6227 100644 --- a/scripts/test_ctc_metric.py +++ b/scripts/test_ctc_metric.py @@ -1,122 +1,72 @@ -import os -from glob import glob -from pathlib import Path - import numpy as np -import pandas as pd -import imageio.v3 as imageio - -import napari -from deepcell_tracking.utils import load_trks -# from deepcell_tracking.isbi_utils import trk_to_isbi +from deepcell_tracking.isbi_utils import trk_to_isbi from traccuracy.loaders._ctc import _get_node_attributes +from get_tracking_results import get_tracking_data, load_tracking_segmentation -def load_tracking_segmentation(experiment): - ROOT = r"/home/anwai/results/tracking/MicroSAM testing/" - TRACKMATE_ROOT = r"/home/anwai/results/tracking/trackmate_stardist/microSAM revision every 3rd fr" - - if experiment == "vit_l": - seg_path = glob(os.path.join(ROOT, r"round 2 vit_l", "*.tif"))[0] - elif experiment == "vit_l_lm": - seg_path = glob(os.path.join(ROOT, "vit_l_finetuned", "*.tif"))[0] - elif experiment == "vit_l_specialist": - seg_path = glob(os.path.join(ROOT, "vit_l_specialist", "*.tif"))[0] - elif experiment == "trackmate_stardist": - seg_path = glob(os.path.join(TRACKMATE_ROOT, "*.tif"))[0] - else: - raise ValueError(experiment) - - return imageio.imread(seg_path) - - -def check_tracking_results(raw, labels, curr_lineages, chosen_frames): - seg_default = load_tracking_segmentation("vit_l") - seg_generalist = load_tracking_segmentation("vit_l_lm") - seg_specialist = load_tracking_segmentation("vit_l_specialist") - - # let's get the tracks only for the objects present per frame - for idx in np.unique(labels)[1:]: - lineage = curr_lineages[idx] - lineage["frames"] = [frame for frame in lineage["frames"] if frame in chosen_frames] - v = napari.Viewer() - v.add_image(raw) - v.add_labels(labels) +def extract_df_from_segmentation(segmentation): + track_ids = np.unique(segmentation)[1:] + last_frame = segmentation.shape[0] - 1 - v.add_labels(seg_default, visible=False) - v.add_labels(seg_generalist, visible=False) - v.add_labels(seg_specialist, visible=False) + all_tracks = [] + splits = 0 + for idx in track_ids: - napari.run() + frames = np.unique(np.where(segmentation == idx)[0]) + if frames.min() == 0: # object starts at first frame + if frames.max() == last_frame: # object is tracked until the last frame + pid = 0 + have_fam = None # they can't split in this case + else: # object either goes out of frame or splits + pid = 0 + have_fam = frames.max() # let's assume that it splits, we will know if it does or not -def get_tracking_data(): - data_dir = "/home/anwai/data/dynamicnuclearnet/DynamicNuclearNet-tracking-v1_0/" - data_source = np.load(os.path.join(data_dir, "data-source.npz"), allow_pickle=True) + else: + if have_fam is not None: # takes the parent information from above + pid = have_fam + splits += 1 - fname = "test.trks" - track_file = os.path.join(data_dir, fname) - split_name = Path(track_file).stem + if splits > 2: # assumes every mother cell splits into two daughter cells + print("The mother cell has made enough daughter splits, hence this is a new object.") + splits = 0 + # pid = 0 # this is the case where an objects appears at nth frame and has no parent id + else: + pid = 0 # assumes that it was an object that started at a random frame - data = load_trks(track_file) + track_dict = { + "Cell_ID": idx, + "Start": frames.min(), + "End": frames.max(), + "Parent_ID": pid, + } - X = data["X"] - y = data["y"] - lineages = data["lineages"] + print(track_dict) + all_tracks.append(track_dict) - meta = pd.DataFrame( - data_source[split_name], - columns=["filename", "experiment", "pixel_size", "screening_passed", "time_step", "specimen"] - ) - print(meta) - - # let's convert the data to expected shape - X = X.squeeze(-1) - y = y.squeeze(-1) - - # NOTE: chosen slice for the tracking user study. - _slice = 7 - raw, labels = X[_slice, ...], y[_slice, ...] - curr_lineages = lineages[_slice] - - # NOTE: let's get every third frame of data and see how it looks - chosen_frames = list(range(0, raw.shape[0], 3)) - raw = np.stack([raw[frame] for frame in chosen_frames]) - labels = np.stack([labels[frame] for frame in chosen_frames]) - - # let's create a value map - frmaps = {} - for i, frval in enumerate(chosen_frames): - frmaps[frval] = i - - # let's remove frames which are not a part of our chosen frames - for k, v in curr_lineages.items(): - curr_frames = v["frames"] - v["frames"] = [frmaps[frval] for frval in curr_frames if frval in chosen_frames] - - return raw, labels, curr_lineages, chosen_frames + breakpoint() def evaluate_tracking(raw, labels, curr_lineages, chosen_frames, segmentation_method): seg = load_tracking_segmentation(segmentation_method) + # calcuates node attributes for each detection gt_df = _get_node_attributes(labels) seg_df = _get_node_attributes(seg) + # converts inputs to isbi-track format - the version expected as inputs in traccuracy output = trk_to_isbi(curr_lineages, path=None) - breakpoint() + df = extract_df_from_segmentation(seg) def main(): raw, labels, curr_lineages, chosen_frames = get_tracking_data() - # check_tracking_results(raw, labels, curr_lineages, chosen_frames) - - segmentation_method = "trackmate_stardist" + segmentation_method = "vit_l_specialist" evaluate_tracking(raw, labels, curr_lineages, chosen_frames, segmentation_method) From 51352f1adda032d80b063fa7287619ecee5856d2 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Mon, 6 May 2024 19:25:53 +0200 Subject: [PATCH 3/6] Add metric evaluation --- scripts/get_tracking_results.py | 53 ++++++++++++---- scripts/test_ctc_metric.py | 108 ++++++++++++++++++++++++-------- 2 files changed, 123 insertions(+), 38 deletions(-) diff --git a/scripts/get_tracking_results.py b/scripts/get_tracking_results.py index 4d4b817..17d11d2 100644 --- a/scripts/get_tracking_results.py +++ b/scripts/get_tracking_results.py @@ -13,18 +13,42 @@ def load_tracking_segmentation(experiment): result_dir = os.path.join(ROOT, "results") - if experiment == "vit_l": - seg_path = os.path.join(result_dir, "vit_l.tif") - elif experiment == "vit_l_lm": - seg_path = os.path.join(result_dir, "vit_l_lm.tif") - elif experiment == "vit_l_specialist": - seg_path = os.path.join(result_dir, "vit_l_lm_specialist.tif") - elif experiment == "trackmate_stardist": - seg_path = os.path.join(result_dir, "trackmate_stardist", "every_3rd_fr_result.tif") - else: - raise ValueError(experiment) - return imageio.imread(seg_path) + if experiment.startswith("vit"): + if experiment == "vit_l": + seg_path = os.path.join(result_dir, "vit_l.tif") + seg = imageio.imread(seg_path) + # HACK + ignore_labels = [8, 44, 57, 102, 50] + + elif experiment == "vit_l_lm": + seg_path = os.path.join(result_dir, "vit_l_lm.tif") + seg = imageio.imread(seg_path) + # HACK + ignore_labels = [] + + elif experiment == "vit_l_specialist": + seg_path = os.path.join(result_dir, "vit_l_lm_specialist.tif") + seg = imageio.imread(seg_path) + # HACK + ignore_labels = [88, 45, 30, 46] + + # elif experiment == "trackmate_stardist": + # seg_path = os.path.join(result_dir, "trackmate_stardist", "every_3rd_fr_result.tif") + # seg = imageio.imread(seg_path) + + else: + raise ValueError(experiment) + + # HACK: + # we remove some labels as they have a weird lineage, is creating issues for creating the graph + # (e.g. frames where the object exists: 1, 2, 4, 5, 6) + seg[np.isin(seg, ignore_labels)] = 0 + + return seg + + else: # return the result directory for stardist + return os.path.join(result_dir, "trackmate_stardist", "01_RES") def check_tracking_results(raw, labels, curr_lineages, chosen_frames): @@ -93,4 +117,11 @@ def get_tracking_data(): curr_frames = v["frames"] v["frames"] = [frmaps[frval] for frval in curr_frames if frval in chosen_frames] + # HACK: + # we remove label with id 62 as it has a weird lineage, is creating issues for creating the graph + ignore_labels = [62, 87, 92, 99, 58] + labels[np.isin(labels, ignore_labels)] = 0 + for _label in ignore_labels: + curr_lineages.pop(_label) + return raw, labels, curr_lineages, chosen_frames diff --git a/scripts/test_ctc_metric.py b/scripts/test_ctc_metric.py index 68c6227..dc06637 100644 --- a/scripts/test_ctc_metric.py +++ b/scripts/test_ctc_metric.py @@ -1,41 +1,56 @@ +import os import numpy as np +import pandas as pd from deepcell_tracking.isbi_utils import trk_to_isbi -from traccuracy.loaders._ctc import _get_node_attributes +from traccuracy import run_metrics +from traccuracy._tracking_graph import TrackingGraph +from traccuracy.matchers import CTCMatcher, IOUMatcher +from traccuracy.metrics import CTCMetrics, DivisionMetrics +from traccuracy.loaders._ctc import _get_node_attributes, ctc_to_graph, _check_ctc, load_ctc_data from get_tracking_results import get_tracking_data, load_tracking_segmentation +def mark_potential_split(frames, last_frame, idx): + if frames.max() == last_frame: # object is tracked until the last frame + split_frame = None # they can't split in this case + prev_parent_id = None + else: # object either goes out of frame or splits + split_frame = frames.max() # let's assume that it splits, we will know if it does or not + prev_parent_id = idx + return split_frame, prev_parent_id + + def extract_df_from_segmentation(segmentation): track_ids = np.unique(segmentation)[1:] last_frame = segmentation.shape[0] - 1 all_tracks = [] - splits = 0 - for idx in track_ids: + prev_parent_id = None + for idx in track_ids: frames = np.unique(np.where(segmentation == idx)[0]) if frames.min() == 0: # object starts at first frame - if frames.max() == last_frame: # object is tracked until the last frame - pid = 0 - have_fam = None # they can't split in this case - else: # object either goes out of frame or splits - pid = 0 - have_fam = frames.max() # let's assume that it splits, we will know if it does or not + pid = 0 + split_frame, prev_parent_id = mark_potential_split(frames, last_frame, idx) else: - if have_fam is not None: # takes the parent information from above - pid = have_fam - splits += 1 - - if splits > 2: # assumes every mother cell splits into two daughter cells - print("The mother cell has made enough daughter splits, hence this is a new object.") - splits = 0 - # pid = 0 # this is the case where an objects appears at nth frame and has no parent id + if split_frame is not None: # takes the parent information from above + # have fam is the end frame of the potential parent, so our frame has to be the next frame + if split_frame + 1 == frames.min(): + pid = prev_parent_id + + # otherwise we just have some track that starts so it's not the child + else: + pid = 0 + split_frame, prev_parent_id = mark_potential_split(frames, last_frame, idx) + else: pid = 0 # assumes that it was an object that started at a random frame + split_frame, prev_parent_id = mark_potential_split(frames, last_frame, idx) track_dict = { "Cell_ID": idx, @@ -44,30 +59,69 @@ def extract_df_from_segmentation(segmentation): "Parent_ID": pid, } - print(track_dict) - all_tracks.append(track_dict) + all_tracks.append(pd.DataFrame.from_dict([track_dict])) - breakpoint() + pred_tracks_df = pd.concat(all_tracks) + return pred_tracks_df -def evaluate_tracking(raw, labels, curr_lineages, chosen_frames, segmentation_method): +def evaluate_tracking(labels, curr_lineages, segmentation_method): seg = load_tracking_segmentation(segmentation_method) + if os.path.isdir(seg): # for trackmate stardist + seg_T = load_ctc_data( + data_dir=seg, + track_path=os.path.join(seg, 'res_track.txt'), + name=f'DynamicNuclearNet-{segmentation_method}' + ) + + else: # for micro-sam + seg_nodes = _get_node_attributes(seg) + seg_df = extract_df_from_segmentation(seg) + seg_G = ctc_to_graph(seg_df, seg_nodes) + _check_ctc(seg_df, seg_nodes, seg) + seg_T = TrackingGraph(seg_G, segmentation=seg, name=f"DynamicNuclearNet-{segmentation_method}") + + breakpoint() + # calcuates node attributes for each detection - gt_df = _get_node_attributes(labels) - seg_df = _get_node_attributes(seg) + gt_nodes = _get_node_attributes(labels) + + # converts inputs to isbi-tracking format - the version expected as inputs in traccuracy + gt_df = trk_to_isbi(curr_lineages, path=None) + + # creates graphs from ctc-type info (isbi-type? probably means the same thing) + gt_G = ctc_to_graph(gt_df, gt_nodes) + + # OPTIONAL: This tests if inputs (images, dfs and node attributes) to create tracking graphs are as expected + _check_ctc(gt_df, gt_nodes, labels) + + gt_T = TrackingGraph(gt_G, segmentation=labels, name="DynamicNuclearNet-GT") + + ctc_results = run_metrics( + gt_data=gt_T, + pred_data=seg_T, + matcher=CTCMatcher(), + metrics=[CTCMetrics(), DivisionMetrics(max_frame_buffer=0)], + ) + print(ctc_results) - # converts inputs to isbi-track format - the version expected as inputs in traccuracy - output = trk_to_isbi(curr_lineages, path=None) + breakpoint() - df = extract_df_from_segmentation(seg) + iou_results = run_metrics( + gt_data=gt_T, + pred_data=seg_T, + matcher=IOUMatcher(iou_threshold=0.1), + metrics=[DivisionMetrics(max_frame_buffer=0)], + ) + print(iou_results) def main(): raw, labels, curr_lineages, chosen_frames = get_tracking_data() segmentation_method = "vit_l_specialist" - evaluate_tracking(raw, labels, curr_lineages, chosen_frames, segmentation_method) + evaluate_tracking(labels, curr_lineages, segmentation_method) if __name__ == "__main__": From f5bdc60ed61675aa3d69751e95f445602605fa5c Mon Sep 17 00:00:00 2001 From: anwai98 Date: Tue, 7 May 2024 15:42:40 +0200 Subject: [PATCH 4/6] Update scripts to access stored gt tracks --- scripts/get_tracking_results.py | 26 +------- scripts/gt_tracks.csv | 108 ++++++++++++++++++++++++++++++++ scripts/test_ctc_metric.py | 46 +++++++------- 3 files changed, 133 insertions(+), 47 deletions(-) create mode 100644 scripts/gt_tracks.csv diff --git a/scripts/get_tracking_results.py b/scripts/get_tracking_results.py index 17d11d2..7d64fa0 100644 --- a/scripts/get_tracking_results.py +++ b/scripts/get_tracking_results.py @@ -17,34 +17,17 @@ def load_tracking_segmentation(experiment): if experiment.startswith("vit"): if experiment == "vit_l": seg_path = os.path.join(result_dir, "vit_l.tif") - seg = imageio.imread(seg_path) - # HACK - ignore_labels = [8, 44, 57, 102, 50] - elif experiment == "vit_l_lm": seg_path = os.path.join(result_dir, "vit_l_lm.tif") - seg = imageio.imread(seg_path) - # HACK - ignore_labels = [] - elif experiment == "vit_l_specialist": seg_path = os.path.join(result_dir, "vit_l_lm_specialist.tif") - seg = imageio.imread(seg_path) - # HACK - ignore_labels = [88, 45, 30, 46] - # elif experiment == "trackmate_stardist": # seg_path = os.path.join(result_dir, "trackmate_stardist", "every_3rd_fr_result.tif") - # seg = imageio.imread(seg_path) else: raise ValueError(experiment) - # HACK: - # we remove some labels as they have a weird lineage, is creating issues for creating the graph - # (e.g. frames where the object exists: 1, 2, 4, 5, 6) - seg[np.isin(seg, ignore_labels)] = 0 - + seg = imageio.imread(seg_path) return seg else: # return the result directory for stardist @@ -117,11 +100,4 @@ def get_tracking_data(): curr_frames = v["frames"] v["frames"] = [frmaps[frval] for frval in curr_frames if frval in chosen_frames] - # HACK: - # we remove label with id 62 as it has a weird lineage, is creating issues for creating the graph - ignore_labels = [62, 87, 92, 99, 58] - labels[np.isin(labels, ignore_labels)] = 0 - for _label in ignore_labels: - curr_lineages.pop(_label) - return raw, labels, curr_lineages, chosen_frames diff --git a/scripts/gt_tracks.csv b/scripts/gt_tracks.csv new file mode 100644 index 0000000..1d0065c --- /dev/null +++ b/scripts/gt_tracks.csv @@ -0,0 +1,108 @@ +,Cell_ID,Start,End,Parent_ID +0,1,0,23,0 +1,2,0,23,0 +2,3,0,23,0 +3,4,0,23,0 +4,5,0,20,0 +5,6,0,5,0 +6,7,0,23,0 +7,8,0,23,0 +8,9,0,23,0 +9,10,0,4,0 +10,11,0,22,0 +11,12,0,23,0 +12,13,0,3,0 +13,14,0,23,0 +14,15,0,3,0 +15,16,0,23,0 +16,17,0,23,0 +17,18,0,23,0 +18,19,0,3,0 +19,20,0,23,0 +20,21,0,23,0 +21,22,0,23,0 +22,23,0,22,0 +23,24,0,3,0 +24,25,0,23,0 +25,26,0,23,0 +26,27,0,4,0 +27,28,0,5,0 +28,29,0,23,0 +29,30,0,23,0 +30,31,0,20,0 +31,32,0,11,0 +32,33,0,23,0 +33,34,0,23,0 +34,35,0,7,0 +35,36,0,23,0 +36,37,0,18,0 +37,38,0,23,0 +38,39,0,23,0 +39,40,0,18,0 +40,41,0,1,0 +41,42,0,23,0 +42,43,0,2,0 +43,44,0,23,0 +44,45,0,23,0 +45,46,0,23,0 +46,47,0,23,0 +47,48,0,2,0 +48,49,0,13,0 +49,50,0,13,0 +50,51,0,23,0 +51,52,0,23,0 +52,53,0,3,0 +53,54,0,23,0 +54,55,0,23,0 +55,56,0,1,0 +56,57,0,3,0 +57,58,0,12,0 +58,59,0,23,0 +59,60,0,0,0 +60,61,1,23,60 +61,62,1,23,0 +62,63,2,23,41 +63,64,2,23,41 +64,65,2,23,56 +65,66,2,23,56 +66,67,3,23,48 +67,68,3,23,48 +68,69,3,23,43 +69,70,3,23,43 +70,109,4,23,0 +71,71,4,23,13 +72,72,4,23,13 +73,73,4,23,57 +74,74,4,23,24 +75,75,4,23,24 +76,76,4,23,53 +77,77,4,23,53 +78,78,4,4,0 +79,79,5,23,10 +80,80,5,23,10 +81,81,5,21,27 +82,82,5,9,27 +83,83,6,23,6 +84,84,6,23,6 +85,85,6,23,28 +86,86,6,23,28 +87,87,7,13,0 +88,88,7,21,0 +89,89,8,18,0 +90,90,8,23,35 +91,91,8,23,35 +92,92,12,14,0 +93,93,12,23,32 +94,94,12,23,32 +95,95,13,15,58 +96,96,13,23,58 +97,97,14,14,50 +98,98,14,23,50 +99,99,14,23,49 +100,100,14,23,49 +101,101,18,22,0 +102,102,19,23,0 +103,103,23,23,23 +104,104,23,23,23 +105,105,23,23,11 +106,106,23,23,11 diff --git a/scripts/test_ctc_metric.py b/scripts/test_ctc_metric.py index dc06637..fc6e9cf 100644 --- a/scripts/test_ctc_metric.py +++ b/scripts/test_ctc_metric.py @@ -2,16 +2,12 @@ import numpy as np import pandas as pd -from deepcell_tracking.isbi_utils import trk_to_isbi - from traccuracy import run_metrics +from traccuracy.matchers import CTCMatcher from traccuracy._tracking_graph import TrackingGraph -from traccuracy.matchers import CTCMatcher, IOUMatcher from traccuracy.metrics import CTCMetrics, DivisionMetrics from traccuracy.loaders._ctc import _get_node_attributes, ctc_to_graph, _check_ctc, load_ctc_data -from get_tracking_results import get_tracking_data, load_tracking_segmentation - def mark_potential_split(frames, last_frame, idx): if frames.max() == last_frame: # object is tracked until the last frame @@ -65,12 +61,10 @@ def extract_df_from_segmentation(segmentation): return pred_tracks_df -def evaluate_tracking(labels, curr_lineages, segmentation_method): - seg = load_tracking_segmentation(segmentation_method) - +def evaluate_tracking(raw, labels, seg, segmentation_method): if os.path.isdir(seg): # for trackmate stardist seg_T = load_ctc_data( - data_dir=seg, + data_dir=seg, track_path=os.path.join(seg, 'res_track.txt'), name=f'DynamicNuclearNet-{segmentation_method}' ) @@ -84,11 +78,12 @@ def evaluate_tracking(labels, curr_lineages, segmentation_method): breakpoint() - # calcuates node attributes for each detection + # calcuates node attributes for each detectionc gt_nodes = _get_node_attributes(labels) # converts inputs to isbi-tracking format - the version expected as inputs in traccuracy - gt_df = trk_to_isbi(curr_lineages, path=None) + # it's preconverted using "from deepcell_tracking.isbi_utils import trk_to_isbi" + gt_df = pd.read_csv("./gt_tracks.csv") # creates graphs from ctc-type info (isbi-type? probably means the same thing) gt_G = ctc_to_graph(gt_df, gt_nodes) @@ -106,22 +101,29 @@ def evaluate_tracking(labels, curr_lineages, segmentation_method): ) print(ctc_results) - breakpoint() - iou_results = run_metrics( - gt_data=gt_T, - pred_data=seg_T, - matcher=IOUMatcher(iou_threshold=0.1), - metrics=[DivisionMetrics(max_frame_buffer=0)], - ) - print(iou_results) +def get_tracking_data(segmentation_method): + import h5py + + with h5py.File("./tracking_micro_sam.h5", "r") as f: + raw = f["raw"][:] + labels = f["labels"][:] + + if segmentation_method.startswith("vit"): + segmentation = f[f"segmentations/{segmentation_method}"][:] + else: + ROOT = "/scratch/projects/nim00007/sam/for_tracking" + result_dir = os.path.join(ROOT, "results") + segmentation = os.path.join(result_dir, "trackmate_stardist", "01_RES") + + return raw, labels, segmentation def main(): - raw, labels, curr_lineages, chosen_frames = get_tracking_data() + segmentation_method = "trackmate_stardist" - segmentation_method = "vit_l_specialist" - evaluate_tracking(labels, curr_lineages, segmentation_method) + raw, labels, segmentation = get_tracking_data(segmentation_method) + evaluate_tracking(raw, labels, segmentation, segmentation_method) if __name__ == "__main__": From 60e2a097270f3cbf824e549b4ad90ff5757f8419 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Tue, 7 May 2024 15:50:47 +0200 Subject: [PATCH 5/6] Refactor paths --- scripts/test_ctc_metric.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/scripts/test_ctc_metric.py b/scripts/test_ctc_metric.py index fc6e9cf..4846d99 100644 --- a/scripts/test_ctc_metric.py +++ b/scripts/test_ctc_metric.py @@ -2,6 +2,8 @@ import numpy as np import pandas as pd +import h5py + from traccuracy import run_metrics from traccuracy.matchers import CTCMatcher from traccuracy._tracking_graph import TrackingGraph @@ -9,6 +11,10 @@ from traccuracy.loaders._ctc import _get_node_attributes, ctc_to_graph, _check_ctc, load_ctc_data +# ROOT = "/scratch/usr/nimanwai/micro-sam/for_tracking/for_traccuracy/" # hlrn +ROOT = "media/anwai/ANWAI/results/micro-sam/for_traccuracy/" # local + + def mark_potential_split(frames, last_frame, idx): if frames.max() == last_frame: # object is tracked until the last frame split_frame = None # they can't split in this case @@ -83,7 +89,7 @@ def evaluate_tracking(raw, labels, seg, segmentation_method): # converts inputs to isbi-tracking format - the version expected as inputs in traccuracy # it's preconverted using "from deepcell_tracking.isbi_utils import trk_to_isbi" - gt_df = pd.read_csv("./gt_tracks.csv") + gt_df = pd.read_csv(os.path.join(ROOT, "gt_tracks.csv")) # creates graphs from ctc-type info (isbi-type? probably means the same thing) gt_G = ctc_to_graph(gt_df, gt_nodes) @@ -103,18 +109,16 @@ def evaluate_tracking(raw, labels, seg, segmentation_method): def get_tracking_data(segmentation_method): - import h5py + _path = os.path.join(ROOT, "tracking_micro_sam.h5") - with h5py.File("./tracking_micro_sam.h5", "r") as f: + with h5py.File(_path, "r") as f: raw = f["raw"][:] labels = f["labels"][:] if segmentation_method.startswith("vit"): segmentation = f[f"segmentations/{segmentation_method}"][:] else: - ROOT = "/scratch/projects/nim00007/sam/for_tracking" - result_dir = os.path.join(ROOT, "results") - segmentation = os.path.join(result_dir, "trackmate_stardist", "01_RES") + segmentation = os.path.join(ROOT, "trackmate_stardist", "01_RES") return raw, labels, segmentation From 892ad6336cd419cb2e5a81906126b397e18f124e Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Tue, 7 May 2024 18:55:09 +0200 Subject: [PATCH 6/6] Add minor refactors - to the closing logic --- scripts/test_ctc_metric.py | 87 ++++++++++++++++++++++++++++++++++---- 1 file changed, 78 insertions(+), 9 deletions(-) diff --git a/scripts/test_ctc_metric.py b/scripts/test_ctc_metric.py index 4846d99..f3a32ea 100644 --- a/scripts/test_ctc_metric.py +++ b/scripts/test_ctc_metric.py @@ -10,9 +10,11 @@ from traccuracy.metrics import CTCMetrics, DivisionMetrics from traccuracy.loaders._ctc import _get_node_attributes, ctc_to_graph, _check_ctc, load_ctc_data +from scipy.ndimage import binary_closing + # ROOT = "/scratch/usr/nimanwai/micro-sam/for_tracking/for_traccuracy/" # hlrn -ROOT = "media/anwai/ANWAI/results/micro-sam/for_traccuracy/" # local +ROOT = "/media/anwai/ANWAI/results/micro-sam/for_traccuracy/" # local def mark_potential_split(frames, last_frame, idx): @@ -67,7 +69,7 @@ def extract_df_from_segmentation(segmentation): return pred_tracks_df -def evaluate_tracking(raw, labels, seg, segmentation_method): +def evaluate_tracking(raw, labels, seg, segmentation_method, filter_label_ids): if os.path.isdir(seg): # for trackmate stardist seg_T = load_ctc_data( data_dir=seg, @@ -82,14 +84,19 @@ def evaluate_tracking(raw, labels, seg, segmentation_method): _check_ctc(seg_df, seg_nodes, seg) seg_T = TrackingGraph(seg_G, segmentation=seg, name=f"DynamicNuclearNet-{segmentation_method}") - breakpoint() - + ids, sizes = np.unique(seg_df.Parent_ID.values, return_counts=True) + print(ids, sizes) # calcuates node attributes for each detectionc gt_nodes = _get_node_attributes(labels) # converts inputs to isbi-tracking format - the version expected as inputs in traccuracy # it's preconverted using "from deepcell_tracking.isbi_utils import trk_to_isbi" gt_df = pd.read_csv(os.path.join(ROOT, "gt_tracks.csv")) + mask = np.ones(len(gt_df), dtype="bool") + # breakpoint() + mask[np.isin(gt_df.Cell_ID, filter_label_ids)] = False + gt_df = gt_df[mask] + # breakpoint() # creates graphs from ctc-type info (isbi-type? probably means the same thing) gt_G = ctc_to_graph(gt_df, gt_nodes) @@ -108,7 +115,14 @@ def evaluate_tracking(raw, labels, seg, segmentation_method): print(ctc_results) -def get_tracking_data(segmentation_method): +def size_filter(segmentation, min_size=100): + ids, sizes = np.unique(segmentation, return_counts=True) + filter_ids = ids[sizes < min_size] + segmentation[np.isin(segmentation, filter_ids)] = 0 + return segmentation + + +def get_tracking_data(segmentation_method, visualize=False): _path = os.path.join(ROOT, "tracking_micro_sam.h5") with h5py.File(_path, "r") as f: @@ -117,17 +131,72 @@ def get_tracking_data(segmentation_method): if segmentation_method.startswith("vit"): segmentation = f[f"segmentations/{segmentation_method}"][:] + segmentation = size_filter(segmentation) else: segmentation = os.path.join(ROOT, "trackmate_stardist", "01_RES") - return raw, labels, segmentation + # test case + def check_consecutive(instances): + instance_ids = np.unique(instances)[1:] + + id_list = [] + for idx in instance_ids: + frames = np.unique(np.where(instances == idx)[0]) + consistent_instance = (sorted(frames) == list(range(min(frames), max(frames) + 1))) + if not consistent_instance: + id_list.append(idx) + + return id_list + + def rectify_labels(instances): + id_list = check_consecutive(instances) + print("Closing instances", id_list) + for idx in id_list: + object_mask = (instances == idx) + + structuring_element = np.zeros((3, 1, 1)) + structuring_element[:, 0, 0] = 1 + + closed_mask = binary_closing(object_mask.copy(), iterations=1, structure=structuring_element) + # breakpoint() + closed_mask = np.logical_or(object_mask, closed_mask) + # breakpoint() + + instances[closed_mask] = idx + + # import napari + # v = napari.Viewer() + # v.add_image(closed_mask.astype("uint8") * 255, name="After Closing", blending="additive", colormap="blue") + # v.add_image(object_mask.astype("uint8") * 255, name="Original") + # v.add_labels(instances, visible=False) + # napari.run() + + return instances + + filter_ids = check_consecutive(labels) + labels[np.isin(labels, filter_ids)] = 0 + + if not os.path.isdir(segmentation): + segmentation = rectify_labels(segmentation) + + if visualize: + import napari + + v = napari.Viewer() + v.add_image(raw) + if not os.path.isdir(segmentation): + v.add_labels(segmentation, visible=False) + + napari.run() + + return raw, labels, segmentation, filter_ids def main(): - segmentation_method = "trackmate_stardist" + segmentation_method = "vit_l_lm" - raw, labels, segmentation = get_tracking_data(segmentation_method) - evaluate_tracking(raw, labels, segmentation, segmentation_method) + raw, labels, segmentation, filter_ids = get_tracking_data(segmentation_method, visualize=False) + evaluate_tracking(raw, labels, segmentation, segmentation_method, filter_ids) if __name__ == "__main__":