-
Notifications
You must be signed in to change notification settings - Fork 0
/
plotter.py
175 lines (137 loc) · 9.59 KB
/
plotter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
from enum import Enum
import numpy as np
from data.plotter_evaluation import find_acc_at_early_stop_indices, find_early_stop_iterations, get_means_and_y_errors
class PlotType(str, Enum):
""" Define available types of plots. """
TRAIN_LOSS = "Training-Loss"
VAL_LOSS = "Validation-Loss"
VAL_ACC = "Validation-Accuracy"
TEST_ACC = "Test-Accuracy"
EARLY_STOP_ITER = "Early-Stopping Iteration"
# generators
def gen_iteration_space(arr, plot_step):
""" Generate a linear space from 'plot_step' with the same length as 'arr' and step size 'plot_step'. """
len_arr = len(arr)
return np.linspace(start=plot_step, stop=len_arr * plot_step, num=len_arr)
def gen_labels_on_ax(ax, plot_type: PlotType, iteration=True):
""" Generate labels for the x- and y-axis on given ax.
'iteration' defines if the x-label shows iterations or sparsity. """
ax.set_ylabel(f"{plot_type.value}")
ax.set_xlabel(f"{'Iteration' if iteration else 'Sparsity'}")
def gen_title_on_ax(ax, plot_type: PlotType, early_stop=False):
""" Generate plot-title on given ax.
'early_stop' defines if early-stopping should be mentioned. """
ax.set_title(f"Average {plot_type.value}{' at early-stop' if early_stop else ''}")
def setup_early_stop_ax(ax, force_zero, log_step=7):
""" Invert x-axis and activate log-scale with 'log_step' steps for x-axis. """
ax.set_xscale('log', base=2)
ax.set_xticks([2 ** (-p) for p in range(log_step)])
ax.set_xticklabels([2 ** (-p) for p in range(log_step)])
ax.invert_xaxis() # also inverts plot!
setup_grids_on_ax(ax, force_zero) # for correct scaling the grids need to be set after plotting
def setup_grids_on_ax(ax, force_zero=False):
""" Setup grids on given ax.
'force_zero' sets the minimum y-value to zero, e.g. for loss plots. """
ax.grid()
if force_zero:
ax.set_ylim(bottom=0)
def setup_labeling_on_ax(ax, plot_type: PlotType, iteration=True, early_stop=False):
""" Setup complete labeling on ax, i.e. generate title, labels and legend. """
gen_title_on_ax(ax, plot_type, early_stop)
gen_labels_on_ax(ax, plot_type, iteration)
ax.legend()
# subplots
def plot_average_at_early_stop_on_ax(ax, hists, sparsity_hist, net_name, random=False, color=None):
""" Plot means and error-bars for given early-stopping iterations or accuracies on ax.
Suppose 'hists' has shape (net_count, prune_count+1, 1) for accuracies or (net_count, prune_count+1) for iterations.
Suppose 'sparsity_hist' has shape (prune_count+1).
'hists' is a solid line with error bars if 'random' is False, and a dotted line otherwise.
If 'color' is not specified, choose the next color from the color cycle. """
mean, neg_y_err, pos_y_err = get_means_and_y_errors(hists)
# for accuracies each mean and y_err has shape (prune_count+1, 1), so squeeze them to shape (prune_count+1)
mean = np.squeeze(mean)
neg_y_err = np.squeeze(neg_y_err)
pos_y_err = np.squeeze(pos_y_err)
# plot and return instance of `ErrorbarContainer` to read its color
return ax.errorbar(x=sparsity_hist[1:] if random else sparsity_hist, y=mean, yerr=[neg_y_err, pos_y_err],
color=color, elinewidth=1, marker='x', ls=':' if random else '-',
label=f"{net_name} reinit" if random else net_name)
def plot_averages_on_ax(ax, hists, sparsity_hist, plot_step, random=False):
""" Plot means and error-bars for 'hists' on ax.
Suppose hists has shape (net_count, prune_count+1) and 'sparsity_hist' has shape (prune_count+1).
Plot dashed baseline (unpruned) and a solid line for each pruning step, if 'random' is False.
Plot dotted lines for each pruning step, if 'random' is True. """
_, prune_count, _ = hists.shape
prune_count -= 1 # baseline at index 0, thus first pruned round at index 1
h_mean, h_neg_y_err, h_pos_y_err = get_means_and_y_errors(hists)
xs = gen_iteration_space(h_mean[0], plot_step)
if random:
plot_pruned_means_on_ax(ax, xs, h_mean, h_neg_y_err, h_pos_y_err, sparsity_hist, prune_count, ':')
else:
plot_baseline_mean_on_ax(ax, xs, h_mean[0], h_neg_y_err[0], h_pos_y_err[0])
plot_pruned_means_on_ax(ax, xs, h_mean[1:], h_neg_y_err[1:], h_pos_y_err[1:], sparsity_hist[1:], prune_count)
def plot_baseline_mean_on_ax(ax, xs, ys, y_err_neg, y_err_pos):
""" Plot the baseline as dashed line wit error bars on given ax. """
ax.errorbar(x=xs, y=ys, yerr=[y_err_neg, y_err_pos], elinewidth=1.2, ls='--', color="C0", label="Sparsity 1.0000",
errorevery=5, capsize=2)
def plot_pruned_means_on_ax(ax, xs, ys, y_err_neg, y_err_pos, sparsity_hist, prune_count, ls='-'):
""" Plot means per pruning level between 'prune_min' and 'prune_max' as line with error bars on given ax.
Labels contain the sparsity at given level of pruning.
'ls' specifies the line style (e.g. '-'=solid and ':'=dotted) and colors start with color-spec C1. """
for p in range(prune_count):
ax.errorbar(x=xs, y=ys[p], yerr=[y_err_neg[p], y_err_pos[p]], color=f"C{p + 1}", elinewidth=1.2, ls=ls,
label=f"Sparsity {sparsity_hist[p]:.4f}", errorevery=5, capsize=2)
# plots
def plot_acc_at_early_stop_on_ax(ax, loss_hists, acc_hists, sparsity_hist, net_name, plot_type: PlotType,
rnd_loss_hists=None, rnd_acc_hists=None, force_zero=False, setup_ax=True, log_step=7):
""" Plot means and error bars for the given accuracies at the time an early stopping criterion would end training.
Use 'loss_hists' to find accuracies from 'acc_hists', analog for random histories, if given.
Suppose 'acc_hists' and 'loss_hists' have shape (net_count, prune_count+1, data_length), 'rnd_acc_hists' and
'rnd_loss_hists' have shape (net_count, prune_count, data_length) and 'sparsity_hist' has shape (prune_count+1)
with prune_count > 1.
If 'setup_ax' is True, add grids and labels, invert x-axis and apply log-scale to x-axis.
Use 'net_name' to generate labels for the legend.
Plot accuracies as solid line and random accuracies as dotted line in the same color. """
assert (sparsity_hist.shape[0] > 1) and (sparsity_hist.shape[0] == loss_hists.shape[1]), \
f"'prune_count' (dimension of 'sparsity_hist') needs to be greater than one, but is {sparsity_hist.shape}."
early_stop_acc = find_acc_at_early_stop_indices(loss_hists, acc_hists)
original_plot = plot_average_at_early_stop_on_ax(ax, early_stop_acc, sparsity_hist, net_name, random=False)
if rnd_loss_hists is not None and rnd_acc_hists is not None:
random_early_stop_acc = find_acc_at_early_stop_indices(rnd_loss_hists, rnd_acc_hists)
plot_average_at_early_stop_on_ax(ax, random_early_stop_acc, sparsity_hist, net_name, random=True,
color=original_plot.lines[0].get_color())
if setup_ax:
setup_early_stop_ax(ax, force_zero, log_step)
setup_labeling_on_ax(ax, plot_type, iteration=False, early_stop=True)
def plot_average_hists_on_ax(ax, hists, sparsity_hist, plot_step, plot_type: PlotType, rnd_hists=None,
force_zero=False):
""" Plot means and error bars for the given histories in 'hists' and 'rnd_hists' (if given).
Suppose 'hists' has shape (net_count, prune_count+1, data_length), 'rnd_hists' has shape
(net_count, prune_count, data_length) and 'sparsity_hist' has shape (prune_count+1).
The x-axis is labeled with iterations, which are reconstructed from plot_step.
The baseline (i.e. the lowest sparsity) is a dashed line, all further pruning-levels from 'hists' are solid lines,
all levels of pruning from 'rnd_hists' are dotted lines. """
plot_averages_on_ax(ax, hists, sparsity_hist, plot_step, random=False)
if rnd_hists is not None:
plot_averages_on_ax(ax, rnd_hists, sparsity_hist[1:], plot_step, random=True)
setup_grids_on_ax(ax, force_zero) # for correct scaling the grids need to be set after plotting
setup_labeling_on_ax(ax, plot_type, iteration=True, early_stop=False)
def plot_early_stop_iterations_on_ax(ax, loss_hists, sparsity_hist, plot_step, net_name, rnd_loss_hists=None,
force_zero=False, setup_ax=True, log_step=7):
""" Plot means and error bars for early-stopping iterations based on 'loss_hists' and 'rnd_loss_hists', if given.
Suppose 'loss_hists' has shape (net_count, prune_count+1, data_length), 'loss_hists' has shape
(net_count, prune_count, data_length) and 'sparsity_hist' has shape (prune_count+1) with prune_count > 1.
If 'setup_ax' is True, add grids and labels, invert x-axis and apply log-scale to x-axis.
Use 'net_name' to generate labels for the legend.
Plot iterations as solid line and random iterations as dotted line in the same color. """
assert (sparsity_hist.shape[0] > 1) and (sparsity_hist.shape[0] == loss_hists.shape[1]), \
f"'prune_count' (dimension of 'sparsity_hist') needs to be greater than one, but is {sparsity_hist.shape}."
early_stop_iterations = find_early_stop_iterations(loss_hists, plot_step)
original_plot = plot_average_at_early_stop_on_ax(ax, early_stop_iterations, sparsity_hist, net_name, random=False)
if rnd_loss_hists is not None:
random_early_stop_iterations = find_early_stop_iterations(rnd_loss_hists, plot_step)
plot_average_at_early_stop_on_ax(ax, random_early_stop_iterations, sparsity_hist, net_name, random=True,
color=original_plot.lines[0].get_color())
if setup_ax:
setup_early_stop_ax(ax, force_zero, log_step)
setup_labeling_on_ax(ax, PlotType.EARLY_STOP_ITER, iteration=False, early_stop=True)