diff --git a/param/_utils.py b/param/_utils.py index 4fdc71c9..857b3fa9 100644 --- a/param/_utils.py +++ b/param/_utils.py @@ -1,22 +1,28 @@ +from __future__ import annotations + import asyncio import collections import contextvars import datetime as dt -import inspect import functools +import inspect import numbers import os import re import sys import traceback import warnings - -from collections import defaultdict, OrderedDict +from collections import OrderedDict, abc, defaultdict from contextlib import contextmanager from numbers import Real from textwrap import dedent from threading import get_ident -from collections import abc +from typing import TYPE_CHECKING, Callable, Concatenate, ParamSpec, TypeVar + +if TYPE_CHECKING: + _P = ParamSpec("_P") + _R = TypeVar("_R") + CallableT = TypeVar("CallableT", bound=Callable) DEFAULT_SIGNATURE = inspect.Signature([ inspect.Parameter('self', inspect.Parameter.POSITIONAL_OR_KEYWORD), @@ -282,12 +288,14 @@ def flatten(line): yield element -def accept_arguments(f): +def accept_arguments( + f: Callable[Concatenate[CallableT, _P], _R] +) -> Callable[_P, Callable[[CallableT], _R]]: """ Decorator for decorators that accept arguments """ @functools.wraps(f) - def _f(*args, **kwargs): + def _f(*args: _P.args, **kwargs: _P.kwargs) -> Callable[[CallableT], _R]: return lambda actual_f: f(actual_f, *args, **kwargs) return _f diff --git a/param/depends.py b/param/depends.py index 8b4c172d..344b2b4d 100644 --- a/param/depends.py +++ b/param/depends.py @@ -1,16 +1,46 @@ +from __future__ import annotations + import inspect from collections import defaultdict from functools import wraps +from typing import TYPE_CHECKING, TypeVar, Callable, Protocol, TypedDict, overload from .parameterized import ( Parameter, Parameterized, ParameterizedMetaclass, transform_reference, ) from ._utils import accept_arguments, iscoroutinefunction +if TYPE_CHECKING: + CallableT = TypeVar("CallableT", bound=Callable) + Dependency = Parameter | str + + class DependencyInfo(TypedDict): + dependencies: tuple[Dependency, ...] + kw: dict[str, Dependency] + watch: bool + on_init: bool + + class DependsFunc(Protocol[CallableT]): + _dinfo: DependencyInfo + __call__: CallableT + +@overload +def depends( + *dependencies: str, watch: bool = ..., on_init: bool = ... +) -> Callable[[CallableT], DependsFunc[CallableT]]: + ... + +@overload +def depends( + *dependencies: Parameter, watch: bool = ..., on_init: bool = ..., **kw: Parameter +) -> Callable[[CallableT], DependsFunc[CallableT]]: + ... @accept_arguments -def depends(func, *dependencies, watch=False, on_init=False, **kw): +def depends( + func: CallableT, /, *dependencies: Dependency, watch: bool = False, on_init: bool = False, **kw: Parameter +) -> Callable[[CallableT], DependsFunc[CallableT]]: """Annotates a function or Parameterized method to express its dependencies. The specified dependencies can be either be Parameter instances or if a @@ -117,6 +147,6 @@ def cb(*events): _dinfo.update({'dependencies': dependencies, 'kw': kw, 'watch': watch, 'on_init': on_init}) - _depends._dinfo = _dinfo + _depends._dinfo = _dinfo # type: ignore[attr-defined] return _depends