Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] experimental - mapping tags on attributes #313

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions skbase/base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,43 @@ def __eq__(self, other):

return deep_equals(self_params, other_params)

def __getattr__(self, attr):
"""Get attribute dunder, defaults to object tags if no attribute found.

In tag names, the following characters are replaced:

* colon by double underscore, i.e., ":": "__"
* dash by single underscore, i.e., "-": "_"
"""
# early stop for reserved attributes to avoid infinite recursion
reserved_attr = attr.endswith("_dynamic")
if reserved_attr:
return object.__getattribute__(self, attr)

# get tags and normalized keys
tag_dict = self.get_tags()

# if attribute is in tag_dict, return tag value
if attr in tag_dict:
return tag_dict[attr]

# not found, now try normalized keys

def norm_key(k):
"""Replace colon by double underscore, dash by single underscore."""
return k.replace(
":",
"__",
).replace("-", "_")

tag_dict_norm = {norm_key(k): v for k, v in tag_dict.items()}

if attr in tag_dict_norm:
return tag_dict_norm[attr]

# otherwise raise the default AttributeError
return object.__getattribute__(self, attr)

def reset(self):
"""Reset the object to a clean post-init state.

Expand Down
22 changes: 22 additions & 0 deletions skbase/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,28 @@ def test_get_tag_raises(fixture_tag_class_object: Child):
fixture_tag_class_object.get_tag("bar")


def test_get_tag_attr(
fixture_tag_class_object: Child, fixture_object_tags: Dict[str, Any]
):
"""Test get_tag mapping on get_attr.

Raises
------
AssertError if inheritance logic in get_tag is incorrect
AssertError if default override logic in get_tag is incorrect
"""
object_tags = {}
object_tags_keys = fixture_object_tags.keys()

for key in object_tags_keys:
object_tags[key] = getattr(fixture_tag_class_object, key)

msg = "Inheritance logic in BaseObject.get_tag is incorrect"

for key in object_tags_keys:
assert object_tags[key] == fixture_object_tags[key], msg


def test_set_tags(
fixture_object_instance_set_tags: Any,
fixture_object_set_tags: Dict[str, Any],
Expand Down
Loading