From 167a022861c049a084aa04ae8029b59772fa1bdc Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Wed, 11 Oct 2023 20:55:53 +0300 Subject: [PATCH] test: add test for AdEx solve --- pycaputo/integrate_fire/ad_ex.py | 68 +++++++++++++++++++------------- tests/test_integrate_fire.py | 41 +++++++++++++++++-- 2 files changed, 79 insertions(+), 30 deletions(-) diff --git a/pycaputo/integrate_fire/ad_ex.py b/pycaputo/integrate_fire/ad_ex.py index dccb1da..15c96eb 100644 --- a/pycaputo/integrate_fire/ad_ex.py +++ b/pycaputo/integrate_fire/ad_ex.py @@ -108,7 +108,7 @@ class AdEx(NamedTuple): def from_dimensional(cls, dim: AdExDim, alpha: float | tuple[float, float]) -> AdEx: """Construct non-dimensional parameters for the AdEx model. - :arg dim: a set of dimensional parameters with the apropriate units. + :arg dim: a set of dimensional parameters with the appropriate units. :arg alpha: the order of the fractional derivatives for two model components :math:`(V, w)`. These can be the same if the two variables use the same order. @@ -320,10 +320,20 @@ class AdExModel: #: Non-dimensional parameters for the model. param: AdEx + if __debug__: + + def __post_init__(self) -> None: + if not isinstance(self.param, AdEx): + raise TypeError( + f"Invalid parameter type: '{type(self.param).__name__}'" + ) + def source(self, t: float, y: Array) -> Array: - r"""Evaluation of the right-hand side source terms at :math:`(t, \mathbf{y})`.""" + r"""Evaluation of the right-hand side source terms at + :math:`(t, \mathbf{y})`. + """ V, w = y - I, el, tau_w, a, *_ = self.param # noqa: E741 + _, _, I, el, tau_w, a, *_ = self.param # noqa: E741 return np.array( [ @@ -335,7 +345,7 @@ def source(self, t: float, y: Array) -> Array: def source_jac(self, t: float, y: Array) -> Array: r"""Evaluation of the right-hand side Jacobian at :math:`(t, \mathbf{y})`.""" V, w = y - I, el, tau_w, a, *_ = self.param # noqa: E741 + _, _, I, el, tau_w, a, *_ = self.param # noqa: E741 # J_{ij} = d f_i / d y_j return np.array( @@ -351,7 +361,7 @@ def spiked(self, t: float, y: Array) -> float: :returns: a delta from the current solution to the threshold, i.e. :math:`V - V_{peak}`. If the return value is positive, the threshold - was hit and a spike/reset should have occured. + was hit and a spike/reset should have occurred. """ V, _ = y return float(V - self.param.v_peak) @@ -402,34 +412,38 @@ def get_lambert_time_step(ad_ex: AdEx) -> float | None: return float((h / gamma(2 - alpha)) ** (1 / alpha)) -@dataclass(frozen=True) -class AdExIntegrateFireL1Method(CaputoIntegrateFireL1Method): - #: Parameters for the AdEx model. - ad_ex: AdExModel +def ad_ex_solve(ad_ex: AdExModel, t: float, y0: Array, c: Array, r: Array) -> Array: + # NOTE: small rename to match write-up + hV, hw = c + rV, rw = r + _, _, I, el, tau_w, a, *_ = ad_ex.param # noqa: E741 - def solve(self, t: float, y0: Array, c: Array, r: Array) -> Array: - # NOTE: small rename to match write-up - hV, hw = c - rV, rw = r - _, _, I, el, tau_w, a, *_ = self.ad_ex.param # noqa: E741 + # w coefficients: w = c0 V + c1 + c0 = a * hw / (tau_w + hw) + c1 = (tau_w * rw - a * hw * el) / (hw + tau_w) - # w coefficients: w = c0 V + c1 - c0 = a * hw / (tau_w + hw) - c1 = (tau_w * rw - a * hw * el) / (hw + tau_w) + # V coefficients: d0 V + d1 = d2 exp(V) + d0 = 1 + hV * (1 + c0) + d1 = -hV * (I + el - c1) + rV + d2 = hV - # V coefficients: d0 V + d1 = d2 exp(V) - d0 = 1 + hV * (1 + c0) - d1 = -hV * (I + el - c1) + rV - d2 = hV + # solve + from scipy.special import lambertw - # solve - from scipy.special import lambertw + dstar = -d2 / d0 * np.exp(d1 / d0) + Vstar = -d1 / d0 - lambertw(dstar) + wstar = c0 * Vstar + c1 - dstar = -d2 / d0 * np.exp(d1 / d0) - Vstar = -d1 / d0 - lambertw(dstar) - wstar = c0 * Vstar + c1 + return np.array([Vstar, wstar]) - return np.array([Vstar, wstar]) + +@dataclass(frozen=True) +class AdExIntegrateFireL1Method(CaputoIntegrateFireL1Method): + #: Parameters for the AdEx model. + ad_ex: AdExModel + + def solve(self, t: float, y0: Array, c: Array, r: Array) -> Array: + return ad_ex_solve(self.ad_ex, t, y0, c, r) # }}} diff --git a/tests/test_integrate_fire.py b/tests/test_integrate_fire.py index 54044bf..3772eb1 100644 --- a/tests/test_integrate_fire.py +++ b/tests/test_integrate_fire.py @@ -15,7 +15,7 @@ dirname = pathlib.Path(__file__).parent set_recommended_matplotlib() -# {{{ +# {{{ test_ad_ex_naud_parameters def test_ad_ex_naud_parameters() -> None: @@ -23,9 +23,8 @@ def test_ad_ex_naud_parameters() -> None: alpha = (0.77, 0.31) for name, dim in AD_EX_PARAMS.items(): - ad_ex = AdEx.from_dimensional(dim, alpha) - assert all(np.all(np.isfinite(v)) for v in ad_ex) + assert all(np.all(np.isfinite(np.array(v))) for v in ad_ex) assert str(dim) assert str(ad_ex) # logger.info("Parameters:\n%s\n%s", dim, ad_ex) @@ -38,6 +37,42 @@ def test_ad_ex_naud_parameters() -> None: # }}} +# {{{ test_ad_ex_solve + + +def test_ad_ex_solve() -> None: + from pycaputo.integrate_fire.ad_ex import ( + AdExModel, + ad_ex_solve, + get_ad_ex_parameters, + ) + + rng = np.random.default_rng(seed=42) + dt = 1.0e-2 + alpha = 0.9 + + from math import gamma + + param = get_ad_ex_parameters("Naud4h", alpha) + ad_ex = AdExModel(param) + + t = dt + y0 = np.array([rng.uniform(param.v_reset, param.v_peak), rng.uniform()]) + c = np.array([gamma(2 - alpha) * dt**alpha, gamma(2 - alpha) * dt**alpha]) + r = dt**alpha / gamma(1 - alpha) * y0 + + y = ad_ex_solve(ad_ex, t, y0, c, r) + error = y - c * ad_ex.source(t, y) - r + error_imag = np.linalg.norm(error.imag) + error_real = np.linalg.norm(error.real) + logger.info("Error: %s (real %.12e imag %.12e)", error.real, error_real, error_imag) + assert error_imag < 1.0e-15 + assert error_real < 7.0e-2 + + +# }}} + + if __name__ == "__main__": import sys