diff --git a/docs/source/api/api.rst b/docs/source/api/api.rst index d00e2333..bad262b3 100644 --- a/docs/source/api/api.rst +++ b/docs/source/api/api.rst @@ -158,12 +158,17 @@ Implicit Differentiation custom_root nn.ImplicitMetaGradientModule + root_vjp Custom Solvers ~~~~~~~~~~~~~~ .. autofunction:: custom_root +VJPs of Root +~~~~~~~~~~~~ + +.. autofunction:: root_vjp Implicit Meta-Gradient Module ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/implicit_diff/implicit_diff.rst b/docs/source/implicit_diff/implicit_diff.rst index 5544c25f..885c28d1 100644 --- a/docs/source/implicit_diff/implicit_diff.rst +++ b/docs/source/implicit_diff/implicit_diff.rst @@ -64,6 +64,15 @@ This can be implemented with: new_theta = fixed_point_function(phi, theta) return torchopt.pytree.tree_sub(new_theta, theta) +VJPs of Root +------------ + +.. autosummary:: + + torchopt.diff.implicit.root_vjp + +We also provide lower-level routines for computing the VJPs of roots of functions. The VJPs of roots are useful for computing the VJPs of the inner-level optimal solutions in the context of implicit differentiation. + Custom Solvers -------------- diff --git a/docs/source/spelling_wordlist.txt b/docs/source/spelling_wordlist.txt index 49fdbb69..ea458c22 100644 --- a/docs/source/spelling_wordlist.txt +++ b/docs/source/spelling_wordlist.txt @@ -175,3 +175,4 @@ ctx Duchi invertible AdaGrad +vjp diff --git a/tests/test_implicit.py b/tests/test_implicit.py index db19f829..c2709006 100644 --- a/tests/test_implicit.py +++ b/tests/test_implicit.py @@ -480,6 +480,75 @@ def outer_level(p, xs, ys): helpers.assert_pytree_all_close(tuple(model.parameters()), jax_params_as_tensor) +@helpers.parametrize( + dtype=[torch.float64, torch.float32], +) +def test_rr_root_vjp( + dtype: torch.dtype, +): + helpers.seed_everything(42) + helpers.dtype_torch2numpy(dtype) + input_size = 10 + + init_params_torch = torch.randn(input_size, dtype=dtype) + l2reg_torch = torch.rand(1, dtype=dtype).squeeze_().requires_grad_(True) + + loader = get_rr_dataset_torch() + + def ridge_objective(params, l2reg, data): + """Ridge objective function.""" + X_tr, y_tr = data + residuals = X_tr @ params - y_tr + regularization_loss = 0.5 * l2reg * torch.sum(torch.square(params)) + return 0.5 * torch.mean(torch.square(residuals)) + regularization_loss + + def ridge_solver_cg(params, l2reg, data): + """Solve ridge regression by conjugate gradient.""" + X_tr, y_tr = data + + def matvec(u): + return X_tr.T @ (X_tr @ u) + + solve = torchopt.linear_solve.solve_cg( + ridge=len(y_tr) * l2reg.item(), + init=params, + maxiter=20, + ) + + return solve(matvec=matvec, b=X_tr.T @ y_tr) + + def ridge_solver_jac(params, l2reg, data, eps=1e-8): + return ( + ridge_solver_cg(params, l2reg + eps, data) - ridge_solver_cg(params, l2reg - eps, data) + ) / (2 * eps) + + for xs, ys, xq, yq in loader: + xs = xs.to(dtype=dtype) + ys = ys.to(dtype=dtype) + xq = xq.to(dtype=dtype) + yq = yq.to(dtype=dtype) + + optimality_fn = functorch.grad(ridge_objective) + solution = ridge_solver_cg(init_params_torch, l2reg_torch, (xs, ys)) + + def vjp(g): + return torchopt.diff.implicit.root_vjp( + optimality_fn=optimality_fn, # noqa: B023 + solution=solution.view(1, -1), # noqa: B023 + args=(l2reg_torch, (xs, ys)), # noqa: B023 + grad_outputs=g, + output_is_tensor=True, + argnums=(1,), + solve=torchopt.linear_solve.solve_cg(), + ) + + I = torch.eye(len(solution)) # noqa: E741 + # J = functorch.vmap(vjp)(I) + J = torch.stack([vjp(I[:, i].view(1, -1))[1] for i in range(I.shape[1])]) + J_num = ridge_solver_jac(init_params_torch, l2reg_torch, (xs, ys), eps=5e-3) + helpers.assert_all_close(J, J_num) + + @pytest.mark.skipif(not HAS_JAX, reason='JAX is not installed') @helpers.parametrize( dtype=[torch.float64, torch.float32], diff --git a/torchopt/diff/implicit/__init__.py b/torchopt/diff/implicit/__init__.py index 4e50b615..e7219d6a 100644 --- a/torchopt/diff/implicit/__init__.py +++ b/torchopt/diff/implicit/__init__.py @@ -17,6 +17,7 @@ from torchopt.diff.implicit import nn from torchopt.diff.implicit.decorator import custom_root from torchopt.diff.implicit.nn import ImplicitMetaGradientModule +from torchopt.diff.implicit.utils import root_vjp -__all__ = ['custom_root', 'ImplicitMetaGradientModule'] +__all__ = ['custom_root', 'ImplicitMetaGradientModule', 'root_vjp'] diff --git a/torchopt/diff/implicit/decorator.py b/torchopt/diff/implicit/decorator.py index 031aa11f..34074762 100644 --- a/torchopt/diff/implicit/decorator.py +++ b/torchopt/diff/implicit/decorator.py @@ -343,9 +343,8 @@ def backward( # pylint: disable=too-many-locals raise TypeError( f'keyword arguments to solver_fn could not be resolved to positional ' f'arguments based on the signature {reference_signature}. This can ' - f'happen under custom_root if optimality_fn takes catch-all **kwargs, or ' - f'under custom_fixed_point if fixed_point_fn takes catch-all **kwargs, ' - f'both of which are currently unsupported.', + f'happen under custom_root if optimality_fn takes catch-all **kwargs, ' + f'which are currently unsupported.', ) # Compute VJPs w.r.t. args. diff --git a/torchopt/diff/implicit/utils.py b/torchopt/diff/implicit/utils.py new file mode 100644 index 00000000..0d32fd59 --- /dev/null +++ b/torchopt/diff/implicit/utils.py @@ -0,0 +1,77 @@ +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implicit Meta-Gradient.""" + +# pylint: disable=invalid-name + +from __future__ import annotations + +from typing import Callable + +from torchopt import linear_solve +from torchopt.diff.implicit.decorator import Args, _root_vjp +from torchopt.typing import TensorOrTensors, TupleOfOptionalTensors, TupleOfTensors + + +__all__ = ['root_vjp'] + + +# pylint: disable-next=too-many-arguments,too-many-locals,too-many-branches +def root_vjp( + optimality_fn: Callable[..., TensorOrTensors], + solution: TupleOfTensors, + args: Args, + grad_outputs: TupleOfTensors, + output_is_tensor: bool, + argnums: tuple[int, ...], + solve: Callable[..., TensorOrTensors] | None = None, +) -> TupleOfOptionalTensors: + """Return vector-Jacobian product of a root. + + The root is the `solution` of ``optimality_fn(solution, *args) == 0``. + + Args: + optimality_fun (callable): An equation function, ``optimality_fn(params, *args)``. The + invariant is ``optimality_fn(solution, *args) == 0`` at ``solution``. + solution (tuple of Tensors): solution / root of `optimality_fun`. + args (Args): tuple containing the arguments with respect to which we wish to + differentiate ``solution`` against. + grad_outputs (tuple of Tensors): The "vector" in the vector-Jacobian product. + Usually gradients w.r.t. each output. None values can be specified for scalar + Tensors or ones that don't require grad. If a None value would be acceptable + for all grad_tensors, then this argument is optional. Default: None. + output_is_tensor (bool): Whether the output of ``optimality_fn`` is a single tensor. + argnums (int or tuple of int): Specifies arguments to compute gradients with respect to. The + ``argnums`` can be an integer or a tuple of integers. + solve (callable, optional): A linear solver of the form ``solve(matvec, b)``. + (default: :func:`linear_solve.solve_normal_cg`) + + Returns: + tuple of the same length as ``len(args)`` containing the vector-Jacobian products w.r.t. + each argument. Each ``vjps[i]`` has the same pytree structure as + ``args[i]``. + """ + if solve is None: + solve = linear_solve.solve_normal_cg() + + return _root_vjp( + optimality_fn=optimality_fn, + solution=solution, + args=args, + grad_outputs=grad_outputs, + output_is_tensor=output_is_tensor, + argnums=argnums, + solve=solve, + )