diff --git a/examples/partials/README.md b/examples/partials/README.md new file mode 100644 index 00000000..560767d1 --- /dev/null +++ b/examples/partials/README.md @@ -0,0 +1,2 @@ +# Partials - Configuring arbitrary classes / callables + diff --git a/examples/partials/partials_example.py b/examples/partials/partials_example.py new file mode 100644 index 00000000..34d30a23 --- /dev/null +++ b/examples/partials/partials_example.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from simple_parsing import ArgumentParser +from simple_parsing.helpers import subgroups +from simple_parsing.helpers.partial import Partial, config_for + + +# Suppose we want to choose between the Adam and SGD optimizers from PyTorch: +# (NOTE: We don't import pytorch here, so we just create the types to illustrate) +class Optimizer: + def __init__(self, params): + ... + + +class Adam(Optimizer): + def __init__( + self, + params, + lr: float = 3e-4, + beta1: float = 0.9, + beta2: float = 0.999, + eps: float = 1e-08, + ): + self.params = params + self.lr = lr + self.beta1 = beta1 + self.beta2 = beta2 + self.eps = eps + + +class SGD(Optimizer): + def __init__( + self, + params, + lr: float = 3e-4, + weight_decay: float | None = None, + momentum: float = 0.9, + eps: float = 1e-08, + ): + self.params = params + self.lr = lr + self.weight_decay = weight_decay + self.momentum = momentum + self.eps = eps + + +# Dynamically create a dataclass that will be used for the above type: +# NOTE: We could use Partial[Adam] or Partial[Optimizer], however this would treat `params` as a +# required argument. +# AdamConfig = Partial[Adam] # would treat 'params' as a required argument. +# SGDConfig = Partial[SGD] # same here +AdamConfig: type[Partial[Adam]] = config_for(Adam, ignore_args="params") +SGDConfig: type[Partial[SGD]] = config_for(SGD, ignore_args="params") + + +@dataclass +class Config: + + # Which optimizer to use. + optimizer: Partial[Optimizer] = subgroups( + { + "sgd": SGDConfig, + "adam": AdamConfig, + }, + default_factory=AdamConfig, + ) + + +parser = ArgumentParser() +parser.add_arguments(Config, "config") +args = parser.parse_args() + + +config: Config = args.config +print(config) +expected = "Config(optimizer=AdamConfig(lr=0.0003, beta1=0.9, beta2=0.999, eps=1e-08))" + +my_model_parameters = [123] # nn.Sequential(...).parameters() + +optimizer = config.optimizer(params=my_model_parameters) +print(vars(optimizer)) +expected += """ +{'params': [123], 'lr': 0.0003, 'beta1': 0.9, 'beta2': 0.999, 'eps': 1e-08} +""" diff --git a/requirements.txt b/requirements.txt index 04655753..9fa34b4f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ docstring-parser~=0.15 -typing_extensions>=4.3.0 +typing_extensions>=4.5.0 diff --git a/simple_parsing/__init__.py b/simple_parsing/__init__.py index 1fb03db3..79aa1a34 100644 --- a/simple_parsing/__init__.py +++ b/simple_parsing/__init__.py @@ -6,8 +6,10 @@ from .decorators import main from .help_formatter import SimpleHelpFormatter from .helpers import ( + Partial, Serializable, choice, + config_for, field, flag, list_field, @@ -31,6 +33,7 @@ "ArgumentGenerationMode", "ArgumentParser", "choice", + "config_for", "ConflictResolution", "DashVariant", "field", @@ -44,6 +47,7 @@ "parse_known_args", "parse", "ParsingError", + "Partial", "replace", "Serializable", "SimpleHelpFormatter", diff --git a/simple_parsing/helpers/__init__.py b/simple_parsing/helpers/__init__.py index f5b44c63..a9717151 100644 --- a/simple_parsing/helpers/__init__.py +++ b/simple_parsing/helpers/__init__.py @@ -2,6 +2,7 @@ from .fields import * from .flatten import FlattenedAccess from .hparams import HyperParameters +from .partial import Partial, config_for from .serialization import FrozenSerializable, Serializable, SimpleJsonEncoder, encode try: diff --git a/simple_parsing/helpers/nested_partial.py b/simple_parsing/helpers/nested_partial.py new file mode 100644 index 00000000..75a273e1 --- /dev/null +++ b/simple_parsing/helpers/nested_partial.py @@ -0,0 +1,48 @@ +import functools +from typing import Any, Generic, TypeVar + +_T = TypeVar("_T") + + +class npartial(functools.partial, Generic[_T]): + """Partial that also invokes partials in args and kwargs before feeding them to the function. + + Useful for creating nested partials, e.g.: + + + >>> from dataclasses import dataclass, field + >>> @dataclass + ... class Value: + ... v: int = 0 + >>> @dataclass + ... class ValueWrapper: + ... value: Value + ... + >>> from functools import partial + >>> @dataclass + ... class WithRegularPartial: + ... wrapped: ValueWrapper = field( + ... default_factory=partial(ValueWrapper, value=Value(v=123)), + ... ) + + Here's the problem: This here is BAD! They both share the same instance of Value! + + >>> WithRegularPartial().wrapped.value is WithRegularPartial().wrapped.value + True + >>> @dataclass + ... class WithNPartial: + ... wrapped: ValueWrapper = field( + ... default_factory=npartial(ValueWrapper, value=npartial(Value, v=123)), + ... ) + >>> WithNPartial().wrapped.value is WithNPartial().wrapped.value + False + + This is fine now! + """ + + def __call__(self, *args: Any, **keywords: Any) -> _T: + keywords = {**self.keywords, **keywords} + args = self.args + args + args = tuple(arg() if isinstance(arg, npartial) else arg for arg in args) + keywords = {k: v() if isinstance(v, npartial) else v for k, v in keywords.items()} + return self.func(*args, **keywords) diff --git a/simple_parsing/helpers/partial.py b/simple_parsing/helpers/partial.py new file mode 100644 index 00000000..c6b2bffb --- /dev/null +++ b/simple_parsing/helpers/partial.py @@ -0,0 +1,314 @@ +""" A Partial helper that can be used to add arguments for an arbitrary class or callable. """ +from __future__ import annotations + +import dataclasses +import functools +import inspect +import typing +from dataclasses import make_dataclass +from functools import lru_cache, singledispatch, wraps +from logging import getLogger as get_logger +from typing import ( + Any, + Callable, + Dict, + Generic, + Hashable, + Sequence, + _ProtocolMeta, + cast, + get_type_hints, +) + +from typing_extensions import ParamSpec, TypeVar + +import simple_parsing + +__all__ = ["Partial", "adjust_default", "config_for", "infer_type_annotation_from_default"] + +C = TypeVar("C", bound=Callable) +_P = ParamSpec("_P") +_T = TypeVar("_T", bound=Any) +_C = TypeVar("_C", bound=Callable[..., Any]) + +logger = get_logger(__name__) + + +@singledispatch +def adjust_default(default: Any) -> Any: + """Used the adjust the default value of a parameter that we extract from the signature. + + IF in some libraries, the signature has a special default value, that we shouldn't use as the + default, e.g. "MyLibrary.REQUIRED" or something, then a handler can be registered here to + convert it to something else. + + For example, here's a fix for the `lr` param of the `torch.optim.SGD` optimizer, which has a + weird annotation of `_RequiredParameter`: + + ```python + from torch.optim.optimizer import _RequiredParameter + + @adjust_default.register(_RequiredParameter) + def _(default: Any) -> Any: + return dataclasses.MISSING + ``` + """ + return default + + +_P = ParamSpec("_P") +_OutT = TypeVar("_OutT") + + +def _cache_when_possible(fn: Callable[_P, _OutT]) -> Callable[_P, _OutT]: + """Makes `fn` behave like `functools.cache(fn)` when args are all hashable, else no change.""" + cached_fn = lru_cache(maxsize=None)(fn) + + def _all_hashable(args: tuple, kwargs: dict) -> bool: + return all(isinstance(arg, Hashable) for arg in args) and all( + isinstance(arg, Hashable) for arg in kwargs.values() + ) + + @wraps(fn) + def _switch(*args: _P.args, **kwargs: _P.kwargs) -> _OutT: + if _all_hashable(args, kwargs): + hashable_kwargs = typing.cast(Dict[str, Hashable], kwargs) + return cached_fn(*args, **hashable_kwargs) + return fn(*args, **kwargs) + + return _switch + + +@_cache_when_possible +def config_for( + cls: type[_T] | Callable[_P, _T], + ignore_args: str | Sequence[str] = (), + frozen: bool = True, + **defaults, +) -> type[Partial[_T]]: + """Create a dataclass that contains the arguments for the constructor of `cls`. + + Example: + + >>> import dataclasses + >>> import simple_parsing as sp + >>> class Adam: # i.e. `torch.optim.Adam`, which we don't have installed in this example. + ... def __init__(self, params, lr=1e-3, betas=(0.9, 0.999)): + ... self.params = params + ... self.lr = lr + ... self.betas = betas + ... def __repr__(self) -> str: + ... return f"Adam(params={self.params}, lr={self.lr}, betas={self.betas})" + ... + >>> AdamConfig = sp.config_for(Adam, ignore_args="params") + >>> parser = sp.ArgumentParser() + >>> _ = parser.add_arguments(AdamConfig, dest="optimizer") + + + >>> args = parser.parse_args(["--lr", "0.1", "--betas", "0.1", "0.2"]) + >>> args.optimizer + AdamConfig(lr=0.1, betas=(0.1, 0.2)) + + The return dataclass is a subclass of `functools.partial` that returns the `Adam` object: + + >>> isinstance(args.optimizer, functools.partial) + True + >>> dataclasses.is_dataclass(args.optimizer) + True + >>> args.optimizer(params=[1, 2, 3]) + Adam(params=[1, 2, 3], lr=0.1, betas=(0.1, 0.2)) + + >>> parser.print_help() # doctest: +SKIP + usage: pytest [-h] [--lr float] [--betas float float] + + options: + -h, --help show this help message and exit + + AdamConfig ['optimizer']: + Auto-Generated configuration dataclass for simple_parsing.helpers.partial.Adam + + --lr float + --betas float float + """ + if isinstance(ignore_args, str): + ignore_args = (ignore_args,) + else: + ignore_args = tuple(ignore_args) + + assert isinstance(defaults, dict) + + signature = inspect.signature(cls) + + fields: list[tuple[str, type, dataclasses.Field]] = [] + + class_annotations = get_type_hints(cls) + + class_docstring_help = _parse_args_from_docstring(cls.__doc__ or "") + if inspect.isclass(cls): + class_constructor_help = _parse_args_from_docstring(cls.__init__.__doc__ or "") + else: + class_constructor_help = {} + + for name, parameter in signature.parameters.items(): + default = defaults.get(name, parameter.default) + if default is parameter.empty: + default = dataclasses.MISSING + default = adjust_default(default) + + if name in ignore_args: + logger.debug(f"Ignoring argument {name}") + continue + + if parameter.annotation is not inspect.Parameter.empty: + field_type = parameter.annotation + elif name in class_annotations: + field_type = class_annotations[name] + elif default is not dataclasses.MISSING: + # Infer the type from the default value. + field_type = infer_type_annotation_from_default(default) + else: + logger.warning( + f"Don't know what the type of field '{name}' of class {cls} is! " + f"Ignoring this argument." + ) + continue + + class_help_entries = {v for k, v in class_docstring_help.items() if k.startswith(name)} + init_help_entries = {v for k, v in class_constructor_help.items() if k.startswith(name)} + help_entries = init_help_entries or class_help_entries + if help_entries: + help_str = help_entries.pop() + else: + help_str = "" + + if default is dataclasses.MISSING: + field = simple_parsing.field(help=help_str, required=True) + # insert since fields without defaults need to go first. + fields.insert(0, (name, field_type, field)) + logger.debug(f"Adding required field: {fields[0]}") + else: + field = simple_parsing.field(default=default, help=help_str) + fields.append((name, field_type, field)) + logger.debug(f"Adding optional field: {fields[-1]}") + + cls_name = _get_generated_config_class_name(cls) + config_class = make_dataclass(cls_name=cls_name, bases=(Partial,), fields=fields, frozen=frozen) + config_class._target_ = cls + config_class.__doc__ = ( + f"Auto-Generated configuration dataclass for {cls.__module__}.{cls.__qualname__}\n" + + (cls.__doc__ or "") + ) + + return config_class + + +@singledispatch +def infer_type_annotation_from_default(default: Any) -> Any | type: + """Used when there is a default value, but no type annotation, to infer the type of field to + create on the config dataclass. + """ + if isinstance(default, (int, str, float, bool)): + return type(default) + if isinstance(default, tuple): + return typing.Tuple[tuple(infer_type_annotation_from_default(d) for d in default)] + if isinstance(default, list): + if not default: + return list + # Assuming that all items have the same type. + return typing.List[infer_type_annotation_from_default(default[0])] + if isinstance(default, dict): + if not default: + return dict + raise NotImplementedError( + f"Don't know how to infer type annotation to use for default of {default}" + ) + + +def _parse_args_from_docstring(docstring: str) -> dict[str, str]: + """Taken from `pytorch_lightning.utilities.argparse`.""" + arg_block_indent = None + current_arg = "" + parsed = {} + for line in docstring.split("\n"): + stripped = line.lstrip() + if not stripped: + continue + line_indent = len(line) - len(stripped) + if stripped.startswith(("Args:", "Arguments:", "Parameters:")): + arg_block_indent = line_indent + 4 + elif arg_block_indent is None: + continue + elif line_indent < arg_block_indent: + break + elif line_indent == arg_block_indent: + current_arg, arg_description = stripped.split(":", maxsplit=1) + parsed[current_arg] = arg_description.lstrip() + elif line_indent > arg_block_indent: + parsed[current_arg] += f" {stripped}" + return parsed + + +def _get_generated_config_class_name(target: type | Callable) -> str: + if inspect.isclass(target): + return target.__name__ + "Config" + elif inspect.isfunction(target): + return target.__name__ + "_config" + raise NotImplementedError(target) + + +class _Partial(_ProtocolMeta): + _target_: _C + + def __getitem__(cls, target: Callable[_P, _T]) -> type[Callable[_P, _T]]: + # full_path = target.__module__ + "." + target.__qualname__ + # if full_path in _autogenerated_config_classes: + # return _autogenerated_config_classes[full_path] + + # TODO: Maybe we should make a distinction here between Partial[_T] and Partial[SomeClass?] + # Create the config class. + config_class = config_for(target) + # Set it's module to be the one calling this, and set that class name in the globals of + # the calling module? --> No, too hacky. + + # OR: Set the module to be simple_parsing.helpers.partial ? + # TODO: What if we had the name of the class directly encode how to recreate the class? + config_class.__module__ = __name__ + _autogenerated_config_classes[config_class.__qualname__] = config_class + return config_class + + +_autogenerated_config_classes: dict[str, type] = {} + + +def __getattr__(name: str): + """ + Getting an attribute on this module here will check for the autogenerated config class with that name. + """ + if name in globals(): + return globals()[name] + + if name in _autogenerated_config_classes: + return _autogenerated_config_classes[name] + + raise AttributeError(f"Module {__name__} has no attribute {name}") + + +class Partial(functools.partial, Generic[_T], metaclass=_Partial): + def __new__(cls, __func: Callable[_P, _T] | None = None, *args: _P.args, **kwargs: _P.kwargs): + _func = __func or cls._target_ + assert _func is not None + return super().__new__(cls, _func, *args, **kwargs) + + def __call__(self: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs) -> _T: + constructor_kwargs = { + field.name: getattr(self, field.name) for field in dataclasses.fields(self) + } + constructor_kwargs.update(**kwargs) + # TODO: Use `nested_partial` as a base class? (to instantiate all the partials inside as + # well?) + self = cast(Partial, self) + return type(self)._target_(*args, **constructor_kwargs) + + def __getattr__(self, name: str): + if name in self.keywords: + return self.keywords[name] diff --git a/test/helpers/__init__.py b/test/helpers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/helpers/test_encoding/test_encoding_with_dc_types__json_obj0_.json b/test/helpers/test_encoding/test_encoding_with_dc_types__json_obj0_.json index a5529924..275dbea7 100644 --- a/test/helpers/test_encoding/test_encoding_with_dc_types__json_obj0_.json +++ b/test/helpers/test_encoding/test_encoding_with_dc_types__json_obj0_.json @@ -1 +1 @@ -{"_type_": "test_encoding.Container", "item": {"_type_": "test_encoding.B", "a": 123, "b": "hey"}} +{"_type_": "test.helpers.test_encoding.Container", "item": {"_type_": "test.helpers.test_encoding.B", "a": 123, "b": "hey"}} diff --git a/test/helpers/test_encoding/test_encoding_with_dc_types__json_obj1_.json b/test/helpers/test_encoding/test_encoding_with_dc_types__json_obj1_.json index 92d60e00..8f1635ba 100644 --- a/test/helpers/test_encoding/test_encoding_with_dc_types__json_obj1_.json +++ b/test/helpers/test_encoding/test_encoding_with_dc_types__json_obj1_.json @@ -1 +1 @@ -{"_type_": "test_encoding.Container", "item": {"_type_": "test_encoding.BB", "a": 123, "b": "hey", "extra_field": 111, "other_field": 222}} +{"_type_": "test.helpers.test_encoding.Container", "item": {"_type_": "test.helpers.test_encoding.BB", "a": 123, "b": "hey", "extra_field": 111, "other_field": 222}} diff --git a/test/helpers/test_encoding/test_encoding_with_dc_types__yaml_obj0_.yaml b/test/helpers/test_encoding/test_encoding_with_dc_types__yaml_obj0_.yaml index f0d60dfe..459eab4e 100644 --- a/test/helpers/test_encoding/test_encoding_with_dc_types__yaml_obj0_.yaml +++ b/test/helpers/test_encoding/test_encoding_with_dc_types__yaml_obj0_.yaml @@ -1,5 +1,5 @@ -_type_: test_encoding.Container +_type_: test.helpers.test_encoding.Container item: - _type_: test_encoding.B + _type_: test.helpers.test_encoding.B a: 123 b: hey diff --git a/test/helpers/test_encoding/test_encoding_with_dc_types__yaml_obj1_.yaml b/test/helpers/test_encoding/test_encoding_with_dc_types__yaml_obj1_.yaml index d2b25fef..3e91ca4d 100644 --- a/test/helpers/test_encoding/test_encoding_with_dc_types__yaml_obj1_.yaml +++ b/test/helpers/test_encoding/test_encoding_with_dc_types__yaml_obj1_.yaml @@ -1,6 +1,6 @@ -_type_: test_encoding.Container +_type_: test.helpers.test_encoding.Container item: - _type_: test_encoding.BB + _type_: test.helpers.test_encoding.BB a: 123 b: hey extra_field: 111 diff --git a/test/helpers/test_partial.py b/test/helpers/test_partial.py new file mode 100644 index 00000000..50dfc7d8 --- /dev/null +++ b/test/helpers/test_partial.py @@ -0,0 +1,122 @@ +import functools +from collections.abc import Hashable +from dataclasses import dataclass, fields, is_dataclass + +import simple_parsing as sp +from simple_parsing import ArgumentParser +from simple_parsing.helpers.partial import Partial + +from ..testutils import TestSetup + + +@dataclass +class Foo: + a: int = 1 + b: int = 2 + + +def some_function(v1: int = 123, v2: int = 456): + """Gives back the mean of two numbers.""" + return (v1 + v2) / 2 + + +def test_partial_class_attribute(): + @dataclass + class Bob(TestSetup): + foo_factory: Partial[Foo] + + parser = ArgumentParser() + parser.add_arguments(Bob, dest="bob") + args = parser.parse_args("--a 456".split()) + bob = args.bob + foo_factory: Partial[Foo] = bob.foo_factory + assert foo_factory.a == 456 + assert foo_factory.b == 2 + assert str(foo_factory) == "FooConfig(a=456, b=2)" + + foo = foo_factory() + assert foo == Foo(a=456, b=2) + assert is_dataclass(foo_factory) + assert isinstance(foo_factory, functools.partial) + + +def test_partial_function_attribute(): + @dataclass + class Bob(TestSetup): + some_fn: Partial[some_function] # type: ignore + + bob = Bob.setup("--v2 781") + assert str(bob.some_fn) == "some_function_config(v1=123, v2=781)" + assert bob.some_fn() == some_function(v2=781) + assert bob.some_fn(v1=3, v2=7) == some_function(3, 7) + + +def test_dynamic_classes_are_cached(): + assert Partial[Foo] is Partial[Foo] + + +def test_pickling(): + # TODO: Test that we can pickle / unpickle these dynamic classes objects. + + import pickle + + dynamic_class = Partial[some_function] + + serialized = pickle.dumps(dynamic_class) + + deserialized = pickle.loads(serialized) + assert deserialized is dynamic_class + + +def some_function_with_required_arg(required_arg, v1: int = 123, v2: int = 456): + """Gives back the mean of two numbers.""" + return required_arg, (v1 + v2) / 2 + + +@dataclass +class FooWithRequiredArg(TestSetup): + some_fn: Partial[some_function_with_required_arg] + + +def test_partial_for_fn_with_required_args(): + bob = FooWithRequiredArg.setup("--v1 1 --v2 2") + assert is_dataclass(bob.some_fn) + assert isinstance(bob.some_fn, functools.partial) + + assert "required_arg" not in [f.name for f in fields(bob.some_fn)] + assert bob.some_fn(123) == (123, 1.5) + + +def test_getattr(): + bob = FooWithRequiredArg.setup("--v1 1 --v2 2") + some_fn_partial = bob.some_fn + assert some_fn_partial.v1 == 1 + assert some_fn_partial.v2 == 2 + + +def test_works_with_frozen_instances_as_default(): + @dataclass + class A: + x: int + y: bool = True + + AConfig = sp.config_for(A, ignore_args="x", frozen=True) + + a1_config = AConfig(y=False) + a2_config = AConfig(y=True) + + assert isinstance(a1_config, functools.partial) + assert isinstance(a1_config, Hashable) + + @dataclass(frozen=True) + class ParentConfig: + a: Partial[A] = sp.subgroups( + { + "a1": a1_config, + "a2": a2_config, + }, + default=a2_config, + ) + + b = sp.parse(ParentConfig, args="--a a2") + assert b.a(x=1) == A(x=1, y=a2_config.y) diff --git a/test/helpers/test_partial_postponed.py b/test/helpers/test_partial_postponed.py new file mode 100644 index 00000000..937716db --- /dev/null +++ b/test/helpers/test_partial_postponed.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import pickle +from dataclasses import dataclass, is_dataclass +from test.testutils import TestSetup + +from simple_parsing import ArgumentParser +from simple_parsing.helpers.partial import Partial + + +@dataclass +class Foo: + a: int = 1 + b: int = 2 + + +def some_function(v1: int = 123, v2: int = 456): + """Gives back the mean of two numbers.""" + return (v1 + v2) / 2 + + +def test_partial_class_attribute(): + @dataclass + class Bob(TestSetup): + foo_factory: Partial[Foo] + + parser = ArgumentParser() + parser.add_arguments(Bob, dest="bob") + args = parser.parse_args("--a 456".split()) + bob = args.bob + foo_factory: Partial[Foo] = bob.foo_factory + assert is_dataclass(foo_factory) + assert foo_factory.a == 456 + assert foo_factory.b == 2 + assert str(foo_factory) == "FooConfig(a=456, b=2)" + + +def test_partial_function_attribute(): + @dataclass + class Bob(TestSetup): + some_fn: Partial[some_function] + + bob = Bob.setup("--v2 781") + assert str(bob.some_fn) == "some_function_config(v1=123, v2=781)" + assert bob.some_fn() == some_function(v2=781) + assert bob.some_fn(v1=3, v2=7) == some_function(3, 7) + + +def test_dynamic_classes_are_cached(): + assert Partial[Foo] is Partial[Foo] + + +# bob = Bob(foo_factory=Foo, some_fn=some_function) + + +def test_pickling(): + # TODO: Test that we can pickle / unpickle these dynamic classes objects. + dynamic_class = Partial[some_function] + + serialized = pickle.dumps(dynamic_class) + + deserialized = pickle.loads(serialized) + assert deserialized is dynamic_class