Skip to content

Commit

Permalink
Update _safe_updates to use global dict. Adds tests for plots
Browse files Browse the repository at this point in the history
  • Loading branch information
robsdavis committed Aug 24, 2023
1 parent 54e201d commit 3b54c72
Show file tree
Hide file tree
Showing 14 changed files with 226 additions and 37 deletions.
66 changes: 43 additions & 23 deletions src/datagnosis/plugins/core/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,20 @@
from datagnosis.utils.constants import DEVICE
from datagnosis.utils.reproducibility import clear_cache, enable_reproducible_results

# The complete list of parameters that can be passed to the _updates method of any of the plugins.
UPDATE_PARAMS: Dict[str, Any] = {
"y_pred": None,
"y_batch": None,
"sample_ids": None,
"net": None,
"device": None,
"logits": None,
"targets": None,
"probs": None,
"indices": None,
"data_uncert_class": None,
}


# Base class for Hardness Classification Methods (HCMs)
class Plugin(metaclass=ABCMeta):
Expand Down Expand Up @@ -89,6 +103,8 @@ def __init__(
log.debug("Fixing seed for reproducibility.")
enable_reproducible_results(0)
log.debug(f"Initialized parent plugin for {self.name()}")
# Update UPDATE_PARAMS
UPDATE_PARAMS["device"] = self.device

@staticmethod
@abstractmethod
Expand Down Expand Up @@ -194,11 +210,15 @@ def fit(
indices,
batch_idx,
) = utils.load_update_values_from_cache(update_values_cache_file)

if self.update_point == "mid-epoch":
log.debug("Updating scores mid-epoch")
self._safe_update(
y_pred=outputs, y_batch=true_label, sample_ids=indices
)
# Update UPDATE_PARAMS before calling _safe_update() for mid-epoch
UPDATE_PARAMS["y_pred"] = outputs
UPDATE_PARAMS["y_batch"] = true_label
UPDATE_PARAMS["sample_ids"] = indices
UPDATE_PARAMS["net"] = self.model
self._safe_update(**UPDATE_PARAMS)
else:
if (
use_caches_if_exist
Expand All @@ -220,9 +240,12 @@ def fit(

if self.update_point == "mid-epoch":
log.debug("Updating scores mid-epoch")
self._safe_update(
y_pred=outputs, y_batch=true_label, sample_ids=indices
)
# Update UPDATE_PARAMS before calling _safe_update() for mid-epoch
UPDATE_PARAMS["y_pred"] = outputs
UPDATE_PARAMS["y_batch"] = true_label
UPDATE_PARAMS["sample_ids"] = indices
UPDATE_PARAMS["net"] = self.model
self._safe_update(**UPDATE_PARAMS)
loss = self.criterion(outputs, true_label)
loss.backward()
self.optimizer.step()
Expand Down Expand Up @@ -286,25 +309,23 @@ def fit(

if self.update_point == "per-epoch":
log.debug(f"Updating plugin after epoch {epoch+1}")
self._safe_update(
net=self.model,
device=self.device,
logits=logits,
targets=targets,
probs=probs,
indices=indices,
)
# Update UPDATE_PARAMS before calling _safe_update() for per-epoch
UPDATE_PARAMS["net"] = self.model
UPDATE_PARAMS["logits"] = logits
UPDATE_PARAMS["targets"] = targets
UPDATE_PARAMS["probs"] = probs
UPDATE_PARAMS["indices"] = indices
self._safe_update(**UPDATE_PARAMS)

if self.update_point == "post-epoch":
log.debug("Updating plugin after training")
self._safe_update(
net=self.model,
data_uncert_class=self.data_uncert_class,
device=self.device,
logits=logits,
targets=targets,
probs=probs,
)
# Update UPDATE_PARAMS before calling _safe_update() for per-epoch
UPDATE_PARAMS["net"] = self.model
UPDATE_PARAMS["logits"] = logits
UPDATE_PARAMS["targets"] = targets
UPDATE_PARAMS["probs"] = probs
UPDATE_PARAMS["data_uncert_class"] = self.data_uncert_class
self._safe_update(**UPDATE_PARAMS)
self.has_been_fit = True
self.compute_scores()
log.debug("Plugin fit complete and scores computed.")
Expand Down Expand Up @@ -447,7 +468,6 @@ def _extract_datapoints_by_top_n(
log.info(extracted)
extracted = sorted(extracted)
log.info(extracted)
print("top_n", self.dataloader_unshuffled.dataset[extracted])
return (
self.dataloader_unshuffled.dataset[extracted], # pyright: ignore
extraction_scores[extracted],
Expand Down
26 changes: 13 additions & 13 deletions tests/plugins/core/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,19 +172,19 @@ def test_mock_plugin_extract_all_methods() -> None:
[0.10, 0.16, 0.15, 0.14, 0.13, 0.12, 0.11, 0.14, 0.12, 0.06]
), np.asarray([0.01, 0.16, 0.15, 0.14, 0.13, 0.12, 0.11, 0.14, 0.12, 0.12])

# # extract datapoints
# # extracted have format: ((features, Labels, Indices), scores)
# extract_indices = [0, 1, 5]
# extracted = plugin.extract_datapoints(method="index", indices=extract_indices)
# assert isinstance(extracted, tuple)
# assert isinstance(extracted[0][0], torch.Tensor)
# assert isinstance(extracted[0][1], torch.Tensor)
# assert isinstance(extracted[0][2], List)
# assert isinstance(extracted[1], np.ndarray)
# assert extracted[0][0].shape[0] == len(extract_indices)
# assert extracted[0][1].shape[0] == len(extract_indices)
# assert len(extracted[0][2]) == len(extract_indices)
# assert extracted[1].shape[0] == len(extract_indices)
# extract datapoints
# extracted have format: ((features, Labels, Indices), scores)
extract_indices = [0, 1, 5]
extracted = plugin.extract_datapoints(method="index", indices=extract_indices)
assert isinstance(extracted, tuple)
assert isinstance(extracted[0][0], torch.Tensor)
assert isinstance(extracted[0][1], torch.Tensor)
assert isinstance(extracted[0][2], List)
assert isinstance(extracted[1], np.ndarray)
assert extracted[0][0].shape[0] == len(extract_indices)
assert extracted[0][1].shape[0] == len(extract_indices)
assert len(extracted[0][2]) == len(extract_indices)
assert extracted[1].shape[0] == len(extract_indices)

extracted = plugin.extract_datapoints(method="top_n", n=3, sort_by_index=True)
assert isinstance(extracted, tuple)
Expand Down
14 changes: 14 additions & 0 deletions tests/plugins/generic/test_aum.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,17 @@ def test_plugin_scores(test_plugin: Plugin) -> None:
assert len(scores) == len(y)
assert isinstance(scores, np.ndarray)
assert scores.dtype in [np.float32, np.float64]


@pytest.mark.parametrize(
"test_plugin", generate_fixtures(plugin_name, plugin, plugin_args)
)
def test_plugin_plots(test_plugin: Plugin) -> None:
X, y = load_iris(return_X_y=True, as_frame=True)
datahander = DataHandler(X, y, batch_size=32) # pyright: ignore
test_plugin.fit(
datahandler=datahander,
use_caches_if_exist=False,
workspace="test_workspace",
)
test_plugin.plot_scores(show=False)
14 changes: 14 additions & 0 deletions tests/plugins/generic/test_conf_agree.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,17 @@ def test_plugin_scores(test_plugin: Plugin) -> None:
assert isinstance(scores, np.ndarray)
assert scores.dtype in [np.float32, np.float64]
assert all([0.0 <= score <= 1.0 for score in scores])


@pytest.mark.parametrize(
"test_plugin", generate_fixtures(plugin_name, plugin, plugin_args)
)
def test_plugin_plots(test_plugin: Plugin) -> None:
X, y = load_iris(return_X_y=True, as_frame=True)
datahander = DataHandler(X, y, batch_size=32) # pyright: ignore
test_plugin.fit(
datahandler=datahander,
use_caches_if_exist=False,
workspace="test_workspace",
)
test_plugin.plot_scores(show=False)
14 changes: 14 additions & 0 deletions tests/plugins/generic/test_confident_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,17 @@ def test_plugin_scores(test_plugin: Plugin) -> None:
assert isinstance(scores, np.ndarray)
assert scores.dtype in [np.float32, np.float64]
assert all([0.0 <= score <= 1.0 for score in scores])


@pytest.mark.parametrize(
"test_plugin", generate_fixtures(plugin_name, plugin, plugin_args)
)
def test_plugin_plots(test_plugin: Plugin) -> None:
X, y = load_iris(return_X_y=True, as_frame=True)
datahander = DataHandler(X, y, batch_size=32) # pyright: ignore
test_plugin.fit(
datahandler=datahander,
use_caches_if_exist=False,
workspace="test_workspace",
)
test_plugin.plot_scores(show=False)
14 changes: 14 additions & 0 deletions tests/plugins/generic/test_data_iq.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,17 @@ def test_plugin_scores(test_plugin: Plugin) -> None:
assert scores[0].dtype in [np.float32, np.float64]
assert scores[1].dtype in [np.float32, np.float64]
assert all([0.0 <= score <= 1.0 for score in scores[0]])


@pytest.mark.parametrize(
"test_plugin", generate_fixtures(plugin_name, plugin, plugin_args)
)
def test_plugin_plots(test_plugin: Plugin) -> None:
X, y = load_iris(return_X_y=True, as_frame=True)
datahander = DataHandler(X, y, batch_size=32) # pyright: ignore
test_plugin.fit(
datahandler=datahander,
use_caches_if_exist=False,
workspace="test_workspace",
)
test_plugin.plot_scores(show=False)
14 changes: 14 additions & 0 deletions tests/plugins/generic/test_data_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,17 @@ def test_plugin_scores(test_plugin: Plugin) -> None:
assert scores[0].dtype in [np.float32, np.float64]
assert scores[1].dtype in [np.float32, np.float64]
assert all([0.0 <= score <= 1.0 for score in scores[0]])


@pytest.mark.parametrize(
"test_plugin", generate_fixtures(plugin_name, plugin, plugin_args)
)
def test_plugin_plots(test_plugin: Plugin) -> None:
X, y = load_iris(return_X_y=True, as_frame=True)
datahander = DataHandler(X, y, batch_size=32) # pyright: ignore
test_plugin.fit(
datahandler=datahander,
use_caches_if_exist=False,
workspace="test_workspace",
)
test_plugin.plot_scores(show=False)
14 changes: 14 additions & 0 deletions tests/plugins/generic/test_el2n.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,17 @@ def test_plugin_scores(test_plugin: Plugin) -> None:
assert isinstance(scores, np.ndarray)
assert scores.dtype in [np.float32, np.float64]
assert all([0.0 <= score for score in scores])


@pytest.mark.parametrize(
"test_plugin", generate_fixtures(plugin_name, plugin, plugin_args)
)
def test_plugin_plots(test_plugin: Plugin) -> None:
X, y = load_iris(return_X_y=True, as_frame=True)
datahander = DataHandler(X, y, batch_size=32) # pyright: ignore
test_plugin.fit(
datahandler=datahander,
use_caches_if_exist=False,
workspace="test_workspace",
)
test_plugin.plot_scores(show=False)
14 changes: 14 additions & 0 deletions tests/plugins/generic/test_forgetting.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,17 @@ def test_plugin_scores(test_plugin: Plugin) -> None:
assert isinstance(scores, np.ndarray)
assert scores.dtype in [np.float32, np.float64]
assert all([0.0 <= score <= 1.0 for score in scores])


@pytest.mark.parametrize(
"test_plugin", generate_fixtures(plugin_name, plugin, plugin_args)
)
def test_plugin_plots(test_plugin: Plugin) -> None:
X, y = load_iris(return_X_y=True, as_frame=True)
datahander = DataHandler(X, y, batch_size=32) # pyright: ignore
test_plugin.fit(
datahandler=datahander,
use_caches_if_exist=False,
workspace="test_workspace",
)
test_plugin.plot_scores(show=False)
14 changes: 14 additions & 0 deletions tests/plugins/generic/test_grand.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,17 @@ def test_plugin_scores(test_plugin: Plugin) -> None:
assert isinstance(scores, np.ndarray)
assert scores.dtype in [np.float32, np.float64]
assert all([0.0 <= score for score in scores])


@pytest.mark.parametrize(
"test_plugin", generate_fixtures(plugin_name, plugin, plugin_args)
)
def test_plugin_plots(test_plugin: Plugin) -> None:
X, y = load_iris(return_X_y=True, as_frame=True)
datahander = DataHandler(X, y, batch_size=32) # pyright: ignore
test_plugin.fit(
datahandler=datahander,
use_caches_if_exist=False,
workspace="test_workspace",
)
test_plugin.plot_scores(show=False)
14 changes: 14 additions & 0 deletions tests/plugins/generic/test_large_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,17 @@ def test_plugin_scores(test_plugin: Plugin) -> None:
assert isinstance(scores, np.ndarray)
assert scores.dtype in [np.float32, np.float64]
assert all([0.0 <= score for score in scores])


@pytest.mark.parametrize(
"test_plugin", generate_fixtures(plugin_name, plugin, plugin_args)
)
def test_plugin_plots(test_plugin: Plugin) -> None:
X, y = load_iris(return_X_y=True, as_frame=True)
datahander = DataHandler(X, y, batch_size=32) # pyright: ignore
test_plugin.fit(
datahandler=datahander,
use_caches_if_exist=False,
workspace="test_workspace",
)
test_plugin.plot_scores(show=False)
14 changes: 14 additions & 0 deletions tests/plugins/generic/test_prototypicality.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,17 @@ def test_plugin_scores(test_plugin: Plugin) -> None:
assert isinstance(scores, np.ndarray)
assert scores.dtype in [np.float32, np.float64]
assert all([0.0 <= score for score in scores])


@pytest.mark.parametrize(
"test_plugin", generate_fixtures(plugin_name, plugin, plugin_args)
)
def test_plugin_plots(test_plugin: Plugin) -> None:
X, y = load_iris(return_X_y=True, as_frame=True)
datahander = DataHandler(X, y, batch_size=32) # pyright: ignore
test_plugin.fit(
datahandler=datahander,
use_caches_if_exist=False,
workspace="test_workspace",
)
test_plugin.plot_scores(show=False)
14 changes: 14 additions & 0 deletions tests/plugins/generic/test_vog.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,17 @@ def test_plugin_scores(test_plugin: Plugin) -> None:
assert isinstance(scores, np.ndarray)
assert scores.dtype in [np.float32, np.float64]
assert all([0.0 <= score <= 1.0 for score in scores])


@pytest.mark.parametrize(
"test_plugin", generate_fixtures(plugin_name, plugin, plugin_args)
)
def test_plugin_plots(test_plugin: Plugin) -> None:
X, y = load_iris(return_X_y=True, as_frame=True)
datahander = DataHandler(X, y, batch_size=32) # pyright: ignore
test_plugin.fit(
datahandler=datahander,
use_caches_if_exist=False,
workspace="test_workspace",
)
test_plugin.plot_scores(show=False)
17 changes: 16 additions & 1 deletion tests/plugins/images/test_allsh.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,24 @@ def test_plugin_scores(test_plugin: Plugin) -> None:
datahandler=datahander,
use_caches_if_exist=False,
workspace="test_workspace",
epochs=2,
)
scores = test_plugin.scores
assert len(scores) == len(y)
assert isinstance(scores, np.ndarray)
assert scores.dtype in [np.float32, np.float64]


@pytest.mark.parametrize(
"test_plugin", generate_fixtures(plugin_name, plugin, plugin_args_mnist_lenet)
)
def test_plugin_plots(test_plugin: Plugin) -> None:
X, y, _, _ = load_mnist()
X = X[:100]
y = y[:100]
datahander = DataHandler(X, y, batch_size=32) # pyright: ignore
test_plugin.fit(
datahandler=datahander,
use_caches_if_exist=False,
workspace="test_workspace",
)
test_plugin.plot_scores(show=False)

0 comments on commit 3b54c72

Please sign in to comment.