Skip to content

Commit

Permalink
Some bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
keisen committed Mar 25, 2024
1 parent ca28edf commit 569bbe3
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions tests/tf_keras_vis/utils/model_modifiers_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import tensorflow as tf
from packaging.version import parse as version

from tf_keras_vis.activation_maximization import ActivationMaximization
from tf_keras_vis.saliency import Saliency
Expand Down Expand Up @@ -37,12 +38,18 @@ class TestExtractIntermediateLayer():
])
@pytest.mark.usefixtures("mixed_precision")
def test__call__(self, model, layer, expected_error):
assert model.outputs[0].shape == (None, 2)
if version(tf.version.VERSION) < version("2.15"):
assert model.outputs[0].shape.as_list() == [None, 2]
else:
assert model.outputs[0].shape == (None, 2)
with assert_raises(expected_error):
instance = ActivationMaximization(model,
model_modifier=ExtractIntermediateLayer(layer))
assert instance.model != model
assert instance.model.outputs[0].shape == (None, 6, 6, 6)
if version(tf.version.VERSION) < version("2.15"):
assert instance.model.outputs[0].shape.as_list() == [None, 6, 6, 6]
else:
assert instance.model.outputs[0].shape == (None, 6, 6, 6)
instance([CategoricalScore(0)])


Expand Down

0 comments on commit 569bbe3

Please sign in to comment.