You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello,
I have run the sample code and already when initializing the IncrementalPFI the passed loss function (as in the example Accuracy()) is checked with the method validate_loss_function. In the _get_loss_function_from_river_metric method, the update method is called with a dict, but the method update in metrics.Accuracy of the current river version 0.21 no longer expects a dict but a value.
Here is the example code:
from river.metrics import Accuracy
from river.forest import ARFClassifier
from ixai.explainer import IncrementalPFI
stream = Agrawal(classification_function=2)
feature_names = list([x_0 for x_0, _ in stream.take(1)][0].keys())
model = ARFClassifier(n_models=10, max_depth=10, leaf_prediction='mc')
incremental_pfi = IncrementalPFI(
model_function=model.predict_one,
loss_function=Accuracy(),
feature_names=feature_names,
smoothing_alpha=0.001,
n_inner_samples=5
)
training_metric = Accuracy()
for (n, (x, y)) in enumerate(stream, start=1):
training_metric.update(y, model.predict_one(x)) # inference
incremental_pfi.explain_one(x, y) # explaining
model.learn_one(x, y) # learning
if n % 1000 == 0:
print(f"{n}: Accuracy: {training_metric.get():.3f}, PFI: {incremental_pfi.importance_values}")
The error is:
Traceback (most recent call last):
File "/media/user/main.py", line 10, in <module>
incremental_pfi = IncrementalPFI(
File "/media/user/.virtualenvs/pyenv1/lib/python3.10/site-packages/ixai/explainer/pfi.py", line 68, in __init__
super(IncrementalPFI, self).__init__(
File "/media/user/.virtualenvs/pyenv1/lib/python3.10/site-packages/ixai/explainer/base.py", line 71, in __init__
self._loss_function = validate_loss_function(loss_function)
File "/media/user/.virtualenvs/pyenv1/lib/python3.10/site-packages/ixai/utils/validators/loss.py", line 30, in validate_loss_function
validated_loss_function = _get_loss_function_from_river_metric(river_metric=loss_function)
File "/media/user/.virtualenvs/pyenv1/lib/python3.10/site-packages/ixai/utils/validators/loss.py", line 18, in _get_loss_function_from_river_metric
_ = river_metric.update(y_true=0, y_pred={0: 0}).revert(y_true=0, y_pred={0: 0})
File "/media/user/.virtualenvs/pyenv1/lib/python3.10/site-packages/river/metrics/base.py", line 93, in update
self.cm.update(
File "/media/user/.virtualenvs/pyenv1/lib/python3.10/site-packages/river/metrics/confusion.py", line 67, in update
self._update(y_true, y_pred, w)
File "/media/user/.virtualenvs/pyenv1/lib/python3.10/site-packages/river/metrics/confusion.py", line 75, in _update
self.data[y_true][y_pred] += w
TypeError: unhashable type: 'dict'
The text was updated successfully, but these errors were encountered:
Oh no ... I haven't looked into the code for some time. That's the problem when implementing against a changing API. Is there a quick workaround for you?
I suspect these problems to come up more and more. However at the moment I am not capable of maintining this to fit with the current version of river.
Hello,
I have run the sample code and already when initializing the IncrementalPFI the passed loss function (as in the example Accuracy()) is checked with the method validate_loss_function. In the _get_loss_function_from_river_metric method, the update method is called with a dict, but the method update in metrics.Accuracy of the current river version 0.21 no longer expects a dict but a value.
Here is the example code:
The error is:
The text was updated successfully, but these errors were encountered: