Skip to content

Commit

Permalink
feat(fode): improve error reporting on fails
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl committed Aug 8, 2023
1 parent aee103d commit bfa2110
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 21 deletions.
2 changes: 2 additions & 0 deletions pycaputo/fode/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ class StepFailed(Event):
t: float
#: Current iteration.
iteration: int
#: A reason on why the step failed (if available).
reason: str


@dataclass(frozen=True)
Expand Down
6 changes: 5 additions & 1 deletion pycaputo/fode/caputo.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def func(y: Array) -> Array:
return np.array(y - c * self.source(t, y) - r)

def jac(y: Array) -> Array:
assert self.source_jac is not None
return 1 - c * self.source_jac(t, y)

result = so.root_scalar(
Expand All @@ -168,7 +169,10 @@ def func(y: Array) -> Array:
return np.array(y - c * self.source(t, y) - r, dtype=y0.dtype)

def jac(y: Array) -> Array:
return np.eye(y.size, dtype=y0.dtype) - np.diag(c) @ self.source_jac(t, y)
assert self.source_jac is not None
return np.array(
np.eye(y.size, dtype=y0.dtype) - np.diag(c) @ self.source_jac(t, y)
)

result = so.root(
func,
Expand Down
42 changes: 25 additions & 17 deletions pycaputo/fode/product_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _evolve_pi(
callback: CallbackFunction | None = None,
history: History | None = None,
maxit: int | None = None,
verbose: bool = True,
raise_on_fail: bool = False,
) -> Iterator[Event]:
from pycaputo.fode.base import (
StepCompleted,
Expand Down Expand Up @@ -85,14 +85,19 @@ def _evolve_pi(
# next time step
try:
dt = predict_time_step(t, y)

if not np.isfinite(dt):
raise ValueError(f"Invalid time step at iteration {n}: {dt!r}")
except Exception as exc:
if verbose:
logger.error("Failed to predict time step.", exc_info=exc)
logger.error("Failed to predict time step.", exc_info=exc)
if raise_on_fail:
raise exc

yield StepFailed(t=t, iteration=n, reason=str(exc))

if not np.isfinite(dt):
logger.error("Predicted time step is not finite: %g", dt)
if raise_on_fail:
raise ValueError(f"Predicted time step is not finite: {dt}")

yield StepFailed(t=t, iteration=n)
yield StepFailed(t=t, iteration=n, reason="time step is not finite")

if tfinal is not None:
# NOTE: adding eps to ensure that t >= tfinal is true
Expand All @@ -103,18 +108,21 @@ def _evolve_pi(
# advance
try:
y = advance(m, history, t, y)
if not np.all(np.isfinite(y)):
if verbose:
logger.error("Failed to update solution: %s", y)

yield StepFailed(t=t, iteration=n)
else:
yield StepCompleted(t=t, iteration=n, dt=dt, y=y)
except Exception as exc:
if verbose:
logger.error("Failed to advance time step.", exc_info=exc)
logger.error("Failed to advance solution.", exc_info=exc)
if raise_on_fail:
raise exc

yield StepFailed(t=t, iteration=n, reason=str(exc))

if not np.all(np.isfinite(y)):
logger.error("Failed to update solution: %s", y)
if raise_on_fail:
raise ValueError(f"Predicted solution is not finite: {y}")

yield StepFailed(t=t, iteration=n)
yield StepFailed(t=t, iteration=n, reason="solution is not finite")
else:
yield StepCompleted(t=t, iteration=n, dt=dt, y=y)


# }}}
Expand Down
8 changes: 5 additions & 3 deletions tests/test_fode.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def test_caputo_fode(
with BlockTimer(name=m.name) as bt:
ts = []
ys = []
for event in evolve(m, verbose=True):
for event in evolve(m):
if isinstance(event, StepFailed):
raise ValueError("Step update failed")
elif isinstance(event, StepCompleted):
Expand Down Expand Up @@ -242,7 +242,9 @@ def test_caputo_fode(
],
)
def test_caputo_fode_system(
factory: Callable[[float, int], fode.FractionalDifferentialEquationMethod],
factory: Callable[
[tuple[float, ...], int], fode.FractionalDifferentialEquationMethod
],
*,
visualize: bool = False,
) -> None:
Expand All @@ -261,7 +263,7 @@ def test_caputo_fode_system(
with BlockTimer(name=m.name) as bt:
ts = []
ys = []
for event in evolve(m, verbose=True):
for event in evolve(m):
if isinstance(event, StepFailed):
raise ValueError("Step update failed")
elif isinstance(event, StepCompleted):
Expand Down

0 comments on commit bfa2110

Please sign in to comment.