Skip to content

Commit

Permalink
MAINT: move adjoint weighting tests to oputils_test
Browse files Browse the repository at this point in the history
  • Loading branch information
Holger Kohr committed Dec 17, 2017
1 parent 81ea2f6 commit 77e1fba
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 206 deletions.
9 changes: 5 additions & 4 deletions odl/operator/oputils.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,17 +482,18 @@ def adjoint(self):
Parameters
----------
unweighted_adjoint : `Operator`
Unweighted variant of the adjoint. It will be patched with a
new ``_call()`` method.
unweighted_adjoint : function
Method on an `Operator` class that returns the unweighted variant
of the adjoint. It will be patched with a new ``_call()`` method.
The weightings of ``domain`` and ``range`` of the operator
must be `ArrayWeighting` or `ConstWeighting`.
optimize : bool, optional
If ``True``, merge and move around constant weightings for
highest expected efficiency.
**Note:** Merging of a constant weight and an array weight will
result in a copy of the array, doubling the amount of required memory.
result in a copy of the array, doubling the amount of required
memory.
Notes
-----
Expand Down
205 changes: 202 additions & 3 deletions odl/test/operator/oputils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,17 @@
import pytest

import odl
from odl.operator.oputils import matrix_representation, power_method_opnorm
from odl.space.pspace import ProductSpace
from odl.operator.oputils import (
matrix_representation, power_method_opnorm, auto_adjoint_weighting)
from odl.operator.pspace_ops import ProductSpaceOperator
from odl.util.testutils import almost_equal
from odl.space.pspace import ProductSpace
from odl.util.testutils import (
almost_equal, all_equal, simple_fixture, noise_element)


optimize_weighting = simple_fixture('optimize', [True, False])
call_variant = simple_fixture('call_variant', ['oop', 'ip', 'dual'])
weighting = simple_fixture('weighting', [1.0, 2.0, [1.0, 2.0]])


def test_matrix_representation():
Expand Down Expand Up @@ -250,5 +257,197 @@ def test_power_method_opnorm_exceptions():
power_method_opnorm(op, maxiter=1, xstart=op.domain.one())


def test_auto_weighting(call_variant, weighting, optimize_weighting):
"""Test the auto_weighting decorator for different adjoint variants."""

class ScalingOpBase(odl.Operator):

def __init__(self, dom, ran, c):
super(ScalingOpBase, self).__init__(dom, ran, linear=True)
self.c = c

if call_variant == 'oop':

class ScalingOp(ScalingOpBase):

def _call(self, x):
return self.c * x

@property
@auto_adjoint_weighting(optimize=optimize_weighting)
def adjoint(self):
return ScalingOp(self.range, self.domain, self.c)

elif call_variant == 'ip':

class ScalingOp(ScalingOpBase):

def _call(self, x, out):
out[:] = self.c * x
return out

@property
@auto_adjoint_weighting(optimize=optimize_weighting)
def adjoint(self):
return ScalingOp(self.range, self.domain, self.c)

elif call_variant == 'dual':

class ScalingOp(ScalingOpBase):

def _call(self, x, out=None):
if out is None:
out = self.c * x
else:
out[:] = self.c * x
return out

@property
@auto_adjoint_weighting(optimize=optimize_weighting)
def adjoint(self):
return ScalingOp(self.range, self.domain, self.c)

else:
assert False

# Test Rn space
rn = odl.rn(2)
rn_w = odl.rn(2, weighting=weighting)
op1 = ScalingOp(rn, rn_w, np.random.uniform(-2, 2))
op2 = ScalingOp(rn_w, rn, np.random.uniform(-2, 2))

for op in [op1, op2]:
dom_el = noise_element(op.domain)
ran_el = noise_element(op.range)
assert pytest.approx(op(dom_el).inner(ran_el),
dom_el.inner(op.adjoint(ran_el)))

# Test product space
pspace = odl.ProductSpace(odl.rn(3), 2)
pspace_w = odl.ProductSpace(odl.rn(3), 2, weighting=weighting)
op1 = ScalingOp(pspace, pspace_w, np.random.uniform(-2, 2))
op2 = ScalingOp(pspace_w, pspace, np.random.uniform(-2, 2))

for op in [op1, op2]:
dom_el = noise_element(op.domain)
ran_el = noise_element(op.range)
assert pytest.approx(op(dom_el).inner(ran_el),
dom_el.inner(op.adjoint(ran_el)))

# Test product space of product space
ppspace = odl.ProductSpace(odl.ProductSpace(odl.rn(3), 2), 2)
ppspace_w = odl.ProductSpace(
odl.ProductSpace(odl.rn(3), 2, weighting=weighting),
2, weighting=weighting)
op1 = ScalingOp(ppspace, ppspace_w, np.random.uniform(-2, 2))
op2 = ScalingOp(ppspace_w, ppspace, np.random.uniform(-2, 2))

for op in [op1, op2]:
dom_el = noise_element(op.domain)
ran_el = noise_element(op.range)
assert pytest.approx(op(dom_el).inner(ran_el),
dom_el.inner(op.adjoint(ran_el)))


def test_auto_weighting_noarg():
"""Test the auto_weighting decorator without the optimize argument."""
rn = odl.rn(2)
weighting_const = np.random.uniform(0.5, 2)
rn_w = odl.rn(2, weighting=weighting_const)

class ScalingOp(odl.Operator):

def __init__(self, dom, ran, c):
super(ScalingOp, self).__init__(dom, ran, linear=True)
self.c = c

def _call(self, x):
return self.c * x

@property
@auto_adjoint_weighting
def adjoint(self):
return ScalingOp(self.range, self.domain, self.c)

op1 = ScalingOp(rn, rn, np.random.uniform(-2, 2))
op2 = ScalingOp(rn_w, rn_w, np.random.uniform(-2, 2))
op3 = ScalingOp(rn, rn_w, np.random.uniform(-2, 2))
op4 = ScalingOp(rn_w, rn, np.random.uniform(-2, 2))

for op in [op1, op2, op3, op4]:
dom_el = noise_element(op.domain)
ran_el = noise_element(op.range)
assert pytest.approx(op(dom_el).inner(ran_el),
dom_el.inner(op.adjoint(ran_el)))


def test_auto_weighting_cached_adjoint():
"""Check if auto_weighting plays well with adjoint caching."""
rn = odl.rn(2)
weighting_const = np.random.uniform(0.5, 2)
rn_w = odl.rn(2, weighting=weighting_const)

class ScalingOp(odl.Operator):

def __init__(self, dom, ran, c):
super(ScalingOp, self).__init__(dom, ran, linear=True)
self.c = c
self._adjoint = None

def _call(self, x):
return self.c * x

@property
@auto_adjoint_weighting
def adjoint(self):
if self._adjoint is None:
self._adjoint = ScalingOp(self.range, self.domain, self.c)
return self._adjoint

op = ScalingOp(rn, rn_w, np.random.uniform(-2, 2))
dom_el = noise_element(op.domain)
op_eval_before = op(dom_el)

adj = op.adjoint
adj_again = op.adjoint
assert adj_again is adj

# Check that original op is intact
assert not hasattr(op, '_call_unweighted') # op shouldn't be mutated
op_eval_after = op(dom_el)
assert all_equal(op_eval_before, op_eval_after)

dom_el = noise_element(op.domain)
ran_el = noise_element(op.range)
op(dom_el)
op.adjoint(ran_el)
assert pytest.approx(op(dom_el).inner(ran_el),
dom_el.inner(op.adjoint(ran_el)))


def test_auto_weighting_raise_on_return_self():
"""Check that auto_weighting raises when adjoint returns self."""
rn = odl.rn(2)

class InvalidScalingOp(odl.Operator):

def __init__(self, dom, ran, c):
super(InvalidScalingOp, self).__init__(dom, ran, linear=True)
self.c = c

def _call(self, x):
return self.c * x

@property
@auto_adjoint_weighting
def adjoint(self):
return self

# This would be a vaild situation for adjoint just returning self
op = InvalidScalingOp(rn, rn, 1.5)
with pytest.raises(TypeError):
op.adjoint


if __name__ == '__main__':
odl.util.test_file(__file__)
Loading

0 comments on commit 77e1fba

Please sign in to comment.