From 46b61acba3a4a4214138f4f17de82c6b63c736c3 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Thu, 13 Apr 2023 17:22:34 +0800 Subject: [PATCH 1/4] feat: make public --- src/adam_op/adam_op_impl_cpu.cpp | 35 ++++++++----------- tests/test_implicit.py | 52 +++++++++++++++++++++++++++++ torchopt/diff/implicit/__init__.py | 4 +-- torchopt/diff/implicit/decorator.py | 51 +++++++++++++++++++++++++++- 4 files changed, 118 insertions(+), 24 deletions(-) diff --git a/src/adam_op/adam_op_impl_cpu.cpp b/src/adam_op/adam_op_impl_cpu.cpp index 1135206d..b9c14e49 100644 --- a/src/adam_op/adam_op_impl_cpu.cpp +++ b/src/adam_op/adam_op_impl_cpu.cpp @@ -40,9 +40,8 @@ void adamForwardInplaceCPUKernel(const other_t b1, scalar_t *__restrict__ updates_ptr, scalar_t *__restrict__ mu_ptr, scalar_t *__restrict__ nu_ptr) { -#pragma omp parallel for num_threads( \ - std::min(n / MIN_NUMEL_USE_OMP, \ - static_cast(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) +#pragma omp parallel for num_threads(std::min( \ + n / MIN_NUMEL_USE_OMP, static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) for (size_t tid = 0; tid < n; ++tid) { const scalar_t updates = updates_ptr[tid]; const scalar_t mu = mu_ptr[tid]; @@ -96,9 +95,8 @@ void adamForwardMuCPUKernel(const scalar_t *__restrict__ updates_ptr, const other_t b1, const size_t n, scalar_t *__restrict__ mu_out_ptr) { -#pragma omp parallel for num_threads( \ - std::min(n / MIN_NUMEL_USE_OMP, \ - static_cast(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) +#pragma omp parallel for num_threads(std::min( \ + n / MIN_NUMEL_USE_OMP, static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) for (size_t tid = 0; tid < n; ++tid) { const scalar_t updates = updates_ptr[tid]; const scalar_t mu = mu_ptr[tid]; @@ -130,9 +128,8 @@ void adamForwardNuCPUKernel(const scalar_t *__restrict__ updates_ptr, const other_t b2, const size_t n, scalar_t *__restrict__ nu_out_ptr) { -#pragma omp parallel for num_threads( \ - std::min(n / MIN_NUMEL_USE_OMP, \ - static_cast(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) +#pragma omp parallel for num_threads(std::min( \ + n / MIN_NUMEL_USE_OMP, static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) for (size_t tid = 0; tid < n; ++tid) { const scalar_t updates = updates_ptr[tid]; const scalar_t nu = nu_ptr[tid]; @@ -168,9 +165,8 @@ void adamForwardUpdatesCPUKernel(const scalar_t *__restrict__ new_mu_ptr, const other_t eps_root, const size_t n, scalar_t *__restrict__ updates_out_ptr) { -#pragma omp parallel for num_threads( \ - std::min(n / MIN_NUMEL_USE_OMP, \ - static_cast(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) +#pragma omp parallel for num_threads(std::min( \ + n / MIN_NUMEL_USE_OMP, static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) for (size_t tid = 0; tid < n; ++tid) { const scalar_t new_mu = new_mu_ptr[tid]; const scalar_t new_nu = new_nu_ptr[tid]; @@ -214,9 +210,8 @@ void adamBackwardMuCPUKernel(const scalar_t *__restrict__ dmu_ptr, const size_t n, scalar_t *__restrict__ dupdates_out_ptr, scalar_t *__restrict__ dmu_out_ptr) { -#pragma omp parallel for num_threads( \ - std::min(n / MIN_NUMEL_USE_OMP, \ - static_cast(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) +#pragma omp parallel for num_threads(std::min( \ + n / MIN_NUMEL_USE_OMP, static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) for (size_t tid = 0; tid < n; ++tid) { const scalar_t dmu = dmu_ptr[tid]; @@ -251,9 +246,8 @@ void adamBackwardNuCPUKernel(const scalar_t *__restrict__ dnu_ptr, const size_t n, scalar_t *__restrict__ dupdates_out_ptr, scalar_t *__restrict__ dnu_out_ptr) { -#pragma omp parallel for num_threads( \ - std::min(n / MIN_NUMEL_USE_OMP, \ - static_cast(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) +#pragma omp parallel for num_threads(std::min( \ + n / MIN_NUMEL_USE_OMP, static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) for (size_t tid = 0; tid < n; ++tid) { const scalar_t dnu = dnu_ptr[tid]; const scalar_t updates = updates_ptr[tid]; @@ -292,9 +286,8 @@ void adamBackwardUpdatesCPUKernel(const scalar_t *__restrict__ dupdates_ptr, const size_t n, scalar_t *__restrict__ dnew_mu_out_ptr, scalar_t *__restrict__ dnew_nu_out_ptr) { -#pragma omp parallel for num_threads( \ - std::min(n / MIN_NUMEL_USE_OMP, \ - static_cast(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) +#pragma omp parallel for num_threads(std::min( \ + n / MIN_NUMEL_USE_OMP, static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) for (size_t tid = 0; tid < n; ++tid) { const scalar_t dupdates = dupdates_ptr[tid]; const scalar_t updates = updates_ptr[tid]; diff --git a/tests/test_implicit.py b/tests/test_implicit.py index db19f829..45d88f74 100644 --- a/tests/test_implicit.py +++ b/tests/test_implicit.py @@ -480,6 +480,58 @@ def outer_level(p, xs, ys): helpers.assert_pytree_all_close(tuple(model.parameters()), jax_params_as_tensor) +def test_rr_root_vjp( + dtype=[torch.float64, torch.float32], + +): + helpers.seed_everything(42) + np_dtype = helpers.dtype_torch2numpy(dtype) + input_size = 10 + + init_params_torch = torch.randn(input_size, dtype=dtype) + l2reg_torch = torch.rand(1, dtype=dtype).squeeze_().requires_grad_(True) + + def ridge_objective(params, l2reg, data): + """Ridge objective function.""" + X_tr, y_tr = data + residuals = X_tr @ params - y_tr + regularization_loss = 0.5 * l2reg * torch.sum(torch.square(params)) + return 0.5 * torch.mean(torch.square(residuals)) + regularization_loss + + def ridge_solver_cg(params, l2reg, data): + """Solve ridge regression by conjugate gradient.""" + X_tr, y_tr = data + + def matvec(u): + return X_tr.T @ (X_tr @ u) + + solve = torchopt.linear_solve.solve_cg( + ridge=len(y_tr) * l2reg.item(), + init=params, + maxiter=20, + ) + + return solve(matvec=matvec, b=X_tr.T @ y_tr) + + for xs, ys, xq, yq in loader: + xs = xs.to(dtype=dtype) + ys = ys.to(dtype=dtype) + xq = xq.to(dtype=dtype) + yq = yq.to(dtype=dtype) + + optimality_fun = grad(ridge_objective, argnums=0, create_graph=True) + solution = ridge_solver_cg(init_params_torch, l2reg_torch, (xs, ys)) + + + def vjp(g): + return vjp(optimality_fun, solution, (lam, X, y), g)[0] # vjp w.r.t. lam + + I = torch.eye(len(sol)) + J = torch.stack([vjp(I[:, i]) for i in range(I.shape[1])]).T + J_num = test_util.ridge_solver_jac(X, y, lam, eps=1e-4) + helpers.assert_all_close(l2reg_torch, l2reg_jax_as_tensor) + # self.assertArraysAllClose(J, J_num, atol=5e-2) + @pytest.mark.skipif(not HAS_JAX, reason='JAX is not installed') @helpers.parametrize( dtype=[torch.float64, torch.float32], diff --git a/torchopt/diff/implicit/__init__.py b/torchopt/diff/implicit/__init__.py index 4e50b615..13e02380 100644 --- a/torchopt/diff/implicit/__init__.py +++ b/torchopt/diff/implicit/__init__.py @@ -15,8 +15,8 @@ """Implicit Meta-Gradient.""" from torchopt.diff.implicit import nn -from torchopt.diff.implicit.decorator import custom_root +from torchopt.diff.implicit.decorator import custom_root, root_vjp from torchopt.diff.implicit.nn import ImplicitMetaGradientModule -__all__ = ['custom_root', 'ImplicitMetaGradientModule'] +__all__ = ['custom_root', 'ImplicitMetaGradientModule', 'root_vjp'] diff --git a/torchopt/diff/implicit/decorator.py b/torchopt/diff/implicit/decorator.py index 031aa11f..d03fcf2f 100644 --- a/torchopt/diff/implicit/decorator.py +++ b/torchopt/diff/implicit/decorator.py @@ -36,7 +36,7 @@ ) -__all__ = ['custom_root'] +__all__ = ['custom_root', 'root_vjp'] Args = Tuple[Any, ...] @@ -419,6 +419,55 @@ def wrapped_solver_fn( return wrapped_solver_fn +# pylint: disable-next=too-many-arguments,too-many-locals,too-many-branches +def root_vjp( + optimality_fn: Callable[..., TensorOrTensors], + solution: TupleOfTensors, + args: Args, + grad_outputs: TupleOfTensors, + output_is_tensor: bool, + argnums: tuple[int, ...], + solve: Callable[..., TensorOrTensors] = None, +) -> TupleOfOptionalTensors: + """Return vector-Jacobian product of a root. + + The root is the `solution` of ``optimality_fn(solution, *args) == 0``. + + Args: + optimality_fun (callable): An equation function, ``optimality_fn(params, *args)``. The + invariant is ``optimality_fn(solution, *args) == 0`` at ``solution``. + solution (tuple of Tensors): solution / root of `optimality_fun`. + args (Args): tuple containing the arguments with respect to which we wish to + differentiate ``solution`` against. + grad_outputs (tuple of Tensors): The "vector" in the vector-Jacobian product. + Usually gradients w.r.t. each output. None values can be specified for scalar + Tensors or ones that don't require grad. If a None value would be acceptable + for all grad_tensors, then this argument is optional. Default: None. + output_is_tensor (bool): Whether the output of ``optimality_fn`` is a single tensor. + argnums (int or tuple of int): Specifies arguments to compute gradients with respect to. The + ``argnums`` can be an integer or a tuple of integers. + solve (callable, optional): A linear solver of the form ``solve(matvec, b)``. + (default: :func:`linear_solve.solve_normal_cg`) + Returns: + tuple of the same length as ``len(args)`` containing the vjps w.r.t. + each argument. Each ``vjps[i]`` has the same pytree structure as + ``args[i]``. + """ + if solve is None: + solve = linear_solve.solve_normal_cg() + + return functools.partial( + _root_vjp, + optimality_fn=optimality_fn, + solution=solution, + args=args, + grad_outputs=grad_outputs, + output_is_tensor=output_is_tensor, + argnums=argnums, + solve=solve, + ) + + def custom_root( optimality_fn: Callable[..., TensorOrTensors], argnums: int | tuple[int, ...], From b9f14e2ae9f5fe78f0cf8f3ca6d5f97239130693 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Thu, 13 Apr 2023 18:23:11 +0800 Subject: [PATCH 2/4] feat: make public --- Makefile | 2 +- docs/source/spelling_wordlist.txt | 1 + tests/test_implicit.py | 40 ++++++++++++++++++++--------- torchopt/diff/implicit/decorator.py | 11 ++++---- 4 files changed, 35 insertions(+), 19 deletions(-) diff --git a/Makefile b/Makefile index 248940df..dfc9ac46 100644 --- a/Makefile +++ b/Makefile @@ -113,7 +113,7 @@ addlicense-install: go-install pytest: test-install cd tests && $(PYTHON) -c 'import $(PROJECT_NAME)' && \ - $(PYTHON) -m pytest --verbose --color=yes --durations=0 \ + $(PYTHON) -m pytest -k "test_rr_root_vjp" --verbose --color=yes --durations=0 \ --cov="$(PROJECT_NAME)" --cov-config=.coveragerc --cov-report=xml --cov-report=term-missing \ $(PYTESTOPTS) . diff --git a/docs/source/spelling_wordlist.txt b/docs/source/spelling_wordlist.txt index 49fdbb69..ea458c22 100644 --- a/docs/source/spelling_wordlist.txt +++ b/docs/source/spelling_wordlist.txt @@ -175,3 +175,4 @@ ctx Duchi invertible AdaGrad +vjp diff --git a/tests/test_implicit.py b/tests/test_implicit.py index 45d88f74..295bc1f0 100644 --- a/tests/test_implicit.py +++ b/tests/test_implicit.py @@ -480,17 +480,21 @@ def outer_level(p, xs, ys): helpers.assert_pytree_all_close(tuple(model.parameters()), jax_params_as_tensor) -def test_rr_root_vjp( +@helpers.parametrize( dtype=[torch.float64, torch.float32], - +) +def test_rr_root_vjp( + dtype: torch.dtype, ): helpers.seed_everything(42) - np_dtype = helpers.dtype_torch2numpy(dtype) + helpers.dtype_torch2numpy(dtype) input_size = 10 init_params_torch = torch.randn(input_size, dtype=dtype) l2reg_torch = torch.rand(1, dtype=dtype).squeeze_().requires_grad_(True) + loader = get_rr_dataset_torch() + def ridge_objective(params, l2reg, data): """Ridge objective function.""" X_tr, y_tr = data @@ -513,24 +517,36 @@ def matvec(u): return solve(matvec=matvec, b=X_tr.T @ y_tr) + def ridge_solver_jac(params, l2reg, data, eps=1e-8): + return ( + ridge_solver_cg(params, l2reg + eps, data) - ridge_solver_cg(params, l2reg - eps, data) + ) / (2 * eps) + for xs, ys, xq, yq in loader: xs = xs.to(dtype=dtype) ys = ys.to(dtype=dtype) xq = xq.to(dtype=dtype) yq = yq.to(dtype=dtype) - optimality_fun = grad(ridge_objective, argnums=0, create_graph=True) + optimality_fn = functorch.grad(ridge_objective, argnums=0) solution = ridge_solver_cg(init_params_torch, l2reg_torch, (xs, ys)) + # def vjp(g): + # return vjp(optimality_fun, solution, (lam, X, y), g)[0] # vjp w.r.t. lam + + # torch.eye(len(sol)) + # J = torch.stack([vjp(I[:, i]) for i in range(I.shape[1])]).T + J = torchopt.diff.implicit.root_vjp( + optimality_fn=optimality_fn, + solution=solution, + args=(l2reg_torch, (xs, ys)), + grad_outputs=1.0, + output_is_tensor=True, + argnums=1, + ) + J_num = ridge_solver_jac(init_params_torch, l2reg_torch, (xs, ys), eps=1e-4) + helpers.assert_all_close(J, J_num) - def vjp(g): - return vjp(optimality_fun, solution, (lam, X, y), g)[0] # vjp w.r.t. lam - - I = torch.eye(len(sol)) - J = torch.stack([vjp(I[:, i]) for i in range(I.shape[1])]).T - J_num = test_util.ridge_solver_jac(X, y, lam, eps=1e-4) - helpers.assert_all_close(l2reg_torch, l2reg_jax_as_tensor) - # self.assertArraysAllClose(J, J_num, atol=5e-2) @pytest.mark.skipif(not HAS_JAX, reason='JAX is not installed') @helpers.parametrize( diff --git a/torchopt/diff/implicit/decorator.py b/torchopt/diff/implicit/decorator.py index d03fcf2f..1e33d7a9 100644 --- a/torchopt/diff/implicit/decorator.py +++ b/torchopt/diff/implicit/decorator.py @@ -343,9 +343,8 @@ def backward( # pylint: disable=too-many-locals raise TypeError( f'keyword arguments to solver_fn could not be resolved to positional ' f'arguments based on the signature {reference_signature}. This can ' - f'happen under custom_root if optimality_fn takes catch-all **kwargs, or ' - f'under custom_fixed_point if fixed_point_fn takes catch-all **kwargs, ' - f'both of which are currently unsupported.', + f'happen under custom_root if optimality_fn takes catch-all **kwargs, ' + f'which are currently unsupported.', ) # Compute VJPs w.r.t. args. @@ -448,16 +447,16 @@ def root_vjp( ``argnums`` can be an integer or a tuple of integers. solve (callable, optional): A linear solver of the form ``solve(matvec, b)``. (default: :func:`linear_solve.solve_normal_cg`) + Returns: - tuple of the same length as ``len(args)`` containing the vjps w.r.t. + tuple of the same length as ``len(args)`` containing the vector-Jacobian products w.r.t. each argument. Each ``vjps[i]`` has the same pytree structure as ``args[i]``. """ if solve is None: solve = linear_solve.solve_normal_cg() - return functools.partial( - _root_vjp, + return _root_vjp( optimality_fn=optimality_fn, solution=solution, args=args, From 5cdd0f043cf8b125ad196795363800d6aec4dac6 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Sat, 15 Apr 2023 09:25:10 +0800 Subject: [PATCH 3/4] feat: make public --- Makefile | 2 +- docs/source/api/api.rst | 5 ++ docs/source/implicit_diff/implicit_diff.rst | 9 +++ tests/test_implicit.py | 31 +++++---- torchopt/diff/implicit/__init__.py | 3 +- torchopt/diff/implicit/decorator.py | 51 +------------- torchopt/diff/implicit/utils.py | 77 +++++++++++++++++++++ 7 files changed, 111 insertions(+), 67 deletions(-) create mode 100644 torchopt/diff/implicit/utils.py diff --git a/Makefile b/Makefile index dfc9ac46..248940df 100644 --- a/Makefile +++ b/Makefile @@ -113,7 +113,7 @@ addlicense-install: go-install pytest: test-install cd tests && $(PYTHON) -c 'import $(PROJECT_NAME)' && \ - $(PYTHON) -m pytest -k "test_rr_root_vjp" --verbose --color=yes --durations=0 \ + $(PYTHON) -m pytest --verbose --color=yes --durations=0 \ --cov="$(PROJECT_NAME)" --cov-config=.coveragerc --cov-report=xml --cov-report=term-missing \ $(PYTESTOPTS) . diff --git a/docs/source/api/api.rst b/docs/source/api/api.rst index d00e2333..bad262b3 100644 --- a/docs/source/api/api.rst +++ b/docs/source/api/api.rst @@ -158,12 +158,17 @@ Implicit Differentiation custom_root nn.ImplicitMetaGradientModule + root_vjp Custom Solvers ~~~~~~~~~~~~~~ .. autofunction:: custom_root +VJPs of Root +~~~~~~~~~~~~ + +.. autofunction:: root_vjp Implicit Meta-Gradient Module ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/implicit_diff/implicit_diff.rst b/docs/source/implicit_diff/implicit_diff.rst index 5544c25f..885c28d1 100644 --- a/docs/source/implicit_diff/implicit_diff.rst +++ b/docs/source/implicit_diff/implicit_diff.rst @@ -64,6 +64,15 @@ This can be implemented with: new_theta = fixed_point_function(phi, theta) return torchopt.pytree.tree_sub(new_theta, theta) +VJPs of Root +------------ + +.. autosummary:: + + torchopt.diff.implicit.root_vjp + +We also provide lower-level routines for computing the VJPs of roots of functions. The VJPs of roots are useful for computing the VJPs of the inner-level optimal solutions in the context of implicit differentiation. + Custom Solvers -------------- diff --git a/tests/test_implicit.py b/tests/test_implicit.py index 295bc1f0..c2709006 100644 --- a/tests/test_implicit.py +++ b/tests/test_implicit.py @@ -528,23 +528,24 @@ def ridge_solver_jac(params, l2reg, data, eps=1e-8): xq = xq.to(dtype=dtype) yq = yq.to(dtype=dtype) - optimality_fn = functorch.grad(ridge_objective, argnums=0) + optimality_fn = functorch.grad(ridge_objective) solution = ridge_solver_cg(init_params_torch, l2reg_torch, (xs, ys)) - # def vjp(g): - # return vjp(optimality_fun, solution, (lam, X, y), g)[0] # vjp w.r.t. lam - - # torch.eye(len(sol)) - # J = torch.stack([vjp(I[:, i]) for i in range(I.shape[1])]).T - J = torchopt.diff.implicit.root_vjp( - optimality_fn=optimality_fn, - solution=solution, - args=(l2reg_torch, (xs, ys)), - grad_outputs=1.0, - output_is_tensor=True, - argnums=1, - ) - J_num = ridge_solver_jac(init_params_torch, l2reg_torch, (xs, ys), eps=1e-4) + def vjp(g): + return torchopt.diff.implicit.root_vjp( + optimality_fn=optimality_fn, # noqa: B023 + solution=solution.view(1, -1), # noqa: B023 + args=(l2reg_torch, (xs, ys)), # noqa: B023 + grad_outputs=g, + output_is_tensor=True, + argnums=(1,), + solve=torchopt.linear_solve.solve_cg(), + ) + + I = torch.eye(len(solution)) # noqa: E741 + # J = functorch.vmap(vjp)(I) + J = torch.stack([vjp(I[:, i].view(1, -1))[1] for i in range(I.shape[1])]) + J_num = ridge_solver_jac(init_params_torch, l2reg_torch, (xs, ys), eps=5e-3) helpers.assert_all_close(J, J_num) diff --git a/torchopt/diff/implicit/__init__.py b/torchopt/diff/implicit/__init__.py index 13e02380..e7219d6a 100644 --- a/torchopt/diff/implicit/__init__.py +++ b/torchopt/diff/implicit/__init__.py @@ -15,8 +15,9 @@ """Implicit Meta-Gradient.""" from torchopt.diff.implicit import nn -from torchopt.diff.implicit.decorator import custom_root, root_vjp +from torchopt.diff.implicit.decorator import custom_root from torchopt.diff.implicit.nn import ImplicitMetaGradientModule +from torchopt.diff.implicit.utils import root_vjp __all__ = ['custom_root', 'ImplicitMetaGradientModule', 'root_vjp'] diff --git a/torchopt/diff/implicit/decorator.py b/torchopt/diff/implicit/decorator.py index 1e33d7a9..34074762 100644 --- a/torchopt/diff/implicit/decorator.py +++ b/torchopt/diff/implicit/decorator.py @@ -36,7 +36,7 @@ ) -__all__ = ['custom_root', 'root_vjp'] +__all__ = ['custom_root'] Args = Tuple[Any, ...] @@ -418,55 +418,6 @@ def wrapped_solver_fn( return wrapped_solver_fn -# pylint: disable-next=too-many-arguments,too-many-locals,too-many-branches -def root_vjp( - optimality_fn: Callable[..., TensorOrTensors], - solution: TupleOfTensors, - args: Args, - grad_outputs: TupleOfTensors, - output_is_tensor: bool, - argnums: tuple[int, ...], - solve: Callable[..., TensorOrTensors] = None, -) -> TupleOfOptionalTensors: - """Return vector-Jacobian product of a root. - - The root is the `solution` of ``optimality_fn(solution, *args) == 0``. - - Args: - optimality_fun (callable): An equation function, ``optimality_fn(params, *args)``. The - invariant is ``optimality_fn(solution, *args) == 0`` at ``solution``. - solution (tuple of Tensors): solution / root of `optimality_fun`. - args (Args): tuple containing the arguments with respect to which we wish to - differentiate ``solution`` against. - grad_outputs (tuple of Tensors): The "vector" in the vector-Jacobian product. - Usually gradients w.r.t. each output. None values can be specified for scalar - Tensors or ones that don't require grad. If a None value would be acceptable - for all grad_tensors, then this argument is optional. Default: None. - output_is_tensor (bool): Whether the output of ``optimality_fn`` is a single tensor. - argnums (int or tuple of int): Specifies arguments to compute gradients with respect to. The - ``argnums`` can be an integer or a tuple of integers. - solve (callable, optional): A linear solver of the form ``solve(matvec, b)``. - (default: :func:`linear_solve.solve_normal_cg`) - - Returns: - tuple of the same length as ``len(args)`` containing the vector-Jacobian products w.r.t. - each argument. Each ``vjps[i]`` has the same pytree structure as - ``args[i]``. - """ - if solve is None: - solve = linear_solve.solve_normal_cg() - - return _root_vjp( - optimality_fn=optimality_fn, - solution=solution, - args=args, - grad_outputs=grad_outputs, - output_is_tensor=output_is_tensor, - argnums=argnums, - solve=solve, - ) - - def custom_root( optimality_fn: Callable[..., TensorOrTensors], argnums: int | tuple[int, ...], diff --git a/torchopt/diff/implicit/utils.py b/torchopt/diff/implicit/utils.py new file mode 100644 index 00000000..0d32fd59 --- /dev/null +++ b/torchopt/diff/implicit/utils.py @@ -0,0 +1,77 @@ +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implicit Meta-Gradient.""" + +# pylint: disable=invalid-name + +from __future__ import annotations + +from typing import Callable + +from torchopt import linear_solve +from torchopt.diff.implicit.decorator import Args, _root_vjp +from torchopt.typing import TensorOrTensors, TupleOfOptionalTensors, TupleOfTensors + + +__all__ = ['root_vjp'] + + +# pylint: disable-next=too-many-arguments,too-many-locals,too-many-branches +def root_vjp( + optimality_fn: Callable[..., TensorOrTensors], + solution: TupleOfTensors, + args: Args, + grad_outputs: TupleOfTensors, + output_is_tensor: bool, + argnums: tuple[int, ...], + solve: Callable[..., TensorOrTensors] | None = None, +) -> TupleOfOptionalTensors: + """Return vector-Jacobian product of a root. + + The root is the `solution` of ``optimality_fn(solution, *args) == 0``. + + Args: + optimality_fun (callable): An equation function, ``optimality_fn(params, *args)``. The + invariant is ``optimality_fn(solution, *args) == 0`` at ``solution``. + solution (tuple of Tensors): solution / root of `optimality_fun`. + args (Args): tuple containing the arguments with respect to which we wish to + differentiate ``solution`` against. + grad_outputs (tuple of Tensors): The "vector" in the vector-Jacobian product. + Usually gradients w.r.t. each output. None values can be specified for scalar + Tensors or ones that don't require grad. If a None value would be acceptable + for all grad_tensors, then this argument is optional. Default: None. + output_is_tensor (bool): Whether the output of ``optimality_fn`` is a single tensor. + argnums (int or tuple of int): Specifies arguments to compute gradients with respect to. The + ``argnums`` can be an integer or a tuple of integers. + solve (callable, optional): A linear solver of the form ``solve(matvec, b)``. + (default: :func:`linear_solve.solve_normal_cg`) + + Returns: + tuple of the same length as ``len(args)`` containing the vector-Jacobian products w.r.t. + each argument. Each ``vjps[i]`` has the same pytree structure as + ``args[i]``. + """ + if solve is None: + solve = linear_solve.solve_normal_cg() + + return _root_vjp( + optimality_fn=optimality_fn, + solution=solution, + args=args, + grad_outputs=grad_outputs, + output_is_tensor=output_is_tensor, + argnums=argnums, + solve=solve, + ) From 2fd5166f33afc3dc2bbc02740635bc800be302ec Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 15 Apr 2023 01:27:24 +0000 Subject: [PATCH 4/4] fix: [pre-commit.ci] auto fixes [...] --- src/adam_op/adam_op_impl_cpu.cpp | 35 +++++++++++++++++++------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/src/adam_op/adam_op_impl_cpu.cpp b/src/adam_op/adam_op_impl_cpu.cpp index b9c14e49..1135206d 100644 --- a/src/adam_op/adam_op_impl_cpu.cpp +++ b/src/adam_op/adam_op_impl_cpu.cpp @@ -40,8 +40,9 @@ void adamForwardInplaceCPUKernel(const other_t b1, scalar_t *__restrict__ updates_ptr, scalar_t *__restrict__ mu_ptr, scalar_t *__restrict__ nu_ptr) { -#pragma omp parallel for num_threads(std::min( \ - n / MIN_NUMEL_USE_OMP, static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) +#pragma omp parallel for num_threads( \ + std::min(n / MIN_NUMEL_USE_OMP, \ + static_cast(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) for (size_t tid = 0; tid < n; ++tid) { const scalar_t updates = updates_ptr[tid]; const scalar_t mu = mu_ptr[tid]; @@ -95,8 +96,9 @@ void adamForwardMuCPUKernel(const scalar_t *__restrict__ updates_ptr, const other_t b1, const size_t n, scalar_t *__restrict__ mu_out_ptr) { -#pragma omp parallel for num_threads(std::min( \ - n / MIN_NUMEL_USE_OMP, static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) +#pragma omp parallel for num_threads( \ + std::min(n / MIN_NUMEL_USE_OMP, \ + static_cast(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) for (size_t tid = 0; tid < n; ++tid) { const scalar_t updates = updates_ptr[tid]; const scalar_t mu = mu_ptr[tid]; @@ -128,8 +130,9 @@ void adamForwardNuCPUKernel(const scalar_t *__restrict__ updates_ptr, const other_t b2, const size_t n, scalar_t *__restrict__ nu_out_ptr) { -#pragma omp parallel for num_threads(std::min( \ - n / MIN_NUMEL_USE_OMP, static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) +#pragma omp parallel for num_threads( \ + std::min(n / MIN_NUMEL_USE_OMP, \ + static_cast(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) for (size_t tid = 0; tid < n; ++tid) { const scalar_t updates = updates_ptr[tid]; const scalar_t nu = nu_ptr[tid]; @@ -165,8 +168,9 @@ void adamForwardUpdatesCPUKernel(const scalar_t *__restrict__ new_mu_ptr, const other_t eps_root, const size_t n, scalar_t *__restrict__ updates_out_ptr) { -#pragma omp parallel for num_threads(std::min( \ - n / MIN_NUMEL_USE_OMP, static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) +#pragma omp parallel for num_threads( \ + std::min(n / MIN_NUMEL_USE_OMP, \ + static_cast(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) for (size_t tid = 0; tid < n; ++tid) { const scalar_t new_mu = new_mu_ptr[tid]; const scalar_t new_nu = new_nu_ptr[tid]; @@ -210,8 +214,9 @@ void adamBackwardMuCPUKernel(const scalar_t *__restrict__ dmu_ptr, const size_t n, scalar_t *__restrict__ dupdates_out_ptr, scalar_t *__restrict__ dmu_out_ptr) { -#pragma omp parallel for num_threads(std::min( \ - n / MIN_NUMEL_USE_OMP, static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) +#pragma omp parallel for num_threads( \ + std::min(n / MIN_NUMEL_USE_OMP, \ + static_cast(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) for (size_t tid = 0; tid < n; ++tid) { const scalar_t dmu = dmu_ptr[tid]; @@ -246,8 +251,9 @@ void adamBackwardNuCPUKernel(const scalar_t *__restrict__ dnu_ptr, const size_t n, scalar_t *__restrict__ dupdates_out_ptr, scalar_t *__restrict__ dnu_out_ptr) { -#pragma omp parallel for num_threads(std::min( \ - n / MIN_NUMEL_USE_OMP, static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) +#pragma omp parallel for num_threads( \ + std::min(n / MIN_NUMEL_USE_OMP, \ + static_cast(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) for (size_t tid = 0; tid < n; ++tid) { const scalar_t dnu = dnu_ptr[tid]; const scalar_t updates = updates_ptr[tid]; @@ -286,8 +292,9 @@ void adamBackwardUpdatesCPUKernel(const scalar_t *__restrict__ dupdates_ptr, const size_t n, scalar_t *__restrict__ dnew_mu_out_ptr, scalar_t *__restrict__ dnew_nu_out_ptr) { -#pragma omp parallel for num_threads(std::min( \ - n / MIN_NUMEL_USE_OMP, static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) +#pragma omp parallel for num_threads( \ + std::min(n / MIN_NUMEL_USE_OMP, \ + static_cast(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) for (size_t tid = 0; tid < n; ++tid) { const scalar_t dupdates = dupdates_ptr[tid]; const scalar_t updates = updates_ptr[tid];