Skip to content

Commit

Permalink
ENH: add auto_weighting decorator for adjoints
Browse files Browse the repository at this point in the history
  • Loading branch information
Holger Kohr committed Oct 5, 2017
1 parent 7d1cab6 commit 5da4039
Show file tree
Hide file tree
Showing 2 changed files with 215 additions and 3 deletions.
214 changes: 212 additions & 2 deletions odl/space/space_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

from odl.set import RealNumbers, ComplexNumbers
from odl.space.entry_points import ntuples_impl, fn_impl
from odl.util import (
is_real_floating_dtype, is_complex_floating_dtype, is_scalar_dtype)
from odl.space.weighting import ArrayWeighting, ConstWeighting
from odl.util import is_scalar_dtype, OptionalArgDecorator


def vector(array, dtype=None, impl='numpy'):
Expand Down Expand Up @@ -244,6 +244,216 @@ def rn(size, dtype=None, impl='numpy', **kwargs):
return rn


class auto_weighting(OptionalArgDecorator):

"""Make an unweighted adjoint automatically account for weightings.
Depending on the weightings, the correction is achieved by composing
the unweighted operator with either `ScalingOperator` or
`ConstantOperator`. The following rules are applied for the domain
weighting ``w``, the range weighting ``v`` and the provided unweighted
adjoint ``B^*``:
- If both ``w`` and ``v`` are arrays, return ::
(1 / w) * (B^*) * v
- If ``w`` is an array and ``v`` a constant, return ::
(v / w) * (B^*)
- If ``w`` is a constant and ``v`` an array, return ::
(B^*) * (w / v)
- If both ``w`` and ``v`` are constants, return ::
(B^*) * (v / w)
if ``B.range.size < B.domain.size``, otherwise ::
(v / w) * (B^*)
- Ignore constants 1.0.
To avoid the inconvenience of dealing with `OperatorComp` objects,
the given operator is monkey-patched instead of composed.
Parameters
----------
unweighted_adjoint : `Operator`
Unweighted variant of the adjoint. It will be patched with a
new ``_call()`` method.
The weightings of ``domain`` and ``range`` of the operator
must be `ArrayWeighting` or `ConstWeighting`.
optimize : bool, optional
If ``True``, merge and move around constant weightings for
highest expected efficiency.
Notes
-----
Consider a linear operator :math:`A: X_w \\to Y_v` between spaces with
weights :math:`w` and :math:`v`, respectively, along with the same
operator :math:`B: X \\to Y` defined between the unweighted variants of
the spaces. (This means that :math:`B f = A f` for all
:math:`f \\in X \cong X_w`).
Then, the adjoint of :math:`A` is related to the adjoint of :math:`B`
as follows:
.. math::
\\langle Af, g \\rangle_{Y_v} =
\\langle Bf, v \cdot g \\rangle_Y =
\\langle f, B^*(v \cdot g) \\rangle_X =
\\langle f, w^{-1}\, B^*(v \cdot g) \\rangle_{X_w}.
Thus, from the existing unweighted adjoint :math:`B^*` one can compute
the weighted one as :math:`A^* = w^{-1}\, B^*(v\, \cdot)`.
Depending on the types of weighting, this expression can be simplified
further, e.g., a constant weight can be absorbed into the other weight.
"""

@staticmethod
def _wrapper(unweighted_adjoint, optimize=True):
"""Return the weighted variant of the unweighted adjoint."""
# Support decorating the `adjoint` property directly
import inspect
from functools import wraps
from odl.operator.operator import Operator

if (inspect.isfunction(unweighted_adjoint) and
unweighted_adjoint.__name__ == 'adjoint'):
# We need this level of indirection since `self` needs to
# be filled in with the instance, but we decorate at class
# level.
@wraps(unweighted_adjoint)
def weighted_adjoint(self):
adj = unweighted_adjoint(self)
if not isinstance(adj, Operator):
raise TypeError('`adjoint` did not return an `Operator`')
if adj is self:
raise TypeError(
'returning `self` in an `adjoint` property using '
'`auto_weighting` is not allowed')

# This is for cached adjoints: don't double-wrap
if hasattr(adj, '_call_unweighted'):
return adj
else:
return auto_weighting._instance_wrapper(adj, optimize)

return weighted_adjoint

else:
raise TypeError(
"`auto_weighting` can only be applied to 'adjoint' methods "
'(@auto_weighting decorator)')

@staticmethod
def _instance_wrapper(unweighted_adjoint, optimize=True):
"""Wrapper for `Operator` instances."""
# Use notions of the original operator, not the adjoint
dom_weighting = unweighted_adjoint.range.weighting
ran_weighting = unweighted_adjoint.domain.weighting

if isinstance(dom_weighting, ArrayWeighting):
dom_w_type = 'array'
dom_w = dom_weighting.array
elif isinstance(dom_weighting, ConstWeighting):
dom_w_type = 'const'
dom_w = dom_weighting.const
else:
raise TypeError(
'weighting of `unweighted_adjoint.range` must be of '
'type `ArrayWeighting` or `ConstWeighting`, got {}'
''.format(type(dom_weighting)))

if isinstance(ran_weighting, ArrayWeighting):
ran_w_type = 'array'
ran_w = ran_weighting.array
elif isinstance(ran_weighting, ConstWeighting):
ran_w_type = 'const'
ran_w = ran_weighting.const
else:
raise TypeError(
'weighting of `unweighted_adjoint.domain` must be of '
'type `ArrayWeighting` or `ConstWeighting`, got {}'
''.format(type(ran_weighting)))

# Compute the effective weights and mark constants 1.0 as to be
# skipped
if not optimize:
new_dom_w, new_ran_w = dom_w, ran_w
skip_dom = dom_w_type == 'const' and dom_w == 1.0
skip_ran = ran_w_type == 'const' and ran_w == 1.0
elif dom_w_type == 'array' and ran_w_type == 'array':
new_dom_w, new_ran_w = dom_w, ran_w
skip_dom = skip_ran = False
elif dom_w_type == 'array' and ran_w_type == 'const':
new_dom_w = dom_w / ran_w
new_ran_w = 1.0
skip_dom = False
skip_ran = True
elif dom_w_type == 'const' and ran_w_type == 'array':
new_dom_w = 1.0
new_ran_w = ran_w / dom_w
skip_dom = True
skip_ran = False
elif dom_w_type == 'const' and ran_w_type == 'const':
if unweighted_adjoint.domain.size < unweighted_adjoint.range.size:
new_dom_w = 1.0
new_ran_w = ran_w / dom_w
skip_dom = True
skip_ran = False
else:
new_dom_w = dom_w / ran_w
new_ran_w = 1.0
skip_dom = False
skip_ran = True

# Define the new `_call` depending on original signature
self = unweighted_adjoint

# Monkey-patching starts here
if self._call_has_out and self._call_out_optional:
def _call(x, out=None):
if not skip_ran:
x = new_ran_w * x
out = self._call_unweighted(x, out=out)
if not skip_dom:
out /= new_dom_w
return out

self._call_unweighted = self._call_in_place
self._call_in_place = self._call_out_of_place = _call

elif self._call_has_out and not self._call_out_optional:
def _call(x, out):
if not skip_ran:
x = new_ran_w * x
self._call_unweighted(x, out=out)
if not skip_dom:
out /= new_dom_w
return out

self._call_unweighted = self._call_in_place
self._call_in_place = _call

else:
def _call(x):
if not skip_ran:
x = new_ran_w * x
out = self._call_unweighted(x)
if not skip_dom:
out /= new_dom_w
return out

self._call_unweighted = self._call_out_of_place
self._call_out_of_place = _call

return self


if __name__ == '__main__':
# pylint: disable=wrong-import-position
from odl.util.testutils import run_doctests
Expand Down
4 changes: 3 additions & 1 deletion odl/space/weighting.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
import numpy as np

from odl.space.base_ntuples import FnBaseVector
from odl.util import array1d_repr, arraynd_repr, signature_string, indent_rows
from odl.util import (
array1d_repr, arraynd_repr, signature_string, indent_rows,
OptionalArgDecorator)


__all__ = ('MatrixWeighting', 'ArrayWeighting', 'ConstWeighting',
Expand Down

0 comments on commit 5da4039

Please sign in to comment.