diff --git a/tests/test_projection.py b/tests/test_projection.py index 801a2c1..4f4806a 100644 --- a/tests/test_projection.py +++ b/tests/test_projection.py @@ -15,7 +15,7 @@ import pytest import torch from geoopt.manifolds import PoincareBall -from hierarchy_transformers.models.arithmetic import project_onto_subspace, reflect_about_subspace +from hierarchy_transformers.models.hierarchy_transformer.hyperbolic import project_onto_subspace, reflect_about_subspace @pytest.fixture def manifold():