Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
HDembinski committed Aug 16, 2023
1 parent ab13146 commit 326ce7f
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 34 deletions.
22 changes: 8 additions & 14 deletions src/jacobi/_jacobi.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ def jacobi(
fn : Callable
Function with the signature `fn(x, *args)`, where `x` is a number or a sequence
of numbers and `*args` are optional auxiliary arguments. The function must
return a number or a sequence of numbers (ideally as a numpy array). The length
of `x` can differ from the output sequence. Derivatives are only computed with
respect to `x`, the auxiliary arguments are ignored.
return a number or a regular shape of numbers (ideally as a numpy array). The
length of `x` can differ from the output sequence. Derivatives are only
computed with respect to `x`, the auxiliary arguments are ignored.
x : Number or array of numbers
The derivative is computed with respect to `x`. If `x` is an array, the Jacobi
matrix is computed with respect to each element of `x`.
Expand Down Expand Up @@ -270,19 +270,13 @@ def _first(method, f0, fn, x, i, h, args):


def _wrap_function_if_needed(fn, fval):
if not isinstance(fval, (float, np.ndarray)):
if not isinstance(fval, float):
try:
fval = np.asarray(fval)
fval = np.asarray(fval, dtype=float)
except ValueError as e:
raise ValueError(
"function return value cannot be converted into numpy array"
"function return value cannot be converted into "
"1D numpy array of floats"
) from e
fn_orig = fn
fn = lambda *args: np.asarray(fn_orig(*args)) # noqa
if isinstance(fval, np.ndarray) and fval.dtype.kind != "f":
msg = (
f"invalid dtype for function return value. "
f"Must be float, found '{fval.dtype}'"
)
raise ValueError(msg)
return lambda *args: np.asarray(fn(*args)), fval
return fn, fval
14 changes: 4 additions & 10 deletions test/test_jacobi.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,15 +251,9 @@ def fn(x):


@pytest.mark.parametrize("method", (None, -1, 0, 1))
def test_bad_return_value_1(method):
with pytest.raises(
ValueError, match="function return value cannot be converted into numpy array"
):
jacobi(lambda x: [1, [1, 2]], (1, 2), method=method)


@pytest.mark.parametrize("method", (None, -1, 0, 1))
@pytest.mark.parametrize("fn", (lambda x: "s", lambda x: ("a", "b")))
@pytest.mark.parametrize(
"fn", (lambda x: [1, [1, 2]], lambda x: "s", lambda x: ("a", "b"))
)
def test_bad_return_value_2(method, fn):
with pytest.raises(ValueError, match="invalid dtype"):
with pytest.raises(ValueError, match="function return value cannot be converted"):
jacobi(fn, (1, 2), method=method)
14 changes: 4 additions & 10 deletions test/test_propagate.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,15 +327,9 @@ def fn(x):


@pytest.mark.parametrize("method", (None, -1, 0, 1))
def test_bad_return_value_1(method):
with pytest.raises(
ValueError, match="function return value cannot be converted into numpy array"
):
propagate(lambda x: [1, [1, 2]], (1, 2), ((1, 0), (0, 1)), method=method)


@pytest.mark.parametrize("method", (None, -1, 0, 1))
@pytest.mark.parametrize("fn", (lambda x: "s", lambda x: ("a", "b")))
@pytest.mark.parametrize(
"fn", (lambda x: [1, [1, 2]], lambda x: "s", lambda x: ("a", "b"))
)
def test_bad_return_value_2(method, fn):
with pytest.raises(ValueError, match="invalid dtype"):
with pytest.raises(ValueError, match="function return value cannot be converted"):
propagate(fn, (1, 2), ((1, 0), (0, 1)), method=method)

0 comments on commit 326ce7f

Please sign in to comment.