Skip to content

Commit

Permalink
feat: rely on singledispatch for the registry
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl committed Sep 17, 2023
1 parent 762af51 commit 25e6cf8
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 87 deletions.
7 changes: 0 additions & 7 deletions docs/differentiation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,6 @@ registering it with the :func:`~functools.singledispatch` mechanism as
:language: python
:linenos:

Finally, we can register it for :func:`pycaputo.diff` using

.. literalinclude:: ../examples/example-custom-diff.py
:lines: 40-40
:language: python
:linenos:

The complete example can be found in
:download:`examples/example-custom-diff.py <../examples/example-custom-diff.py>`.

Expand Down
7 changes: 0 additions & 7 deletions docs/quadrature.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,6 @@ registering it with the :func:`~functools.singledispatch` mechanism as
:language: python
:linenos:

Finally, we can register it for :func:`pycaputo.quad` using

.. literalinclude:: ../examples/example-custom-quad.py
:lines: 37-37
:language: python
:linenos:

The complete example can be found in
:download:`examples/example-custom-quad.py <../examples/example-custom-quad.py>`.

Expand Down
4 changes: 2 additions & 2 deletions examples/example-custom-diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np

from pycaputo.derivatives import RiemannLiouvilleDerivative
from pycaputo.differentiation import DerivativeMethod, diff, register_method
from pycaputo.differentiation import DerivativeMethod, diff, make_method_from_name
from pycaputo.grid import Points
from pycaputo.utils import Array, ArrayOrScalarFunction

Expand Down Expand Up @@ -37,4 +37,4 @@ def _diff_rl(
return np.zeros_like(fx)


register_method("RLdiff", RiemannLiouvilleDerivativeMethod)
m = make_method_from_name("RiemannLiouvilleDerivativeMethod", 0.5)
4 changes: 2 additions & 2 deletions examples/example-custom-quad.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from pycaputo.derivatives import HadamardDerivative
from pycaputo.grid import Points
from pycaputo.quadrature import QuadratureMethod, quad, register_method
from pycaputo.quadrature import QuadratureMethod, make_method_from_name, quad
from pycaputo.utils import Array, ArrayOrScalarFunction


Expand All @@ -34,4 +34,4 @@ def _quad_hadamard(
return np.zeros_like(fx)


register_method("Hadamard", HadamardQuadratureMethod)
m = make_method_from_name("HadamardQuadratureMethod", -1.5)
41 changes: 7 additions & 34 deletions pycaputo/differentiation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,36 +15,6 @@
)
from pycaputo.grid import Points

REGISTERED_METHODS: dict[str, type[DerivativeMethod]] = {
"CaputoL1Method": CaputoL1Method,
"CaputoL2CMethod": CaputoL2CMethod,
"CaputoL2Method": CaputoL2Method,
"CaputoModifiedL1Method": CaputoModifiedL1Method,
"CaputoSpectralMethod": CaputoSpectralMethod,
}


def register_method(
name: str,
method: type[DerivativeMethod],
*,
force: bool = False,
) -> None:
"""Register a new derivative approximation method.
:arg name: a canonical name for the method.
:arg method: a class that will be used to construct the method.
:arg force: if *True*, any existing methods will be overwritten.
"""

if not force and name in REGISTERED_METHODS:
raise ValueError(
f"A method by the name '{name}' is already registered. Use 'force=True' to"
" overwrite it."
)

REGISTERED_METHODS[name] = method


def make_method_from_name(
name: str,
Expand All @@ -55,17 +25,21 @@ def make_method_from_name(
:arg d: a fractional operator that should be discretized by the method. If
the method does not support this operator, it can fail.
"""
if name not in REGISTERED_METHODS:

methods: dict[str, type[DerivativeMethod]] = {
cls.__name__: cls for cls in diff.registry
}
if name not in methods:
raise ValueError(
"Unknown differentiation method '{}'. Known methods are '{}'".format(
name, "', '".join(REGISTERED_METHODS)
name, "', '".join(methods)
)
)

if not isinstance(d, FractionalOperator):
d = CaputoDerivative(order=d, side=Side.Left)

return REGISTERED_METHODS[name](d)
return methods[name](d)


def guess_method_for_order(
Expand Down Expand Up @@ -114,7 +88,6 @@ def guess_method_for_order(
__all__ = (
"DerivativeMethod",
"diff",
"register_method",
"guess_method_for_order",
"make_method_from_name",
"CaputoDerivativeMethod",
Expand Down
41 changes: 6 additions & 35 deletions pycaputo/quadrature/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,37 +16,6 @@
RiemannLiouvilleTrapezoidalMethod,
)

REGISTERED_METHODS: dict[str, type[QuadratureMethod]] = {
"RiemannLiouvilleConvolutionMethod": RiemannLiouvilleConvolutionMethod,
"RiemannLiouvilleCubicHermiteMethod": RiemannLiouvilleCubicHermiteMethod,
"RiemannLiouvilleRectangularMethod": RiemannLiouvilleRectangularMethod,
"RiemannLiouvilleSimpsonMethod": RiemannLiouvilleSimpsonMethod,
"RiemannLiouvilleSpectralMethod": RiemannLiouvilleSpectralMethod,
"RiemannLiouvilleTrapezoidalMethod": RiemannLiouvilleTrapezoidalMethod,
}


def register_method(
name: str,
method: type[QuadratureMethod],
*,
force: bool = False,
) -> None:
"""Register a new integral approximation method.
:arg name: a canonical name for the method.
:arg method: a class that will be used to construct the method.
:arg force: if *True*, any existing methods will be overwritten.
"""

if not force and name in REGISTERED_METHODS:
raise ValueError(
f"A method by the name '{name}' is already registered. Use 'force=True' to"
" overwrite it."
)

REGISTERED_METHODS[name] = method


def make_method_from_name(
name: str,
Expand All @@ -58,17 +27,20 @@ def make_method_from_name(
the method does not support this operator, it can fail.
"""

if name not in REGISTERED_METHODS:
methods: dict[str, type[QuadratureMethod]] = {
cls.__name__: cls for cls in quad.registry
}
if name not in methods:
raise ValueError(
"Unknown quadrature method '{}'. Known methods are '{}'".format(
name, "', '".join(REGISTERED_METHODS)
name, "', '".join(methods)
)
)

if not isinstance(d, FractionalOperator):
d = RiemannLiouvilleDerivative(order=d, side=Side.Left)

return REGISTERED_METHODS[name](d)
return methods[name](d)


def guess_method_for_order(
Expand Down Expand Up @@ -113,7 +85,6 @@ def guess_method_for_order(
__all__ = (
"QuadratureMethod",
"quad",
"register_method",
"make_method_from_name",
"guess_method_for_order",
"RiemannLiouvilleConvolutionMethod",
Expand Down

0 comments on commit 25e6cf8

Please sign in to comment.