forked from modAL-python/modAL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
disagreement.py
208 lines (162 loc) · 8.62 KB
/
disagreement.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
"""
Disagreement measures and disagreement based query strategies for the Committee model.
"""
from collections import Counter
from typing import Tuple
import numpy as np
from scipy.stats import entropy
from sklearn.exceptions import NotFittedError
from sklearn.base import BaseEstimator
from modAL.utils.data import modALinput
from modAL.utils.selection import multi_argmax, shuffled_argmax
from modAL.models.base import BaseCommittee
def vote_entropy(committee: BaseCommittee, X: modALinput, **predict_proba_kwargs) -> np.ndarray:
"""
Calculates the vote entropy for the Committee. First it computes the predictions of X for each learner in the
Committee, then calculates the probability distribution of the votes. The entropy of this distribution is the vote
entropy of the Committee, which is returned.
Args:
committee: The :class:`modAL.models.BaseCommittee` instance for which the vote entropy is to be calculated.
X: The data for which the vote entropy is to be calculated.
**predict_proba_kwargs: Keyword arguments for the :meth:`predict_proba` of the Committee.
Returns:
Vote entropy of the Committee for the samples in X.
"""
n_learners = len(committee)
try:
votes = committee.vote(X, **predict_proba_kwargs)
except NotFittedError:
return np.zeros(shape=(X.shape[0],))
p_vote = np.zeros(shape=(X.shape[0], len(committee.classes_)))
for vote_idx, vote in enumerate(votes):
vote_counter = Counter(vote)
for class_idx, class_label in enumerate(committee.classes_):
p_vote[vote_idx, class_idx] = vote_counter[class_label]/n_learners
entr = entropy(p_vote, axis=1)
return entr
def consensus_entropy(committee: BaseCommittee, X: modALinput, **predict_proba_kwargs) -> np.ndarray:
"""
Calculates the consensus entropy for the Committee. First it computes the class probabilties of X for each learner
in the Committee, then calculates the consensus probability distribution by averaging the individual class
probabilities for each learner. The entropy of the consensus probability distribution is the vote entropy of the
Committee, which is returned.
Args:
committee: The :class:`modAL.models.BaseCommittee` instance for which the consensus entropy is to be calculated.
X: The data for which the consensus entropy is to be calculated.
**predict_proba_kwargs: Keyword arguments for the :meth:`predict_proba` of the Committee.
Returns:
Consensus entropy of the Committee for the samples in X.
"""
try:
proba = committee.predict_proba(X, **predict_proba_kwargs)
except NotFittedError:
return np.zeros(shape=(X.shape[0],))
entr = np.transpose(entropy(np.transpose(proba)))
return entr
def KL_max_disagreement(committee: BaseCommittee, X: modALinput, **predict_proba_kwargs) -> np.ndarray:
"""
Calculates the max disagreement for the Committee. First it computes the class probabilties of X for each learner in
the Committee, then calculates the consensus probability distribution by averaging the individual class
probabilities for each learner. Then each learner's class probabilities are compared to the consensus distribution
in the sense of Kullback-Leibler divergence. The max disagreement for a given sample is the argmax of the KL
divergences of the learners from the consensus probability.
Args:
committee: The :class:`modAL.models.BaseCommittee` instance for which the max disagreement is to be calculated.
X: The data for which the max disagreement is to be calculated.
**predict_proba_kwargs: Keyword arguments for the :meth:`predict_proba` of the Committee.
Returns:
Max disagreement of the Committee for the samples in X.
"""
try:
p_vote = committee.vote_proba(X, **predict_proba_kwargs)
except NotFittedError:
return np.zeros(shape=(X.shape[0],))
p_consensus = np.mean(p_vote, axis=1)
learner_KL_div = np.zeros(shape=(X.shape[0], len(committee)))
for learner_idx, _ in enumerate(committee):
learner_KL_div[:, learner_idx] = entropy(np.transpose(p_vote[:, learner_idx, :]), qk=np.transpose(p_consensus))
return np.max(learner_KL_div, axis=1)
def vote_entropy_sampling(committee: BaseCommittee, X: modALinput,
n_instances: int = 1, random_tie_break=False,
**disagreement_measure_kwargs) -> np.ndarray:
"""
Vote entropy sampling strategy.
Args:
committee: The committee for which the labels are to be queried.
X: The pool of samples to query from.
n_instances: Number of samples to be queried.
random_tie_break: If True, shuffles utility scores to randomize the order. This
can be used to break the tie when the highest utility score is not unique.
**disagreement_measure_kwargs: Keyword arguments to be passed for the disagreement
measure function.
Returns:
The indices of the instances from X chosen to be labelled;
the instances from X chosen to be labelled.
"""
disagreement = vote_entropy(committee, X, **disagreement_measure_kwargs)
if not random_tie_break:
return multi_argmax(disagreement, n_instances=n_instances)
return shuffled_argmax(disagreement, n_instances=n_instances)
def consensus_entropy_sampling(committee: BaseCommittee, X: modALinput,
n_instances: int = 1, random_tie_break=False,
**disagreement_measure_kwargs) -> np.ndarray:
"""
Consensus entropy sampling strategy.
Args:
committee: The committee for which the labels are to be queried.
X: The pool of samples to query from.
n_instances: Number of samples to be queried.
random_tie_break: If True, shuffles utility scores to randomize the order. This
can be used to break the tie when the highest utility score is not unique.
**disagreement_measure_kwargs: Keyword arguments to be passed for the disagreement
measure function.
Returns:
The indices of the instances from X chosen to be labelled;
the instances from X chosen to be labelled.
"""
disagreement = consensus_entropy(committee, X, **disagreement_measure_kwargs)
if not random_tie_break:
return multi_argmax(disagreement, n_instances=n_instances)
return shuffled_argmax(disagreement, n_instances=n_instances)
def max_disagreement_sampling(committee: BaseCommittee, X: modALinput,
n_instances: int = 1, random_tie_break=False,
**disagreement_measure_kwargs) -> np.ndarray:
"""
Maximum disagreement sampling strategy.
Args:
committee: The committee for which the labels are to be queried.
X: The pool of samples to query from.
n_instances: Number of samples to be queried.
random_tie_break: If True, shuffles utility scores to randomize the order. This
can be used to break the tie when the highest utility score is not unique.
**disagreement_measure_kwargs: Keyword arguments to be passed for the disagreement
measure function.
Returns:
The indices of the instances from X chosen to be labelled;
the instances from X chosen to be labelled.
"""
disagreement = KL_max_disagreement(committee, X, **disagreement_measure_kwargs)
if not random_tie_break:
return multi_argmax(disagreement, n_instances=n_instances)
return shuffled_argmax(disagreement, n_instances=n_instances)
def max_std_sampling(regressor: BaseEstimator, X: modALinput,
n_instances: int = 1, random_tie_break=False,
**predict_kwargs) -> np.ndarray:
"""
Regressor standard deviation sampling strategy.
Args:
regressor: The regressor for which the labels are to be queried.
X: The pool of samples to query from.
n_instances: Number of samples to be queried.
random_tie_break: If True, shuffles utility scores to randomize the order. This
can be used to break the tie when the highest utility score is not unique.
**predict_kwargs: Keyword arguments to be passed to :meth:`predict` of the CommiteeRegressor.
Returns:
The indices of the instances from X chosen to be labelled;
the instances from X chosen to be labelled.
"""
_, std = regressor.predict(X, return_std=True, **predict_kwargs)
std = std.reshape(X.shape[0], )
if not random_tie_break:
return multi_argmax(std, n_instances=n_instances)
return shuffled_argmax(std, n_instances=n_instances)