diff --git a/skbase/base/_base.py b/skbase/base/_base.py index 3b0ca27c..acf35e2b 100644 --- a/skbase/base/_base.py +++ b/skbase/base/_base.py @@ -104,6 +104,43 @@ def __eq__(self, other): return deep_equals(self_params, other_params) + def __getattr__(self, attr): + """Get attribute dunder, defaults to object tags if no attribute found. + + In tag names, the following characters are replaced: + + * colon by double underscore, i.e., ":": "__" + * dash by single underscore, i.e., "-": "_" + """ + # early stop for reserved attributes to avoid infinite recursion + reserved_attr = attr.endswith("_dynamic") + if reserved_attr: + return object.__getattribute__(self, attr) + + # get tags and normalized keys + tag_dict = self.get_tags() + + # if attribute is in tag_dict, return tag value + if attr in tag_dict: + return tag_dict[attr] + + # not found, now try normalized keys + + def norm_key(k): + """Replace colon by double underscore, dash by single underscore.""" + return k.replace( + ":", + "__", + ).replace("-", "_") + + tag_dict_norm = {norm_key(k): v for k, v in tag_dict.items()} + + if attr in tag_dict_norm: + return tag_dict_norm[attr] + + # otherwise raise the default AttributeError + return object.__getattribute__(self, attr) + def reset(self): """Reset the object to a clean post-init state. diff --git a/skbase/tests/test_base.py b/skbase/tests/test_base.py index d3ed95b8..fe810c95 100644 --- a/skbase/tests/test_base.py +++ b/skbase/tests/test_base.py @@ -365,6 +365,28 @@ def test_get_tag_raises(fixture_tag_class_object: Child): fixture_tag_class_object.get_tag("bar") +def test_get_tag_attr( + fixture_tag_class_object: Child, fixture_object_tags: Dict[str, Any] +): + """Test get_tag mapping on get_attr. + + Raises + ------ + AssertError if inheritance logic in get_tag is incorrect + AssertError if default override logic in get_tag is incorrect + """ + object_tags = {} + object_tags_keys = fixture_object_tags.keys() + + for key in object_tags_keys: + object_tags[key] = getattr(fixture_tag_class_object, key) + + msg = "Inheritance logic in BaseObject.get_tag is incorrect" + + for key in object_tags_keys: + assert object_tags[key] == fixture_object_tags[key], msg + + def test_set_tags( fixture_object_instance_set_tags: Any, fixture_object_set_tags: Dict[str, Any],