diff --git a/odl/diagnostics/operator.py b/odl/diagnostics/operator.py index b92b07767bc..3c8095a4f6d 100644 --- a/odl/diagnostics/operator.py +++ b/odl/diagnostics/operator.py @@ -1,4 +1,6 @@ -# Copyright 2014-2017 The ODL contributors +# coding: utf-8 + +# Copyright 2014-2017 The ODL contributors # # This file is part of ODL. # @@ -9,380 +11,1012 @@ """Standardized tests for `Operator`'s.""" from __future__ import print_function, division, absolute_import -from builtins import object +from functools import partial import numpy as np +import sys from odl.diagnostics.examples import samples -from odl.operator import power_method_opnorm -from odl.util.testutils import FailCounter - - -__all__ = ('OperatorTest',) - - -class OperatorTest(object): - - """Automated tests for `Operator` implementations. - - This class allows users to automatically test various - features of an Operator such as linearity, the adjoint definition and - definition of the derivative. - """ - - def __init__(self, operator, operator_norm=None, verbose=True, tol=1e-5): - """Initialize a new instance. - - Parameters - ---------- - operator : `Operator` - The operator to run tests on - operator_norm : float, optional - The norm of the operator, used for error estimates. If - ``None`` is given, the norm is estimated during - initialization. - verbose : bool, optional - If ``True``, print additional info text. - tol : float, optional - Tolerance parameter used as a base for the actual tolerance - in the tests. Depending on the expected accuracy, the actual - tolerance used in a test can be a factor times this number. - """ - self.operator = operator - self.verbose = False - if operator_norm is None: - self.operator_norm = self.norm() - else: - self.operator_norm = float(operator_norm) +from odl.operator.operator import Operator - self.verbose = bool(verbose) - self.tol = float(tol) +__all__ = ('check_operator', 'check_operator_properties', + 'check_operator_norm', 'check_operator_linearity', + 'check_operator_adjoint', 'check_operator_derivative') - def log(self, message): - """Print message if ``self.verbose == True``.""" - if self.verbose: - print(message) +VERBOSITY_LEVELS = {'DEBUG': 0, 'INFO': 1, 'CHECK': 2, 'WARNING': 3, + 'ERROR': 4, 'SECTION': 99, 'QUIET': 100} - def norm(self): - """Estimate the operator norm of the operator. - The norm is estimated by calculating +#TODO: add `logfile` parameter - ``A(x).norm() / x.norm()`` - for some nonzero ``x`` +def _log(message, verbosity, level='DEBUG', file=sys.stderr): + """Log a message depending on verbosity.""" + if not message: + # Make sure an empty string is printed as well + message = '\n' - Returns - ------- - norm : float - Estimate of operator norm + if level == verbosity == 'QUIET': + # Special case, don't print 'QUIET :' + print(message, file=file) - References - ---------- - Wikipedia article on `Operator norm - `_. - """ - self.log('\n== Calculating operator norm ==\n') + elif VERBOSITY_LEVELS[level] >= VERBOSITY_LEVELS[verbosity]: + message_lines = str(message).splitlines() + message_lines = ['{:<7}: '.format(level) + line + for line in message_lines] + print('\n'.join(message_lines), file=file) - operator_norm = max(power_method_opnorm(self.operator, maxiter=2, - xstart=x) - for name, x in samples(self.operator.domain) - if name != 'Zero') - self.log('Norm is at least: {}'.format(operator_norm)) - self.operator_norm = operator_norm - return operator_norm +def _log_linewidth(): + """Return the available log linewidth based on ``np.get_printoptions``.""" + start_len = max(len(k) for k in VERBOSITY_LEVELS) + len(': ') + end_len = len(' NOT OK') + return np.get_printoptions()['linewidth'] - start_len - end_len - def self_adjoint(self): - """Verify `` == ``.""" - left_inner_vals = [] - right_inner_vals = [] - with FailCounter( - test_name='Verifying the identity = ', - err_msg='error = | - | / ||A|| ||x|| ||y||', - logger=self.log) as counter: +def _check_cond(cond, cond_str, logger, level_true, level_false): + """Check a condition and log, returning 1 on failure for ERROR.""" + log_linewidth = _log_linewidth() + parts = [cond_str[i * log_linewidth: (i + 1) * log_linewidth] + for i in range(int(np.ceil(len(cond_str) / log_linewidth)))] + if cond: + log_fmt = '{{:<{}}} OK'.format(log_linewidth) + parts[-1] = log_fmt.format(parts[-1]) + level = level_true + failed = 0 + else: + log_fmt = '{{:<{}}} NOT OK'.format(log_linewidth) + parts[-1] = log_fmt.format(parts[-1]) + level = level_false + failed = 1 if level_false == 'ERROR' else 0 - for [name_x, x], [name_y, y] in samples(self.operator.domain, - self.operator.range): - x_norm = x.norm() - y_norm = y.norm() + for part in parts: + logger(part, level=level) - l_inner = self.operator(x).inner(y) - r_inner = x.inner(self.operator(y)) + return failed - denom = self.operator_norm * x_norm * y_norm - error = 0 if denom == 0 else abs(l_inner - r_inner) / denom - if error > self.tol: - counter.fail('x={:25s} y={:25s} : error={:6.5f}' - ''.format(name_x, name_y, error)) +def _get_derivative(op, arg=None): + """Return a tuple ``(has_deriv, deriv)`` from ``op``. - left_inner_vals.append(l_inner) - right_inner_vals.append(r_inner) + Calls ``op.derivative(arg)`` if ``arg`` is given, otherwise uses + ``op.domain.one()`` if possible, or ``op.domain.element()``. + """ + if arg is None: + try: + arg = op.domain.one() + except (AttributeError, NotImplementedError): + arg = op.domain.element() + try: + deriv = op.derivative(arg) + has_deriv = True + except NotImplementedError: + has_deriv = False + deriv = None + return has_deriv, deriv + + +def _get_inverse(op): + """Return a tuple ``(has_inverse, inverse)`` from ``op``.""" + try: + inverse = op.inverse + has_inverse = True + except NotImplementedError: + has_inverse = False + inverse = None + return has_inverse, inverse + + +def _get_adjoint(op): + """Return a tuple ``(has_adjoint, adjoint)`` from ``op``.""" + try: + adjoint = op.adjoint + has_adjoint = True + except NotImplementedError: + has_adjoint = False + adjoint = None + return has_adjoint, adjoint + + +def _get_opnorm(op, maxiter=2): + """Return a tuple ``(norm_exact, norm_est)`` from ``op``. + + If ``op.norm(estimate=False)`` is not implemented, ``norm_exact`` is + ``None``, likewise for ``norm_est`` and ``op.norm(estimate=True)``. + The ``maxiter`` parameter is used for the norm estimate iteration. + """ + try: + norm_exact = op.norm(estimate=False) + except NotImplementedError: + norm_exact = None + try: + norm_est = op.norm(estimate=True, maxiter=maxiter) + except (TypeError, ValueError, NotImplementedError): + norm_est = None + + return norm_exact, norm_est + + +def print_inputs(args, kwargs, verbosity): + """Print all function inputs for a certain verbosity level.""" + log = partial(_log, verbosity=verbosity) + log_linewidth = _log_linewidth() + + log('', level='DEBUG') + log('Inputs', level='DEBUG') + log('-' * log_linewidth, level='DEBUG') + for arg in args: + log(repr(arg), level='DEBUG') + for key, val in kwargs.items(): + log('{} = {!r}'.format(key, val), level='DEBUG') + + +def check_operator_properties(operator, verbosity='INFO', deriv_arg=None): + """Check and return basic operator properties. + + This function checks whether ``derivative``, ``inverse`` and ``adjoint`` + are implemented. + + Parameters + ---------- + operator : `Operator` + The operator on which to run the check. + verbosity : str, optional + Level of output verbosity. Possible values and corresponding print + outputs are: + + - ``'DEBUG'``: Everything + - ``'INFO'``: Informational context, warnings and errors + - ``'WARNING'``: Warnings and errors + - ``'ERROR'``: Errors + - ``'QUIET'``: Only a summary at the end + + deriv_arg : ``operator.domain`` element-like, optional + Argument to ``operator.derivative``. For the default ``None``, + ``operator.domain.one()`` is used if possible, else an uninitialized + ``operator.domain.element()``. + + Returns + ------- + result : dict + Dictionary with the following keys: + + - ``num_failed(int)``: Number of failed checks. + - ``deriv(Operator or None)``: The derivative at ``deriv_arg`` if + implemented, else ``None``. + - ``inverse(Operator or None)``: The inverse if implemented, + else ``None``. + - ``adjoint(Operator or None)``: The adjoint if implemented, + else ``None``. + """ + assert isinstance(operator, Operator), 'bad type {}'.format(type(operator)) + op = operator + verbosity, verb_in = str(verbosity).upper(), verbosity + assert verbosity in VERBOSITY_LEVELS, 'bad verbosity {!r}'.format(verb_in) + + num_failed = 0 + + log = partial(_log, verbosity=verbosity) + log_linewidth = _log_linewidth() + + log('', level='SECTION') + log('Basic properties', level='SECTION') + log('=' * log_linewidth, level='SECTION') + + print_inputs( + args=[op], + kwargs={'verbosity': verbosity, 'deriv_arg': deriv_arg}, + verbosity=verbosity) + + log('## Getting operator properties...', level='DEBUG') + has_deriv, deriv = _get_derivative(op, arg=deriv_arg) + has_inverse, inverse = _get_inverse(op) + has_adjoint, adjoint = _get_adjoint(op) + log('## Done.', level='DEBUG') + + # --- Default properties --- # + + log('op.domain = {!r}'.format(op.domain), level='INFO') + log('op.range = {!r}'.format(op.range), level='INFO') + log('op.is_linear is {}'.format(op.is_linear), level='INFO') + log('operator.is_functional is {}'.format(op.is_functional), + level='INFO') + + # --- Derivative --- # + + log('', level='SECTION') + log('Derivative', level='SECTION') + log('-' * log_linewidth, level='SECTION') + + if has_deriv: + log('op.derivative implemented', level='INFO') + log('op.derivative(x) = {!r}'.format(deriv), level='INFO') + log('[x = {!r}]'.format(deriv_arg), level='DEBUG') + else: + log('op.derivative NOT implemented', level='INFO') + + # --- Inverse --- # + + log('', level='SECTION') + log('Inverse', level='SECTION') + log('-' * log_linewidth, level='SECTION') + + if has_inverse: + log('op.inverse implemented', level='INFO') + log('inverse = {!r}'.format(inverse), level='INFO') + else: + log('op.inverse NOT implemented', level='INFO') + + if has_inverse: + num_failed += _check_cond( + inverse.domain == op.range, 'op.inverse.domain == op.range', + log, level_true='CHECK', level_false='ERROR') + num_failed += _check_cond( + inverse.range == op.domain, 'op.inverse.range == op.domain', + log, level_true='CHECK', level_false='ERROR') + + # --- Adjoint --- # + + log('', level='SECTION') + log('Adjoint', level='SECTION') + log('-' * log_linewidth, level='SECTION') + + if has_adjoint: + log('op.adjoint implemented', level='INFO') + log('adjoint = {!r}'.format(adjoint), level='INFO') + else: + log('op.adjoint NOT implemented', level='INFO') + + # --- Summary --- # + + if verbosity == 'QUIET': + log('properties: {} failed'.format(num_failed), level='QUIET') + else: + failed_level = 'INFO' if num_failed == 0 else 'ERROR' + log('', level=failed_level) + log('## Number of failed checks: {}'.format(num_failed), + level=failed_level) + + return dict(num_failed=num_failed, + deriv=deriv, inverse=inverse, adjoint=adjoint) + + +def check_operator_norm(operator, verbosity='INFO', tol=1e-5, + norm_kwargs=None): + """Check and return the operator norm. + + This function checks whether ``norm()`` is available, with both + ``estimate=True`` and ``estimate=False``. If both are available, it is + verified that the estimate is less than or equal to the exact norm. + + Parameters + ---------- + operator : `Operator` + The operator on which to run the check. + verbosity : str, optional + Level of output verbosity. Possible values and corresponding print + outputs are: + + - ``'DEBUG'``: Everything + - ``'INFO'``: Informational context, warnings and errors + - ``'WARNING'``: Warnings and errors + - ``'ERROR'``: Errors + - ``'QUIET'``: Only a summary at the end + + tol : float, optional + Relative tolerance for the norm comparison. + norm_kwargs : dict, optional + Keyword arguments to be used as follows:: + + operator.norm(estimate=False, **norm_kwargs) + + The default ``None`` is equivalent to ``{'maxiter': 10}``. + + Returns + ------- + result : dict + Dictionary with the following keys: + + - ``num_failed(int)``: Number of failed checks. + - ``opnorm_exact(float or None)``: The exact operator norm if + ``operator.norm(estimate=False)`` is implemented, + otherwise ``None``. + - ``opnorm_est(float or None)``: The estimate for the operator norm + using a power iteration if applicable, otherwise ``None``. + """ + assert isinstance(operator, Operator), 'bad type {}'.format(type(operator)) + op = operator + verbosity, verb_in = str(verbosity).upper(), verbosity + assert verbosity in VERBOSITY_LEVELS, 'bad verbosity {!r}'.format(verb_in) + tol = float(tol) + if norm_kwargs is None: + norm_kwargs = {'maxiter': 10} + + num_failed = 0 + + log = partial(_log, verbosity=verbosity) + log_linewidth = _log_linewidth() + + log('', level='SECTION') + log('Operator norm', level='SECTION') + log('=' * log_linewidth, level='SECTION') + + print_inputs( + args=[op], + kwargs={'verbosity': verbosity, + 'tol': tol, + 'norm_kwargs': norm_kwargs}, + verbosity=verbosity) + + # --- Exact norm --- # + + try: + norm_exact = op.norm(estimate=False) + has_exact_norm = True + except NotImplementedError: + norm_exact = None + has_exact_norm = False + + if has_exact_norm: + log('Exact norm `op.norm(estimate=False)` implemented.', + level='INFO') + log('Exact norm: {}'.format(norm_exact), level='INFO') + else: + log('Exact norm `op.norm(estimate=False)` NOT implemented.', + level='INFO') + + # --- Norm estimate --- # + + has_adjoint, _ = _get_adjoint(op) + if has_adjoint: + log('## Computing operator norm estimate...', level='DEBUG') + norm_est = op.norm(estimate=True, **norm_kwargs) + log('## Done.', level='DEBUG') + log('Estimated norm: {}'.format(norm_est), level='INFO') + if has_exact_norm: + num_failed += _check_cond( + norm_est <= norm_exact * (1 + tol), + 'estimated norm <= exact norm', + log, level_true='CHECK', level_false='ERROR') + else: + log('Operator has no adjoint, skipping norm estimate.', level='INFO') + norm_est = None + + # --- Summary --- # + + if verbosity == 'QUIET': + log('norm: {} failed'.format(num_failed), level='QUIET') + else: + failed_level = 'INFO' if num_failed == 0 else 'ERROR' + log('', level=failed_level) + log('## Number of failed checks: {}'.format(num_failed), + level=failed_level) + + return dict(num_failed=num_failed, + norm_exact=norm_exact, norm_est=norm_est) + + +def check_operator_linearity(operator, verbosity='INFO', opnorm=None, + tol=1e-5): + """Check whether the operator really is linear. + + This function verifies additivity :: + + A(x + y) = A(x) + A(y) + + and scale invariance :: + + A(s * x) = s * A(x) + + for vectors ``x``, ``y`` and scalars ``s``. + + Parameters + ---------- + operator : `Operator` + The operator on which to run the check. + verbosity : str, optional + Level of output verbosity. Possible values and corresponding print + outputs are: + + - ``'DEBUG'``: Everything + - ``'INFO'``: Informational context, warnings and errors + - ``'WARNING'``: Warnings and errors + - ``'ERROR'``: Errors + - ``'QUIET'``: Only a summary at the end + + opnorm : float, optional + Operator norm used to scale the error in order to make it + scale-invariant. For ``None``, it is retrieved or computed on the fly. + tol : float, optional + Relative tolerance parameter for the error in the checks. + + Returns + ------- + result : dict + Dictionary with the following keys: + + - ``num_failed(int)``: Number of failed checks. + """ + assert isinstance(operator, Operator), 'bad type {}'.format(type(operator)) + op = operator + verbosity, verb_in = str(verbosity).upper(), verbosity + assert verbosity in VERBOSITY_LEVELS, 'bad verbosity {!r}'.format(verb_in) + tol = float(tol) - scale = np.polyfit(left_inner_vals, right_inner_vals, 1)[0] - self.log('\nThe adjoint seems to be scaled according to:') - self.log('(x, Ay) / (Ax, y) = {}. Should be 1.0'.format(scale)) + num_failed = 0 - def _adjoint_definition(self): - """Verify `` == ``.""" - left_inner_vals = [] - right_inner_vals = [] + log = partial(_log, verbosity=verbosity) + log_linewidth = _log_linewidth() - with FailCounter( - test_name='Verifying the identity = ', - err_msg='error = || / ||A|| ||x|| ||y||', - logger=self.log) as counter: + log('', level='SECTION') + log('Linearity', level='SECTION') + log('=' * log_linewidth, level='SECTION') - for [name_x, x], [name_y, y] in samples(self.operator.domain, - self.operator.range): - x_norm = x.norm() - y_norm = y.norm() + print_inputs( + args=[op], + kwargs={'verbosity': verbosity, + 'opnorm': opnorm, + 'tol': tol}, + verbosity=verbosity) - l_inner = self.operator(x).inner(y) - r_inner = x.inner(self.operator.adjoint(y)) + if opnorm is None: + try: + opnorm = op.norm(estimate=False) + except NotImplementedError: + try: + opnorm = op.norm(estimate=True, maxiter=2) + except (TypeError, ValueError, NotImplementedError): + pass + + if opnorm is None: + log('unable to get or compute operator norm, using opnorm=1.0', + level='WARNING') + opnorm = 1.0 + elif opnorm == 0: + log('opnorm = 0 given, using 1.0 instead', level='WARNING') + opnorm = 1.0 + + # --- Scale invariance --- # + + log('', level='SECTION') + log('Scale invariance', level='SECTION') + log('-' * log_linewidth, level='SECTION') + log('err = ||op(s*x) - s * op(x)|| / (|s| * ||op|| * ||x||)', + level='INFO') + log('-' * log_linewidth, level='INFO') + for (name_x, x), (_, s) in samples(op.domain, op.domain.field): + s_op_x = s * op(x) + op_s_x = op(s * x) + + denom = abs(s) * opnorm * x.norm() + if denom == 0: + denom = 1.0 + + err = (op_s_x - s_op_x).norm() / denom + num_failed += _check_cond( + err <= tol, + 'x={:<20} s={:< 7.2f} err={:.1}'.format(name_x, s, err), + log, level_true='CHECK', level_false='ERROR') + + # Compute only if necessary + op_s_x_norm = op_s_x.norm() if verbosity == 'DEBUG' else 1.0 + s_op_x_norm = s_op_x.norm() if verbosity == 'DEBUG' else 1.0 + log('||op(s*x)||={:.3}'.format(op_s_x_norm), level='DEBUG') + log('||s * op(x)||={:.3}'.format(s_op_x_norm), level='DEBUG') + log('|s|*||op||*||x||={:.3}'.format(denom), level='DEBUG') + + log('-' * log_linewidth, level='INFO') + + # --- Additivity --- # + + log('', level='SECTION') + log('Additivity', level='SECTION') + log('-' * log_linewidth, level='SECTION') + log('err = ||op(x + y) - op(x) - op(y)|| / (||op|| * (||x|| + ||y||))', + level='INFO') + log('-' * log_linewidth, level='INFO') + + for (name_x, x), (name_y, y) in samples(op.domain, op.domain): + op_x = op(x) + op_y = op(y) + op_x_y = op(x + y) + + denom = opnorm * (x.norm() + y.norm()) + if denom == 0: + denom = 1.0 + + err = (op_x_y - op_x - op_y).norm() / denom + num_failed += _check_cond( + err <= tol, + 'x={:<20} y={:<20} err={:.1}'.format(name_x, name_y, err), + log, level_true='CHECK', level_false='ERROR') + + # Compute only if necessary + op_x_y_norm = op_x_y.norm() if verbosity == 'DEBUG' else 1.0 + op_x_op_y_norm = (op_x + op_y).norm() if verbosity == 'DEBUG' else 1.0 + log('||op(x + y)||={:.3}'.format(op_x_y_norm), level='DEBUG') + log('||op(x) + op(y)||={:.3}'.format(op_x_op_y_norm), level='DEBUG') + log('||op||*(||x||+||y||)={:.3}'.format(denom), level='DEBUG') + + log('-' * log_linewidth, level='INFO') + + # --- Summary --- # + + if verbosity == 'QUIET': + log('linearity: {} failed'.format(num_failed), level='QUIET') + else: + failed_level = 'INFO' if num_failed == 0 else 'ERROR' + log('', level=failed_level) + log('## Number of failed checks: {}'.format(num_failed), + level=failed_level) + + return dict(num_failed=num_failed) + + +def check_operator_adjoint(operator, verbosity='INFO', opnorm=None, tol=1e-5): + """Check whether the adjoint satisfies its mathematical properties. + + This function verifies the adjointness property :: + + _Y = _X + + and whether the adjoint of the adjoint is equivalent to the original + operator. + + Parameters + ---------- + operator : `Operator` + The operator on which to run the check. + verbosity : str, optional + Level of output verbosity. Possible values and corresponding print + outputs are: + + - ``'DEBUG'``: Everything + - ``'INFO'``: Informational context, warnings and errors + - ``'WARNING'``: Warnings and errors + - ``'ERROR'``: Errors + - ``'QUIET'``: Only a summary at the end + + opnorm : float, optional + Operator norm used to scale the error in order to make it + scale-invariant. For ``None``, it is retrieved or computed on the fly. + tol : float, optional + Relative tolerance parameter for the error in the checks. + + Returns + ------- + result : dict + Dictionary with the following keys: + + - ``num_failed(int)``: Number of failed checks. + """ + assert isinstance(operator, Operator), 'bad type {}'.format(type(operator)) + op = operator + verbosity, verb_in = str(verbosity).upper(), verbosity + assert verbosity in VERBOSITY_LEVELS, 'bad verbosity {!r}'.format(verb_in) + tol = float(tol) - denom = self.operator_norm * x_norm * y_norm - error = 0 if denom == 0 else abs(l_inner - r_inner) / denom + num_failed = 0 - if error > self.tol: - counter.fail('x={:25s} y={:25s} : error={:6.5f}' - ''.format(name_x, name_y, error)) + log = partial(_log, verbosity=verbosity) + log_linewidth = _log_linewidth() - left_inner_vals.append(l_inner) - right_inner_vals.append(r_inner) + log('', level='SECTION') + log('Adjoint', level='SECTION') + log('=' * log_linewidth, level='SECTION') - scale = np.polyfit(left_inner_vals, right_inner_vals, 1)[0] - self.log('\nThe adjoint seems to be scaled according to:') - self.log('(x, A^T y) / (Ax, y) = {}. Should be 1.0'.format(scale)) + print_inputs( + args=[op], + kwargs={'verbosity': verbosity, + 'opnorm': opnorm, + 'tol': tol}, + verbosity=verbosity) - def _adjoint_of_adjoint(self): - """Verify ``(A^*)^* == A``""" + if opnorm is None: try: - self.operator.adjoint.adjoint - except AttributeError: - print('A^* has no adjoint') - return - - if self.operator.adjoint.adjoint is self.operator: - self.log('(A^*)^* == A') - return - - with FailCounter( - test_name='\nVerifying the identity Ax = (A^*)^* x', - err_msg='error = ||Ax - (A^*)^* x|| / ||A|| ||x||', - logger=self.log) as counter: - for [name_x, x] in self.operator.domain.examples: - opx = self.operator(x) - op_adj_adj_x = self.operator.adjoint.adjoint(x) - - denom = self.operator_norm * x.norm() - if denom == 0: - error = 0 - else: - error = (opx - op_adj_adj_x).norm() / denom - - if error > self.tol: - counter.fail('x={:25s} : error={:6.5f}' - ''.format(name_x, error)) - - def adjoint(self): - """Verify that `Operator.adjoint` works appropriately. - - References - ---------- - Wikipedia article on `Adjoint - `_. - """ + opnorm = op.norm(estimate=False) + except NotImplementedError: + try: + opnorm = op.norm(estimate=True, maxiter=2) + except (TypeError, ValueError, NotImplementedError): + pass + + if opnorm is None: + log('unable to get or compute operator norm, using opnorm=1.0', + level='WARNING') + opnorm = 1.0 + elif opnorm == 0: + log('opnorm = 0 given, using 1.0 instead', level='WARNING') + opnorm = 1.0 + + has_adjoint, adjoint = _get_adjoint(op) + if not has_adjoint: + log('Operator adjoint not implemented, skipping checks', level='INFO') + return dict(num_failed=num_failed) + + # --- Basic properties --- # + + num_failed += _check_cond( + adjoint.is_linear, 'op.adjoint.is_linear', + log, level_true='CHECK', level_false='WARNING') + num_failed += _check_cond( + adjoint.domain == op.range, 'op.adjoint.domain == op.range', + log, level_true='CHECK', level_false='ERROR') + num_failed += _check_cond( + adjoint.range == op.domain, 'op.adjoint.range == op.domain', + log, level_true='CHECK', level_false='ERROR') + + # --- Adjoint definition --- # + + log('', level='SECTION') + log('Adjoint definition', level='SECTION') + log('-' * log_linewidth, level='SECTION') + log('err = | - | / (||op|| * ||x|| * ||y||)', + level='INFO') + log('-' * log_linewidth, level='INFO') + inner1_vals = [] + inner2_vals = [] + num_failed_def = 0 + for (name_x, x), (name_y, y) in samples(op.domain, op.domain): + inner1 = op(x).inner(y) + inner2 = x.inner(adjoint(y)) + inner1_vals.append(inner1) + inner2_vals.append(inner2) + + denom = opnorm * x.norm() * y.norm() + if denom == 0: + denom = 1.0 + + err = abs(inner1 - inner2) / denom + num_failed_def += _check_cond( + err <= tol, + 'x={:<20} y={:<20} err={:.1}'.format(name_x, name_y, err), + log, level_true='CHECK', level_false='ERROR') + + log('={:.3}'.format(inner1), level='DEBUG') + log('={:.3}'.format(inner2), level='DEBUG') + log('||op||*||x||*||y||={:.3}'.format(denom), level='DEBUG') + + prop_level = 'DEBUG' if num_failed_def == 0 else 'ERROR' + factor = np.polyfit(inner1_vals, inner2_vals, deg=1)[0] + log('', level=prop_level) + log('Proportionality constant: ~ factor * ', + level=prop_level) + log('with factor = {:.3}'.format(factor), level=prop_level) + + num_failed += num_failed_def + log('-' * log_linewidth, level='INFO') + + # --- Adjoint of adjoint --- # + + log('', level='SECTION') + log('Adjoint of adjoint', level='SECTION') + log('-' * log_linewidth, level='SECTION') + log('err = ||op(x) - adj.adjoint(x)|| / (||op|| * ||x||)', + level='INFO') + log('-' * log_linewidth, level='INFO') + + for (name_x, x) in samples(op.domain): + op_x = op(x) + adj_adj_x = adjoint.adjoint(x) + + denom = opnorm * x.norm() + if denom == 0: + denom = 1.0 + + err = (op_x - adj_adj_x).norm() / denom + num_failed += _check_cond( + err <= tol, + 'x={:<20} err={:.1}'.format(name_x, err), + log, level_true='CHECK', level_false='ERROR') + + # Compute only if necessary + op_x_norm = op_x.norm() if verbosity == 'DEBUG' else 1.0 + adj_adj_x_norm = adj_adj_x.norm() if verbosity == 'DEBUG' else 1.0 + log('||op(x)||={:.3}'.format(op_x_norm), level='DEBUG') + log('||adj.adjoint(x)||={:.3}'.format(adj_adj_x_norm), level='DEBUG') + log('||op||*||x||={:.3}'.format(denom), level='DEBUG') + + log('-' * log_linewidth, level='INFO') + + # --- Summary --- # + + if verbosity == 'QUIET': + log('adjoint: {} failed'.format(num_failed), level='QUIET') + else: + failed_level = 'INFO' if num_failed == 0 else 'ERROR' + log('', level=failed_level) + log('## Number of failed checks: {}'.format(num_failed), + level=failed_level) + + return dict(num_failed=num_failed) + + +def check_operator_derivative(operator, verbosity='INFO', tol=1e-4): + """Check whether the derivative satisfies its mathematical properties. + + This function verifies that the ``derivative`` can be approximated + by finite differences in chosen directions (Gâteaux derivative) :: + + A'(x)(v) ~ [A(x + tv) - A(x)] / t for t --> 0. + + Parameters + ---------- + operator : `Operator` + The operator on which to run the check. + verbosity : str, optional + Level of output verbosity. Possible values and corresponding print + outputs are: + + - ``'DEBUG'``: Everything + - ``'INFO'``: Informational context, warnings and errors + - ``'WARNING'``: Warnings and errors + - ``'ERROR'``: Errors + - ``'QUIET'``: Only a summary at the end + + tol : float, optional + Relative tolerance parameter for the error in the checks. Since + derivative checking is prone to numerical instability, this tolerance + needs to be larger than in other, more stable checks. + + Returns + ------- + result : dict + Dictionary with the following keys: + + - ``num_failed(int)``: Number of failed checks. + """ + assert isinstance(operator, Operator), 'bad type {}'.format(type(operator)) + op = operator + verbosity, verb_in = str(verbosity).upper(), verbosity + assert verbosity in VERBOSITY_LEVELS, 'bad verbosity {!r}'.format(verb_in) + tol = float(tol) + + num_failed = 0 + + log = partial(_log, verbosity=verbosity) + log_linewidth = _log_linewidth() + + log('', level='SECTION') + log('Derivative', level='SECTION') + log('=' * log_linewidth, level='SECTION') + + print_inputs( + args=[op], + kwargs={'verbosity': verbosity, + 'tol': tol}, + verbosity=verbosity) + + has_deriv, deriv = _get_derivative(op) + if not has_deriv: + log('Operator derivative not implemented, skipping checks', + level='INFO') + return dict(num_failed=num_failed) + + # --- Basic properties --- # + + num_failed += _check_cond( + deriv.is_linear, 'op.derivative(x).is_linear', + log, level_true='CHECK', level_false='WARNING') + num_failed += _check_cond( + deriv.domain == op.domain, 'op.derivative(x).domain == op.domain', + log, level_true='CHECK', level_false='ERROR') + num_failed += _check_cond( + deriv.range == op.range, 'op.derivative(x).range == op.range', + log, level_true='CHECK', level_false='ERROR') + + if op.is_linear: + num_failed += _check_cond( + deriv is op, 'op.is_linear and op is op.derivative(x)', + log, level_true='CHECK', level_false='WARNING') + + # --- Directional (Gâteaux) derivative --- # + + log('', level='SECTION') + log('Directional derivative', level='SECTION') + log('-' * log_linewidth, level='SECTION') + log('err = ||(op(x + c*dx) - op(x)) / c - deriv(x)(dx)|| / ', level='INFO') + log(' (||deriv(x)|| * ||dx||)', level='INFO') + log('-' * log_linewidth, level='INFO') + + for (name_x, x), (name_dx, dx) in samples(op.domain, op.domain): + # Precompute some values + deriv_x = op.derivative(x) + deriv_x_dx = deriv_x(dx) + + num_failed += _check_cond( + deriv_x.is_linear, 'op.derivative(x).is_linear', + log, level_true='DEBUG', level_false='WARNING') + num_failed += _check_cond( + deriv_x.domain == op.domain, + 'op.derivative(x).domain == op.domain', + log, level_true='DEBUG', level_false='ERROR') + num_failed += _check_cond( + deriv_x.range == op.range, 'op.derivative(x).range == op.range', + log, level_true='DEBUG', level_false='ERROR') + + op_x = op(x) try: - self.operator.adjoint + deriv_x_norm = deriv_x.norm(estimate=False) except NotImplementedError: - print('Operator has no adjoint') - return + try: + deriv_x_norm = deriv_x.norm(estimate=True, maxiter=2) + except (ValueError, TypeError, NotImplementedError): + deriv_x_norm = 1.0 + + denom = deriv_x_norm * dx.norm() + if denom == 0: + denom = 1.0 + + # Compute finite difference with decreasing step size, where the + # range depends on the data type precision + + # Start value; float32: c = 1e-3, float64: c = 1e-5 + c = np.cbrt(np.finfo(op.domain.dtype).resolution) + deriv_ok = False + + cs = [] + diff_norms = [] + errs = [] + while c >= 10 * np.finfo(op.domain.dtype).resolution: + finite_diff = (op(x + c * dx) - op_x) / c + err = (finite_diff - deriv_x_dx).norm() / denom + + cs.append(c) + errs.append(err) + # Compute only if needed + if verbosity == 'DEBUG': + diff_norms.append(finite_diff.norm()) + + if err < tol: + deriv_ok = True + break + + c /= 10.0 + + num_failed += _check_cond( + deriv_ok, 'x={:<20} dx={:<20} minerr={:.1}' + ''.format(name_x, name_dx, min(errs)), + log, level_true='CHECK', level_false='ERROR') + + if verbosity == 'DEBUG': + deriv_x_dx_norm = deriv_x_dx.norm() + + log('||deriv(x)||*||dx||={:.3}'.format(denom), level='DEBUG') + for c, err, diff_norm in zip(cs, errs, diff_norms): + log('c={:.1} err={:.3}'.format(c, err), level='DEBUG') + log('||(op(x + c*dx) - op(x)) / c||={:.3}'.format(diff_norm), + level='DEBUG') + log('||deriv(x)(dx)||={:.3}'.format(deriv_x_dx_norm), + level='DEBUG') + + log('-' * log_linewidth, level='INFO') + + # --- Summary --- # + + if verbosity == 'QUIET': + log('derivative: {} failed'.format(num_failed), level='QUIET') + else: + failed_level = 'INFO' if num_failed == 0 else 'ERROR' + log('', level=failed_level) + log('## Number of failed checks: {}'.format(num_failed), + level=failed_level) + + return dict(num_failed=num_failed) + + +def check_operator(operator, verbosity='INFO', checks=None, tol=1e-5, + deriv_arg=None, norm_kwargs=None): + """Run a set of standard tests on the provided operator. + + Parameters + ---------- + operator : `Operator` + The operator on which to run the checks. + verbosity : str, optional + Level of output verbosity. Possible values and corresponding print + outputs are: + + - ``'DEBUG'``: Everything + - ``'INFO'``: Informational context, warnings and errors + - ``'WARNING'``: Warnings and errors + - ``'ERROR'``: Errors + - ``'QUIET'``: Only a summary at the end + + checks : sequence of str, optional + Checks that should be run. Available checks are: + + - ``'properties'``: Basic checks for domain, range etc., see + `check_operator_properties` + - ``'norm'``: Check exact vs. estimated operator norm if available, + see `check_operator_norm` + - ``'linearity'``: Test for scale invariance and additivity, see + `check_operator_linearity` + - ``'adjoint'``: Check adjointness properties, see + `check_operator_adjoint` + - ``'derivative'``: Verify the directional derivative using finite + differences (note that this may be subject to numerical instability), + see `check_operator_derivative` + + For the default ``None``, the first 4 checks are run if + ``operator.is_linear``, otherwise the first and the last. + + tol : float, optional + Tolerance parameter used as a base for the actual tolerance + in the tests. Depending on the expected accuracy, the actual + tolerance used in a test can be a factor times this number. + deriv_arg : ``operator.domain`` element-like, optional + Argument to ``operator.derivative`` for checking its presence. For + the default ``None``, ``operator.domain.one()`` is used if possible, + else an uninitialized ``operator.domain.element()``. + norm_kwargs : dict, optional + Keyword arguments to be used as follows:: + + operator.norm(estimate=False, **norm_kwargs) + + The default ``None`` is equivalent to ``{'maxiter': 10}``. + """ + assert isinstance(operator, Operator), 'bad type {}'.format(type(operator)) + op = operator + verbosity, verb_in = str(verbosity).upper(), verbosity + assert verbosity in VERBOSITY_LEVELS, 'bad verbosity {!r}'.format(verb_in) + all_checks = {'properties', 'norm', 'linearity', 'adjoint', 'derivative'} + if checks is None: + if op.is_linear: + checks = ('properties', 'norm', 'linearity', 'adjoint') + else: + checks = ('properties', 'derivative') + checks, chk_in = tuple(str(c).lower() for c in checks), checks + assert set(checks).issubset(all_checks), 'invalid checks {}'.format(chk_in) + tol = float(tol) + if norm_kwargs is None: + norm_kwargs = {'maxiter': 10} - self.log('\n== Verifying operator adjoint ==\n') + log = partial(_log, verbosity=verbosity) - domain_range_ok = True - if self.operator.domain != self.operator.adjoint.range: - print('*** ERROR: A.domain != A.adjoint.range ***') - domain_range_ok = False + log('', level='SECTION') + log('Operator check', level='SECTION') + log('==============', level='SECTION') + log('==============', level='SECTION') - if self.operator.range != self.operator.adjoint.domain: - print('*** ERROR: A.range != A.adjoint.domain ***') - domain_range_ok = False + print_inputs( + args=[op], + kwargs={'verbosity': verbosity, + 'checks': checks, + 'tol': tol, + 'deriv_arg': deriv_arg, + 'norm_kwargs': norm_kwargs}, + verbosity=verbosity) - if domain_range_ok: - self.log('Domain and range of adjoint are OK.') - else: - print('Domain and range of adjoint are not OK, exiting.') - return - - self._adjoint_definition() - self._adjoint_of_adjoint() - - def _derivative_convergence(self): - """Verify that the derivative is a first-order approximation. - - The code verifies if - - ``||A(x+c*p) - A(x) - A'(x)(c*p)|| / c = o(c)`` - - for ``c --> 0``. - """ - with FailCounter( - test_name='Verifying that derivative is a first-order ' - 'approximation', - err_msg="error = inf_c ||A(x+c*p)-A(x)-A'(x)(c*p)|| / c", - logger=self.log) as counter: - for [name_x, x], [name_dx, dx] in samples(self.operator.domain, - self.operator.domain): - # Precompute some values - deriv = self.operator.derivative(x) - derivdx = deriv(dx) - opx = self.operator(x) - - c = 1e-4 # initial step - derivative_ok = False - - minerror = float('inf') - while c > 1e-14: - exact_step = self.operator(x + dx * c) - opx - expected_step = c * derivdx - err = (exact_step - expected_step).norm() / c - - # Need to be slightly more generous here due to possible - # numerical instabilities. - # TODO: perform more tests to find a good threshold here. - if err < 10 * self.tol: - derivative_ok = True - break - else: - minerror = min(minerror, err) - - c /= 10.0 - - if not derivative_ok: - counter.fail('x={:15s} p={:15s}, error={}' - ''.format(name_x, name_dx, minerror)) - - def derivative(self): - """Verify that `Operator.derivative` works appropriately. - - The code verifies if - - ``||A(x+c*p) - A(x) - A'(x)(c*p)|| / c = o(c)`` - - for ``c --> 0`` using a selection of elements ``x`` and ``p``. - - References - ---------- - Wikipedia article on `Derivative - `_. - Wikipedia article on `Frechet derivative - `_. - """ - self.log('\n== Verifying operator derivative ==') + if 'properties' in checks: + res_props = check_operator_properties(op, verbosity, deriv_arg) - try: - deriv = self.operator.derivative(self.operator.domain.zero()) + if 'norm' in checks: + res_norm = check_operator_norm(op, verbosity, tol, norm_kwargs) + norm_exact = res_norm['norm_exact'] + norm_est = res_norm['norm_est'] + opnorm = norm_exact if norm_exact is not None else norm_est + else: + opnorm = None - if not deriv.is_linear: - print('Derivative is not a linear operator') - return - except NotImplementedError: - print('Operator has no derivative') - return - - if self.operator.is_linear and deriv is self.operator: - self.log('A is linear and A.derivative is A') - return - - self._derivative_convergence() - - def _scale_invariance(self): - """Verify ``A(c*x) = c * A(x)``.""" - with FailCounter( - test_name='Verifying homogeneity under scalar multiplication', - err_msg='error = ||A(c*x)-c*A(x)|| / |c| ||A|| ||x||', - logger=self.log) as counter: - for [name_x, x], [_, scale] in samples(self.operator.domain, - self.operator.domain.field): - opx = self.operator(x) - scaled_opx = self.operator(scale * x) - - denom = self.operator_norm * scale * x.norm() - error = (0 if denom == 0 - else (scaled_opx - opx * scale).norm() / denom) - - if error > self.tol: - counter.fail('x={:25s} scale={:7.2f} error={:6.5f}' - ''.format(name_x, scale, error)) - - def _addition_invariance(self): - """Verify ``A(x+y) = A(x) + A(y)``.""" - with FailCounter( - test_name='Verifying distributivity under vector addition', - err_msg='error = ||A(x+y) - A(x) - A(y)|| / ' - '||A||(||x|| + ||y||)', - logger=self.log) as counter: - for [name_x, x], [name_y, y] in samples(self.operator.domain, - self.operator.domain): - opx = self.operator(x) - opy = self.operator(y) - opxy = self.operator(x + y) - - denom = self.operator_norm * (x.norm() + y.norm()) - error = (0 if denom == 0 - else (opxy - opx - opy).norm() / denom) - - if error > self.tol: - counter.fail('x={:25s} y={:25s} error={:6.5f}' - ''.format(name_x, name_y, error)) - - def linear(self): - """Verify that the operator is actually linear.""" - if not self.operator.is_linear: - print('Operator is not linear') - return - - self.log('\n== Verifying operator linearity ==\n') - - # Test if zero gives zero - result = self.operator(self.operator.domain.zero()) - result_norm = result.norm() - if result_norm != 0.0: - print("||A(0)||={:6.5f}. Should be 0.0000".format(result_norm)) - - self._scale_invariance() - self._addition_invariance() - - def run_tests(self): - """Run all tests on this operator.""" - print('\n== RUNNING ALL TESTS ==') - print('Operator = {}'.format(self.operator)) - - self.norm() - - if self.operator.is_linear: - self.linear() - self.adjoint() - else: - self.derivative() + if 'linearity' in checks: + res_lin = check_operator_linearity(op, verbosity, opnorm, tol) + + if 'adjoint' in checks: + res_adj = check_operator_adjoint(op, verbosity, opnorm, tol) - def __str__(self): - return '{}({})'.format(self.__class__.__name__, self.operator) + if 'derivative' in checks: + res_deriv = check_operator_derivative(op, verbosity, 10 * tol) - def __repr__(self): - return '{}({!r})'.format(self.__class__.__name__, self.operator) + # TODO: do stuff with results if __name__ == '__main__': import odl space = odl.uniform_discr([0, 0], [1, 1], [3, 3]) # Linear operator - I = odl.IdentityOperator(space) - OperatorTest(I, verbose=False).run_tests() + op = odl.ScalingOperator(space, 2.0) + check_operator(op, verbosity='QUIET') # Nonlinear operator op(x) = x**4 op = odl.PowerOperator(space, 4) - OperatorTest(op).run_tests() + check_operator(op, verbosity='QUIET') diff --git a/odl/space/base_tensors.py b/odl/space/base_tensors.py index bc96b392347..512a1f04ee4 100644 --- a/odl/space/base_tensors.py +++ b/odl/space/base_tensors.py @@ -412,23 +412,39 @@ def examples(self): rand_state = np.random.get_state() np.random.seed(1337) + example_yielded = False + + try: + yield ('zero', self.zero()) + example_yielded = True + except NotImplementedError: + pass + try: + yield ('one', self.one()) + example_yielded = True + except NotImplementedError: + pass + if is_numeric_dtype(self.dtype): - yield ('Linearly spaced samples', self.element( + yield ('linspace(0, 1)', self.element( np.linspace(0, 1, self.size).reshape(self.shape))) - yield ('Normally distributed noise', + yield ('rand_norm(0, 1)', self.element(np.random.standard_normal(self.shape))) + example_yielded = True if self.is_real: - yield ('Uniformly distributed noise', + yield ('rand_uni(0, 1)', self.element(np.random.uniform(size=self.shape))) + example_yielded = True elif self.is_complex: - yield ('Uniformly distributed noise', + yield ('rand_uni(0+0j, 1+1j)', self.element(np.random.uniform(size=self.shape) + np.random.uniform(size=self.shape) * 1j)) - else: - # TODO: return something that always works, like zeros or ones? - raise NotImplementedError('no examples available for non-numeric' - 'data type') + example_yielded = True + + if not example_yielded: + raise NotImplementedError('no examples available for space {!r}' + ''.format(self)) np.random.set_state(rand_state) diff --git a/odl/util/testutils.py b/odl/util/testutils.py index 8a455ba4ac1..29f8bdabbe6 100644 --- a/odl/util/testutils.py +++ b/odl/util/testutils.py @@ -447,7 +447,7 @@ def noise_elements(space, n=1): class FailCounter(object): - """Used to count the number of failures of something + """Context manager used to count the number of failures of something. Usage:: @@ -461,12 +461,16 @@ class FailCounter(object): ``*** FAILED 1 TEST CASE(S) ***`` """ - def __init__(self, test_name, err_msg=None, logger=print): + def __init__(self, test_name, err_msg=None, logger=None): self.num_failed = 0 self.test_name = test_name self.err_msg = err_msg self.fail_strings = [] - self.log = logger + + def default_logger(msg, level='INFO'): + print(msg, file=sys.stderr) + + self.log = default_logger if logger is None else logger def __enter__(self): return self @@ -480,18 +484,17 @@ def fail(self, string=None): self.fail_strings += [str(string)] def __exit__(self, type, value, traceback): + self.log(self.test_name, level='INFO') if self.num_failed == 0: - self.log('{:<70}: Completed all test cases.' - ''.format(self.test_name)) + self.log('All test cases passed.', level='INFO') else: - print(self.test_name) - for fail_string in self.fail_strings: - print(fail_string) + self.log(fail_string, level='ERROR') if self.err_msg is not None: - print(self.err_msg) - print('*** FAILED {} TEST CASE(S) ***'.format(self.num_failed)) + self.log(self.err_msg, level='ERROR') + self.log('*** FAILED {} TEST CASE(S) ***'.format(self.num_failed), + level='ERROR') class Timer(object): diff --git a/odl/util/utility.py b/odl/util/utility.py index 2e609927068..8e8823034be 100644 --- a/odl/util/utility.py +++ b/odl/util/utility.py @@ -18,13 +18,12 @@ import sys -__all__ = ('array_str', 'dtype_str', 'dtype_repr', 'npy_printoptions', - 'signature_string', 'indent', - 'is_numeric_dtype', 'is_int_dtype', 'is_floating_dtype', - 'is_real_dtype', 'is_real_floating_dtype', - 'is_complex_floating_dtype', 'real_dtype', 'complex_dtype', - 'is_string', 'nd_iterator', 'conj_exponent', 'writable_array', - 'run_from_ipython', 'NumpyRandomSeed', 'cache_arguments', 'unique') +__all__ = ('NumpyRandomSeed', 'array_str', 'cache_arguments', 'complex_dtype', + 'conj_exponent', 'dtype_repr', 'dtype_str', 'indent', + 'is_complex_floating_dtype', 'is_floating_dtype', 'is_int_dtype', + 'is_numeric_dtype', 'is_real_dtype', 'is_real_floating_dtype', + 'is_string', 'nd_iterator', 'npy_printoptions', 'real_dtype', + 'run_from_ipython', 'signature_string', 'unique', 'writable_array') TYPE_MAP_R2C = {np.dtype(dtype): np.result_type(dtype, 1j)