Skip to content

Commit

Permalink
test: add test for AdEx solve
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl committed Oct 11, 2023
1 parent b556e08 commit 167a022
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 30 deletions.
68 changes: 41 additions & 27 deletions pycaputo/integrate_fire/ad_ex.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class AdEx(NamedTuple):
def from_dimensional(cls, dim: AdExDim, alpha: float | tuple[float, float]) -> AdEx:
"""Construct non-dimensional parameters for the AdEx model.
:arg dim: a set of dimensional parameters with the apropriate units.
:arg dim: a set of dimensional parameters with the appropriate units.
:arg alpha: the order of the fractional derivatives for two model
components :math:`(V, w)`. These can be the same if the two variables
use the same order.
Expand Down Expand Up @@ -320,10 +320,20 @@ class AdExModel:
#: Non-dimensional parameters for the model.
param: AdEx

if __debug__:

def __post_init__(self) -> None:
if not isinstance(self.param, AdEx):
raise TypeError(
f"Invalid parameter type: '{type(self.param).__name__}'"
)

def source(self, t: float, y: Array) -> Array:
r"""Evaluation of the right-hand side source terms at :math:`(t, \mathbf{y})`."""
r"""Evaluation of the right-hand side source terms at
:math:`(t, \mathbf{y})`.
"""
V, w = y
I, el, tau_w, a, *_ = self.param # noqa: E741
_, _, I, el, tau_w, a, *_ = self.param # noqa: E741

return np.array(
[
Expand All @@ -335,7 +345,7 @@ def source(self, t: float, y: Array) -> Array:
def source_jac(self, t: float, y: Array) -> Array:
r"""Evaluation of the right-hand side Jacobian at :math:`(t, \mathbf{y})`."""
V, w = y
I, el, tau_w, a, *_ = self.param # noqa: E741
_, _, I, el, tau_w, a, *_ = self.param # noqa: E741

# J_{ij} = d f_i / d y_j
return np.array(
Expand All @@ -351,7 +361,7 @@ def spiked(self, t: float, y: Array) -> float:
:returns: a delta from the current solution to the threshold, i.e.
:math:`V - V_{peak}`. If the return value is positive, the threshold
was hit and a spike/reset should have occured.
was hit and a spike/reset should have occurred.
"""
V, _ = y
return float(V - self.param.v_peak)
Expand Down Expand Up @@ -402,34 +412,38 @@ def get_lambert_time_step(ad_ex: AdEx) -> float | None:
return float((h / gamma(2 - alpha)) ** (1 / alpha))


@dataclass(frozen=True)
class AdExIntegrateFireL1Method(CaputoIntegrateFireL1Method):
#: Parameters for the AdEx model.
ad_ex: AdExModel
def ad_ex_solve(ad_ex: AdExModel, t: float, y0: Array, c: Array, r: Array) -> Array:
# NOTE: small rename to match write-up
hV, hw = c
rV, rw = r
_, _, I, el, tau_w, a, *_ = ad_ex.param # noqa: E741

def solve(self, t: float, y0: Array, c: Array, r: Array) -> Array:
# NOTE: small rename to match write-up
hV, hw = c
rV, rw = r
_, _, I, el, tau_w, a, *_ = self.ad_ex.param # noqa: E741
# w coefficients: w = c0 V + c1
c0 = a * hw / (tau_w + hw)
c1 = (tau_w * rw - a * hw * el) / (hw + tau_w)

# w coefficients: w = c0 V + c1
c0 = a * hw / (tau_w + hw)
c1 = (tau_w * rw - a * hw * el) / (hw + tau_w)
# V coefficients: d0 V + d1 = d2 exp(V)
d0 = 1 + hV * (1 + c0)
d1 = -hV * (I + el - c1) + rV
d2 = hV

# V coefficients: d0 V + d1 = d2 exp(V)
d0 = 1 + hV * (1 + c0)
d1 = -hV * (I + el - c1) + rV
d2 = hV
# solve
from scipy.special import lambertw

# solve
from scipy.special import lambertw
dstar = -d2 / d0 * np.exp(d1 / d0)
Vstar = -d1 / d0 - lambertw(dstar)
wstar = c0 * Vstar + c1

dstar = -d2 / d0 * np.exp(d1 / d0)
Vstar = -d1 / d0 - lambertw(dstar)
wstar = c0 * Vstar + c1
return np.array([Vstar, wstar])

return np.array([Vstar, wstar])

@dataclass(frozen=True)
class AdExIntegrateFireL1Method(CaputoIntegrateFireL1Method):
#: Parameters for the AdEx model.
ad_ex: AdExModel

def solve(self, t: float, y0: Array, c: Array, r: Array) -> Array:
return ad_ex_solve(self.ad_ex, t, y0, c, r)


# }}}
41 changes: 38 additions & 3 deletions tests/test_integrate_fire.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,16 @@
dirname = pathlib.Path(__file__).parent
set_recommended_matplotlib()

# {{{
# {{{ test_ad_ex_naud_parameters


def test_ad_ex_naud_parameters() -> None:
from pycaputo.integrate_fire.ad_ex import AD_EX_PARAMS, AdEx, get_lambert_time_step

alpha = (0.77, 0.31)
for name, dim in AD_EX_PARAMS.items():

ad_ex = AdEx.from_dimensional(dim, alpha)
assert all(np.all(np.isfinite(v)) for v in ad_ex)
assert all(np.all(np.isfinite(np.array(v))) for v in ad_ex)
assert str(dim)
assert str(ad_ex)
# logger.info("Parameters:\n%s\n%s", dim, ad_ex)
Expand All @@ -38,6 +37,42 @@ def test_ad_ex_naud_parameters() -> None:
# }}}


# {{{ test_ad_ex_solve


def test_ad_ex_solve() -> None:
from pycaputo.integrate_fire.ad_ex import (
AdExModel,
ad_ex_solve,
get_ad_ex_parameters,
)

rng = np.random.default_rng(seed=42)
dt = 1.0e-2
alpha = 0.9

from math import gamma

param = get_ad_ex_parameters("Naud4h", alpha)
ad_ex = AdExModel(param)

t = dt
y0 = np.array([rng.uniform(param.v_reset, param.v_peak), rng.uniform()])
c = np.array([gamma(2 - alpha) * dt**alpha, gamma(2 - alpha) * dt**alpha])
r = dt**alpha / gamma(1 - alpha) * y0

y = ad_ex_solve(ad_ex, t, y0, c, r)
error = y - c * ad_ex.source(t, y) - r
error_imag = np.linalg.norm(error.imag)
error_real = np.linalg.norm(error.real)
logger.info("Error: %s (real %.12e imag %.12e)", error.real, error_real, error_imag)
assert error_imag < 1.0e-15
assert error_real < 7.0e-2


# }}}


if __name__ == "__main__":
import sys

Expand Down

0 comments on commit 167a022

Please sign in to comment.