Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: make root_vjp public #157

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/api/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -158,12 +158,17 @@ Implicit Differentiation

custom_root
nn.ImplicitMetaGradientModule
root_vjp

Custom Solvers
~~~~~~~~~~~~~~

.. autofunction:: custom_root

VJPs of Root
~~~~~~~~~~~~

.. autofunction:: root_vjp

Implicit Meta-Gradient Module
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
9 changes: 9 additions & 0 deletions docs/source/implicit_diff/implicit_diff.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ This can be implemented with:
new_theta = fixed_point_function(phi, theta)
return torchopt.pytree.tree_sub(new_theta, theta)

VJPs of Root
------------

.. autosummary::

torchopt.diff.implicit.root_vjp

We also provide lower-level routines for computing the VJPs of roots of functions. The VJPs of roots are useful for computing the VJPs of the inner-level optimal solutions in the context of implicit differentiation.

Custom Solvers
--------------

Expand Down
1 change: 1 addition & 0 deletions docs/source/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,4 @@ ctx
Duchi
invertible
AdaGrad
vjp
69 changes: 69 additions & 0 deletions tests/test_implicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,75 @@ def outer_level(p, xs, ys):
helpers.assert_pytree_all_close(tuple(model.parameters()), jax_params_as_tensor)


@helpers.parametrize(
dtype=[torch.float64, torch.float32],
)
def test_rr_root_vjp(
dtype: torch.dtype,
):
helpers.seed_everything(42)
helpers.dtype_torch2numpy(dtype)
input_size = 10

init_params_torch = torch.randn(input_size, dtype=dtype)
l2reg_torch = torch.rand(1, dtype=dtype).squeeze_().requires_grad_(True)

loader = get_rr_dataset_torch()

def ridge_objective(params, l2reg, data):
"""Ridge objective function."""
X_tr, y_tr = data
residuals = X_tr @ params - y_tr
regularization_loss = 0.5 * l2reg * torch.sum(torch.square(params))
return 0.5 * torch.mean(torch.square(residuals)) + regularization_loss

def ridge_solver_cg(params, l2reg, data):
"""Solve ridge regression by conjugate gradient."""
X_tr, y_tr = data

def matvec(u):
return X_tr.T @ (X_tr @ u)

solve = torchopt.linear_solve.solve_cg(
ridge=len(y_tr) * l2reg.item(),
init=params,
maxiter=20,
)

return solve(matvec=matvec, b=X_tr.T @ y_tr)

def ridge_solver_jac(params, l2reg, data, eps=1e-8):
return (
ridge_solver_cg(params, l2reg + eps, data) - ridge_solver_cg(params, l2reg - eps, data)
) / (2 * eps)

for xs, ys, xq, yq in loader:
xs = xs.to(dtype=dtype)
ys = ys.to(dtype=dtype)
xq = xq.to(dtype=dtype)
yq = yq.to(dtype=dtype)

optimality_fn = functorch.grad(ridge_objective)
solution = ridge_solver_cg(init_params_torch, l2reg_torch, (xs, ys))

def vjp(g):
return torchopt.diff.implicit.root_vjp(
optimality_fn=optimality_fn, # noqa: B023
solution=solution.view(1, -1), # noqa: B023
args=(l2reg_torch, (xs, ys)), # noqa: B023
grad_outputs=g,
output_is_tensor=True,
argnums=(1,),
solve=torchopt.linear_solve.solve_cg(),
)

I = torch.eye(len(solution)) # noqa: E741
# J = functorch.vmap(vjp)(I)
J = torch.stack([vjp(I[:, i].view(1, -1))[1] for i in range(I.shape[1])])
J_num = ridge_solver_jac(init_params_torch, l2reg_torch, (xs, ys), eps=5e-3)
helpers.assert_all_close(J, J_num)


@pytest.mark.skipif(not HAS_JAX, reason='JAX is not installed')
@helpers.parametrize(
dtype=[torch.float64, torch.float32],
Expand Down
3 changes: 2 additions & 1 deletion torchopt/diff/implicit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from torchopt.diff.implicit import nn
from torchopt.diff.implicit.decorator import custom_root
from torchopt.diff.implicit.nn import ImplicitMetaGradientModule
from torchopt.diff.implicit.utils import root_vjp


__all__ = ['custom_root', 'ImplicitMetaGradientModule']
__all__ = ['custom_root', 'ImplicitMetaGradientModule', 'root_vjp']
5 changes: 2 additions & 3 deletions torchopt/diff/implicit/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,9 +343,8 @@ def backward( # pylint: disable=too-many-locals
raise TypeError(
f'keyword arguments to solver_fn could not be resolved to positional '
f'arguments based on the signature {reference_signature}. This can '
f'happen under custom_root if optimality_fn takes catch-all **kwargs, or '
f'under custom_fixed_point if fixed_point_fn takes catch-all **kwargs, '
f'both of which are currently unsupported.',
f'happen under custom_root if optimality_fn takes catch-all **kwargs, '
f'which are currently unsupported.',
)

# Compute VJPs w.r.t. args.
Expand Down
77 changes: 77 additions & 0 deletions torchopt/diff/implicit/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright 2022-2023 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Implicit Meta-Gradient."""

# pylint: disable=invalid-name

from __future__ import annotations

from typing import Callable

from torchopt import linear_solve
from torchopt.diff.implicit.decorator import Args, _root_vjp
from torchopt.typing import TensorOrTensors, TupleOfOptionalTensors, TupleOfTensors


__all__ = ['root_vjp']


# pylint: disable-next=too-many-arguments,too-many-locals,too-many-branches
def root_vjp(
optimality_fn: Callable[..., TensorOrTensors],
solution: TupleOfTensors,
args: Args,
grad_outputs: TupleOfTensors,
output_is_tensor: bool,
argnums: tuple[int, ...],
solve: Callable[..., TensorOrTensors] | None = None,
) -> TupleOfOptionalTensors:
"""Return vector-Jacobian product of a root.

The root is the `solution` of ``optimality_fn(solution, *args) == 0``.

Args:
optimality_fun (callable): An equation function, ``optimality_fn(params, *args)``. The
invariant is ``optimality_fn(solution, *args) == 0`` at ``solution``.
solution (tuple of Tensors): solution / root of `optimality_fun`.
args (Args): tuple containing the arguments with respect to which we wish to
differentiate ``solution`` against.
grad_outputs (tuple of Tensors): The "vector" in the vector-Jacobian product.
Usually gradients w.r.t. each output. None values can be specified for scalar
Tensors or ones that don't require grad. If a None value would be acceptable
for all grad_tensors, then this argument is optional. Default: None.
output_is_tensor (bool): Whether the output of ``optimality_fn`` is a single tensor.
argnums (int or tuple of int): Specifies arguments to compute gradients with respect to. The
``argnums`` can be an integer or a tuple of integers.
solve (callable, optional): A linear solver of the form ``solve(matvec, b)``.
(default: :func:`linear_solve.solve_normal_cg`)

Returns:
tuple of the same length as ``len(args)`` containing the vector-Jacobian products w.r.t.
each argument. Each ``vjps[i]`` has the same pytree structure as
``args[i]``.
"""
if solve is None:
solve = linear_solve.solve_normal_cg()

return _root_vjp(
optimality_fn=optimality_fn,
solution=solution,
args=args,
grad_outputs=grad_outputs,
output_is_tensor=output_is_tensor,
argnums=argnums,
solve=solve,
)