Skip to content

Commit

Permalink
TST: add tests for auto_weighting decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
Holger Kohr committed Oct 5, 2017
1 parent d1e4462 commit c211b66
Showing 1 changed file with 172 additions and 1 deletion.
173 changes: 172 additions & 1 deletion odl/test/space/space_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@

import odl
from odl import vector
from odl.util.testutils import all_equal
from odl.space.space_utils import auto_weighting
from odl.util.testutils import all_equal, simple_fixture, noise_element


auto_weighting_optimize = simple_fixture('optimize', [True, False])
call_variant = simple_fixture('call_variant', ['oop', 'ip', 'dual'])


def test_vector_numpy():
Expand Down Expand Up @@ -77,5 +82,171 @@ def test_vector_numpy():
vector([[1, 0], [0, 1]])


def test_auto_weighting(call_variant, auto_weighting_optimize):
"""Test the auto_weighting decorator for different adjoint variants."""
rn = odl.rn(2)
rn_w = odl.rn(2, weighting=2)

class ScalingOpBase(odl.Operator):

def __init__(self, dom, ran, c):
super(ScalingOpBase, self).__init__(dom, ran, linear=True)
self.c = c

if call_variant == 'oop':

class ScalingOp(ScalingOpBase):

def _call(self, x):
return self.c * x

@property
@auto_weighting(optimize=auto_weighting_optimize)
def adjoint(self):
return ScalingOp(self.range, self.domain, self.c)

elif call_variant == 'ip':

class ScalingOp(ScalingOpBase):

def _call(self, x, out):
out[:] = self.c * x
return out

@property
@auto_weighting(optimize=auto_weighting_optimize)
def adjoint(self):
return ScalingOp(self.range, self.domain, self.c)

elif call_variant == 'dual':

class ScalingOp(ScalingOpBase):

def _call(self, x, out=None):
if out is None:
out = self.c * x
else:
out[:] = self.c * x
return out

@property
@auto_weighting(optimize=auto_weighting_optimize)
def adjoint(self):
return ScalingOp(self.range, self.domain, self.c)

else:
assert False

op1 = ScalingOp(rn, rn, 1.5)
op2 = ScalingOp(rn_w, rn_w, 1.5)
op3 = ScalingOp(rn, rn_w, 1.5)
op4 = ScalingOp(rn_w, rn, 1.5)

for op in [op1, op2, op3, op4]:
dom_el = noise_element(op.domain)
ran_el = noise_element(op.range)
assert pytest.approx(op(dom_el).inner(ran_el),
dom_el.inner(op.adjoint(ran_el)))


def test_auto_weighting_noarg():
"""Test the auto_weighting decorator without the optimize argument."""
rn = odl.rn(2)
rn_w = odl.rn(2, weighting=2)

class ScalingOp(odl.Operator):

def __init__(self, dom, ran, c):
super(ScalingOp, self).__init__(dom, ran, linear=True)
self.c = c

def _call(self, x):
return self.c * x

@property
@auto_weighting
def adjoint(self):
return ScalingOp(self.range, self.domain, self.c)

op1 = ScalingOp(rn, rn, 1.5)
op2 = ScalingOp(rn_w, rn_w, 1.5)
op3 = ScalingOp(rn, rn_w, 1.5)
op4 = ScalingOp(rn_w, rn, 1.5)

for op in [op1, op2, op3, op4]:
dom_el = noise_element(op.domain)
ran_el = noise_element(op.range)
assert pytest.approx(op(dom_el).inner(ran_el),
dom_el.inner(op.adjoint(ran_el)))


def test_auto_weighting_cached_adjoint():
"""Check if auto_weighting plays well with adjoint caching."""
rn = odl.rn(2)
rn_w = odl.rn(2, weighting=2)

class ScalingOp(odl.Operator):

def __init__(self, dom, ran, c):
super(ScalingOp, self).__init__(dom, ran, linear=True)
self.c = c
self._adjoint = None

def _call(self, x):
return self.c * x

@property
@auto_weighting
def adjoint(self):
if self._adjoint is None:
self._adjoint = ScalingOp(self.range, self.domain, self.c)
return self._adjoint

op = ScalingOp(rn, rn_w, 1.5)
dom_el = noise_element(op.domain)
op_eval_before = op(dom_el)

adj = op.adjoint
adj_again = op.adjoint
assert adj_again is adj

# Check that original op is intact
assert not hasattr(op, '_call_unweighted') # op shouldn't be mutated
op_eval_after = op(dom_el)
assert all_equal(op_eval_before, op_eval_after)

dom_el = noise_element(op.domain)
ran_el = noise_element(op.range)
op(dom_el)
op.adjoint(ran_el)
assert pytest.approx(op(dom_el).inner(ran_el),
dom_el.inner(op.adjoint(ran_el)))


def test_auto_weighting_raise_on_return_self():
"""Check that auto_weighting raises when adjoint returns self."""
rn = odl.rn(2)

class InvalidScalingOp(odl.Operator):

def __init__(self, dom, ran, c):
super(InvalidScalingOp, self).__init__(dom, ran, linear=True)
self.c = c
self._adjoint = None

def _call(self, x):
return self.c * x

@property
@auto_weighting
def adjoint(self):
return self

# This would be a vaild situation for adjont just returning self
op = InvalidScalingOp(rn, rn, 1.5)
with pytest.raises(TypeError):
op.adjoint


if __name__ == '__main__':
odl.util.test_file(__file__)

0 comments on commit c211b66

Please sign in to comment.