diff --git a/scripts/README.md b/scripts/README.md new file mode 100644 index 0000000..79fdc9e --- /dev/null +++ b/scripts/README.md @@ -0,0 +1,4 @@ +- VIT-L-LM: {'TRA': 0.9842945483712385, 'DET': 0.9847838957963292} +- VIT-L: {'TRA': 0.9821987632280266, 'DET': 0.9832445233866193} +- VIT-L-Specialist': {'TRA': 0.9734016404046677, 'DET': 0.975902901124926}} +- TRACKMATE: {'TRA': 0.9497270304535693, 'DET': 0.9510953226761397} diff --git a/scripts/test_ctc_metric.py b/scripts/test_ctc_metric.py index f3a32ea..6d6cc86 100644 --- a/scripts/test_ctc_metric.py +++ b/scripts/test_ctc_metric.py @@ -14,7 +14,8 @@ # ROOT = "/scratch/usr/nimanwai/micro-sam/for_tracking/for_traccuracy/" # hlrn -ROOT = "/media/anwai/ANWAI/results/micro-sam/for_traccuracy/" # local +# ROOT = "/media/anwai/ANWAI/results/micro-sam/for_traccuracy/" # local +ROOT = "./for_traccuracy" def mark_potential_split(frames, last_frame, idx): @@ -34,6 +35,8 @@ def extract_df_from_segmentation(segmentation): all_tracks = [] prev_parent_id = None + parent_to_children = {} + for idx in track_ids: frames = np.unique(np.where(segmentation == idx)[0]) @@ -46,6 +49,7 @@ def extract_df_from_segmentation(segmentation): # have fam is the end frame of the potential parent, so our frame has to be the next frame if split_frame + 1 == frames.min(): pid = prev_parent_id + parent_to_children[pid] = parent_to_children.get(pid, []) + [idx] # otherwise we just have some track that starts so it's not the child else: @@ -66,6 +70,28 @@ def extract_df_from_segmentation(segmentation): all_tracks.append(pd.DataFrame.from_dict([track_dict])) pred_tracks_df = pd.concat(all_tracks) + + # Remove false positive splits + false_parents = [pid for pid, children in parent_to_children.items() if len(children) != 2] + pred_tracks_df.loc[np.isin(pred_tracks_df.Parent_ID.values, false_parents), "Parent_ID"] = 0 + + # breakpoint() + + parent_ids, n_children = np.unique(pred_tracks_df.Parent_ID.values, return_counts=True) + assert (n_children[1:] == 2).all() + + # # For visualization. + # import napari + # from nifty.tools import takeDict + # replace = {cid: (cid if pid == 0 else pid) for cid, pid in zip(pred_tracks_df.Cell_ID, pred_tracks_df.Parent_ID)} + # replace[0] = 0 + # lineage = takeDict(replace, segmentation) + + # v = napari.Viewer() + # v.add_labels(segmentation) + # v.add_labels(lineage) + # napari.run() + return pred_tracks_df @@ -84,8 +110,6 @@ def evaluate_tracking(raw, labels, seg, segmentation_method, filter_label_ids): _check_ctc(seg_df, seg_nodes, seg) seg_T = TrackingGraph(seg_G, segmentation=seg, name=f"DynamicNuclearNet-{segmentation_method}") - ids, sizes = np.unique(seg_df.Parent_ID.values, return_counts=True) - print(ids, sizes) # calcuates node attributes for each detectionc gt_nodes = _get_node_attributes(labels) @@ -118,6 +142,8 @@ def evaluate_tracking(raw, labels, seg, segmentation_method, filter_label_ids): def size_filter(segmentation, min_size=100): ids, sizes = np.unique(segmentation, return_counts=True) filter_ids = ids[sizes < min_size] + # HACK FOR VIT SPECIALIST + # filter_ids = np.concatenate([filter_ids, [30, 45]]) segmentation[np.isin(segmentation, filter_ids)] = 0 return segmentation @@ -156,7 +182,7 @@ def rectify_labels(instances): structuring_element = np.zeros((3, 1, 1)) structuring_element[:, 0, 0] = 1 - + closed_mask = binary_closing(object_mask.copy(), iterations=1, structure=structuring_element) # breakpoint() closed_mask = np.logical_or(object_mask, closed_mask) @@ -193,7 +219,7 @@ def rectify_labels(instances): def main(): - segmentation_method = "vit_l_lm" + segmentation_method = "trackmate_stardist" raw, labels, segmentation, filter_ids = get_tracking_data(segmentation_method, visualize=False) evaluate_tracking(raw, labels, segmentation, segmentation_method, filter_ids)