From 65e102401add3a31d12ea1c5f760e3c682afc9ae Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Tue, 24 May 2022 15:18:49 +0200 Subject: [PATCH 1/2] feat: JIT-compile TF functions --- src/tensorwaves/function/sympy/__init__.py | 4 +++- tests/function/test_sympy.py | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/tensorwaves/function/sympy/__init__.py b/src/tensorwaves/function/sympy/__init__.py index 9f068865..5d7eb7df 100644 --- a/src/tensorwaves/function/sympy/__init__.py +++ b/src/tensorwaves/function/sympy/__init__.py @@ -229,18 +229,20 @@ def tensorflow_lambdify() -> Callable: try: # pylint: disable=import-error # pyright: reportMissingImports=false + import tensorflow as tf import tensorflow.experimental.numpy as tnp except ImportError: # pragma: no cover raise_missing_module_error("tensorflow", extras_require="tf") from ._printer import TensorflowPrinter - return _sympy_lambdify( + func = _sympy_lambdify( expression, symbols, modules=tnp, printer=TensorflowPrinter(), use_cse=use_cse, ) + return tf.function(func, jit_compile=True) modules = get_backend_modules(backend) if isinstance(backend, str): diff --git a/tests/function/test_sympy.py b/tests/function/test_sympy.py index 4ac0d387..ae21f091 100644 --- a/tests/function/test_sympy.py +++ b/tests/function/test_sympy.py @@ -117,6 +117,8 @@ def test_fast_lambdify(backend: str, max_complexity: int, use_cse: bool): repr_start = "" else: repr_start = " Date: Thu, 7 Mar 2024 22:11:11 +0100 Subject: [PATCH 2/2] FIX: update TF test names --- tests/function/test_sympy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/function/test_sympy.py b/tests/function/test_sympy.py index 81df1e9c..15bfd817 100644 --- a/tests/function/test_sympy.py +++ b/tests/function/test_sympy.py @@ -116,7 +116,7 @@ def test_fast_lambdify(backend: str, max_complexity: int, use_cse: bool): # cspell:ignore lambdifygenerated repr_start = "