Skip to content

Commit

Permalink
Annotate depends and accept_arguments decorators
Browse files Browse the repository at this point in the history
  • Loading branch information
gandhis1 committed Aug 10, 2024
1 parent 31f5717 commit 5d6bb59
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 4 deletions.
12 changes: 10 additions & 2 deletions param/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import sys
import traceback
import warnings
from typing import ParamSpec, TypeVar, Callable, TYPE_CHECKING, Concatenate

from collections import defaultdict, OrderedDict
from contextlib import contextmanager
Expand All @@ -18,6 +19,11 @@
from threading import get_ident
from collections import abc

if TYPE_CHECKING:
_P1 = ParamSpec("P1")
_P2 = ParamSpec("P2")
_R = TypeVar("_R")

DEFAULT_SIGNATURE = inspect.Signature([
inspect.Parameter('self', inspect.Parameter.POSITIONAL_OR_KEYWORD),
inspect.Parameter('params', inspect.Parameter.VAR_KEYWORD),
Expand Down Expand Up @@ -282,12 +288,14 @@ def flatten(line):
yield element


def accept_arguments(f):
def accept_arguments(
f: Callable[Concatenate[Callable[_P1, _R], _P2], _R]
) -> Callable[_P2, Callable[[Callable[_P1, _R]], Callable[_P1, _R]]]:
"""
Decorator for decorators that accept arguments
"""
@functools.wraps(f)
def _f(*args, **kwargs):
def _f(*args: _P2.args, **kwargs: _P2.kwargs) -> Callable[[Callable[_P1, _R]], Callable[_P1, _R]]:
return lambda actual_f: f(actual_f, *args, **kwargs)
return _f

Expand Down
34 changes: 32 additions & 2 deletions param/depends.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,46 @@
from __future__ import annotations

import inspect

from collections import defaultdict
from functools import wraps
from typing import TYPE_CHECKING, TypeVar, Callable, Protocol, TypedDict, overload

from .parameterized import (
Parameter, Parameterized, ParameterizedMetaclass, transform_reference,
)
from ._utils import accept_arguments, iscoroutinefunction

if TYPE_CHECKING:
CallableT = TypeVar("CallableT", bound=Callable)
Dependency = Parameter | str

class DependencyInfo(TypedDict):
dependencies: tuple[Dependency, ...]
kw: dict[str, Dependency]
watch: bool
on_init: bool

class DependsFunc(Protocol[CallableT]):
_dinfo: DependencyInfo
__call__: CallableT

@overload
def depends(
*dependencies: str, watch: bool = ..., on_init: bool = ...
) -> Callable[[CallableT], DependsFunc[CallableT]]:
...

@overload
def depends(
*dependencies: Parameter, watch: bool = ..., on_init: bool = ..., **kw: Parameter
) -> Callable[[CallableT], DependsFunc[CallableT]]:
...

@accept_arguments
def depends(func, *dependencies, watch=False, on_init=False, **kw):
def depends(
func: CallableT, *dependencies: Dependency, watch: bool = False, on_init: bool = False, **kw: Parameter
) -> Callable[[CallableT], DependsFunc[CallableT]]:
"""Annotates a function or Parameterized method to express its dependencies.
The specified dependencies can be either be Parameter instances or if a
Expand Down Expand Up @@ -117,6 +147,6 @@ def cb(*events):
_dinfo.update({'dependencies': dependencies,
'kw': kw, 'watch': watch, 'on_init': on_init})

_depends._dinfo = _dinfo
_depends._dinfo = _dinfo # type: ignore[attr-defined]

return _depends

0 comments on commit 5d6bb59

Please sign in to comment.