diff --git a/CHANGELOG.md b/CHANGELOG.md index 7310eca6..8b3a6818 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ * `Transformation` * `Shear` * `Convolve` + * `InterpolatedImage` and `Interpolant` * Added implementation of fundamental operations: * `drawImage` * `drawReal` @@ -24,10 +25,5 @@ * Added a `from_galsim` method to convert from GalSim objects to JAX-GalSim objects * Caveats - * Currently the FFT convolution does not perform kwrapping of hermitian images, - so it will lead to erroneous results on underesolved images that need k-space wrapping. - Wrapping for real images is implemented. K-space images arise from doing convolutions - via FFTs and so one would expect that underresolved images with convolutions may not be - rendered as accurately. * Real space convolution and photon shooting methods are not yet implemented in drawImage. diff --git a/jax_galsim/__init__.py b/jax_galsim/__init__.py index 2c360fea..775900a2 100644 --- a/jax_galsim/__init__.py +++ b/jax_galsim/__init__.py @@ -43,16 +43,14 @@ from .gaussian import Gaussian from .box import Box, Pixel from .gsobject import GSObject - -# Interpolation from .moffat import Moffat - from .sum import Add, Sum from .transform import Transform, Transformation from .convolve import Convolve, Convolution, Deconvolution, Deconvolve # WCS from .wcs import ( + BaseWCS, AffineTransform, JacobianWCS, OffsetWCS, @@ -77,7 +75,11 @@ Quintic, Lanczos, ) +from .interpolatedimage import InterpolatedImage, _InterpolatedImage # packages kept separate from . import bessel from . import fits + +# this one is specific to jax_galsim +from . import core diff --git a/jax_galsim/convolve.py b/jax_galsim/convolve.py index 0a349039..3f34d1b1 100644 --- a/jax_galsim/convolve.py +++ b/jax_galsim/convolve.py @@ -118,7 +118,7 @@ def __init__(self, *args, **kwargs): # Save the construction parameters (as they are at this point) as attributes so they # can be inspected later if necessary. if bool(real_space): - raise NotImplementedError("Real space convolutions are not implemented") + raise NotImplementedError("Real-space convolutions are not implemented") self._real_space = bool(real_space) # Figure out what gsparams to use @@ -296,7 +296,7 @@ def _max_sb(self): return self.flux / jnp.sum(jnp.array(area_list)) def _xValue(self, pos): - raise NotImplementedError("Not implemented") + raise NotImplementedError("Real-space convolutions are not implemented") def _kValue(self, kpos): kv_list = [ @@ -305,10 +305,10 @@ def _kValue(self, kpos): return jnp.prod(jnp.array(kv_list)) def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): - raise NotImplementedError("Not implemented") + raise NotImplementedError("Real-space convolutions are not implemented") def _shoot(self, photons, rng): - raise NotImplementedError("Not implemented") + raise NotImplementedError("Photon shooting convolutions are not implemented") def _drawKImage(self, image, jac=None): image = self.obj_list[0]._drawKImage(image, jac) diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index c41a8a40..554cc46a 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -1,6 +1,16 @@ from functools import partial import jax +import jax.numpy as jnp + + +@jax.jit +def compute_major_minor_from_jacobian(jac): + h1 = jnp.hypot(jac[0, 0] + jac[1, 1], jac[0, 1] - jac[1, 0]) + h2 = jnp.hypot(jac[0, 0] - jac[1, 1], jac[0, 1] + jac[1, 0]) + major = 0.5 * jnp.abs(h1 + h2) + minor = 0.5 * jnp.abs(h1 - h2) + return major, minor def convert_to_float(x): @@ -43,6 +53,54 @@ def cast_scalar_to_int(x): return x +def is_equal_with_arrays(x, y): + """Return True if the data is equal, False otherwise. Handles jax.Array types.""" + if isinstance(x, list): + if isinstance(y, list) and len(x) == len(y): + for vx, vy in zip(x, y): + if not is_equal_with_arrays(vx, vy): + return False + return True + else: + return False + elif isinstance(x, tuple): + if isinstance(y, tuple) and len(x) == len(y): + for vx, vy in zip(x, y): + if not is_equal_with_arrays(vx, vy): + return False + return True + else: + return False + elif isinstance(x, set): + if isinstance(y, set) and len(x) == len(y): + for vx, vy in zip(sorted(x), sorted(y)): + if not is_equal_with_arrays(vx, vy): + return False + return True + else: + return False + elif isinstance(x, dict): + if isinstance(y, dict) and len(x) == len(y): + for kx, vx in x.items(): + if kx not in y or (not is_equal_with_arrays(vx, y[kx])): + return False + return True + else: + return False + elif isinstance(x, jax.Array) and jnp.ndim(x) > 0: + if isinstance(y, jax.Array) and y.shape == x.shape: + return jnp.array_equal(x, y) + else: + return False + elif (isinstance(x, jax.Array) and jnp.ndim(x) == 0) or ( + isinstance(y, jax.Array) and jnp.ndim(y) == 0 + ): + # this case covers comparing an array scalar to a python scalar or vice versa + return jnp.array_equal(x, y) + else: + return x == y + + def _recurse_list_to_tuple(x): if isinstance(x, list): return tuple(_recurse_list_to_tuple(v) for v in x) diff --git a/jax_galsim/core/wrap_image.py b/jax_galsim/core/wrap_image.py index 94ced023..74f7b26c 100644 --- a/jax_galsim/core/wrap_image.py +++ b/jax_galsim/core/wrap_image.py @@ -4,7 +4,7 @@ @jax.jit -def wrap_nonhermition(im, xmin, ymin, nxwrap, nywrap): +def wrap_nonhermitian(im, xmin, ymin, nxwrap, nywrap): def _body_j(j, vals): i, im = vals @@ -33,3 +33,107 @@ def _body_i(i, vals): im = jax.lax.fori_loop(0, im.shape[0], _body_i, im) return im + + +@jax.jit +def expand_hermitian_x(im): + return jnp.concatenate([im[:, 1:][::-1, ::-1].conjugate(), im], axis=1) + + +@jax.jit +def contract_hermitian_x(im): + return im[:, im.shape[1] // 2 :] + + +@jax.jit +def wrap_hermitian_x(im, im_xmin, im_ymin, wrap_xmin, wrap_ymin, wrap_nx, wrap_ny): + im_exp = expand_hermitian_x(im) + im_exp = wrap_nonhermitian( + im_exp, wrap_xmin - im_xmin, wrap_ymin - im_ymin, wrap_nx, wrap_ny + ) + return contract_hermitian_x(im_exp) + + +@jax.jit +def expand_hermitian_y(im): + return jnp.concatenate([im[1:, :][::-1, ::-1].conjugate(), im], axis=0) + + +@jax.jit +def contract_hermitian_y(im): + return im[im.shape[0] // 2 :, :] + + +@jax.jit +def wrap_hermitian_y(im, im_xmin, im_ymin, wrap_xmin, wrap_ymin, wrap_nx, wrap_ny): + im_exp = expand_hermitian_y(im) + im_exp = wrap_nonhermitian( + im_exp, wrap_xmin - im_xmin, wrap_ymin - im_ymin, wrap_nx, wrap_ny + ) + return contract_hermitian_y(im_exp) + + +# I am leaving this code here for posterity. It has a bug that I cannot find. +# It tries to be more clever instead of simply expanding the hermitian image to +# it's full shape, wrapping everything, and then contracting. -MRB +# @jax.jit +# def wrap_hermitian_x(im, im_xmin, im_ymin, wrap_xmin, wrap_ymin, wrap_nx, wrap_ny): +# def _body_j(j, vals): +# i, im = vals + +# # first do zero or positive x freq +# im_y = i + im_ymin +# im_x = j + im_xmin +# wrap_y = (im_y - wrap_ymin) % wrap_ny + wrap_ymin +# wrap_x = (im_x - wrap_xmin) % wrap_nx + wrap_xmin +# wrap_yind = wrap_y - im_ymin +# wrap_xind = wrap_x - im_xmin +# im = jax.lax.cond( +# wrap_xind >= 0, +# lambda wrap_x, im_x, wrap_y, im_y, im, wrap_yind, wrap_xind: jax.lax.cond( +# jnp.abs(wrap_x - im_x) + jnp.abs(wrap_y - im_y) != 0, +# lambda im, wrap_yind, wrap_xind: im.at[wrap_yind, wrap_xind].add(im[i, j]), +# lambda im, wrap_yind, wrap_xind: im, +# im, +# wrap_yind, +# wrap_xind, +# ), +# lambda wrap_x, im_x, wrap_y, im_y, im, wrap_yind, wrap_xind: im, +# wrap_x, im_x, wrap_y, im_y, im, wrap_yind, wrap_xind, +# ) + +# # now do neg x freq +# im_y = -im_y +# im_x = -im_x +# wrap_y = (im_y - wrap_ymin) % wrap_ny + wrap_ymin +# wrap_x = (im_x - wrap_xmin) % wrap_nx + wrap_xmin +# wrap_yind = wrap_y - im_ymin +# wrap_xind = wrap_x - im_xmin +# im = jax.lax.cond( +# im_x != 0, +# lambda wrap_x, im_x, wrap_y, im_y, im, wrap_yind, wrap_xind: jax.lax.cond( +# wrap_xind >= 0, +# lambda wrap_x, im_x, wrap_y, im_y, im, wrap_yind, wrap_xind: jax.lax.cond( +# (jnp.abs(wrap_x - im_x) + jnp.abs(wrap_y - im_y)) != 0, +# lambda im, wrap_yind, wrap_xind: im.at[wrap_yind, wrap_xind].add(im[i, j].conjugate()), +# lambda im, wrap_yind, wrap_xind: im, +# im, +# wrap_yind, +# wrap_xind, +# ), +# lambda wrap_x, im_x, wrap_y, im_y, im, wrap_yind, wrap_xind: im, +# wrap_x, im_x, wrap_y, im_y, im, wrap_yind, wrap_xind, +# ), +# lambda wrap_x, im_x, wrap_y, im_y, im, wrap_yind, wrap_xind: im, +# wrap_x, im_x, wrap_y, im_y, im, wrap_yind, wrap_xind, +# ) + +# return [i, im] + +# def _body_i(i, vals): +# im = vals +# _, im = jax.lax.fori_loop(0, im.shape[1], _body_j, [i, im]) +# return im + +# im = jax.lax.fori_loop(0, im.shape[0], _body_i, im) +# return im diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index 1d1f46c7..49458091 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -4,6 +4,7 @@ import numpy as np from jax._src.numpy.util import _wraps +from jax_galsim.core.utils import is_equal_with_arrays from jax_galsim.gsparams import GSParams from jax_galsim.position import Position, PositionD, PositionI from jax_galsim.utilities import parse_pos_args @@ -178,7 +179,7 @@ def __neg__(self): def __eq__(self, other): return (self is other) or ( (type(other) is self.__class__) - and (self.tree_flatten() == other.tree_flatten()) + and is_equal_with_arrays(self.tree_flatten(), other.tree_flatten()) ) @_wraps(_galsim.GSObject.xValue) @@ -771,7 +772,7 @@ def drawFFT_makeKImage(self, image): with jax.ensure_compile_time_eval(): Nk = self.gsparams.maximum_fft_size N = Nk - dk = 2.0 * np.pi / (N * image.scale) + dk = 2.0 * np.pi / (N * image.scale) else: # Start with what this profile thinks a good size would be given the image's pixel scale. N = self.getGoodImageSize(image.scale) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 8b24c31d..bfe32b7b 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -631,9 +631,9 @@ def _wrap(self, bounds, hermx, hermy): Equivalent to ``image.wrap(bounds, hermitian=='x', hermitian=='y')``. """ if not hermx and not hermy: - from jax_galsim.core.wrap_image import wrap_nonhermition + from jax_galsim.core.wrap_image import wrap_nonhermitian - self._array = wrap_nonhermition( + self._array = wrap_nonhermitian( self._array, # zero indexed location of subimage bounds.xmin - self.xmin, @@ -642,8 +642,31 @@ def _wrap(self, bounds, hermx, hermy): bounds.xmax - bounds.xmin + 1, bounds.ymax - bounds.ymin + 1, ) + elif hermx and not hermy: + from jax_galsim.core.wrap_image import wrap_hermitian_x + + self._array = wrap_hermitian_x( + self._array, + -self.xmax, + self.ymin, + -bounds.xmax + 1, + bounds.ymin, + 2 * bounds.xmax, + bounds.ymax - bounds.ymin + 1, + ) + elif not hermx and hermy: + from jax_galsim.core.wrap_image import wrap_hermitian_y + + self._array = wrap_hermitian_y( + self._array, + self.xmin, + -self.ymax, + bounds.xmin, + -bounds.ymax + 1, + bounds.xmax - bounds.xmin + 1, + 2 * bounds.ymax, + ) - # FIXME: Wrapping not yet implemented for hermitian images return self.subImage(bounds) @_wraps(_galsim.Image.calculate_fft) @@ -659,11 +682,15 @@ def calculate_fft(self): "calculate_fft requires that the image has a PixelScale wcs." ) - No2 = jnp.maximum( - -self.bounds.xmin, - self.bounds.xmax + 1, - -self.bounds.ymin, - self.bounds.ymax + 1, + No2 = max( + max( + -self.bounds.xmin, + self.bounds.xmax + 1, + ), + max( + -self.bounds.ymin, + self.bounds.ymax + 1, + ), ) full_bounds = BoundsI(-No2, No2 - 1, -No2, No2 - 1) @@ -680,7 +707,11 @@ def calculate_fft(self): dk = jnp.pi / (No2 * dx) out = Image(BoundsI(0, No2, -No2, No2 - 1), dtype=np.complex128, scale=dk) - out._image = jnp.fft.rfft2(ximage._image) + # we shift the image before and after the FFT to match the layout of the modes + # used by GalSim + out._array = jnp.fft.fftshift( + jnp.fft.rfft2(jnp.fft.fftshift(ximage.array)), axes=0 + ) out *= dx * dx out.setOrigin(0, -No2) @@ -707,7 +738,10 @@ def calculate_inverse_fft(self): self.bounds, ) - No2 = jnp.maximum(self.bounds.xmax, -self.bounds.ymin, self.bounds.ymax) + No2 = max( + max(self.bounds.xmax, -self.bounds.ymin), + self.bounds.ymax, + ) target_bounds = BoundsI(0, No2, -No2, No2 - 1) if self.bounds == target_bounds: @@ -729,7 +763,10 @@ def calculate_inverse_fft(self): # For the inverse, we need a bit of extra space for the fft. out_extra = Image(BoundsI(-No2, No2 + 1, -No2, No2 - 1), dtype=float, scale=dx) - out_extra._image = jnp.fft.irfft2(kimage._image) + # we shift the image before and after the FFT to match the layout used by galsim + out_extra._array = jnp.fft.fftshift( + jnp.fft.irfft2(jnp.fft.fftshift(kimage.array, axes=0)) + ) # Now cut off the bit we don't need. out = out_extra.subImage(BoundsI(-No2, No2 - 1, -No2, No2 - 1)) out *= (dk * No2 / jnp.pi) ** 2 @@ -744,19 +781,19 @@ def good_fft_size(cls, input_size): going to be performing FFTs on an image, these will tend to be faster at performing the FFT. """ + # we use the math module here since this function should not be jitted. + import math + # Reference from GalSim C++ # https://github.com/GalSim-developers/GalSim/blob/ece3bd32c1ae6ed771f2b489c5ab1b25729e0ea4/src/Image.cpp#L1009 - input_size = int(input_size) - if input_size <= 2: - return 2 # Reduce slightly to eliminate potential rounding errors: insize = (1.0 - 1.0e-5) * input_size - log2n = jnp.log(2.0) * jnp.ceil(jnp.log(insize) / jnp.log(2.0)) - log2n3 = jnp.log(3.0) + jnp.log(2.0) * jnp.ceil( - (jnp.log(insize) - jnp.log(3.0)) / jnp.log(2.0) + log2n = math.log(2.0) * math.ceil(math.log(insize) / math.log(2.0)) + log2n3 = math.log(3.0) + math.log(2.0) * math.ceil( + (math.log(insize) - math.log(3.0)) / math.log(2.0) ) - log2n3 = max(log2n3, jnp.log(6.0)) # must be even number - Nk = int(jnp.ceil(jnp.exp(min(log2n, log2n3)) - 1.0e-5)) + log2n3 = max(log2n3, math.log(6.0)) # must be even number + Nk = max(int(math.ceil(math.exp(min(log2n, log2n3)) - 1.0e-5)), 2) return Nk def copyFrom(self, rhs): diff --git a/jax_galsim/interpolant.py b/jax_galsim/interpolant.py index 5ef0b5f2..9ed587e6 100644 --- a/jax_galsim/interpolant.py +++ b/jax_galsim/interpolant.py @@ -12,6 +12,7 @@ from jax.tree_util import register_pytree_node_class from jax_galsim.bessel import si +from jax_galsim.core.utils import is_equal_with_arrays from jax_galsim.gsparams import GSParams @@ -137,7 +138,7 @@ def _i(self): def __eq__(self, other): return (self is other) or ( type(other) is self.__class__ - and self.tree_flatten() == other.tree_flatten() + and is_equal_with_arrays(self.tree_flatten()[1], other.tree_flatten()[1]) ) def __ne__(self, other): @@ -158,8 +159,11 @@ def xval(self, x): an array. """ if jnp.ndim(x) > 1: - raise GalSimValueError("kval only takes scalar or 1D array values", x) + raise GalSimValueError("xval only takes scalar or 1D array values", x) + return self._xval_noraise(x) + + def _xval_noraise(self, x): return self.__class__._xval(x) def kval(self, k): @@ -176,6 +180,9 @@ def kval(self, k): if jnp.ndim(k) > 1: raise GalSimValueError("kval only takes scalar or 1D array values", k) + return self._kval_noraise(k) + + def _kval_noraise(self, k): return self.__class__._uval(k / 2.0 / jnp.pi) def unit_integrals(self, max_len=None): @@ -267,20 +274,7 @@ def __init__(self, tol=None, gsparams=None): gsparams = GSParams(kvalue_accuracy=tol) self._gsparams = GSParams.check(gsparams) - def xval(self, x): - """Calculate the value of the interpolant kernel at one or more x values - - Parameters: - x: The value (as a float) or values (as a np.array) at which to compute the - amplitude of the Interpolant kernel. - - Returns: - xval: The value(s) at the x location(s). If x was an array, then this is also - an array. - """ - if jnp.ndim(x) > 1: - raise GalSimValueError("kval only takes scalar or 1D array values", x) - + def _xval_noraise(self, x): return Delta._xval(x, self._gsparams.kvalue_accuracy) @jax.jit @@ -1336,66 +1330,41 @@ def __init__( conserve_dc=True, tol=None, gsparams=None, - _K=None, - _C=None, - _umax=None, - _du=None, ): if tol is not None: from galsim.deprecated import depr depr("tol", 2.2, "gsparams=GSParams(kvalue_accuracy=tol)") gsparams = GSParams(kvalue_accuracy=tol) - self._n = int(n) - self._conserve_dc = bool(conserve_dc) + self._n = n + self._conserve_dc = conserve_dc self._gsparams = GSParams.check(gsparams) - if _C is None or _K is None: - _K = [0.0] + [Lanczos._raw_uval(i + 1.0, n).item() for i in range(5)] - _C = [0.0] * 6 - _C[0] = 1.0 + 2.0 * ( - _K[1] * (1.0 + 3.0 * _K[1] + _K[2] + _K[3]) - + _K[2] - + _K[3] - + _K[4] - + _K[5] - ) - _C[1] = -_K[1] * (1.0 + 4.0 * _K[1] + _K[2] + 2.0 * _K[3]) - _C[2] = _K[1] * (_K[1] - 2.0 * _K[2] + _K[3]) - _K[2] - _C[3] = _K[1] * (_K[1] - 2.0 * _K[3]) - _K[3] - _C[4] = _K[1] * _K[3] - _K[4] - _C[5] = -_K[5] - _K = tuple(_K) - _C = tuple(_C) - self._K = _K - self._C = _C - else: - self._K = _K - self._C = _C - - self._K_arr = jnp.array(self._K, dtype=float) - self._C_arr = jnp.array(self._C, dtype=float) - - if _du is None: - _du = ( - self._gsparams.table_spacing - * jnp.power(self._gsparams.kvalue_accuracy / 200.0, 0.25) - / self._n - ).item() - self._du = _du - else: - self._du = _du - - if _umax is None: - self._umax = _find_umax_lanczos( - self._du, - self._n, - self._conserve_dc, - self._C, - self._gsparams.kvalue_accuracy, - ).item() - else: - self._umax = _umax + @property + def _C_arr(self): + return self._C_arr_vals[self._n] + + @property + def _K_arr(self): + return self._K_arr_vals[self._n] + + @property + def _du(self): + return ( + self._gsparams.table_spacing + * jnp.power(self._gsparams.kvalue_accuracy / 200.0, 0.25) + / self._n + ) + + @property + def _umax(self): + return _find_umax_lanczos( + self._du, + self._n, + self._conserve_dc, + self._C_arr, + self._gsparams.kvalue_accuracy, + ) def tree_flatten(self): """This function flattens the Interpolant into a list of children @@ -1408,13 +1377,6 @@ def tree_flatten(self): "n": self._n, "conserve_dc": self._conserve_dc, } - if hasattr(self, "_du"): - aux_data["_du"] = self._du - if hasattr(self, "_umax"): - aux_data["_umax"] = self._umax - if hasattr(self, "_K"): - aux_data["_K"] = self._K - aux_data["_C"] = self._C return (children, aux_data) @classmethod @@ -1508,20 +1470,7 @@ def _no_dcval(val, x, n, _K): _K, ) - def xval(self, x): - """Calculate the value of the interpolant kernel at one or more x values - - Parameters: - x: The value (as a float) or values (as a np.array) at which to compute the - amplitude of the Interpolant kernel. - - Returns: - xval: The value(s) at the x location(s). If x was an array, then this is also - an array. - """ - if jnp.ndim(x) > 1: - raise GalSimValueError("kval only takes scalar or 1D array values", x) - + def _xval_noraise(self, x): return Lanczos._xval(x, self._n, self._conserve_dc, self._K_arr) def _raw_uval(u, n): @@ -1578,20 +1527,7 @@ def _no_dcval(retval, u, n, _C): _C, ) - def kval(self, k): - """Calculate the value of the interpolant kernel in Fourier space at one or more k values. - - Parameters: - k: The value (as a float) or values (as a np.array) at which to compute the - amplitude of the Interpolant kernel in Fourier space. - - Returns: - kval: The k-value(s) at the k location(s). If k was an array, then this is also - an array. - """ - if jnp.ndim(k) > 1: - raise GalSimValueError("kval only takes scalar or 1D array values", k) - + def _kval_noraise(self, k): return Lanczos._uval(k / 2.0 / jnp.pi, self._n, self._conserve_dc, self._C_arr) def urange(self): @@ -1679,3 +1615,30 @@ def _body(vals): _body, [0.0, 0.0], )[0] + + +@jax.jit +def _compute_C_K_lanczos(n): + _K = jnp.concatenate( + (jnp.zeros(1), Lanczos._raw_uval(jnp.arange(5) + 1.0, n)), axis=0 + ) + _C = jnp.zeros(6) + _C = _C.at[0].set( + 1.0 + + 2.0 + * (_K[1] * (1.0 + 3.0 * _K[1] + _K[2] + _K[3]) + _K[2] + _K[3] + _K[4] + _K[5]) + ) + _C = _C.at[1].set(-_K[1] * (1.0 + 4.0 * _K[1] + _K[2] + 2.0 * _K[3])) + _C = _C.at[2].set(_K[1] * (_K[1] - 2.0 * _K[2] + _K[3]) - _K[2]) + _C = _C.at[3].set(_K[1] * (_K[1] - 2.0 * _K[3]) - _K[3]) + _C = _C.at[4].set(_K[1] * _K[3] - _K[4]) + _C = _C.at[5].set(-_K[5]) + + return _C, _K + + +Lanczos._C_arr_vals = {} +Lanczos._K_arr_vals = {} +for n in range(1, 31): + Lanczos._C_arr_vals[n] = _compute_C_K_lanczos(n)[0] + Lanczos._K_arr_vals[n] = _compute_C_K_lanczos(n)[1] diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py new file mode 100644 index 00000000..586652a1 --- /dev/null +++ b/jax_galsim/interpolatedimage.py @@ -0,0 +1,1194 @@ +import copy +import math +import textwrap +from functools import partial + +import galsim as _galsim +import jax +import jax.numpy as jnp +from galsim.errors import ( + GalSimIncompatibleValuesError, + GalSimRangeError, + GalSimUndefinedBoundsError, + GalSimValueError, +) +from galsim.utilities import doc_inherit +from jax._src.numpy.util import _wraps +from jax.tree_util import register_pytree_node_class + +from jax_galsim import fits +from jax_galsim.bounds import BoundsI +from jax_galsim.core.utils import compute_major_minor_from_jacobian, ensure_hashable +from jax_galsim.gsobject import GSObject +from jax_galsim.gsparams import GSParams +from jax_galsim.image import Image +from jax_galsim.interpolant import Quintic +from jax_galsim.position import PositionD +from jax_galsim.transform import Transformation +from jax_galsim.utilities import convert_interpolant, lazy_property +from jax_galsim.wcs import BaseWCS, PixelScale + +# These keys are removed from the public API of +# InterpolatedImage so that it matches the galsim +# one. +# The DirMeta class does this along with the changes to +# __getattribute__ and __dir__ below. +_KEYS_TO_REMOVE = [ + "flux_ratio", + "jac", + "offset", + "original", +] + + +# magic from https://stackoverflow.com/questions/46120462/how-to-override-the-dir-method-for-a-class +class DirMeta(type): + def __dir__(cls): + keys = set(list(cls.__dict__.keys()) + dir(cls.__base__)) + keys -= set(_KEYS_TO_REMOVE) + return list(keys) + + +@_wraps( + _galsim.InterpolatedImage, + lax_description=textwrap.dedent( + """The JAX equivalent of galsim.InterpolatedImage does not support + + - noise padding + - the pad_image options + - depixelize + - most of the type checks and dtype casts done by galsim + """ + ), +) +@register_pytree_node_class +class InterpolatedImage(Transformation, metaclass=DirMeta): + _req_params = {"image": str} + _opt_params = { + "x_interpolant": str, + "k_interpolant": str, + "normalization": str, + "scale": float, + "flux": float, + "pad_factor": float, + "noise_pad_size": float, + "noise_pad": str, + "pad_image": str, + "calculate_stepk": bool, + "calculate_maxk": bool, + "use_true_center": bool, + "depixelize": bool, + "offset": PositionD, + "hdu": int, + } + _takes_rng = True + + def __init__( + self, + image, + x_interpolant=None, + k_interpolant=None, + normalization="flux", + scale=None, + wcs=None, + flux=None, + pad_factor=4.0, + noise_pad_size=0, + noise_pad=0.0, + rng=None, + pad_image=None, + calculate_stepk=True, + calculate_maxk=True, + use_cache=True, + use_true_center=True, + depixelize=False, + offset=None, + gsparams=None, + _force_stepk=0.0, + _force_maxk=0.0, + _recenter_image=True, # this option is used by _InterpolatedImage below + hdu=None, + _obj=None, + ): + # If the "image" is not actually an image, try to read the image as a file. + if isinstance(image, str): + image = fits.read(image, hdu=hdu) + elif not isinstance(image, Image): + raise TypeError("Supplied image must be an Image or file name") + + self._jax_children = ( + image, + dict( + scale=scale, + wcs=wcs, + flux=flux, + pad_image=pad_image, + offset=offset, + ), + ) + self._jax_aux_data = dict( + x_interpolant=x_interpolant, + k_interpolant=k_interpolant, + normalization=normalization, + pad_factor=pad_factor, + noise_pad_size=noise_pad_size, + noise_pad=noise_pad, + rng=rng, + calculate_stepk=calculate_stepk, + calculate_maxk=calculate_maxk, + use_cache=use_cache, + use_true_center=use_true_center, + depixelize=depixelize, + gsparams=GSParams.check(gsparams), + _force_stepk=_force_stepk, + _force_maxk=_force_maxk, + _recenter_image=_recenter_image, + hdu=hdu, + ) + + if _obj is not None: + obj = _obj + else: + obj = _InterpolatedImageImpl( + image, + x_interpolant=x_interpolant, + k_interpolant=k_interpolant, + normalization=normalization, + scale=scale, + wcs=wcs, + flux=flux, + pad_factor=pad_factor, + noise_pad_size=noise_pad_size, + noise_pad=noise_pad, + rng=rng, + pad_image=pad_image, + calculate_stepk=calculate_stepk, + calculate_maxk=calculate_maxk, + use_cache=use_cache, + use_true_center=use_true_center, + depixelize=depixelize, + offset=offset, + gsparams=GSParams.check(gsparams), + _force_stepk=_force_stepk, + _force_maxk=_force_maxk, + hdu=hdu, + _recenter_image=_recenter_image, + ) + + # we don't use the parent init but instead set things by hand to + # avoid computations upon init + self._gsparams = GSParams.check(gsparams, obj.gsparams) + self._propagate_gsparams = True + if self._propagate_gsparams: + obj = obj.withGSParams(self._gsparams) + self._original = obj + self._params = { + "offset": PositionD(0.0, 0.0), + } + self._jax_children[1]["_obj"] = obj + + @property + def _flux_ratio(self): + return self._original._flux_ratio / self._original._wcs.pixelArea() + + @property + def _jac(self): + return self._original._jac_arr.reshape((2, 2)) + + def __getattribute__(self, name): + if name in _KEYS_TO_REMOVE: + raise AttributeError(f"{self.__class__} has no attribute '{name}'") + return super().__getattribute__(name) + + def __dir__(self): + allattrs = set(self.__dict__.keys() + dir(self.__class__)) + allattrs -= set(_KEYS_TO_REMOVE) + return list(allattrs) + + # the galsim tests use this internal attribute + # so we add it here + @property + def _xim(self): + return self._original._xim + + @property + def _maxk(self): + if self._jax_aux_data["_force_maxk"] > 0: + return self._jax_aux_data["_force_maxk"] + else: + return super()._maxk + + @property + def _stepk(self): + if self._jax_aux_data["_force_stepk"] > 0: + return self._jax_aux_data["_force_stepk"] + else: + return super()._stepk + + @property + def x_interpolant(self): + """The real-space `Interpolant` for this profile.""" + return self._original._x_interpolant + + @property + def k_interpolant(self): + """The Fourier-space `Interpolant` for this profile.""" + return self._original._k_interpolant + + @property + def image(self): + """The underlying `Image` being interpolated.""" + return self._original._image + + def __hash__(self): + # Definitely want to cache this, since the size of the image could be large. + if not hasattr(self, "_hash"): + self._hash = hash( + ("galsim.InterpolatedImage", self.x_interpolant, self.k_interpolant) + ) + self._hash ^= hash( + ( + ensure_hashable(self.flux), + ensure_hashable(self._stepk), + ensure_hashable(self._maxk), + ensure_hashable(self._original._jax_aux_data["pad_factor"]), + ) + ) + self._hash ^= hash( + ( + self._original._xim.bounds, + self._original._image.bounds, + self._original._pad_image.bounds, + ) + ) + # A common offset is 0.5,0.5, and *sometimes* this produces the same hash as 0,0 + # (which is also common). I guess because they are only different in 2 bits. + # This mucking of the numbers seems to help make the hash more reliably different for + # these two cases. Note: "sometiems" because of this: + # https://stackoverflow.com/questions/27522626/hash-function-in-python-3-3-returns-different-results-between-sessions + self._hash ^= hash( + ( + ensure_hashable(self._original._offset.x * 1.234), + ensure_hashable(self._original._offset.y * 0.23424), + ) + ) + self._hash ^= hash(self.gsparams) + self._hash ^= hash(self._original._wcs) + # Just hash the diagonal. Much faster, and usually is unique enough. + # (Let python handle collisions as needed if multiple similar IIs are used as keys.) + self._hash ^= hash(ensure_hashable(self._original._pad_image.array)) + return self._hash + + def __repr__(self): + s = "galsim.InterpolatedImage(%r, %r, %r, wcs=%r" % ( + self._original.image, + self.x_interpolant, + self.k_interpolant, + self._original._wcs, + ) + # Most things we keep even if not required, but the pad_image is large, so skip it + # if it's really just the same as the main image. + if self._original._pad_image.bounds != self._original.image.bounds: + s += ", pad_image=%r" % (self._pad_image) + s += ", pad_factor=%f, flux=%r, offset=%r" % ( + ensure_hashable(self._original._jax_aux_data["pad_factor"]), + ensure_hashable(self.flux), + self._original._offset, + ) + s += ( + ", use_true_center=False, gsparams=%r, _force_stepk=%r, _force_maxk=%r)" + % ( + self.gsparams, + ensure_hashable(self._stepk), + ensure_hashable(self._maxk), + ) + ) + return s + + def __str__(self): + return "galsim.InterpolatedImage(image=%s, flux=%s)" % (self.image, self.flux) + + def __eq__(self, other): + return self is other or ( + isinstance(other, InterpolatedImage) + and self._xim == other._xim + and self.x_interpolant == other.x_interpolant + and self.k_interpolant == other.k_interpolant + and self.flux == other.flux + and self._original._offset == other._original._offset + and self.gsparams == other.gsparams + and self._stepk == other._stepk + and self._maxk == other._maxk + ) + + def tree_flatten(self): + """This function flattens the InterpolatedImage into a list of children + nodes that will be traced by JAX and auxiliary static data.""" + return (self._jax_children, copy.copy(self._jax_aux_data)) + + @classmethod + def tree_unflatten(cls, aux_data, children): + """Recreates an instance of the class from flatten representation""" + val = {} + val.update(aux_data) + val.update(children[1]) + return cls(children[0], **val) + + @_wraps(_galsim.InterpolatedImage.withGSParams) + def withGSParams(self, gsparams=None, **kwargs): + if gsparams == self.gsparams: + return self + # Checking gsparams + gsparams = GSParams.check(gsparams, self.gsparams, **kwargs) + # Flattening the representation to instantiate a clean new object + children, aux_data = self.tree_flatten() + aux_data["gsparams"] = gsparams + ret = self.tree_unflatten(aux_data, children) + + return ret + + +@partial(jax.jit, static_argnums=(1,)) +def _zeropad_image(arr, npad): + return jnp.pad(arr, npad, mode="constant", constant_values=0.0) + + +@register_pytree_node_class +class _InterpolatedImageImpl(GSObject): + _cache_noise_pad = {} + + _has_hard_edges = False + _is_axisymmetric = False + _is_analytic_x = True + _is_analytic_k = True + + def __init__( + self, + image, + x_interpolant=None, + k_interpolant=None, + normalization="flux", + scale=None, + wcs=None, + flux=None, + pad_factor=4.0, + noise_pad_size=0, + noise_pad=0.0, + rng=None, + pad_image=None, + calculate_stepk=True, + calculate_maxk=True, + use_cache=True, + use_true_center=True, + depixelize=False, + offset=None, + gsparams=None, + _force_stepk=0.0, + _force_maxk=0.0, + hdu=None, + _recenter_image=True, + ): + # this class does a ton of munging of the inputs that I don't want to reconstruct when + # flattening and unflattening the class. + # thus I am going to make some refs here so we have it when we need it + self._workspace = {} + self._jax_children = ( + image, + dict( + scale=scale, + wcs=wcs, + flux=flux, + pad_image=pad_image, + offset=offset, + ), + ) + self._jax_aux_data = dict( + x_interpolant=x_interpolant, + k_interpolant=k_interpolant, + normalization=normalization, + pad_factor=pad_factor, + noise_pad_size=noise_pad_size, + noise_pad=noise_pad, + rng=rng, + calculate_stepk=calculate_stepk, + calculate_maxk=calculate_maxk, + use_cache=use_cache, + use_true_center=use_true_center, + depixelize=depixelize, + gsparams=gsparams, + _force_stepk=_force_stepk, + _force_maxk=_force_maxk, + _recenter_image=_recenter_image, + hdu=hdu, + ) + + # it must have well-defined bounds, otherwise seg fault in SBInterpolatedImage constructor + if not image.bounds.isDefined(): + raise GalSimUndefinedBoundsError( + "Supplied image does not have bounds defined." + ) + + # check what normalization was specified for the image: is it an image of surface + # brightness, or flux? + if normalization.lower() not in ("flux", "f", "surface brightness", "sb"): + raise GalSimValueError( + "Invalid normalization requested.", + normalization, + ("flux", "f", "surface brightness", "sb"), + ) + + # Set up the interpolants if none was provided by user, or check that the user-provided ones + # are of a valid type + self._gsparams = GSParams.check(gsparams) + if x_interpolant is None: + self._x_interpolant = Quintic(gsparams=self._gsparams) + else: + self._x_interpolant = convert_interpolant(x_interpolant).withGSParams( + self._gsparams + ) + if k_interpolant is None: + self._k_interpolant = Quintic(gsparams=self._gsparams) + else: + self._k_interpolant = convert_interpolant(k_interpolant).withGSParams( + self._gsparams + ) + + if pad_image is not None: + raise NotImplementedError("pad_image not implemented in jax_galsim.") + + if pad_factor <= 0.0: + raise GalSimRangeError( + "Invalid pad_factor <= 0 in InterpolatedImage", pad_factor, 0.0 + ) + + if noise_pad_size: + raise NotImplementedError( + "InterpolatedImages do not support noise padding in jax_galsim." + ) + else: + if noise_pad: + raise NotImplementedError( + "InterpolatedImages do not support noise padding in jax_galsim." + ) + + if scale is not None: + if wcs is not None: + raise GalSimIncompatibleValuesError( + "Cannot provide both scale and wcs to InterpolatedImage", + scale=self._jax_children[1]["scale"], + wcs=self._jax_children[1]["wcs"], + ) + elif wcs is not None: + if not isinstance(wcs, BaseWCS): + raise TypeError("wcs parameter is not a galsim.BaseWCS instance") + else: + if self._jax_children[0].wcs is None: + raise GalSimIncompatibleValuesError( + "No information given with Image or keywords about pixel scale!", + scale=self._jax_children[1]["scale"], + wcs=self._jax_children[1]["wcs"], + image=self._jax_children[0], + ) + + @doc_inherit + def withGSParams(self, gsparams=None, **kwargs): + if gsparams == self.gsparams: + return self + # Checking gsparams + gsparams = GSParams.check(gsparams, self.gsparams, **kwargs) + # Flattening the representation to instantiate a clean new object + children, aux_data = self.tree_flatten() + aux_data["gsparams"] = gsparams + ret = self.tree_unflatten(aux_data, children) + + return ret + + def tree_flatten(self): + """This function flattens the InterpolatedImage into a list of children + nodes that will be traced by JAX and auxiliary static data.""" + return (self._jax_children, copy.copy(self._jax_aux_data)) + + @classmethod + def tree_unflatten(cls, aux_data, children): + """Recreates an instance of the class from flatten representation""" + val = {} + val.update(aux_data) + val.update(children[1]) + ret = cls(children[0], **val) + return ret + + def __getstate__(self): + d = self.__dict__.copy() + d.pop("_workspace") + return d + + def __setstate__(self, d): + self.__dict__ = d + self._workspace = {} + + @property + def x_interpolant(self): + """The real-space `Interpolant` for this profile.""" + return self._x_interpolant + + @property + def k_interpolant(self): + """The Fourier-space `Interpolant` for this profile.""" + return self._k_interpolant + + @lazy_property + def image(self): + """The underlying `Image` being interpolated.""" + return self._xim[self._image.bounds] + + @property + def _flux(self): + return self._image_flux + + @lazy_property + def _centroid(self): + x, y = self._pad_image.get_pixel_centers() + tot = jnp.sum(self._pad_image.array) + xpos = jnp.sum(x * self._pad_image.array) / tot + ypos = jnp.sum(y * self._pad_image.array) / tot + return PositionD(xpos, ypos) + + @lazy_property + def _max_sb(self): + return jnp.max(jnp.abs(self._pad_image.array)) + + @lazy_property + def _flux_ratio(self): + if self._jax_children[1]["flux"] is None: + flux = self._image_flux + if self._jax_aux_data["normalization"].lower() in ( + "surface brightness", + "sb", + ): + flux *= self._wcs.pixelArea() + else: + flux = self._jax_children[1]["flux"] + + # If the user specified a flux, then set the flux ratio for the transform that wraps + # this class + return flux / self._image_flux + + @lazy_property + def _image_flux(self): + return jnp.sum(self._image.array, dtype=float) + + @lazy_property + def _offset(self): + # Figure out the offset to apply based on the original image (not the padded one). + # We will apply this below in _sbp. + offset = self._parse_offset(self._jax_children[1]["offset"]) + return self._adjust_offset( + self._image.bounds, offset, None, self._jax_aux_data["use_true_center"] + ) + + @lazy_property + def _image(self): + # Store the image as an attribute and make sure we don't change the original image + # in anything we do here. (e.g. set scale, etc.) + if self._jax_aux_data["depixelize"]: + # FIXME: no depixelize in jax_galsim + # self._image = image.view(dtype=np.float64).depixelize(self._x_interpolant) + raise NotImplementedError( + "InterpolatedImages do not support 'depixelize' in jax_galsim." + ) + else: + image = self._jax_children[0].view(dtype=float) + + if self._jax_aux_data["_recenter_image"]: + image.setCenter(0, 0) + + return image + + @lazy_property + def _wcs(self): + im_cen = ( + self._jax_children[0].true_center + if self._jax_aux_data["use_true_center"] + else self._jax_children[0].center + ) + + # error checking was done on init + if self._jax_children[1]["scale"] is not None: + wcs = PixelScale(self._jax_children[1]["scale"]) + elif self._jax_children[1]["wcs"] is not None: + wcs = self._jax_children[1]["wcs"] + else: + wcs = self._jax_children[0].wcs + + return wcs.local(image_pos=im_cen) + + @lazy_property + def _jac_arr(self): + image = self._jax_children[0] + im_cen = ( + image.true_center if self._jax_aux_data["use_true_center"] else image.center + ) + return self._wcs.jacobian(image_pos=im_cen).getMatrix().ravel() + + @lazy_property + def _xim(self): + pad_factor = self._jax_aux_data["pad_factor"] + + # The size of the final padded image is the largest of the various size specifications + pad_size = max(self._image.array.shape) + if pad_factor > 1.0: + pad_size = int(math.ceil(pad_factor * pad_size)) + # And round up to a good fft size + pad_size = Image.good_fft_size(pad_size) + + xim = Image( + _zeropad_image( + self._image.array, (pad_size - max(self._image.array.shape)) // 2 + ), + wcs=PixelScale(1.0), + ) + xim.setCenter(0, 0) + # after the call to setCenter you get a WCS with an offset in + # it instead of a pure pixel scale + xim.wcs = PixelScale(1.0) + + # Now place the given image in the center of the padding image: + xim[self._image.bounds] = self._image + + return xim + + @lazy_property + def _pad_image(self): + # These next two allow for easy pickling/repring. We don't need to serialize all the + # zeros around the edge. But we do need to keep any non-zero padding as a pad_image. + xim = self._xim + nz_bounds = self._image.bounds + return xim[nz_bounds] + + @lazy_property + def _kim(self): + return self._xim.calculate_fft() + + @lazy_property + def _pos_neg_fluxes(self): + # record pos and neg fluxes now too + # see code here: https://github.com/GalSim-developers/GalSim/blob/releases/2.5/src/SBInterpolatedImage.cpp#L1225 + pflux = jnp.sum(jnp.where(self._pad_image.array > 0, self._pad_image.array, 0)) + nflux = jnp.abs( + jnp.sum(jnp.where(self._pad_image.array < 0, self._pad_image.array, 0)) + ) + pint = self._x_interpolant.positive_flux + nint = self._x_interpolant.negative_flux + pint2d = pint * pint + nint * nint + nint2d = 2 * pint * nint + return [ + pint2d * pflux + nint2d * nflux, + pint2d * nflux + nint2d * pflux, + ] + + @property + def _positive_flux(self): + return self._pos_neg_fluxes[0] + + @property + def _negative_flux(self): + return self._pos_neg_fluxes[1] + + def _flux_per_photon(self): + return self._calculate_flux_per_photon() + + @lazy_property + def _maxk(self): + if self._jax_aux_data["_force_maxk"]: + _, minor = compute_major_minor_from_jacobian(self._jac_arr.reshape((2, 2))) + return self._jax_aux_data["_force_maxk"] * minor + else: + return self._getMaxK(self._jax_aux_data["calculate_maxk"]) + + @lazy_property + def _stepk(self): + if self._jax_aux_data["_force_stepk"]: + _, minor = compute_major_minor_from_jacobian(self._jac_arr.reshape((2, 2))) + return self._jax_aux_data["_force_stepk"] * minor + else: + return self._getStepK(self._jax_aux_data["calculate_stepk"]) + + def _getStepK(self, calculate_stepk): + # GalSim cannot automatically know what stepK and maxK are appropriate for the + # input image. So it is usually worth it to do a manual calculation (below). + # + # However, there is also a hidden option to force it to use specific values of stepK and + # maxK (caveat user!). The values of _force_stepk and _force_maxk should be provided in + # terms of physical scale, e.g., for images that have a scale length of 0.1 arcsec, the + # stepK and maxK should be provided in units of 1/arcsec. Then we convert to the 1/pixel + # units required by the C++ layer below. Also note that profile recentering for even-sized + # images (see the ._adjust_offset step below) leads to automatic reduction of stepK slightly + # below what is provided here, while maxK is preserved. + if calculate_stepk: + if calculate_stepk is True: + im = self.image + else: + # If not a bool, then value is max_stepk + R = (jnp.ceil(jnp.pi / calculate_stepk)).astype(int) + b = BoundsI(-R, R, -R, R) + b = self.image.bounds & b + im = self.image[b] + thresh = (1.0 - self.gsparams.folding_threshold) * self._image_flux + # this line appears buggy in galsim - I expect they meant to use im + R = _calculate_size_containing_flux(im, thresh) + else: + R = max(*self.image.array.shape) / 2.0 - 0.5 + return self._getSimpleStepK(R) + + def _getSimpleStepK(self, R): + # Add xInterp range in quadrature just like convolution: + R2 = self._x_interpolant.xrange + R = jnp.hypot(R, R2) + stepk = jnp.pi / R + return stepk + + def _getMaxK(self, calculate_maxk): + if calculate_maxk: + _uscale = 1 / (2 * jnp.pi) + _maxk = self._x_interpolant.urange() / _uscale + + if calculate_maxk is True: + maxk = _find_maxk( + self._kim, _maxk, self._gsparams.maxk_threshold * self.flux + ) + else: + maxk = _find_maxk( + self._kim, calculate_maxk, self._gsparams.maxk_threshold * self.flux + ) + + return maxk + else: + return self._x_interpolant.krange + + def _xValue(self, pos): + x = jnp.array([pos.x], dtype=float) + y = jnp.array([pos.y], dtype=float) + return _xValue_arr( + x, + y, + self._offset.x, + self._offset.y, + self._pad_image.bounds.xmin, + self._pad_image.bounds.ymin, + self._pad_image.array, + self._x_interpolant, + )[0] + + def _kValue(self, kpos): + kx = jnp.array([kpos.x], dtype=float) + ky = jnp.array([kpos.y], dtype=float) + return _kValue_arr( + kx, + ky, + self._offset.x, + self._offset.y, + self._kim.bounds.xmin, + self._kim.bounds.ymin, + self._kim.array, + self._kim.scale, + self._x_interpolant, + self._k_interpolant, + )[0] + + def _shoot(self, photons, rng): + raise NotImplementedError("Photon shooting not implemented.") + + def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): + jacobian = jnp.eye(2) if jac is None else jac + + flux_scaling *= image.scale**2 + + # Create an array of coordinates + coords = jnp.stack(image.get_pixel_centers(), axis=-1) + coords = coords * image.scale # Scale by the image pixel scale + coords = coords - jnp.asarray(offset) # Add the offset + + # Apply the jacobian transformation + inv_jacobian = jnp.linalg.inv(jacobian) + _, logdet = jnp.linalg.slogdet(inv_jacobian) + coords = jnp.dot(coords, inv_jacobian.T) + flux_scaling *= jnp.exp(logdet) + + im = _xValue_arr( + coords[..., 0], + coords[..., 1], + self._offset.x, + self._offset.y, + self._pad_image.bounds.xmin, + self._pad_image.bounds.ymin, + self._pad_image.array, + self._x_interpolant, + ) + + # Apply the flux scaling + im = (im * flux_scaling).astype(image.dtype) + + # Return an image + return Image(array=im, bounds=image.bounds, wcs=image.wcs, check_bounds=False) + + def _drawKImage(self, image, jac=None): + jacobian = jnp.eye(2) if jac is None else jac + + # Create an array of coordinates + coords = jnp.stack(image.get_pixel_centers(), axis=-1) + coords = coords * image.scale # Scale by the image pixel scale + coords = jnp.dot(coords, jacobian) + + im = _kValue_arr( + coords[..., 0], + coords[..., 1], + self._offset.x, + self._offset.y, + self._kim.bounds.xmin, + self._kim.bounds.ymin, + self._kim.array, + self._kim.scale, + self._x_interpolant, + self._k_interpolant, + ) + im = (im).astype(image.dtype) + + # Return an image + return Image(array=im, bounds=image.bounds, wcs=image.wcs, check_bounds=False) + + +@_wraps(_galsim._InterpolatedImage) +def _InterpolatedImage( + image, + x_interpolant=Quintic(), + k_interpolant=Quintic(), + use_true_center=True, + offset=None, + gsparams=None, + force_stepk=0.0, + force_maxk=0.0, +): + return InterpolatedImage( + image, + x_interpolant=x_interpolant, + k_interpolant=k_interpolant, + use_true_center=use_true_center, + offset=offset, + gsparams=gsparams, + calculate_maxk=False, + calculate_stepk=False, + pad_factor=1.0, + flux=jnp.sum(image.array), + _force_stepk=force_stepk, + _force_maxk=force_maxk, + _recenter_image=False, + ) + + +def _xValue_arr(x, y, x_offset, y_offset, xmin, ymin, arr, x_interpolant): + vals = _draw_with_interpolant_xval( + x + x_offset, + y + y_offset, + xmin, + ymin, + arr, + x_interpolant, + ) + return vals + + +@partial(jax.jit, static_argnames=("interp",)) +def _draw_with_interpolant_xval(x, y, xmin, ymin, zp, interp): + """This helper function interpolates an image (`zp`) with an interpolant `interp` + at the pixel locations given by `x`, `y`. The lower-left corner of the image is + `xmin` / `ymin`. + + A more standard C/C++ code would have a set of nested for loops that iterates over each + location to interpolate and then over the nterpolation kernel. + + In JAX, we instead write things such that the loop over the points to be interpolated + is vectorized in the code. We represent the loops over the interpolation kernel as explicit + for loops. + """ + # the vectorization over the interpolation points is easier to think about + # if they are in a 1D array. So we use ravel to flatten them and then reshape + # at the end. + orig_shape = x.shape + + # the variables here are + # x/y: the x/y coordinates of the points to be interpolated + # xi/yi: the index of the nerest pixel below the point + # xp/yp: the x/y coordinate of the nearest pixel below the point + # nx/ny: the size of the x/y arrays + x = x.ravel() + xi = jnp.floor(x - xmin).astype(jnp.int32) + xp = xi + xmin + nx = zp.shape[1] + + y = y.ravel() + yi = jnp.floor(y - ymin).astype(jnp.int32) + yp = yi + ymin + ny = zp.shape[0] + + # this function is the inner loop over the x direction + # the variables are + # i: the index of the location in the interpolation kernel + # z: the final interpolated values + # wy: the weight of the interpolation kernel in the y direction + # msky: a mask that is true if the y index is in bounds + # yind: the y index of the interpolation point needed by the kernel + def _body_1d(i, args): + z, wy, msky, yind, xi, xp, zp = args + + # this block computes the x weight using the + # offset in the interpolation kernel i + xind = xi + i + mskx = (xind >= 0) & (xind < nx) + _x = x - (xp + i) + wx = interp._xval_noraise(_x) + + # the actual interpolation is done here. + # we use jnp.where to only do the interpolation + # where the x and y indices are in bounds. + # the total weight is the product of the x and y weights. + w = wx * wy + msk = msky & mskx + z += jnp.where(msk, zp[yind, xind] * w, 0) + + return [z, wy, msky, yind, xi, xp, zp] + + # this function is the outer loop over the y direction + # the variables are + # i: the index of the location in the interpolation kernel + # z: the final interpolated values + def _body(i, args): + z, xi, yi, xp, yp, zp = args + + # this block computes the x weight using the + # offset in the interpolation kernel i + yind = yi + i + msk = (yind >= 0) & (yind < ny) + _y = y - (yp + i) + wy = interp._xval_noraise(_y) + + # this call computes the interpolant for each x locatoon that gets + # paired with this y location + z = jax.lax.fori_loop( + -interp.xrange, interp.xrange + 1, _body_1d, [z, wy, msk, yind, xi, xp, zp] + )[0] + return [z, xi, yi, xp, yp, zp] + + # the actual loop call for y is here + z = jax.lax.fori_loop( + -interp.xrange, + interp.xrange + 1, + _body, + [jnp.zeros(x.shape, dtype=float), xi, yi, xp, yp, zp], + )[0] + + # we reshape on the way out to match the input shape + return z.reshape(orig_shape) + + +def _kValue_arr( + kx, + ky, + x_offset, + y_offset, + kxmin, + kymin, + arr, + scale, + x_interpolant, + k_interpolant, +): + # phase factor due to offset + # not we shift by -offset which explains the sign + # in the exponent + pfac = jnp.exp(1j * (kx * x_offset + ky * y_offset)) + + kxi = kx / scale + kyi = ky / scale + + _uscale = 1.0 / (2.0 * jnp.pi) + _maxk_xint = x_interpolant.urange() / _uscale / scale + + # here we do the actual inteprolation in k space + val = _draw_with_interpolant_kval( + kxi, + kyi, + kymin, # this is not a bug! we need the minimum for the full periodic space + kymin, + arr, + k_interpolant, + ) + + # finally we multiply by the FFT of the real-space interpolation function + # and mask any values that are outside the range of the real-space interpolation + # FFT + msk = (jnp.abs(kxi) <= _maxk_xint) & (jnp.abs(kyi) <= _maxk_xint) + xint_val = x_interpolant._kval_noraise(kx) * x_interpolant._kval_noraise(ky) + return jnp.where(msk, val * xint_val * pfac, 0.0) + + +@partial(jax.jit, static_argnames=("interp",)) +def _draw_with_interpolant_kval(kx, ky, kxmin, kymin, zp, interp): + """This function interpolates complex k-space images and follows the + same basic structure as _draw_with_interpolant_xval above. + + The key difference is that the k-space images are Hermitian and so + only half of the data is actually in memory. We account for this by + computing all of the interpolation weights and indicies as if we had + the full image. Then finally, if we need a value that is not in memory, + we get it from the values we have via the Hermitian symmetry. + """ + # all of the code below is almost line-for-line the same as the + # _draw_with_interpolant_xval function above. + orig_shape = kx.shape + kx = kx.ravel() + kxi = jnp.floor(kx - kxmin).astype(jnp.int32) + kxp = kxi + kxmin + # this is the number of pixels in the half image and is needed + # for computing values via Hermition symmetry below + nkx_2 = zp.shape[1] - 1 + nkx = nkx_2 * 2 + + ky = ky.ravel() + kyi = jnp.floor(ky - kymin).astype(jnp.int32) + kyp = kyi + kymin + nky = zp.shape[0] + + def _body_1d(i, args): + z, wky, kyind, kxi, nkx, nkx_2, kxp, zp = args + + kxind = (kxi + i) % nkx + _kx = kx - (kxp + i) + wkx = interp._xval_noraise(_kx) + + # this is the key difference from the xval function + # we need to use the Hermitian symmetry to get the + # values that are not in memory + # in memory we have the values at nkx_2 to nkx - 1 + # the Hermitian symmetry is that + # f(ky, kx) = conjugate(f(-kx, -ky)) + # In indices this is a symmetric flip about the central + # pixels at kx = ky = 0. + # we do not need to mask any values that run off the edge of the image + # since we rewrap them using the periodicity of the image. + val = jnp.where( + kxind < nkx_2, + zp[(nky - kyind) % nky, nkx - kxind - nkx_2].conjugate(), + zp[kyind, kxind - nkx_2], + ) + z += val * wkx * wky + + return [z, wky, kyind, kxi, nkx, nkx_2, kxp, zp] + + def _body(i, args): + z, kxi, kyi, nky, nkx, nkx_2, kxp, kyp, zp = args + kyind = (kyi + i) % nky + _ky = ky - (kyp + i) + wky = interp._xval_noraise(_ky) + z = jax.lax.fori_loop( + -interp.xrange, + interp.xrange + 1, + _body_1d, + [z, wky, kyind, kxi, nkx, nkx_2, kxp, zp], + )[0] + return [z, kxi, kyi, nky, nkx, nkx_2, kxp, kyp, zp] + + z = jax.lax.fori_loop( + -interp.xrange, + interp.xrange + 1, + _body, + [jnp.zeros(kx.shape, dtype=complex), kxi, kyi, nky, nkx, nkx_2, kxp, kyp, zp], + )[0] + return z.reshape(orig_shape) + + +@jax.jit +def _flux_frac(a, x, y, cenx, ceny): + def _body(d, args): + res, a, dx, dy, cenx, ceny = args + msk = (jnp.abs(dx) <= d) & (jnp.abs(dx) <= d) + + res = res.at[d].set( + jnp.sum( + jnp.where( + msk, + a, + 0.0, + ) + ) + ) + + return [res, a, dx, dy, cenx, ceny] + + res = jnp.zeros(a.shape[0], dtype=float) - jnp.inf + return jax.lax.fori_loop( + 0, a.shape[0], _body, [res, a, x - cenx, y - ceny, cenx, ceny] + )[0] + + +@jax.jit +def _calculate_size_containing_flux(image, thresh): + cenx, ceny = image.center.x, image.center.y + x, y = image.get_pixel_centers() + fluxes = _flux_frac(image.array, x, y, cenx, ceny) + msk = fluxes >= -jnp.inf + fluxes = jnp.where(msk, fluxes, jnp.max(fluxes)) + d = jnp.arange(image.array.shape[0]) + 1.0 + # below we use a linear interpolation table to find the maximum size + # in pixels that contains a given flux (called thresh here) + # expfac controls how much we oversample the interpolation table + # in order to return a more accurate result + # we have it hard coded at 4 to compromise between speed and accuracy + expfac = 4.0 + dint = jnp.arange(image.array.shape[0] * expfac) / expfac + 1.0 + fluxes = jnp.interp(dint, d, fluxes) + msk = fluxes <= thresh + return ( + jnp.argmax( + jnp.where( + msk, + dint, + -jnp.inf, + ) + ) + / expfac + + 1.0 + ) + + +@jax.jit +def _inner_comp_find_maxk(arr, thresh, kx, ky): + msk = (arr * arr.conjugate()).real > thresh * thresh + max_kx = jnp.max( + jnp.where( + msk, + jnp.abs(kx), + -jnp.inf, + ) + ) + max_ky = jnp.max( + jnp.where( + msk, + jnp.abs(ky), + -jnp.inf, + ) + ) + return jnp.maximum(max_kx, max_ky) + + +@jax.jit +def _find_maxk(kim, max_maxk, thresh): + kx, ky = kim.get_pixel_centers() + kx *= kim.scale + ky *= kim.scale + # this minimum bounds the empirically determined + # maxk from the image (computed by _inner_comp_find_maxk) + # by max_maxk from above + return jnp.minimum( + _inner_comp_find_maxk(kim.array, thresh, kx, ky), + max_maxk, + ) diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index f9ec4839..a13dcd60 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -147,12 +147,12 @@ def __init__( @property def beta(self): """The beta parameter of this `Moffat` profile.""" - return self.params["beta"] + return self._params["beta"] @property def trunc(self): """The truncation radius (if any) of this `Moffat` profile.""" - return self.params["trunc"] + return self._params["trunc"] @property def scale_radius(self): diff --git a/jax_galsim/transform.py b/jax_galsim/transform.py index c64863f8..aa34eec5 100644 --- a/jax_galsim/transform.py +++ b/jax_galsim/transform.py @@ -4,7 +4,7 @@ from jax._src.numpy.util import _wraps from jax.tree_util import register_pytree_node_class -from jax_galsim.core.utils import ensure_hashable +from jax_galsim.core.utils import compute_major_minor_from_jacobian, ensure_hashable from jax_galsim.gsobject import GSObject from jax_galsim.gsparams import GSParams from jax_galsim.position import PositionD @@ -59,17 +59,38 @@ def __init__( "flux_ratio": flux_ratio, } - if isinstance(obj, Transformation): + # this import is here to avoid circular imports + # we do not want to mess with the transform properties of the interpolated image + from .interpolatedimage import InterpolatedImage + + if isinstance(obj, Transformation) and not isinstance(obj, InterpolatedImage): # Combine the two affine transformations into one. - dx, dy = self._fwd(obj.offset.x, obj.offset.y) - self._params["offset"].x += dx - self._params["offset"].y += dy - self._params["jac"] = self._jac.dot(obj.jac) - self._params["flux_ratio"] *= obj._params["flux_ratio"] - self._original = obj.original + dx, dy = self._fwd(obj._params["offset"].x, obj._params["offset"].y) + self._offset.x += dx + self._offset.y += dy + self._params["jac"] = self._jac.dot(obj._jac) + self._params["flux_ratio"] *= obj._flux_ratio + self._original = obj._original else: self._original = obj + ############################################################## + # The internal code of the methods of the Transform class + # should only aceess _offset, _flux_ratio, and _jac. It + # should not pull these directly from _params. + # Things are structured this way since the interpolated image + # class inherits and overrides these methods. + + @property + def _offset(self): + return self._params["offset"] + + # we use this property so that the interpolated image can override + # how flux ratio is computer / stored + @property + def _flux_ratio(self): + return self._params["flux_ratio"] + @property def _jac(self): jac = self._params["jac"] @@ -79,7 +100,7 @@ def _jac(self): lambda jax: jnp.array([1.0, 0.0, 0.0, 1.0]), jac, ) - return jnp.asarray(jac, dtype=float).reshape(2, 2) + return jnp.asarray(jac, dtype=float).reshape((2, 2)) @property def original(self): @@ -94,17 +115,18 @@ def jac(self): @property def offset(self): """The offset of the transformation.""" - return self._params["offset"] + return self._offset @property def flux_ratio(self): """The flux ratio of the transformation.""" - return self._params["flux_ratio"] + return self._flux_ratio @property def _flux(self): return self._flux_scaling * self._original.flux + @_wraps(_galsim.Transformation.withGSParams) def withGSParams(self, gsparams=None, **kwargs): """Create a version of the current object with the given gsparams @@ -113,11 +135,11 @@ def withGSParams(self, gsparams=None, **kwargs): Unless you set ``propagate_gsparams=False``, this method will also update the gsparams of the wrapped component object. """ - if gsparams == self.gsparams: + if gsparams == self._gsparams: return self chld, aux = self.tree_flatten() - aux["gsparams"] = GSParams.check(gsparams, self.gsparams, **kwargs) + aux["gsparams"] = GSParams.check(gsparams, self._gsparams, **kwargs) if self._propagate_gsparams: new_obj = chld[0].withGSParams(aux["gsparams"]) chld = (new_obj,) + chld[1:] @@ -127,11 +149,11 @@ def withGSParams(self, gsparams=None, **kwargs): def __eq__(self, other): return self is other or ( isinstance(other, Transformation) - and self.original == other.original - and jnp.array_equal(self.jac, other.jac) - and self.offset == other.offset - and self.flux_ratio == other.flux_ratio - and self.gsparams == other.gsparams + and self._original == other._original + and jnp.array_equal(self._jac, other._jac) + and self._offset == other._params["offset"] + and self._flux_ratio == other._flux_ratio + and self._gsparams == other._gsparams and self._propagate_gsparams == other._propagate_gsparams ) @@ -139,12 +161,12 @@ def __hash__(self): return hash( ( "galsim.Transformation", - self.original, + self._original, ensure_hashable(self._jac.ravel()), - ensure_hashable(self.offset.x), - ensure_hashable(self.offset.y), - ensure_hashable(self.flux_ratio), - self.gsparams, + ensure_hashable(self._offset.x), + ensure_hashable(self._offset.y), + ensure_hashable(self._flux_ratio), + self._gsparams, self._propagate_gsparams, ) ) @@ -154,11 +176,11 @@ def __repr__(self): "galsim.Transformation(%r, jac=%r, offset=%r, flux_ratio=%r, gsparams=%r, " "propagate_gsparams=%r)" ) % ( - self.original, + self._original, ensure_hashable(self._jac.ravel()), - self.offset, - ensure_hashable(self.flux_ratio), - self.gsparams, + self._offset, + ensure_hashable(self._flux_ratio), + self._gsparams, self._propagate_gsparams, ) @@ -200,15 +222,15 @@ def _str_from_jac(cls, jac): return "" def __str__(self): - s = str(self.original) + s = str(self._original) s += self._str_from_jac(self._jac) - if self.offset.x != 0 or self.offset.y != 0: + if self._offset.x != 0 or self._offset.y != 0: s += ".shift(%s,%s)" % ( - ensure_hashable(self.offset.x), - ensure_hashable(self.offset.y), + ensure_hashable(self._offset.x), + ensure_hashable(self._offset.y), ) - if self.flux_ratio != 1.0: - s += " * %s" % ensure_hashable(self.flux_ratio) + if self._flux_ratio != 1.0: + s += " * %s" % ensure_hashable(self._flux_ratio) return s @property @@ -227,11 +249,11 @@ def _invjac(self): # than flux_ratio, which is really an amplitude scaling. @property def _amp_scaling(self): - return self._params["flux_ratio"] + return self._flux_ratio @property def _flux_scaling(self): - return jnp.abs(self._det) * self._params["flux_ratio"] + return jnp.abs(self._det) * self._flux_ratio def _fwd(self, x, y): res = jnp.dot(self._jac, jnp.array([x, y])) @@ -246,36 +268,25 @@ def _inv(self, x, y): return res[0], res[1] def _kfactor(self, kx, ky): - kx *= -1j * self.offset.x - ky *= -1j * self.offset.y + kx *= -1j * self._offset.x + ky *= -1j * self._offset.y kx += ky return self._flux_scaling * jnp.exp(kx) - def _major_minor(self): - if not hasattr(self, "_major"): - h1 = jnp.hypot( - self._jac[0, 0] + self._jac[1, 1], self._jac[0, 1] - self._jac[1, 0] - ) - h2 = jnp.hypot( - self._jac[0, 0] - self._jac[1, 1], self._jac[0, 1] + self._jac[1, 0] - ) - self._major = 0.5 * abs(h1 + h2) - self._minor = 0.5 * abs(h1 - h2) - @property def _maxk(self): - self._major_minor() - return self._original.maxk / self._minor + _, minor = compute_major_minor_from_jacobian(self._jac) + return self._original.maxk / minor @property def _stepk(self): - self._major_minor() - stepk = self._original.stepk / self._major + major, _ = compute_major_minor_from_jacobian(self._jac) + stepk = self._original.stepk / major # If we have a shift, we need to further modify stepk # stepk = Pi/R # R <- R + |shift| # stepk <- Pi/(Pi/stepk + |shift|) - dr = jnp.hypot(self.offset.x, self.offset.y) + dr = jnp.hypot(self._offset.x, self._offset.y) stepk = jnp.pi / (jnp.pi / stepk + dr) return stepk @@ -289,7 +300,7 @@ def _is_axisymmetric(self): self._original.is_axisymmetric and self._jac[0, 0] == self._jac[1, 1] and self._jac[0, 1] == -self._jac[1, 0] - and self.offset == PositionD(0.0, 0.0) + and self._offset == PositionD(0.0, 0.0) ) @property @@ -304,7 +315,7 @@ def _is_analytic_k(self): def _centroid(self): cen = self._original.centroid cen = PositionD(self._fwd(cen.x, cen.y)) - cen += self.offset + cen += self._offset return cen @property @@ -315,12 +326,16 @@ def _positive_flux(self): def _negative_flux(self): return self._flux_scaling * self._original.negative_flux + @property + def _flux_per_photon(self): + return self._calculate_flux_per_photon() + @property def _max_sb(self): return self._amp_scaling * self._original.max_sb def _xValue(self, pos): - pos -= self.offset + pos -= self._offset inv_pos = PositionD(self._inv(pos.x, pos.y)) return self._original._xValue(inv_pos) * self._amp_scaling @@ -331,12 +346,12 @@ def _kValue(self, kpos): def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): dx, dy = offset if jac is not None: - x1 = jac.dot(self.offset._array) + x1 = jac.dot(self._offset._array) dx += x1[0] dy += x1[1] else: - dx += self.offset.x - dy += self.offset.y + dx += self._offset.x + dy += self._offset.y flux_scaling *= self._flux_scaling jac = ( self._jac @@ -360,7 +375,7 @@ def _drawKImage(self, image, jac=None): image = self._original._drawKImage(image, jac1) _jac = jnp.eye(2) if jac is None else jac - image = apply_kImage_phases(self.offset, image, _jac) + image = apply_kImage_phases(self._offset, image, _jac) image = image * self._flux_scaling return image @@ -372,7 +387,7 @@ def tree_flatten(self): children = (self._original, self._params) # Define auxiliary static data that doesn’t need to be traced aux_data = { - "gsparams": self.gsparams, + "gsparams": self._gsparams, "propagate_gsparams": self._propagate_gsparams, } return (children, aux_data) diff --git a/jax_galsim/utilities.py b/jax_galsim/utilities.py index b9401fb8..d6c74af7 100644 --- a/jax_galsim/utilities.py +++ b/jax_galsim/utilities.py @@ -1,3 +1,5 @@ +import functools + import galsim as _galsim import jax.numpy as jnp from jax._src.numpy.util import _wraps @@ -7,6 +9,29 @@ printoptions = _galsim.utilities.printoptions +@_wraps( + _galsim.utilities.lazy_property, + lax_description=( + "The LAX version of this decorator uses an `_workspace` attribute " + "attached to the object so that the cache can easily be discarded " + "for certain operations." + ), +) +def lazy_property(func): + attname = func.__name__ + "_cached" + + @property + @functools.wraps(func) + def _func(self): + if not hasattr(self, "_workspace"): + self._workspace = {} + if attname not in self._workspace: + self._workspace[attname] = func(self) + return self._workspace[attname] + + return _func + + @_wraps(_galsim.utilities.parse_pos_args) def parse_pos_args(args, kwargs, name1, name2, integer=False, others=[]): def canindex(arg): @@ -92,6 +117,17 @@ def g1g2_to_e1e2(g1, g2): return e1, e2 +@_wraps(_galsim.utilities.convert_interpolant) +def convert_interpolant(interpolant): + from jax_galsim.interpolant import Interpolant + + if isinstance(interpolant, Interpolant): + return interpolant + else: + # Will raise an appropriate exception if this is invalid. + return Interpolant.from_name(interpolant) + + @_wraps(_galsim.utilities.unweighted_moments) def unweighted_moments(image, origin=None): from jax_galsim.position import PositionD diff --git a/jax_galsim/wcs.py b/jax_galsim/wcs.py index 3cebeeb0..51c3a832 100644 --- a/jax_galsim/wcs.py +++ b/jax_galsim/wcs.py @@ -126,6 +126,12 @@ def local(self, image_pos=None, world_pos=None, color=None): raise TypeError("image_pos must be a PositionD or PositionI argument") return self._local(image_pos, color) + @_wraps(_galsim.BaseWCS.jacobian) + def jacobian(self, image_pos=None, world_pos=None, color=None): + if color is None: + color = self._color + return self.local(image_pos, world_pos, color=color)._toJacobian() + @_wraps(_galsim.BaseWCS.affine) def affine(self, image_pos=None, world_pos=None, color=None): if color is None: diff --git a/tests/GalSim b/tests/GalSim index b018d57f..1ed5131a 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit b018d57fba88eabbaacf40d34d3029a77e7071f2 +Subproject commit 1ed5131a54b4dbee384fee6b82b5e2e478ef0492 diff --git a/tests/conftest.py b/tests/conftest.py index 5db13ef4..17175c1b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,20 +1,20 @@ -import inspect -import os -import sys -from functools import lru_cache -from unittest.mock import patch - -import galsim -import pytest -import yaml - # Define the accuracy for running the tests from jax.config import config -import jax_galsim - config.update("jax_enable_x64", True) +import inspect # noqa: E402 +import os # noqa: E402 +import sys # noqa: E402 +from functools import lru_cache # noqa: E402 +from unittest.mock import patch # noqa: E402 + +import galsim # noqa: E402 +import pytest # noqa: E402 +import yaml # noqa: E402 + +import jax_galsim # noqa: E402 + # Identify the path to this current file test_directory = os.path.dirname(os.path.abspath(__file__)) diff --git a/tests/galsim_tests_config.yaml b/tests/galsim_tests_config.yaml index e9aa68ef..a4715dbb 100644 --- a/tests/galsim_tests_config.yaml +++ b/tests/galsim_tests_config.yaml @@ -12,6 +12,7 @@ enabled_tests: - test_shear.py - test_shear_position.py - test_wcs.py + - test_interpolatedimage.py coord: - test_angle.py - test_angleunit.py @@ -23,7 +24,9 @@ enabled_tests: # correspond to features that are not implemented yet # in jax_galsim allowed_failures: - - "NotImplementedError" + - "Phot shooting not yet implemented in drawImage" + - "Real-space convolutions are not implemented" + - "Photon shooting convolutions are not implemented" - "module 'jax_galsim' has no attribute 'Airy'" - "module 'jax_galsim' has no attribute 'Kolmogorov'" - "module 'jax_galsim' has no attribute 'Sersic'" @@ -50,6 +53,10 @@ allowed_failures: - "module 'jax_galsim.utilities' has no attribute 'nCr'" - "'Image' object has no attribute 'bin'" - "has no attribute 'shoot'" + - "module 'jax_galsim' has no attribute 'integ'" + - "module 'jax_galsim' has no attribute 'InterpolatedKImage'" + - "module 'jax_galsim' has no attribute 'CorrelatedNoise'" + - "'Image' object has no attribute 'FindAdaptiveMom'" - "CelestialCoord.precess is too slow" # cannot get jax to warmup but once it does it passes - "ValueError not raised by from_xyz" - "ValueError not raised by greatCirclePoint" @@ -61,3 +68,7 @@ allowed_failures: - "module 'jax_galsim.utilities' has no attribute 'horner2d'" - "'Image' object has no attribute 'FindAdaptiveMom'" - "module 'jax_galsim' has no attribute 'BaseCorrelatedNoise'" + - "'Image' object has no attribute 'addNoise'" + - "Transform does not support callable arguments." + - "The JAX galsim.BaseDeviate does not support being used as a numpy PRNG." + - "jax_galsim does not support the galsim WCS class GSFitsWCS" diff --git a/tests/jax/galsim/test_image_jax.py b/tests/jax/galsim/test_image_jax.py index c14e87ca..1722b9b0 100644 --- a/tests/jax/galsim/test_image_jax.py +++ b/tests/jax/galsim/test_image_jax.py @@ -4502,59 +4502,57 @@ def test_wrap(): im_wrap.bounds, b, "image.wrap(%s) does not have the correct bounds" ) - # FIXME: turn on when hermitian wrapping is implemented - if False: - im2_wrap = im2.wrap(b2, hermitian="y") - # print('im_test = ',im_test[b2].array) - # print('im2_wrap = ',im2_wrap.array) - # print('diff = ',im2_wrap.array-im_test[b2].array) - np.testing.assert_array_almost_equal( - im2_wrap.array, - im_test[b2].array, - 12, - "image.wrap(%s) did not match expectation" % b, - ) - np.testing.assert_array_equal( - im2_wrap.array, - im2[b2].array, - "image.wrap(%s) did not return the right subimage", - ) - np.testing.assert_equal( - im2_wrap.bounds, b2, "image.wrap(%s) does not have the correct bounds" - ) + im2_wrap = im2.wrap(b2, hermitian="y") + # print('im_test = ',im_test[b2].array) + # print('im2_wrap = ',im2_wrap.array) + # print('diff = ',im2_wrap.array-im_test[b2].array) + np.testing.assert_array_almost_equal( + im2_wrap.array, + im_test[b2].array, + 12, + "image.wrap(%s) did not match expectation" % b, + ) + np.testing.assert_array_equal( + im2_wrap.array, + im2[b2].array, + "image.wrap(%s) did not return the right subimage", + ) + np.testing.assert_equal( + im2_wrap.bounds, b2, "image.wrap(%s) does not have the correct bounds" + ) - im3_wrap = im3.wrap(b3, hermitian="x") - # print('im_test = ',im_test[b3].array) - # print('im3_wrap = ',im3_wrap.array) - # print('diff = ',im3_wrap.array-im_test[b3].array) - np.testing.assert_array_almost_equal( - im3_wrap.array, - im_test[b3].array, - 12, - "image.wrap(%s) did not match expectation" % b, - ) - np.testing.assert_array_equal( - im3_wrap.array, - im3[b3].array, - "image.wrap(%s) did not return the right subimage", - ) - np.testing.assert_equal( - im3_wrap.bounds, b3, "image.wrap(%s) does not have the correct bounds" - ) - - b = galsim.BoundsI(-K + 1, K, -L + 1, L) - b2 = galsim.BoundsI(-K + 1, K, 0, L) - b3 = galsim.BoundsI(0, K, -L + 1, L) - assert_raises(TypeError, im.wrap, bounds=None) - assert_raises(ValueError, im3.wrap, b, hermitian="x") - assert_raises(ValueError, im3.wrap, b2, hermitian="x") - assert_raises(ValueError, im.wrap, b3, hermitian="x") - assert_raises(ValueError, im2.wrap, b, hermitian="y") - assert_raises(ValueError, im2.wrap, b3, hermitian="y") - assert_raises(ValueError, im.wrap, b2, hermitian="y") - assert_raises(ValueError, im.wrap, b, hermitian="invalid") - assert_raises(ValueError, im2.wrap, b2, hermitian="invalid") - assert_raises(ValueError, im3.wrap, b3, hermitian="invalid") + im3_wrap = im3.wrap(b3, hermitian="x") + # print('im_test = ',im_test[b3].array) + # print('im3_wrap = ',im3_wrap.array) + # print('diff = ',im3_wrap.array-im_test[b3].array) + np.testing.assert_array_almost_equal( + im3_wrap.array, + im_test[b3].array, + 12, + "image.wrap(%s) did not match expectation" % b, + ) + np.testing.assert_array_equal( + im3_wrap.array, + im3[b3].array, + "image.wrap(%s) did not return the right subimage", + ) + np.testing.assert_equal( + im3_wrap.bounds, b3, "image.wrap(%s) does not have the correct bounds" + ) + + b = galsim.BoundsI(-K + 1, K, -L + 1, L) + b2 = galsim.BoundsI(-K + 1, K, 0, L) + b3 = galsim.BoundsI(0, K, -L + 1, L) + assert_raises(TypeError, im.wrap, bounds=None) + assert_raises(ValueError, im3.wrap, b, hermitian="x") + assert_raises(ValueError, im3.wrap, b2, hermitian="x") + assert_raises(ValueError, im.wrap, b3, hermitian="x") + assert_raises(ValueError, im2.wrap, b, hermitian="y") + assert_raises(ValueError, im2.wrap, b3, hermitian="y") + assert_raises(ValueError, im.wrap, b2, hermitian="y") + assert_raises(ValueError, im.wrap, b, hermitian="invalid") + assert_raises(ValueError, im2.wrap, b2, hermitian="invalid") + assert_raises(ValueError, im3.wrap, b3, hermitian="invalid") @timer diff --git a/tests/jax/test_api.py b/tests/jax/test_api.py index 35719624..6a21de43 100644 --- a/tests/jax/test_api.py +++ b/tests/jax/test_api.py @@ -37,6 +37,7 @@ def test_api_same(): "One of scale_radius, half_light_radius, or fwhm must be specified", "Arguments to Sum must be GSObjects", "'ArrayImpl' object has no attribute 'gsparams'", + "Supplied image must be an Image or file name", "Argument to Deconvolution must be a GSObject.", ] @@ -75,6 +76,19 @@ def _attempt_init(cls, kwargs): else: raise e + if cls in [jax_galsim.InterpolatedImage]: + try: + return cls( + jax_galsim.ImageD(jnp.arange(100).reshape((10, 10))), + scale=1.3, + **kwargs + ) + except Exception as e: + if any(estr in repr(e) for estr in OK_ERRORS): + pass + else: + raise e + return None @@ -100,6 +114,8 @@ def _kfun(x, prof): def _run_object_checks(obj, cls, kind): if kind == "pickle-eval-repr": + from numpy import array # noqa: F401 + # eval repr is identity mapping assert eval(repr(obj)) == obj @@ -351,6 +367,7 @@ def test_api_gsobject(kind): assert "Moffat" in cls_tested assert "Box" in cls_tested assert "Pixel" in cls_tested + assert "InterpolatedImage" in cls_tested @pytest.mark.parametrize( diff --git a/tests/jax/test_image_wrapping.py b/tests/jax/test_image_wrapping.py new file mode 100644 index 00000000..14bfe100 --- /dev/null +++ b/tests/jax/test_image_wrapping.py @@ -0,0 +1,144 @@ +import jax +import numpy as np +from galsim_test_helpers import timer + +import jax_galsim as galsim +from jax_galsim.core.wrap_image import ( + contract_hermitian_x, + contract_hermitian_y, + expand_hermitian_x, + expand_hermitian_y, +) + + +@timer +def test_image_wrapping_expand_contract(): + # For complex images (in particular k-space images), we often want the image to be implicitly + # Hermitian, so we only need to keep around half of it. + M = 38 + N = 25 + im = galsim.ImageCD(2 * M + 1, 2 * N + 1, xmin=-M, ymin=-N) # Explicitly Hermitian + im2 = galsim.ImageCD( + 2 * M + 1, N + 1, xmin=-M, ymin=0 + ) # Implicitly Hermitian across y axis + im3 = galsim.ImageCD( + M + 1, 2 * N + 1, xmin=0, ymin=-N + ) # Implicitly Hermitian across x axis + # print('im = ',im) + # print('im2 = ',im2) + # print('im3 = ',im3) + for i in range(-M, M + 1): + for j in range(-N, N + 1): + # An arbitrary, complicated Hermitian function. + val = ( + np.exp((i / (2.3 * M)) ** 2 + 1j * (2.8 * i - 1.3 * j)) + + ((2 + 3j * j) / (1.9 * N)) ** 3 + ) + # val = 2*(i-j)**2 + 3j*(i+j) + + im[i, j] = val + if j >= 0: + im2[i, j] = val + if i >= 0: + im3[i, j] = val + + # print("im = ",im.array) + + # Confirm that the image is Hermitian. + for i in range(-M, M + 1): + for j in range(-N, N + 1): + assert im(i, j) == im(-i, -j).conjugate() + + im_exp = expand_hermitian_x(im3.array) + np.testing.assert_allclose( + im_exp, + im.array, + err_msg="expand_hermitian_x() did not match expectation", + ) + + im_cnt = contract_hermitian_x(im.array) + np.testing.assert_allclose( + im_cnt, + im3.array, + err_msg="contract_hermitian_x() did not match expectation", + ) + + im_exp = expand_hermitian_y(im2.array) + np.testing.assert_allclose( + im_exp, + im.array, + err_msg="expand_hermitian_x() did not match expectation", + ) + + im_cnt = contract_hermitian_y(im.array) + np.testing.assert_allclose( + im_cnt, + im2.array, + err_msg="contract_hermitian_x() did not match expectation", + ) + + +@timer +def test_image_wrapping_autodiff(): + # For complex images (in particular k-space images), we often want the image to be implicitly + # Hermitian, so we only need to keep around half of it. + M = 38 + N = 25 + K = 8 + L = 5 + im = galsim.ImageCD(2 * M + 1, 2 * N + 1, xmin=-M, ymin=-N) # Explicitly Hermitian + im2 = galsim.ImageCD( + 2 * M + 1, N + 1, xmin=-M, ymin=0 + ) # Implicitly Hermitian across y axis + im3 = galsim.ImageCD( + M + 1, 2 * N + 1, xmin=0, ymin=-N + ) # Implicitly Hermitian across x axis + # print('im = ',im) + # print('im2 = ',im2) + # print('im3 = ',im3) + b = galsim.BoundsI(-K + 1, K, -L + 1, L) + im_test = galsim.ImageCD(b, init_value=0) + for i in range(-M, M + 1): + for j in range(-N, N + 1): + # An arbitrary, complicated Hermitian function. + val = ( + np.exp((i / (2.3 * M)) ** 2 + 1j * (2.8 * i - 1.3 * j)) + + ((2 + 3j * j) / (1.9 * N)) ** 3 + ) + # val = 2*(i-j)**2 + 3j*(i+j) + + im[i, j] = val + if j >= 0: + im2[i, j] = val + if i >= 0: + im3[i, j] = val + + ii = (i - b.xmin) % (b.xmax - b.xmin + 1) + b.xmin + jj = (j - b.ymin) % (b.ymax - b.ymin + 1) + b.ymin + im_test.addValue(ii, jj, val) + # print("im = ",im.array) + + # Confirm that the image is Hermitian. + for i in range(-M, M + 1): + for j in range(-N, N + 1): + assert im(i, j) == im(-i, -j).conjugate() + + @jax.jit + def _wrapit(im): + b3 = galsim.BoundsI(0, K, -L + 1, L) + return im.wrap(b3) + + # make sure this runs + p, grad = jax.vjp(_wrapit, im3) + grad = jax.jit(grad) + grad(p) + jax.jvp(_wrapit, (im3,), (im3 * 2,)) + + def _wrapit(im): + b3 = galsim.BoundsI(0, K, -L + 1, L) + return im.wrap(b3) + + # make sure this runs + p, grad = jax.vjp(_wrapit, im3) + grad(p) + jax.jvp(_wrapit, (im3,), (im3 * 2,)) diff --git a/tests/jax/test_interpolatedimage_utils.py b/tests/jax/test_interpolatedimage_utils.py new file mode 100644 index 00000000..c977eb46 --- /dev/null +++ b/tests/jax/test_interpolatedimage_utils.py @@ -0,0 +1,317 @@ +import hashlib + +import galsim as _galsim +import jax.numpy as jnp +import numpy as np +import pytest + +import jax_galsim +from jax_galsim.interpolant import ( # SincInterpolant, + Cubic, + Lanczos, + Linear, + Nearest, + Quintic, +) +from jax_galsim.interpolatedimage import ( + _draw_with_interpolant_kval, + _draw_with_interpolant_xval, +) + + +@pytest.mark.parametrize( + "interp", + [ + Nearest(), + Linear(), + # this is really slow right now and I am not sure why will fix later + # SincInterpolant(), + Linear(), + Cubic(), + Quintic(), + Lanczos(3, conserve_dc=False), + Lanczos(5, conserve_dc=True), + ], +) +def test_interpolatedimage_utils_draw_with_interpolant_xval(interp): + zp = jnp.array( + [ + [0.01, 0.08, 0.07, 0.02], + [0.13, 0.38, 0.52, 0.06], + [0.09, 0.41, 0.44, 0.09], + [0.04, 0.11, 0.10, 0.01], + ] + ) + for xmin in [-3, 0, 2]: + for ymin in [-5, 0, 1]: + for x in range(4): + for y in range(4): + np.testing.assert_allclose( + _draw_with_interpolant_xval( + jnp.array([x + xmin], dtype=float), + jnp.array([y + ymin], dtype=float), + xmin, + ymin, + zp, + interp, + ), + zp[y, x], + ) + + +@pytest.mark.parametrize( + "interp", + [ + Nearest(), + Linear(), + # this is really slow right now and I am not sure why will fix later + # SincInterpolant(), + Linear(), + Cubic(), + Quintic(), + Lanczos(3, conserve_dc=False), + Lanczos(5, conserve_dc=True), + ], +) +def test_interpolatedimage_utils_draw_with_interpolant_kval(interp): + zp = jnp.array( + [ + [0.01, 0.08, 0.07, 0.02], + [0.13, 0.38, 0.52, 0.06], + [0.09, 0.41, 0.44, 0.09], + [0.04, 0.11, 0.10, 0.01], + ] + ) + kim = jax_galsim.Image(zp, scale=1.0).calculate_fft() + nherm = kim.array.shape[0] + minherm = kim.bounds.ymin + kimherm = jax_galsim.Image( + jnp.zeros((kim.array.shape[0], kim.array.shape[0]), dtype=complex), + xmin=minherm, + ymin=minherm, + ) + for y in range(kimherm.bounds.ymin, kimherm.bounds.ymax + 1): + for x in range(kimherm.bounds.xmin, kimherm.bounds.xmax + 1): + if x >= 0: + kimherm[x, y] = kim[x, y] + else: + if y == minherm: + kimherm[x, y] = kim[-x, y].conj() + else: + kimherm[x, y] = kim[-x, -y].conj() + for x in range(nherm): + for y in range(nherm): + np.testing.assert_allclose( + _draw_with_interpolant_kval( + jnp.array([x + minherm], dtype=float), + jnp.array([y + minherm], dtype=float), + minherm, + minherm, + kim.array, + interp, + ), + kimherm(x + minherm, y + minherm), + ) + + +def test_interpolatedimage_utils_stepk_maxk(): + hlr = 0.5 + fwhm = 0.9 + scale = 0.2 + + ref_array = ( + _galsim.Convolve( + _galsim.Exponential(half_light_radius=hlr), + _galsim.Gaussian(fwhm=fwhm), + ) + .drawImage( + nx=53, + ny=53, + scale=scale, + ) + .array.astype(np.float64) + ) + + gimage_in = _galsim.Image(ref_array) + jgimage_in = jax_galsim.Image(ref_array) + gii = _galsim.InterpolatedImage(gimage_in, scale=scale) + jgii = jax_galsim.InterpolatedImage(jgimage_in, scale=scale) + + rtol = 1e-1 + np.testing.assert_allclose(gii.maxk, jgii.maxk, rtol=rtol, atol=0) + np.testing.assert_allclose(gii.stepk, jgii.stepk, rtol=rtol, atol=0) + + +@pytest.mark.parametrize("x_interp", ["lanczos15", "quintic"]) +@pytest.mark.parametrize("normalization", ["sb", "flux"]) +@pytest.mark.parametrize("use_true_center", [True, False]) +@pytest.mark.parametrize( + "wcs", + [ + _galsim.PixelScale(0.2), + _galsim.JacobianWCS(0.21, 0.03, -0.04, 0.23), + _galsim.AffineTransform(-0.03, 0.21, 0.18, 0.01, _galsim.PositionD(0.3, -0.4)), + ], +) +@pytest.mark.parametrize( + "offset_x", + [ + -4.35, + -0.45, + 0.0, + 0.67, + 3.78, + ], +) +@pytest.mark.parametrize( + "offset_y", + [ + -2.12, + -0.33, + 0.0, + 0.12, + 1.45, + ], +) +@pytest.mark.parametrize( + "ref_array", + [ + _galsim.Gaussian(fwhm=0.9).drawImage(nx=33, ny=33, scale=0.2).array, + _galsim.Gaussian(fwhm=0.9).drawImage(nx=32, ny=32, scale=0.2).array, + ], +) +@pytest.mark.parametrize("method", ["kValue", "xValue"]) +def test_interpolatedimage_utils_comp_to_galsim( + method, + ref_array, + offset_x, + offset_y, + wcs, + use_true_center, + normalization, + x_interp, +): + seed = max( + abs( + int( + hashlib.sha1( + f"{method}{ref_array}{offset_x}{offset_y}{wcs}{use_true_center}{normalization}{x_interp}".encode( + "utf-8" + ) + ).hexdigest(), + 16, + ) + ) + % (10**7), + 1, + ) + + rng = np.random.RandomState(seed=seed) + if rng.uniform() < 0.75: + pytest.skip( + "Skipping `test_interpolatedimage_utils_comp_to_galsim` case at random to save time." + ) + + gimage_in = _galsim.Image(ref_array, scale=0.2) + jgimage_in = jax_galsim.Image(ref_array, scale=0.2) + + gii = _galsim.InterpolatedImage( + gimage_in, + wcs=wcs, + offset=_galsim.PositionD(offset_x, offset_y), + use_true_center=use_true_center, + normalization=normalization, + x_interpolant=x_interp, + ) + jgii = jax_galsim.InterpolatedImage( + jgimage_in, + wcs=jax_galsim.BaseWCS.from_galsim(wcs), + offset=jax_galsim.PositionD(offset_x, offset_y), + use_true_center=use_true_center, + normalization=normalization, + x_interpolant=x_interp, + ) + + np.testing.assert_allclose(gii.stepk, jgii.stepk, rtol=0.5, atol=0) + np.testing.assert_allclose(gii.maxk, jgii.maxk, rtol=0.5, atol=0) + kxvals = [ + (0, 0), + (-5, -5), + (-10, 10), + (1, 1), + (1, -2), + (-1, 0), + (0, -1), + (-1, -1), + (-2, 2), + (-5, 0), + (3, -4), + (-3, 4), + ] + for x, y in kxvals: + if method == "kValue": + dk = jgii._original._kim.scale * rng.uniform(low=0.5, high=1.5) + np.testing.assert_allclose( + gii.kValue(x * dk, y * dk), + jgii.kValue(x * dk, y * dk), + err_msg=f"kValue mismatch: wcs={wcs}, x={x}, y={y}", + ) + else: + dx = jnp.sqrt(jgii._original._wcs.pixelArea()) * rng.uniform( + low=0.5, high=1.5 + ) + np.testing.assert_allclose( + gii.xValue(x * dx, y * dx), + jgii.xValue(x * dx, y * dx), + err_msg=f"xValue mismatch: wcs={wcs}, x={x}, y={y}", + ) + + +def _compute_fft_with_numpy_jax_galsim(im): + import numpy as np + + from jax_galsim import BoundsI, Image + + No2 = max(-im.bounds.xmin, im.bounds.xmax + 1, -im.bounds.ymin, im.bounds.ymax + 1) + + full_bounds = BoundsI(-No2, No2 - 1, -No2, No2 - 1) + if im.bounds == full_bounds: + # Then the image is already in the shape we need. + ximage = im + else: + # Then we pad out with zeros + ximage = Image(full_bounds, dtype=im.dtype, init_value=0) + ximage[im.bounds] = im[im.bounds] + + dx = im.scale + # dk = 2pi / (N dk) + dk = np.pi / (No2 * dx) + + out = Image(BoundsI(0, No2, -No2, No2 - 1), dtype=np.complex128, scale=dk) + out._array = np.fft.fftshift(np.fft.rfft2(np.fft.fftshift(ximage.array)), axes=0) + out *= dx * dx + out.setOrigin(0, -No2) + return out + + +@pytest.mark.parametrize("n", [5, 4]) +def test_interpolatedimage_utils_jax_galsim_fft_vs_galsim_fft(n): + rng = np.random.RandomState(42) + arr = rng.normal(size=(n, n)) + im = jax_galsim.Image(arr, scale=1) + kim = im.calculate_fft() + xkim = kim.calculate_inverse_fft() + np.testing.assert_allclose(im.array, xkim[im.bounds].array) + + np_kim = _compute_fft_with_numpy_jax_galsim(im) + np.testing.assert_allclose(kim.array, np_kim.array) + + rng = np.random.RandomState(42) + arr = rng.normal(size=(n, n)) + gim = jax_galsim.Image(arr, scale=1) + gkim = gim.calculate_fft() + gxkim = gkim.calculate_inverse_fft() + np.testing.assert_allclose(gim.array, gxkim[gim.bounds].array) + np.testing.assert_allclose(gim.array, im.array) + np.testing.assert_allclose(gkim.array, kim.array) + np.testing.assert_allclose(gxkim.array, xkim.array) diff --git a/tests/jax/test_metacal.py b/tests/jax/test_metacal.py new file mode 100644 index 00000000..0b21a072 --- /dev/null +++ b/tests/jax/test_metacal.py @@ -0,0 +1,407 @@ +import time +from functools import partial + +import galsim as _galsim +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +import jax_galsim + + +def _metacal_galsim( + im, psf, nse_im, scale, target_fwhm, g1, iim_kwargs, ipsf_kwargs, inse_kwargs, nk +): + iim = _galsim.InterpolatedImage( + _galsim.ImageD(im), + scale=scale, + x_interpolant="lanczos15", + **iim_kwargs, + ) + ipsf = _galsim.InterpolatedImage( + _galsim.ImageD(psf), + scale=scale, + x_interpolant="lanczos15", + **ipsf_kwargs, + ) + inse = _galsim.InterpolatedImage( + _galsim.ImageD(np.rot90(nse_im, 1)), + scale=scale, + x_interpolant="lanczos15", + **inse_kwargs, + ) + + ppsf_iim = _galsim.Convolve(iim, _galsim.Deconvolve(ipsf)) + ppsf_iim = ppsf_iim.shear(g1=g1, g2=0.0) + + prof = _galsim.Convolve( + ppsf_iim, + _galsim.Gaussian(fwhm=target_fwhm), + gsparams=_galsim.GSParams(minimum_fft_size=nk), + ) + + sim = prof.drawImage( + nx=33, + ny=33, + scale=scale, + method="no_pixel", + ).array.astype(np.float64) + + ppsf_inse = _galsim.Convolve(inse, _galsim.Deconvolve(ipsf)) + ppsf_inse = ppsf_inse.shear(g1=g1, g2=0.0) + snse = ( + _galsim.Convolve( + ppsf_inse, + _galsim.Gaussian(fwhm=target_fwhm), + gsparams=_galsim.GSParams(minimum_fft_size=nk), + ) + .drawImage( + nx=33, + ny=33, + scale=scale, + method="no_pixel", + ) + .array.astype(np.float64) + ) + return sim + np.rot90(snse, 3) + + +@partial(jax.jit, static_argnames=("nk",)) +def _metacal_jax_galsim_render(im, psf, g1, target_psf, scale, nk): + prepsf_im = jax_galsim.Convolve(im, jax_galsim.Deconvolve(psf)) + prepsf_im = prepsf_im.shear(g1=g1, g2=0.0) + + prof = jax_galsim.Convolve( + prepsf_im, + target_psf, + gsparams=jax_galsim.GSParams(minimum_fft_size=nk, maximum_fft_size=nk), + ) + + return prof.drawImage( + nx=33, + ny=33, + scale=scale, + method="no_pixel", + ).array.astype(np.float64) + + +def _metacal_jax_galsim(im, psf, nse_im, scale, target_fwhm, g1, nk): + iim = jax_galsim.InterpolatedImage( + jax_galsim.ImageD(im), scale=scale, x_interpolant="lanczos15" + ) + ipsf = jax_galsim.InterpolatedImage( + jax_galsim.ImageD(psf), scale=scale, x_interpolant="lanczos15" + ) + inse = jax_galsim.InterpolatedImage( + jax_galsim.ImageD(jnp.rot90(nse_im, 1)), scale=scale, x_interpolant="lanczos15" + ) + + target_psf = jax_galsim.Gaussian(fwhm=target_fwhm) + + sim = _metacal_jax_galsim_render(iim, ipsf, g1, target_psf, scale, nk) + + snse = _metacal_jax_galsim_render(inse, ipsf, g1, target_psf, scale, nk) + + return sim + jnp.rot90(snse, 3) + + +@pytest.mark.parametrize("nse", [1e-3, 1e-10]) +def test_metacal_comp_to_galsim(nse): + seed = 42 + hlr = 0.5 + fwhm = 0.9 + scale = 0.2 + g1 = 0.01 + target_fwhm = 1.0 + + rng = np.random.RandomState(seed) + + im = ( + _galsim.Convolve( + _galsim.Exponential(half_light_radius=hlr), + _galsim.Gaussian(fwhm=fwhm), + ) + .drawImage( + nx=33, + ny=33, + scale=scale, + ) + .array.astype(np.float64) + ) + + psf = ( + _galsim.Gaussian(fwhm=fwhm) + .drawImage( + nx=33, + ny=33, + scale=scale, + ) + .array.astype(np.float64) + ) + + nse_im = rng.normal(size=im.shape, scale=nse) + im += rng.normal(size=im.shape, scale=nse) + + # jax galsim and galsim set stepk and maxk differently due to slight + # algorithmic differences. We force them to be the same here for this + # test so it passes. + iim = jax_galsim.InterpolatedImage( + jax_galsim.ImageD(im), + scale=scale, + x_interpolant="lanczos15", + gsparams=jax_galsim.GSParams(minimum_fft_size=128), + ) + iim_kwargs = { + "_force_stepk": iim.stepk.item(), + "_force_maxk": iim.maxk.item(), + } + inse = jax_galsim.InterpolatedImage( + jax_galsim.ImageD(jnp.rot90(nse_im, 1)), + scale=scale, + x_interpolant="lanczos15", + gsparams=jax_galsim.GSParams(minimum_fft_size=128), + ) + inse_kwargs = { + "_force_stepk": inse.stepk.item(), + "_force_maxk": inse.maxk.item(), + } + ipsf = jax_galsim.InterpolatedImage( + jax_galsim.ImageD(psf), scale=scale, x_interpolant="lanczos15" + ) + ipsf_kwargs = { + "_force_stepk": ipsf.stepk.item(), + "_force_maxk": ipsf.maxk.item(), + } + + gt0 = time.time() + gres = _metacal_galsim( + im.copy(), + psf.copy(), + nse_im.copy(), + scale, + target_fwhm, + g1, + iim_kwargs, + ipsf_kwargs, + inse_kwargs, + 128, + ) + gt0 = time.time() - gt0 + + print("galsim time: ", gt0 * 1e3, " [ms]") + + for i in range(2): + if i == 0: + msg = "jit warmup" + elif i == 1: + msg = "jit" + jgt0 = time.time() + jgres = _metacal_jax_galsim( + im.copy(), + psf.copy(), + nse_im.copy(), + scale, + target_fwhm, + g1, + 128, + ) + jgres = jax.block_until_ready(jgres) + jgt0 = time.time() - jgt0 + print("jax-galsim time (%s): " % msg, jgt0 * 1e3, " [ms]") + + gim = gres + jgim = jgres + + atol = 1e-8 + if not np.allclose(gim, jgim, rtol=0, atol=atol): + import proplot as pplt + + fig, axs = pplt.subplots(ncols=3, nrows=1) + + axs[0].imshow(np.arcsinh(gres / nse)) + axs[1].imshow(np.arcsinh(jgres / nse)) + m = axs[2].imshow(jgres - gres) + axs[2].colorbar(m, loc="r") + + fig.show() + + np.testing.assert_allclose(gim, jgim, rtol=0, atol=atol) + + +def test_metacal_vmap(): + start_seed = 42 + hlr = 0.5 + fwhm = 0.9 + scale = 0.2 + nse = 1e-3 + g1 = 0.01 + target_fwhm = 1.0 + + ims = [] + nse_ims = [] + psfs = [] + for _seed in range(10): + seed = _seed + start_seed + rng = np.random.RandomState(seed) + + im = ( + _galsim.Convolve( + _galsim.Exponential(half_light_radius=hlr), + _galsim.Gaussian(fwhm=fwhm), + ) + .drawImage( + nx=33, + ny=33, + scale=scale, + ) + .array.astype(np.float64) + ) + + psf = ( + _galsim.Gaussian(fwhm=fwhm) + .drawImage( + nx=33, + ny=33, + scale=scale, + ) + .array.astype(np.float64) + ) + + nse_im = rng.normal(size=im.shape) * nse + im += rng.normal(size=im.shape) * nse + + ims.append(im) + psfs.append(psf) + nse_ims.append(nse_im) + + ims = np.stack(ims) + psfs = np.stack(psfs) + nse_ims = np.stack(nse_ims) + + gt0 = time.time() + for im, psf, nse_im in zip(ims, psfs, nse_ims): + _metacal_galsim( + im.copy(), + psf.copy(), + nse_im.copy(), + scale, + target_fwhm, + g1, + {}, + {}, + {}, + 128, + ) + gt0 = time.time() - gt0 + print("Galsim time: ", gt0 * 1e3, " [ms]") + + vmap_mcal = jax.vmap( + _metacal_jax_galsim, + in_axes=(0, 0, 0, None, None, None, None), + ) + + for i in range(2): + if i == 0: + msg = "jit warmup" + elif i == 1: + msg = "jit" + + jgt0 = time.time() + vmap_mcal( + ims, + psfs, + nse_ims, + scale, + target_fwhm, + g1, + 128, + ) + jgt0 = time.time() - jgt0 + print("Jax-Galsim time (%s): " % msg, jgt0 * 1e3, " [ms]") + + +@pytest.mark.parametrize( + "draw_method", + [ + "no_pixel", + "auto", + ], +) +@pytest.mark.parametrize( + "nse", + [ + 4e-3, + 1e-3, + 1e-10, + ], +) +def test_metacal_iimage_with_noise(nse, draw_method): + hlr = 0.5 + fwhm = 0.9 + scale = 0.2 + nk = 128 + seed = 42 + + rng = np.random.RandomState(seed) + + im = ( + _galsim.Convolve( + _galsim.Exponential(half_light_radius=hlr), + _galsim.Gaussian(fwhm=fwhm), + ) + .drawImage( + nx=33, + ny=33, + scale=scale, + ) + .array.astype(np.float64) + ) + im += rng.normal(size=im.shape) * nse + + jgiim = jax_galsim.InterpolatedImage( + jax_galsim.ImageD(im), + scale=scale, + x_interpolant="lanczos15", + gsparams=jax_galsim.GSParams(minimum_fft_size=nk), + ) + + giim = _galsim.InterpolatedImage( + _galsim.ImageD(im), + scale=scale, + x_interpolant="lanczos15", + gsparams=_galsim.GSParams(minimum_fft_size=nk), + _force_stepk=jgiim.stepk.item(), + _force_maxk=jgiim.maxk.item(), + ) + + def _plot_real(gim, jgim): + import proplot as pplt + + fig, axs = pplt.subplots(ncols=3, nrows=1) + + axs[0].imshow(gim) + axs[1].imshow(jgim) + m = axs[0, 2].imshow((jgim - gim)) + axs[2].colorbar(m, loc="r") + + fig.show() + + atol = 1e-8 + np.testing.assert_allclose(giim.maxk, jgiim.maxk) + np.testing.assert_allclose(giim.maxk, jgiim.maxk) + + if draw_method == "no_pixel": + gim = giim.drawImage(nx=33, ny=33, scale=scale, method="no_pixel").array + jgim = jgiim.drawImage(nx=33, ny=33, scale=scale, method="no_pixel").array + + if not np.allclose(gim, jgim, rtol=0, atol=atol): + _plot_real(gim, jgim) + np.testing.assert_allclose(gim, jgim, rtol=0, atol=atol) + elif draw_method == "auto": + gim = giim.drawImage(nx=33, ny=33, scale=scale).array + jgim = jgiim.drawImage(nx=33, ny=33, scale=scale).array + + if not np.allclose(gim, jgim, rtol=0, atol=atol): + _plot_real(gim, jgim) + np.testing.assert_allclose(gim, jgim, rtol=0, atol=atol) diff --git a/tests/jax/test_temporary_tests.py b/tests/jax/test_temporary_tests.py index 79c24f14..252b4803 100644 --- a/tests/jax/test_temporary_tests.py +++ b/tests/jax/test_temporary_tests.py @@ -82,6 +82,9 @@ def func(galsim): jax_galsim.Gaussian(fwhm=1.0), jax_galsim.Pixel(scale=1.0), jax_galsim.Exponential(scale_radius=1.0), + jax_galsim.Exponential(half_light_radius=1.0), + jax_galsim.Moffat(fwhm=1.0, beta=3), + jax_galsim.Moffat(scale_radius=1.0, beta=3), jax_galsim.Shear(g1=0.1, g2=0.2), jax_galsim.PositionD(x=0.1, y=0.2), jax_galsim.BoundsI(xmin=0, xmax=1, ymin=0, ymax=1),