Skip to content

Commit

Permalink
add sample_weight support to FScore metrics (#1816)
Browse files Browse the repository at this point in the history
* add sample_weight support to FScore metrics, generate test data with https://colab.research.google.com/drive/1ymB5iOj9YCeBQ-g-8eMGpQhiHr7QUBM4
  • Loading branch information
jharmsen authored May 14, 2020
1 parent 379fe62 commit 89aed9c
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 19 deletions.
25 changes: 14 additions & 11 deletions tensorflow_addons/metrics/f_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,6 @@ def _zero_wt_init(name):
self.false_negatives = _zero_wt_init("false_negatives")
self.weights_intermediate = _zero_wt_init("weights_intermediate")

# TODO: Add sample_weight support, currently it is
# ignored during calculations.
def update_state(self, y_true, y_pred, sample_weight=None):
if self.threshold is None:
threshold = tf.reduce_max(y_pred, axis=-1, keepdims=True)
Expand All @@ -131,17 +129,22 @@ def update_state(self, y_true, y_pred, sample_weight=None):
else:
y_pred = y_pred > self.threshold

y_true = tf.cast(y_true, tf.int32)
y_pred = tf.cast(y_pred, tf.int32)
y_true = tf.cast(y_true, self.dtype)
y_pred = tf.cast(y_pred, self.dtype)

def _count_non_zero(val):
non_zeros = tf.math.count_nonzero(val, axis=self.axis)
return tf.cast(non_zeros, self.dtype)
def _weighted_sum(val, sample_weight):
if sample_weight is not None:
val = tf.math.multiply(val, tf.expand_dims(sample_weight, 1))
return tf.reduce_sum(val, axis=self.axis)

self.true_positives.assign_add(_count_non_zero(y_pred * y_true))
self.false_positives.assign_add(_count_non_zero(y_pred * (y_true - 1)))
self.false_negatives.assign_add(_count_non_zero((y_pred - 1) * y_true))
self.weights_intermediate.assign_add(_count_non_zero(y_true))
self.true_positives.assign_add(_weighted_sum(y_pred * y_true, sample_weight))
self.false_positives.assign_add(
_weighted_sum(y_pred * (1 - y_true), sample_weight)
)
self.false_negatives.assign_add(
_weighted_sum((1 - y_pred) * y_true, sample_weight)
)
self.weights_intermediate.assign_add(_weighted_sum(y_true, sample_weight))

def result(self):
precision = tf.math.divide_no_nan(
Expand Down
90 changes: 82 additions & 8 deletions tensorflow_addons/metrics/tests/f_scores_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,17 @@ def test_config_fbeta():
assert fbeta_obj2.dtype == tf.float32


def _test_tf(avg, beta, act, pred, threshold):
def _test_tf(avg, beta, act, pred, sample_weights, threshold):
act = tf.constant(act, tf.float32)
pred = tf.constant(pred, tf.float32)

fbeta = FBetaScore(3, avg, beta, threshold)
fbeta.update_state(act, pred)
fbeta.update_state(act, pred, sample_weights)
return fbeta.result().numpy()


def _test_fbeta_score(actuals, preds, avg, beta_val, result, threshold):
tf_score = _test_tf(avg, beta_val, actuals, preds, threshold)
def _test_fbeta_score(actuals, preds, sample_weights, avg, beta_val, result, threshold):
tf_score = _test_tf(avg, beta_val, actuals, preds, sample_weights, threshold)
np.testing.assert_allclose(tf_score, result, atol=1e-7, rtol=1e-6)


Expand All @@ -58,7 +58,7 @@ def test_fbeta_perfect_score():

for avg_val in ["micro", "macro", "weighted"]:
for beta in [0.5, 1.0, 2.0]:
_test_fbeta_score(actuals, preds, avg_val, beta, 1.0, 0.66)
_test_fbeta_score(actuals, preds, None, avg_val, beta, 1.0, 0.66)


def test_fbeta_worst_score():
Expand All @@ -67,7 +67,7 @@ def test_fbeta_worst_score():

for avg_val in ["micro", "macro", "weighted"]:
for beta in [0.5, 1.0, 2.0]:
_test_fbeta_score(actuals, preds, avg_val, beta, 0.0, 0.66)
_test_fbeta_score(actuals, preds, None, avg_val, beta, 0.0, 0.66)


@pytest.mark.parametrize(
Expand All @@ -90,7 +90,7 @@ def test_fbeta_worst_score():
def test_fbeta_random_score(avg_val, beta, result):
preds = [[0.7, 0.7, 0.7], [1, 0, 0], [0.9, 0.8, 0]]
actuals = [[0, 0, 1], [1, 1, 0], [1, 1, 1]]
_test_fbeta_score(actuals, preds, avg_val, beta, result, 0.66)
_test_fbeta_score(actuals, preds, None, avg_val, beta, result, 0.66)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -120,7 +120,61 @@ def test_fbeta_random_score_none(avg_val, beta, result):
[0, 0, 1],
]
actuals = [[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [1, 0, 0], [0, 0, 1]]
_test_fbeta_score(actuals, preds, avg_val, beta, result, None)
_test_fbeta_score(actuals, preds, None, avg_val, beta, result, None)


@pytest.mark.parametrize(
"avg_val, beta, sample_weights, result",
[
(None, 0.5, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [0.909091, 0.555556, 1.0]),
(None, 0.5, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], [1.0, 0.0, 1.0]),
(None, 0.5, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], [0.9375, 0.714286, 1.0]),
(None, 1.0, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [0.8, 0.666667, 1.0]),
(None, 1.0, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], [1.0, 0.0, 1.0]),
(None, 1.0, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], [0.857143, 0.8, 1.0]),
(None, 2.0, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [0.714286, 0.833333, 1.0]),
(None, 2.0, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], [1.0, 0.0, 1.0]),
(None, 2.0, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], [0.789474, 0.909091, 1.0]),
("micro", 0.5, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 0.833333),
("micro", 0.5, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], 1.0),
("micro", 0.5, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], 0.9),
("micro", 1.0, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 0.833333),
("micro", 1.0, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], 1.0),
("micro", 1.0, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], 0.9),
("micro", 2.0, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 0.833333),
("micro", 2.0, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], 1.0),
("micro", 2.0, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], 0.9),
("macro", 0.5, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 0.821549),
("macro", 0.5, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], 0.666667),
("macro", 0.5, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], 0.883929),
("macro", 1.0, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 0.822222),
("macro", 1.0, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], 0.666667),
("macro", 1.0, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], 0.885714),
("macro", 2.0, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 0.849206),
("macro", 2.0, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], 0.666667),
("macro", 2.0, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], 0.899522),
("weighted", 0.5, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 0.880471),
("weighted", 0.5, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], 1.0),
("weighted", 0.5, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], 0.917857),
("weighted", 1.0, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 0.844444),
("weighted", 1.0, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], 1.0),
("weighted", 1.0, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], 0.902857),
("weighted", 2.0, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 0.829365),
("weighted", 2.0, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], 1.0),
("weighted", 2.0, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], 0.897608),
],
)
def test_fbeta_weighted_random_score_none(avg_val, beta, sample_weights, result):
preds = [
[0.9, 0.1, 0],
[0.2, 0.6, 0.2],
[0, 0, 1],
[0.4, 0.3, 0.3],
[0, 0.9, 0.1],
[0, 0, 1],
]
actuals = [[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [1, 0, 0], [0, 0, 1]]
_test_fbeta_score(actuals, preds, sample_weights, avg_val, beta, result, None)


def test_keras_model():
Expand All @@ -147,6 +201,26 @@ def test_eq():
np.testing.assert_allclose(fbeta.result().numpy(), f1.result().numpy())


def test_sample_eq():
f1 = F1Score(3)
f1_weighted = F1Score(3)

preds = [
[0.9, 0.1, 0],
[0.2, 0.6, 0.2],
[0, 0, 1],
[0.4, 0.3, 0.3],
[0, 0.9, 0.1],
[0, 0, 1],
]
actuals = [[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [1, 0, 0], [0, 0, 1]]
sample_weights = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]

f1.update_state(actuals, preds)
f1_weighted(actuals, preds, sample_weights)
np.testing.assert_allclose(f1.result().numpy(), f1_weighted.result().numpy())


def test_keras_model_f1():
f1 = F1Score(5)
utils._get_model(f1, 5)
Expand Down

0 comments on commit 89aed9c

Please sign in to comment.