diff --git a/odl/test/space/space_utils_test.py b/odl/test/space/space_utils_test.py index efb9f7cc360..f69dced7e4f 100644 --- a/odl/test/space/space_utils_test.py +++ b/odl/test/space/space_utils_test.py @@ -13,7 +13,12 @@ import odl from odl import vector -from odl.util.testutils import all_equal +from odl.space.space_utils import auto_weighting +from odl.util.testutils import all_equal, simple_fixture, noise_element + + +auto_weighting_optimize = simple_fixture('optimize', [True, False]) +call_variant = simple_fixture('call_variant', ['oop', 'ip', 'dual']) def test_vector_numpy(): @@ -77,5 +82,171 @@ def test_vector_numpy(): vector([[1, 0], [0, 1]]) +def test_auto_weighting(call_variant, auto_weighting_optimize): + """Test the auto_weighting decorator for different adjoint variants.""" + rn = odl.rn(2) + rn_w = odl.rn(2, weighting=2) + + class ScalingOpBase(odl.Operator): + + def __init__(self, dom, ran, c): + super(ScalingOpBase, self).__init__(dom, ran, linear=True) + self.c = c + + if call_variant == 'oop': + + class ScalingOp(ScalingOpBase): + + def _call(self, x): + return self.c * x + + @property + @auto_weighting(optimize=auto_weighting_optimize) + def adjoint(self): + return ScalingOp(self.range, self.domain, self.c) + + elif call_variant == 'ip': + + class ScalingOp(ScalingOpBase): + + def _call(self, x, out): + out[:] = self.c * x + return out + + @property + @auto_weighting(optimize=auto_weighting_optimize) + def adjoint(self): + return ScalingOp(self.range, self.domain, self.c) + + elif call_variant == 'dual': + + class ScalingOp(ScalingOpBase): + + def _call(self, x, out=None): + if out is None: + out = self.c * x + else: + out[:] = self.c * x + return out + + @property + @auto_weighting(optimize=auto_weighting_optimize) + def adjoint(self): + return ScalingOp(self.range, self.domain, self.c) + + else: + assert False + + op1 = ScalingOp(rn, rn, 1.5) + op2 = ScalingOp(rn_w, rn_w, 1.5) + op3 = ScalingOp(rn, rn_w, 1.5) + op4 = ScalingOp(rn_w, rn, 1.5) + + for op in [op1, op2, op3, op4]: + dom_el = noise_element(op.domain) + ran_el = noise_element(op.range) + assert pytest.approx(op(dom_el).inner(ran_el), + dom_el.inner(op.adjoint(ran_el))) + + +def test_auto_weighting_noarg(): + """Test the auto_weighting decorator without the optimize argument.""" + rn = odl.rn(2) + rn_w = odl.rn(2, weighting=2) + + class ScalingOp(odl.Operator): + + def __init__(self, dom, ran, c): + super(ScalingOp, self).__init__(dom, ran, linear=True) + self.c = c + + def _call(self, x): + return self.c * x + + @property + @auto_weighting + def adjoint(self): + return ScalingOp(self.range, self.domain, self.c) + + op1 = ScalingOp(rn, rn, 1.5) + op2 = ScalingOp(rn_w, rn_w, 1.5) + op3 = ScalingOp(rn, rn_w, 1.5) + op4 = ScalingOp(rn_w, rn, 1.5) + + for op in [op1, op2, op3, op4]: + dom_el = noise_element(op.domain) + ran_el = noise_element(op.range) + assert pytest.approx(op(dom_el).inner(ran_el), + dom_el.inner(op.adjoint(ran_el))) + + +def test_auto_weighting_cached_adjoint(): + """Check if auto_weighting plays well with adjoint caching.""" + rn = odl.rn(2) + rn_w = odl.rn(2, weighting=2) + + class ScalingOp(odl.Operator): + + def __init__(self, dom, ran, c): + super(ScalingOp, self).__init__(dom, ran, linear=True) + self.c = c + self._adjoint = None + + def _call(self, x): + return self.c * x + + @property + @auto_weighting + def adjoint(self): + if self._adjoint is None: + self._adjoint = ScalingOp(self.range, self.domain, self.c) + return self._adjoint + + op = ScalingOp(rn, rn_w, 1.5) + dom_el = noise_element(op.domain) + op_eval_before = op(dom_el) + + adj = op.adjoint + adj_again = op.adjoint + assert adj_again is adj + + # Check that original op is intact + assert not hasattr(op, '_call_unweighted') # op shouldn't be mutated + op_eval_after = op(dom_el) + assert all_equal(op_eval_before, op_eval_after) + + dom_el = noise_element(op.domain) + ran_el = noise_element(op.range) + op(dom_el) + op.adjoint(ran_el) + assert pytest.approx(op(dom_el).inner(ran_el), + dom_el.inner(op.adjoint(ran_el))) + + +def test_auto_weighting_raise_on_return_self(): + """Check that auto_weighting raises when adjoint returns self.""" + rn = odl.rn(2) + + class InvalidScalingOp(odl.Operator): + + def __init__(self, dom, ran, c): + super(InvalidScalingOp, self).__init__(dom, ran, linear=True) + self.c = c + self._adjoint = None + + def _call(self, x): + return self.c * x + + @property + @auto_weighting + def adjoint(self): + return self + + # This would be a vaild situation for adjont just returning self + op = InvalidScalingOp(rn, rn, 1.5) + with pytest.raises(TypeError): + op.adjoint + + if __name__ == '__main__': odl.util.test_file(__file__)