diff --git a/torch_em/util/util.py b/torch_em/util/util.py index 4ecbf31c..8b1f8e0b 100644 --- a/torch_em/util/util.py +++ b/torch_em/util/util.py @@ -5,6 +5,8 @@ import numpy as np import torch import torch_em +import matplotlib.pyplot as plt +from matplotlib import colors # this is a fairly brittle way to check if a module is compiled. # would be good to find a better solution, ideall something like @@ -246,3 +248,13 @@ def model_is_equal(model1, model2): if p1.data.ne(p2.data).sum() > 0: return False return True + + + +def get_random_colors(labels): + """Function to generate a random color map for a label image + """ + n_labels = len(np.unique(labels)) - 1 + cmap = [[0, 0, 0]] + np.random.rand(n_labels, 3).tolist() + cmap = colors.ListedColormap(cmap) + return cmap