Skip to content

Commit

Permalink
Partials - Dynamic Config Dataclasses for arbitrary callables (#156)
Browse files Browse the repository at this point in the history
* Partials feature POC

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Functools black magic, partials are pickleable

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Partials feature POC

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Functools black magic, partials are pickleable

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Add postponed annotation version of test

Signed-off-by: Fabrice Normandin <fabrice.normandin@gmail.com>

* Apply pre-commit hooks to partial.py

Signed-off-by: Fabrice Normandin <fabrice.normandin@gmail.com>

* Fix example, rename typevars

Signed-off-by: Fabrice Normandin <fabrice.normandin@gmail.com>

* Add comments in the partials_example.py

Signed-off-by: Fabrice Normandin <fabrice.normandin@gmail.com>

* Fix the partials_example.py file

Signed-off-by: Fabrice Normandin <fabrice.normandin@gmail.com>

* Add `nested_partial` helper function

Signed-off-by: Fabrice Normandin <fabrice.normandin@gmail.com>

* Tweak the partials_example.py

Signed-off-by: Fabrice Normandin <fabrice.normandin@gmail.com>

* Fix issue with using functools.partial[T] in py37

Signed-off-by: Fabrice Normandin <fabrice.normandin@gmail.com>

* Adding some more tests for Partial

Signed-off-by: Fabrice Normandin <fabrice.normandin@gmail.com>

* Simplify `partial.py` a bit

Signed-off-by: Fabrice Normandin <fabrice.normandin@gmail.com>

* Add test from PR suggestion, add `sp.config_for`

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix missing ``` in docstring

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Remove torch.optim.SGD fix an old BUG comment

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Improve docstring of `config_for`

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Add `adjust_default` in __all__

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix import issue in test_partial_postponed.py

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Remove kw_only which appeared in py>=3.9

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Update regression files (idk why though?!)

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Actually use a frozen instance as default in test

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Add `frozen` argument that gets passed through

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix doctest

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

---------

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
Signed-off-by: Fabrice Normandin <fabrice.normandin@gmail.com>
  • Loading branch information
lebrice authored Apr 19, 2023
1 parent 65b07f3 commit 68e16b2
Show file tree
Hide file tree
Showing 14 changed files with 647 additions and 7 deletions.
2 changes: 2 additions & 0 deletions examples/partials/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Partials - Configuring arbitrary classes / callables

86 changes: 86 additions & 0 deletions examples/partials/partials_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from __future__ import annotations

from dataclasses import dataclass

from simple_parsing import ArgumentParser
from simple_parsing.helpers import subgroups
from simple_parsing.helpers.partial import Partial, config_for


# Suppose we want to choose between the Adam and SGD optimizers from PyTorch:
# (NOTE: We don't import pytorch here, so we just create the types to illustrate)
class Optimizer:
def __init__(self, params):
...


class Adam(Optimizer):
def __init__(
self,
params,
lr: float = 3e-4,
beta1: float = 0.9,
beta2: float = 0.999,
eps: float = 1e-08,
):
self.params = params
self.lr = lr
self.beta1 = beta1
self.beta2 = beta2
self.eps = eps


class SGD(Optimizer):
def __init__(
self,
params,
lr: float = 3e-4,
weight_decay: float | None = None,
momentum: float = 0.9,
eps: float = 1e-08,
):
self.params = params
self.lr = lr
self.weight_decay = weight_decay
self.momentum = momentum
self.eps = eps


# Dynamically create a dataclass that will be used for the above type:
# NOTE: We could use Partial[Adam] or Partial[Optimizer], however this would treat `params` as a
# required argument.
# AdamConfig = Partial[Adam] # would treat 'params' as a required argument.
# SGDConfig = Partial[SGD] # same here
AdamConfig: type[Partial[Adam]] = config_for(Adam, ignore_args="params")
SGDConfig: type[Partial[SGD]] = config_for(SGD, ignore_args="params")


@dataclass
class Config:

# Which optimizer to use.
optimizer: Partial[Optimizer] = subgroups(
{
"sgd": SGDConfig,
"adam": AdamConfig,
},
default_factory=AdamConfig,
)


parser = ArgumentParser()
parser.add_arguments(Config, "config")
args = parser.parse_args()


config: Config = args.config
print(config)
expected = "Config(optimizer=AdamConfig(lr=0.0003, beta1=0.9, beta2=0.999, eps=1e-08))"

my_model_parameters = [123] # nn.Sequential(...).parameters()

optimizer = config.optimizer(params=my_model_parameters)
print(vars(optimizer))
expected += """
{'params': [123], 'lr': 0.0003, 'beta1': 0.9, 'beta2': 0.999, 'eps': 1e-08}
"""
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
docstring-parser~=0.15
typing_extensions>=4.3.0
typing_extensions>=4.5.0
4 changes: 4 additions & 0 deletions simple_parsing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from .decorators import main
from .help_formatter import SimpleHelpFormatter
from .helpers import (
Partial,
Serializable,
choice,
config_for,
field,
flag,
list_field,
Expand All @@ -31,6 +33,7 @@
"ArgumentGenerationMode",
"ArgumentParser",
"choice",
"config_for",
"ConflictResolution",
"DashVariant",
"field",
Expand All @@ -44,6 +47,7 @@
"parse_known_args",
"parse",
"ParsingError",
"Partial",
"replace",
"Serializable",
"SimpleHelpFormatter",
Expand Down
1 change: 1 addition & 0 deletions simple_parsing/helpers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .fields import *
from .flatten import FlattenedAccess
from .hparams import HyperParameters
from .partial import Partial, config_for
from .serialization import FrozenSerializable, Serializable, SimpleJsonEncoder, encode

try:
Expand Down
48 changes: 48 additions & 0 deletions simple_parsing/helpers/nested_partial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import functools
from typing import Any, Generic, TypeVar

_T = TypeVar("_T")


class npartial(functools.partial, Generic[_T]):
"""Partial that also invokes partials in args and kwargs before feeding them to the function.
Useful for creating nested partials, e.g.:
>>> from dataclasses import dataclass, field
>>> @dataclass
... class Value:
... v: int = 0
>>> @dataclass
... class ValueWrapper:
... value: Value
...
>>> from functools import partial
>>> @dataclass
... class WithRegularPartial:
... wrapped: ValueWrapper = field(
... default_factory=partial(ValueWrapper, value=Value(v=123)),
... )
Here's the problem: This here is BAD! They both share the same instance of Value!
>>> WithRegularPartial().wrapped.value is WithRegularPartial().wrapped.value
True
>>> @dataclass
... class WithNPartial:
... wrapped: ValueWrapper = field(
... default_factory=npartial(ValueWrapper, value=npartial(Value, v=123)),
... )
>>> WithNPartial().wrapped.value is WithNPartial().wrapped.value
False
This is fine now!
"""

def __call__(self, *args: Any, **keywords: Any) -> _T:
keywords = {**self.keywords, **keywords}
args = self.args + args
args = tuple(arg() if isinstance(arg, npartial) else arg for arg in args)
keywords = {k: v() if isinstance(v, npartial) else v for k, v in keywords.items()}
return self.func(*args, **keywords)
Loading

0 comments on commit 68e16b2

Please sign in to comment.