From 326ce7f3706d2feadc6f3361848646417f6afd8c Mon Sep 17 00:00:00 2001 From: Hans Dembinski Date: Wed, 16 Aug 2023 14:43:40 +0200 Subject: [PATCH] fix --- src/jacobi/_jacobi.py | 22 ++++++++-------------- test/test_jacobi.py | 14 ++++---------- test/test_propagate.py | 14 ++++---------- 3 files changed, 16 insertions(+), 34 deletions(-) diff --git a/src/jacobi/_jacobi.py b/src/jacobi/_jacobi.py index a5b7db0..16e26e4 100644 --- a/src/jacobi/_jacobi.py +++ b/src/jacobi/_jacobi.py @@ -24,9 +24,9 @@ def jacobi( fn : Callable Function with the signature `fn(x, *args)`, where `x` is a number or a sequence of numbers and `*args` are optional auxiliary arguments. The function must - return a number or a sequence of numbers (ideally as a numpy array). The length - of `x` can differ from the output sequence. Derivatives are only computed with - respect to `x`, the auxiliary arguments are ignored. + return a number or a regular shape of numbers (ideally as a numpy array). The + length of `x` can differ from the output sequence. Derivatives are only + computed with respect to `x`, the auxiliary arguments are ignored. x : Number or array of numbers The derivative is computed with respect to `x`. If `x` is an array, the Jacobi matrix is computed with respect to each element of `x`. @@ -270,19 +270,13 @@ def _first(method, f0, fn, x, i, h, args): def _wrap_function_if_needed(fn, fval): - if not isinstance(fval, (float, np.ndarray)): + if not isinstance(fval, float): try: - fval = np.asarray(fval) + fval = np.asarray(fval, dtype=float) except ValueError as e: raise ValueError( - "function return value cannot be converted into numpy array" + "function return value cannot be converted into " + "1D numpy array of floats" ) from e - fn_orig = fn - fn = lambda *args: np.asarray(fn_orig(*args)) # noqa - if isinstance(fval, np.ndarray) and fval.dtype.kind != "f": - msg = ( - f"invalid dtype for function return value. " - f"Must be float, found '{fval.dtype}'" - ) - raise ValueError(msg) + return lambda *args: np.asarray(fn(*args)), fval return fn, fval diff --git a/test/test_jacobi.py b/test/test_jacobi.py index 0ebaebc..52d44b9 100644 --- a/test/test_jacobi.py +++ b/test/test_jacobi.py @@ -251,15 +251,9 @@ def fn(x): @pytest.mark.parametrize("method", (None, -1, 0, 1)) -def test_bad_return_value_1(method): - with pytest.raises( - ValueError, match="function return value cannot be converted into numpy array" - ): - jacobi(lambda x: [1, [1, 2]], (1, 2), method=method) - - -@pytest.mark.parametrize("method", (None, -1, 0, 1)) -@pytest.mark.parametrize("fn", (lambda x: "s", lambda x: ("a", "b"))) +@pytest.mark.parametrize( + "fn", (lambda x: [1, [1, 2]], lambda x: "s", lambda x: ("a", "b")) +) def test_bad_return_value_2(method, fn): - with pytest.raises(ValueError, match="invalid dtype"): + with pytest.raises(ValueError, match="function return value cannot be converted"): jacobi(fn, (1, 2), method=method) diff --git a/test/test_propagate.py b/test/test_propagate.py index 7dbf909..79336a6 100644 --- a/test/test_propagate.py +++ b/test/test_propagate.py @@ -327,15 +327,9 @@ def fn(x): @pytest.mark.parametrize("method", (None, -1, 0, 1)) -def test_bad_return_value_1(method): - with pytest.raises( - ValueError, match="function return value cannot be converted into numpy array" - ): - propagate(lambda x: [1, [1, 2]], (1, 2), ((1, 0), (0, 1)), method=method) - - -@pytest.mark.parametrize("method", (None, -1, 0, 1)) -@pytest.mark.parametrize("fn", (lambda x: "s", lambda x: ("a", "b"))) +@pytest.mark.parametrize( + "fn", (lambda x: [1, [1, 2]], lambda x: "s", lambda x: ("a", "b")) +) def test_bad_return_value_2(method, fn): - with pytest.raises(ValueError, match="invalid dtype"): + with pytest.raises(ValueError, match="function return value cannot be converted"): propagate(fn, (1, 2), ((1, 0), (0, 1)), method=method)