From c622679c5045faed7131db97f970d8a5bb048e77 Mon Sep 17 00:00:00 2001 From: beckermr Date: Sun, 17 Sep 2023 08:57:29 -0500 Subject: [PATCH 01/67] WIP start adding interpolated images --- jax_galsim/interpolatedimage.py | 543 ++++++++++++++++++++++++++++++++ jax_galsim/utilities.py | 10 + 2 files changed, 553 insertions(+) create mode 100644 jax_galsim/interpolatedimage.py diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py new file mode 100644 index 00000000..bb6bd02b --- /dev/null +++ b/jax_galsim/interpolatedimage.py @@ -0,0 +1,543 @@ +import textwrap +import jax.numpy as jnp +from jax._src.numpy.util import _wraps +from jax.tree_util import register_pytree_node_class +import math + +import galsim as _galsim +from galsim.utilities import doc_inherit +from galsim.errors import GalSimRangeError, GalSimValueError, GalSimUndefinedBoundsError +from galsim.errors import GalSimIncompatibleValuesError + +from jax_galsim.gsobject import GSObject +from jax_galsim.gsparams import GSParams +from jax_galsim.image import Image +from jax_galsim.position import PositionD +from jax_galsim.interpolant import Quintic +from jax_galsim.utilities import convert_interpolant +from jax_galsim.bounds import BoundsI + + +@_wraps( + _galsim.InterpolatedImage, + lax_description=textwrap.dedent( + """The JAX equivalent of galsim.InterpolatedImage does not support + + - noise padding + - depixelize + - reading images from FITS files + + """ + ), +) +@register_pytree_node_class +class InterpolatedImage(GSObject): + _req_params = {'image': str} + _opt_params = { + 'x_interpolant': str, + 'k_interpolant': str, + 'normalization': str, + 'scale': float, + 'flux': float, + 'pad_factor': float, + 'noise_pad_size': float, + 'noise_pad': str, + 'pad_image': str, + 'calculate_stepk': bool, + 'calculate_maxk': bool, + 'use_true_center': bool, + 'depixelize': bool, + 'offset': PositionD, + 'hdu': int + } + _takes_rng = True + _cache_noise_pad = {} + + _has_hard_edges = False + _is_axisymmetric = False + _is_analytic_x = True + _is_analytic_k = True + + def __init__(self, image, x_interpolant=None, k_interpolant=None, normalization='flux', + scale=None, wcs=None, flux=None, pad_factor=4., noise_pad_size=0, noise_pad=0., + rng=None, pad_image=None, calculate_stepk=True, calculate_maxk=True, + use_cache=True, use_true_center=True, depixelize=False, offset=None, + gsparams=None, _force_stepk=0., _force_maxk=0., hdu=None): + + from .wcs import BaseWCS, PixelScale + # FIXME: no BaseDeviate in jax_galsim + # from .random import BaseDeviate + + # If the "image" is not actually an image, try to read the image as a file. + if isinstance(image, str): + # FIXME: no FITSIO in jax_galsim + # image = fits.read(image, hdu=hdu) + raise NotImplementedError( + "Reading InterpolatedImages from FITS files is not implemented in jax_galsim." + ) + elif not isinstance(image, Image): + raise TypeError("Supplied image must be an Image or file name") + + # it must have well-defined bounds, otherwise seg fault in SBInterpolatedImage constructor + if not image.bounds.isDefined(): + raise GalSimUndefinedBoundsError("Supplied image does not have bounds defined.") + + # check what normalization was specified for the image: is it an image of surface + # brightness, or flux? + if normalization.lower() not in ("flux", "f", "surface brightness", "sb"): + raise GalSimValueError("Invalid normalization requested.", normalization, + ('flux', 'f', 'surface brightness', 'sb')) + + # Set up the interpolants if none was provided by user, or check that the user-provided ones + # are of a valid type + self._gsparams = GSParams.check(gsparams) + if x_interpolant is None: + self._x_interpolant = Quintic(gsparams=self._gsparams) + else: + self._x_interpolant = convert_interpolant(x_interpolant).withGSParams(self._gsparams) + if k_interpolant is None: + self._k_interpolant = Quintic(gsparams=self._gsparams) + else: + self._k_interpolant = convert_interpolant(k_interpolant).withGSParams(self._gsparams) + + # Store the image as an attribute and make sure we don't change the original image + # in anything we do here. (e.g. set scale, etc.) + if depixelize: + # FIXME: no depixelize in jax_galsim + # self._image = image.view(dtype=np.float64).depixelize(self._x_interpolant) + raise NotImplementedError("InterpolatedImages do not support 'depixelize' in jax_galsim.") + else: + self._image = image.view(dtype=jnp.float64, contiguous=True) + self._image.setCenter(0, 0) + + # Set the wcs if necessary + if scale is not None: + if wcs is not None: + raise GalSimIncompatibleValuesError( + "Cannot provide both scale and wcs to InterpolatedImage", scale=scale, wcs=wcs) + self._image.wcs = PixelScale(scale) + elif wcs is not None: + if not isinstance(wcs, BaseWCS): + raise TypeError("wcs parameter is not a galsim.BaseWCS instance") + self._image.wcs = wcs + elif self._image.wcs is None: + raise GalSimIncompatibleValuesError( + "No information given with Image or keywords about pixel scale!", + scale=scale, wcs=wcs, image=image) + + # Figure out the offset to apply based on the original image (not the padded one). + # We will apply this below in _sbp. + offset = self._parse_offset(offset) + self._offset = self._adjust_offset(self._image.bounds, offset, None, use_true_center) + + im_cen = image.true_center if use_true_center else image.center + self._wcs = self._image.wcs.local(image_pos=im_cen) + + # Build the fully padded real-space image according to the various pad options. + self._buildRealImage(pad_factor, pad_image, noise_pad_size, noise_pad, rng, use_cache) + self._image_flux = jnp.sum(self._image.array, dtype=jnp.float64) + + # I think the only things that will mess up if flux == 0 are the + # calculateStepK and calculateMaxK functions, and rescaling the flux to some value. + if (calculate_stepk or calculate_maxk or flux is not None) and self._image_flux == 0.: + raise GalSimValueError("This input image has zero total flux. It does not define a " + "valid surface brightness profile.", image) + + # Process the different options for flux, stepk, maxk + self._flux = self._getFlux(flux, normalization) + self._calculate_stepk = calculate_stepk + self._calculate_maxk = calculate_maxk + self._stepk = self._getStepK(calculate_stepk, _force_stepk) + self._maxk = self._getMaxK(calculate_maxk, _force_maxk) + + @doc_inherit + def withGSParams(self, gsparams=None, **kwargs): + if gsparams == self.gsparams: + return self + # Checking gsparams + gsparams = GSParams.check(gsparams, self.gsparams, **kwargs) + # Flattening the representation to instantiate a clean new object + children, aux_data = self.tree_flatten() + aux_data["gsparams"] = gsparams + ret = self.tree_unflatten(aux_data, children) + + ret._x_interpolant = self._x_interpolant.withGSParams(ret._gsparams, **kwargs) + ret._k_interpolant = self._k_interpolant.withGSParams(ret._gsparams, **kwargs) + if ret._gsparams.folding_threshold != self._gsparams.folding_threshold: + ret._stepk = ret._getStepK(self._calculate_stepk, 0.) + if ret._gsparams.maxk_threshold != self._gsparams.maxk_threshold: + ret._maxk = ret._getMaxK(self._calculate_maxk, 0.) + return ret + + def tree_flatten(self): + """This function flattens the InterpolatedImage into a list of children + nodes that will be traced by JAX and auxiliary static data.""" + # Define the children nodes of the PyTree that need tracing + children = (self.params,) + # Define auxiliary static data that doesn’t need to be traced + aux_data = {"gsparams": self.gsparams} + return (children, aux_data) + + @classmethod + def tree_unflatten(cls, aux_data, children): + """Recreates an instance of the class from flatten representation""" + return cls(**(children[0]), **aux_data) + + def __eq__(self, other): + return (self is other or + (isinstance(other, InterpolatedImage) and + self._xim == other._xim and + self.x_interpolant == other.x_interpolant and + self.k_interpolant == other.k_interpolant and + self.flux == other.flux and + self._offset == other._offset and + self.gsparams == other.gsparams and + self._stepk == other._stepk and + self._maxk == other._maxk)) + + # TODO: do this in JAX OFC + # @lazy_property + # def _sbp(self): + # min_scale = self._wcs._minScale() + # max_scale = self._wcs._maxScale() + # self._sbii = _galsim.SBInterpolatedImage( + # self._xim._image, self._image.bounds._b, self._pad_image.bounds._b, + # self._x_interpolant._i, self._k_interpolant._i, + # self._stepk*min_scale, + # self._maxk*max_scale, + # self.gsparams._gsp) + + # self._sbp = self._sbii # Temporary. Will overwrite this with the return value. + + # # Apply the offset + # prof = self + # if self._offset != _PositionD(0,0): + # # Opposite direction of what drawImage does. + # prof = prof._shift(-self._offset.x, -self._offset.y) + + # # If the user specified a flux, then set to that flux value. + # if self._flux != self._image_flux: + # flux_ratio = self._flux / self._image_flux + # else: + # flux_ratio = 1. + + # # Bring the profile from image coordinates into world coordinates + # # Note: offset needs to happen first before the transformation, so can't bundle it here. + # prof = self._wcs._profileToWorld(prof, flux_ratio, _PositionD(0,0)) + + # return prof._sbp + + def _buildRealImage(self, pad_factor, pad_image, noise_pad_size, noise_pad, rng, use_cache): + # Check that given pad_image is valid: + if pad_image is not None: + if isinstance(pad_image, str): + # FIXME: no FITSIO in jax_galsim + # pad_image = fits.read(pad_image).view(dtype=np.float64) + raise NotImplementedError( + "Reading padding images for InterpolatedImages from FITS files " + "is not implemented in jax_galsim." + ) + elif isinstance(pad_image, Image): + pad_image = pad_image.view(dtype=jnp.float64, contiguous=True) + else: + raise TypeError("Supplied pad_image must be an Image.", pad_image) + + if pad_factor <= 0.: + raise GalSimRangeError("Invalid pad_factor <= 0 in InterpolatedImage", pad_factor, 0.) + + # Convert noise_pad_size from arcsec to pixels according to the local wcs. + # Use the minimum scale, since we want to make sure noise_pad_size is + # as large as we need in any direction. + if noise_pad_size: + # FIXME: no BaseDeviate in jax_galsim so no noise padding + # if noise_pad_size < 0: + # raise GalSimValueError("noise_pad_size may not be negative", noise_pad_size) + # if not noise_pad: + # raise GalSimIncompatibleValuesError( + # "Must provide noise_pad if noise_pad_size > 0", + # noise_pad=noise_pad, noise_pad_size=noise_pad_size) + # noise_pad_size = int(math.ceil(noise_pad_size / self._wcs._minScale())) + # noise_pad_size = Image.good_fft_size(noise_pad_size) + raise NotImplementedError("InterpolatedImages do not support noise padding in jax_galsim.") + else: + if noise_pad: + # FIXME: no BaseDeviate in jax_galsim so no noise padding + # raise GalSimIncompatibleValuesError( + # "Must provide noise_pad_size if noise_pad != 0", + # noise_pad=noise_pad, noise_pad_size=noise_pad_size) + raise NotImplementedError("InterpolatedImages do not support noise padding in jax_galsim.") + + # The size of the final padded image is the largest of the various size specifications + pad_size = max(self._image.array.shape) + if pad_factor > 1.: + pad_size = int(math.ceil(pad_factor * pad_size)) + if noise_pad_size: + pad_size = max(pad_size, noise_pad_size) + if pad_image: + pad_image.setCenter(0, 0) + pad_size = max(pad_size, *pad_image.array.shape) + # And round up to a good fft size + pad_size = Image.good_fft_size(pad_size) + + self._xim = Image(pad_size, pad_size, dtype=jnp.float64, wcs=self._wcs) + self._xim.setCenter(0, 0) + + # If requested, fill (some of) this image with noise padding. + nz_bounds = self._image.bounds + # FIXME: no BaseDeviate in jax_galsim so no noise padding + # if noise_pad: + # # This is a bit involved, so pass this off to another helper function. + # b = self._buildNoisePadImage(noise_pad_size, noise_pad, rng, use_cache) + # nz_bounds += b + + # The the user gives us a pad image to use, fill the relevant portion with that. + if pad_image: + # assert self._xim.bounds.includes(pad_image.bounds) + self._xim[pad_image.bounds] = pad_image + nz_bounds += pad_image.bounds + + # Now place the given image in the center of the padding image: + # assert self._xim.bounds.includes(self._image.bounds) + self._xim[self._image.bounds] = self._image + self._xim.wcs = self._wcs + + # And update the _image to be that portion of the full real image rather than the + # input image. + self._image = self._xim[self._image.bounds] + + # These next two allow for easy pickling/repring. We don't need to serialize all the + # zeros around the edge. But we do need to keep any non-zero padding as a pad_image. + self._pad_image = self._xim[nz_bounds] + # self._pad_factor = (max(self._xim.array.shape)-1.e-6) / max(self._image.array.shape) + self._pad_factor = pad_factor + + # FIXME: no BaseDeviate in jax_galsim so no noise padding + # def _buildNoisePadImage(self, noise_pad_size, noise_pad, rng, use_cache): + # """A helper function that builds the ``pad_image`` from the given ``noise_pad`` + # specification. + # """ + # from .random import BaseDeviate + # from .noise import GaussianNoise + # from .correlatednoise import BaseCorrelatedNoise, CorrelatedNoise + + # # Make sure we make rng a BaseDeviate if rng is None + # rng1 = BaseDeviate(rng) + + # # Figure out what kind of noise to apply to the image + # try: + # noise_pad = float(noise_pad) + # except (TypeError, ValueError): + # if isinstance(noise_pad, BaseCorrelatedNoise): + # noise = noise_pad.copy(rng=rng1) + # elif isinstance(noise_pad, Image): + # noise = CorrelatedNoise(noise_pad, rng1) + # elif use_cache and noise_pad in InterpolatedImage._cache_noise_pad: + # noise = InterpolatedImage._cache_noise_pad[noise_pad] + # if rng: + # # Make sure that we are using a specified RNG by resetting that in this cached + # # CorrelatedNoise instance, otherwise preserve the cached RNG + # noise = noise.copy(rng=rng1) + # elif isinstance(noise_pad, basestring): + # noise = CorrelatedNoise(fits.read(noise_pad), rng1) + # if use_cache: + # InterpolatedImage._cache_noise_pad[noise_pad] = noise + # else: + # raise GalSimValueError( + # "Input noise_pad must be a float/int, a CorrelatedNoise, Image, or filename " + # "containing an image to use to make a CorrelatedNoise.", noise_pad) + + # else: + # if noise_pad < 0.: + # raise GalSimRangeError("Noise variance may not be negative.", noise_pad, 0.) + # noise = GaussianNoise(rng1, sigma = np.sqrt(noise_pad)) + + # # Find the portion of xim to fill with noise. + # # It's allowed for the noise padding to not cover the whole pad image + # half_size = noise_pad_size // 2 + # b = _BoundsI(-half_size, -half_size + noise_pad_size-1, + # -half_size, -half_size + noise_pad_size-1) + # #assert self._xim.bounds.includes(b) + # noise_image = self._xim[b] + # # Add the noise + # noise_image.addNoise(noise) + # return b + + def _getFlux(self, flux, normalization): + # If the user specified a surface brightness normalization for the input Image, then + # need to rescale flux by the pixel area to get proper normalization. + if flux is None: + flux = self._image_flux + if normalization.lower() in ('surface brightness', 'sb'): + flux *= self._wcs.pixelArea() + return flux + + def _getStepK(self, calculate_stepk, _force_stepk): + # GalSim cannot automatically know what stepK and maxK are appropriate for the + # input image. So it is usually worth it to do a manual calculation (below). + # + # However, there is also a hidden option to force it to use specific values of stepK and + # maxK (caveat user!). The values of _force_stepk and _force_maxk should be provided in + # terms of physical scale, e.g., for images that have a scale length of 0.1 arcsec, the + # stepK and maxK should be provided in units of 1/arcsec. Then we convert to the 1/pixel + # units required by the C++ layer below. Also note that profile recentering for even-sized + # images (see the ._adjust_offset step below) leads to automatic reduction of stepK slightly + # below what is provided here, while maxK is preserved. + if _force_stepk > 0.: + return _force_stepk + elif calculate_stepk: + if calculate_stepk is True: + im = self._image + else: + # If not a bool, then value is max_stepk + R = (jnp.ceil(jnp.pi / calculate_stepk)).astype(int) + b = BoundsI(-R, R, -R, R) + b = self._image.bounds & b + im = self._image[b] + thresh = (1.0 - self.gsparams.folding_threshold) * self._image_flux + # this line appears buggy in galsim - I expect they meant to use im + R = _galsim.CalculateSizeContainingFlux(im._image, thresh) + else: + R = jnp.max(self._image.array.shape) / 2. - 0.5 + return self._getSimpleStepK(R) + + def _getSimpleStepK(self, R): + min_scale = self._wcs._minScale() + # Add xInterp range in quadrature just like convolution: + R2 = self._x_interpolant.xrange + R = jnp.hypot(R, R2) + stepk = jnp.pi / (R * min_scale) + return stepk + + def _getMaxK(self, calculate_maxk, _force_maxk): + max_scale = self._wcs._maxScale() + if _force_maxk > 0.: + return _force_maxk + elif calculate_maxk: + self._maxk = 0. + self._sbp + if calculate_maxk is True: + self._sbii.calculateMaxK(0.) + else: + # If not a bool, then value is max_maxk + self._sbii.calculateMaxK(float(calculate_maxk)) + self.__dict__.pop('_sbp') # Need to remake it. + return self._sbii.maxK() / max_scale + else: + return self._x_interpolant.krange / max_scale + + def __hash__(self): + # Definitely want to cache this, since the size of the image could be large. + if not hasattr(self, '_hash'): + self._hash = hash(("galsim.InterpolatedImage", self.x_interpolant, self.k_interpolant)) + self._hash ^= hash((self.flux, self._stepk, self._maxk, self._pad_factor)) + self._hash ^= hash((self._xim.bounds, self._image.bounds, self._pad_image.bounds)) + # A common offset is 0.5,0.5, and *sometimes* this produces the same hash as 0,0 + # (which is also common). I guess because they are only different in 2 bits. + # This mucking of the numbers seems to help make the hash more reliably different for + # these two cases. Note: "sometiems" because of this: + # https://stackoverflow.com/questions/27522626/hash-function-in-python-3-3-returns-different-results-between-sessions + self._hash ^= hash((self._offset.x * 1.234, self._offset.y * 0.23424)) + self._hash ^= hash(self._gsparams) + self._hash ^= hash(self._xim.wcs) + # Just hash the diagonal. Much faster, and usually is unique enough. + # (Let python handle collisions as needed if multiple similar IIs are used as keys.) + self._hash ^= hash(tuple(jnp.diag(self._pad_image.array))) + return self._hash + + def __repr__(self): + s = 'galsim.InterpolatedImage(%r, %r, %r' % ( + self._image, self.x_interpolant, self.k_interpolant + ) + # Most things we keep even if not required, but the pad_image is large, so skip it + # if it's really just the same as the main image. + if self._pad_image.bounds != self._image.bounds: + s += ', pad_image=%r' % (self._pad_image) + s += ', pad_factor=%f, flux=%r, offset=%r' % (self._pad_factor, self.flux, self._offset) + s += ', use_true_center=False, gsparams=%r, _force_stepk=%r, _force_maxk=%r)' % ( + self.gsparams, self._stepk, self._maxk + ) + return s + + def __str__(self): + return 'galsim.InterpolatedImage(image=%s, flux=%s)' % (self.image, self.flux) + + def __getstate__(self): + d = self.__dict__.copy() + # TODO - probably remove these pops for things we don't have + d.pop('_sbii', None) + d.pop('_sbp', None) + # Only pickle _pad_image. Not _xim or _image + d['_xim_bounds'] = self._xim.bounds + d['_image_bounds'] = self._image.bounds + d.pop('_xim', None) + d.pop('_image', None) + return d + + def __setstate__(self, d): + xim_bounds = d.pop('_xim_bounds') + image_bounds = d.pop('_image_bounds') + self.__dict__ = d + if self._pad_image.bounds == xim_bounds: + self._xim = self._pad_image + else: + self._xim = Image(xim_bounds, wcs=self._wcs, dtype=jnp.float64) + self._xim[self._pad_image.bounds] = self._pad_image + self._image = self._xim[image_bounds] + + @property + def x_interpolant(self): + """The real-space `Interpolant` for this profile. + """ + return self._x_interpolant + + @property + def k_interpolant(self): + """The Fourier-space `Interpolant` for this profile. + """ + return self._k_interpolant + + @property + def image(self): + """The underlying `Image` being interpolated. + """ + return self._image + + @property + def _centroid(self): + return PositionD(self._sbp.centroid()) + + @property + def _positive_flux(self): + return self._sbp.getPositiveFlux() + + @property + def _negative_flux(self): + return self._sbp.getNegativeFlux() + + # @lazy_property + def _flux_per_photon(self): + # FIXME: jax_galsim does not photon shoot + # return self._calculate_flux_per_photon() + raise NotImplementedError("Photon shooting not implemented.") + + @property + def _max_sb(self): + return self._sbp.maxSB() + + def _xValue(self, pos): + return self._sbp.xValue(pos._p) + + def _kValue(self, kpos): + return self._sbp.kValue(kpos._p) + + def _shoot(self, photons, rng): + raise NotImplementedError("Photon shooting not implemented.") + + def _drawReal(self, image, jac=None, offset=(0., 0.), flux_scaling=1.): + dx, dy = offset + _jac = 0 if jac is None else jac.__array_interface__['data'][0] + self._sbp.draw(image._image, image.scale, _jac, dx, dy, flux_scaling) + + def _drawKImage(self, image, jac=None): + _jac = 0 if jac is None else jac.__array_interface__['data'][0] + self._sbp.drawK(image._image, image.scale, _jac) diff --git a/jax_galsim/utilities.py b/jax_galsim/utilities.py index edeadb0d..e195d317 100644 --- a/jax_galsim/utilities.py +++ b/jax_galsim/utilities.py @@ -88,3 +88,13 @@ def g1g2_to_e1e2(g1, g2): e1 = g1 * (e / g) e2 = g2 * (e / g) return e1, e2 + + +@_wraps(_galsim.utilities.convert_interpolant) +def convert_interpolant(interpolant): + from jax_galsim.interpolant import Interpolant + if isinstance(interpolant, Interpolant): + return interpolant + else: + # Will raise an appropriate exception if this is invalid. + return Interpolant.from_name(interpolant) From c32c8fa5c221a8ec592c411647d8d124bf8d7dd6 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 22 Sep 2023 08:15:19 -0500 Subject: [PATCH 02/67] ENH k-space wrapping for hermitian images --- CHANGELOG.md | 5 - jax_galsim/core/wrap_image.py | 106 ++++++++++++- jax_galsim/image.py | 29 +++- tests/jax/galsim/test_shear_position_jax.py | 155 ++++++++++++++------ 4 files changed, 240 insertions(+), 55 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index aae284eb..a8117fb5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,10 +24,5 @@ * Added a `from_galsim` method to convert from GalSim objects to JAX-GalSim objects * Caveats - * Currently the FFT convolution does not perform kwrapping of hermitian images, - so it will lead to erroneous results on underesolved images that need k-space wrapping. - Wrapping for real images is implemented. K-space images arise from doing convolutions - via FFTs and so one would expect that underresolved images with convolutions may not be - rendered as accurately. * Real space convolution and photon shooting methods are not yet implemented in drawImage. diff --git a/jax_galsim/core/wrap_image.py b/jax_galsim/core/wrap_image.py index 94ced023..74f7b26c 100644 --- a/jax_galsim/core/wrap_image.py +++ b/jax_galsim/core/wrap_image.py @@ -4,7 +4,7 @@ @jax.jit -def wrap_nonhermition(im, xmin, ymin, nxwrap, nywrap): +def wrap_nonhermitian(im, xmin, ymin, nxwrap, nywrap): def _body_j(j, vals): i, im = vals @@ -33,3 +33,107 @@ def _body_i(i, vals): im = jax.lax.fori_loop(0, im.shape[0], _body_i, im) return im + + +@jax.jit +def expand_hermitian_x(im): + return jnp.concatenate([im[:, 1:][::-1, ::-1].conjugate(), im], axis=1) + + +@jax.jit +def contract_hermitian_x(im): + return im[:, im.shape[1] // 2 :] + + +@jax.jit +def wrap_hermitian_x(im, im_xmin, im_ymin, wrap_xmin, wrap_ymin, wrap_nx, wrap_ny): + im_exp = expand_hermitian_x(im) + im_exp = wrap_nonhermitian( + im_exp, wrap_xmin - im_xmin, wrap_ymin - im_ymin, wrap_nx, wrap_ny + ) + return contract_hermitian_x(im_exp) + + +@jax.jit +def expand_hermitian_y(im): + return jnp.concatenate([im[1:, :][::-1, ::-1].conjugate(), im], axis=0) + + +@jax.jit +def contract_hermitian_y(im): + return im[im.shape[0] // 2 :, :] + + +@jax.jit +def wrap_hermitian_y(im, im_xmin, im_ymin, wrap_xmin, wrap_ymin, wrap_nx, wrap_ny): + im_exp = expand_hermitian_y(im) + im_exp = wrap_nonhermitian( + im_exp, wrap_xmin - im_xmin, wrap_ymin - im_ymin, wrap_nx, wrap_ny + ) + return contract_hermitian_y(im_exp) + + +# I am leaving this code here for posterity. It has a bug that I cannot find. +# It tries to be more clever instead of simply expanding the hermitian image to +# it's full shape, wrapping everything, and then contracting. -MRB +# @jax.jit +# def wrap_hermitian_x(im, im_xmin, im_ymin, wrap_xmin, wrap_ymin, wrap_nx, wrap_ny): +# def _body_j(j, vals): +# i, im = vals + +# # first do zero or positive x freq +# im_y = i + im_ymin +# im_x = j + im_xmin +# wrap_y = (im_y - wrap_ymin) % wrap_ny + wrap_ymin +# wrap_x = (im_x - wrap_xmin) % wrap_nx + wrap_xmin +# wrap_yind = wrap_y - im_ymin +# wrap_xind = wrap_x - im_xmin +# im = jax.lax.cond( +# wrap_xind >= 0, +# lambda wrap_x, im_x, wrap_y, im_y, im, wrap_yind, wrap_xind: jax.lax.cond( +# jnp.abs(wrap_x - im_x) + jnp.abs(wrap_y - im_y) != 0, +# lambda im, wrap_yind, wrap_xind: im.at[wrap_yind, wrap_xind].add(im[i, j]), +# lambda im, wrap_yind, wrap_xind: im, +# im, +# wrap_yind, +# wrap_xind, +# ), +# lambda wrap_x, im_x, wrap_y, im_y, im, wrap_yind, wrap_xind: im, +# wrap_x, im_x, wrap_y, im_y, im, wrap_yind, wrap_xind, +# ) + +# # now do neg x freq +# im_y = -im_y +# im_x = -im_x +# wrap_y = (im_y - wrap_ymin) % wrap_ny + wrap_ymin +# wrap_x = (im_x - wrap_xmin) % wrap_nx + wrap_xmin +# wrap_yind = wrap_y - im_ymin +# wrap_xind = wrap_x - im_xmin +# im = jax.lax.cond( +# im_x != 0, +# lambda wrap_x, im_x, wrap_y, im_y, im, wrap_yind, wrap_xind: jax.lax.cond( +# wrap_xind >= 0, +# lambda wrap_x, im_x, wrap_y, im_y, im, wrap_yind, wrap_xind: jax.lax.cond( +# (jnp.abs(wrap_x - im_x) + jnp.abs(wrap_y - im_y)) != 0, +# lambda im, wrap_yind, wrap_xind: im.at[wrap_yind, wrap_xind].add(im[i, j].conjugate()), +# lambda im, wrap_yind, wrap_xind: im, +# im, +# wrap_yind, +# wrap_xind, +# ), +# lambda wrap_x, im_x, wrap_y, im_y, im, wrap_yind, wrap_xind: im, +# wrap_x, im_x, wrap_y, im_y, im, wrap_yind, wrap_xind, +# ), +# lambda wrap_x, im_x, wrap_y, im_y, im, wrap_yind, wrap_xind: im, +# wrap_x, im_x, wrap_y, im_y, im, wrap_yind, wrap_xind, +# ) + +# return [i, im] + +# def _body_i(i, vals): +# im = vals +# _, im = jax.lax.fori_loop(0, im.shape[1], _body_j, [i, im]) +# return im + +# im = jax.lax.fori_loop(0, im.shape[0], _body_i, im) +# return im diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 53b9130d..233e256d 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -627,9 +627,9 @@ def _wrap(self, bounds, hermx, hermy): Equivalent to ``image.wrap(bounds, hermitian=='x', hermitian=='y')``. """ if not hermx and not hermy: - from jax_galsim.core.wrap_image import wrap_nonhermition + from jax_galsim.core.wrap_image import wrap_nonhermitian - self._array = wrap_nonhermition( + self._array = wrap_nonhermitian( self._array, # zero indexed location of subimage bounds.xmin - self.xmin, @@ -638,8 +638,31 @@ def _wrap(self, bounds, hermx, hermy): bounds.xmax - bounds.xmin + 1, bounds.ymax - bounds.ymin + 1, ) + elif hermx and not hermy: + from jax_galsim.core.wrap_image import wrap_hermitian_x + + self._array = wrap_hermitian_x( + self._array, + -self.xmax, + self.ymin, + -bounds.xmax + 1, + bounds.ymin, + 2 * bounds.xmax, + bounds.ymax - bounds.ymin + 1, + ) + elif not hermx and hermy: + from jax_galsim.core.wrap_image import wrap_hermitian_y + + self._array = wrap_hermitian_y( + self._array, + self.xmin, + -self.ymax, + bounds.xmin, + -bounds.ymax + 1, + bounds.xmax - bounds.xmin + 1, + 2 * bounds.ymax, + ) - # FIXME: Wrapping not yet implemented for hermitian images return self.subImage(bounds) @_wraps(_galsim.Image.calculate_fft) diff --git a/tests/jax/galsim/test_shear_position_jax.py b/tests/jax/galsim/test_shear_position_jax.py index d8265406..93d077cd 100644 --- a/tests/jax/galsim/test_shear_position_jax.py +++ b/tests/jax/galsim/test_shear_position_jax.py @@ -2,6 +2,12 @@ import galsim from galsim_test_helpers import timer, assert_raises +from jax_galsim.core.wrap_image import ( + expand_hermitian_x, + contract_hermitian_x, + expand_hermitian_y, + contract_hermitian_y, +) @timer @@ -104,6 +110,32 @@ def test_wrap_jax_weird_real(): ) +def test_wrap_jax_inds(): + Nx = 2 + Ny = 2 + im = np.ones((2 * Ny + 1, Nx + 1), dtype=np.complex128) + ymin = -Ny + xmin = 0 + + Nxs = 1 + Nys = 1 + ymins = -Nys + xmins = -Nxs + py = 2 * Nys + 1 + px = 2 * Nxs + 1 + + print(" ") + for i in range(im.shape[0]): + for j in range(im.shape[1]): + im_y = i + ymin + im_x = j + xmin + + wrap_y = (im_y - ymins) % py + ymins + wrap_x = (im_x - xmins) % px + xmins + + print("% d % d % d % d" % (im_x, im_y, wrap_x, wrap_y)) + + @timer def test_wrap_jax_complex(): # For complex images (in particular k-space images), we often want the image to be implicitly @@ -125,6 +157,9 @@ def test_wrap_jax_complex(): b = galsim.BoundsI(-K + 1, K, -L + 1, L) b2 = galsim.BoundsI(-K + 1, K, 0, L) b3 = galsim.BoundsI(0, K, -L + 1, L) + # print('b = ',b) + # print('b2 = ',b2) + # print('b3 = ',b3) im_test = galsim.ImageCD(b, init_value=0) for i in range(-M, M + 1): for j in range(-N, N + 1): @@ -151,6 +186,34 @@ def test_wrap_jax_complex(): for j in range(-N, N + 1): assert im(i, j) == im(-i, -j).conjugate() + im_exp = expand_hermitian_x(im3.array) + np.testing.assert_allclose( + im_exp, + im.array, + err_msg="expand_hermitian_x() did not match expectation", + ) + + im_cnt = contract_hermitian_x(im.array) + np.testing.assert_allclose( + im_cnt, + im3.array, + err_msg="contract_hermitian_x() did not match expectation", + ) + + im_exp = expand_hermitian_y(im2.array) + np.testing.assert_allclose( + im_exp, + im.array, + err_msg="expand_hermitian_x() did not match expectation", + ) + + im_cnt = contract_hermitian_y(im.array) + np.testing.assert_allclose( + im_cnt, + im2.array, + err_msg="contract_hermitian_x() did not match expectation", + ) + im_wrap = im.wrap(b) # print("im_wrap = ",im_wrap.array) np.testing.assert_allclose( @@ -159,49 +222,51 @@ def test_wrap_jax_complex(): err_msg="image.wrap(%s) did not match expectation" % b, ) np.testing.assert_array_equal( - im_wrap.array, im[b].array, "image.wrap(%s) did not return the right subimage" + im_wrap.array, + im[b].array, + "image.wrap(%s) did not return the right subimage" % b, ) np.testing.assert_equal( - im_wrap.bounds, b, "image.wrap(%s) does not have the correct bounds" + im_wrap.bounds, b, "image.wrap(%s) does not have the correct bounds" % b ) - # FIXME: turn on when wrapping works for hermitian images - if False: - im2_wrap = im2.wrap(b2, hermitian="y") - # print('im_test = ',im_test[b2].array) - # print('im2_wrap = ',im2_wrap.array) - # print('diff = ',im2_wrap.array-im_test[b2].array) - np.testing.assert_allclose( - im2_wrap.array, - im_test[b2].array, - err_msg="image.wrap(%s) did not match expectation" % b, - ) - np.testing.assert_array_equal( - im2_wrap.array, - im2[b2].array, - "image.wrap(%s) did not return the right subimage", - ) - np.testing.assert_equal( - im2_wrap.bounds, b2, "image.wrap(%s) does not have the correct bounds" - ) + im2_wrap = im2.wrap(b2, hermitian="y") + # print('im_test = ',im_test[b2].array) + # print('im2_wrap = ',im2_wrap.array) + # print('diff = ',im2_wrap.array-im_test[b2].array) + np.testing.assert_allclose( + im2_wrap.array, + im_test[b2].array, + err_msg="image.wrap(%s) did not match expectation" % b, + ) + np.testing.assert_array_equal( + im2_wrap.array, + im2[b2].array, + "image.wrap(%s) did not return the right subimage", + ) + np.testing.assert_equal( + im2_wrap.bounds, b2, "image.wrap(%s) does not have the correct bounds" + ) - im3_wrap = im3.wrap(b3, hermitian="x") - # print('im_test = ',im_test[b3].array) - # print('im3_wrap = ',im3_wrap.array) - # print('diff = ',im3_wrap.array-im_test[b3].array) - np.testing.assert_allclose( - im3_wrap.array, - im_test[b3].array, - err_msg="image.wrap(%s) did not match expectation" % b, - ) - np.testing.assert_array_equal( - im3_wrap.array, - im3[b3].array, - "image.wrap(%s) did not return the right subimage", - ) - np.testing.assert_equal( - im3_wrap.bounds, b3, "image.wrap(%s) does not have the correct bounds" - ) + im3_wrap = im3.wrap(b3, hermitian="x") + # print('im_test = ',im_test[b3].array) + # print('im3_wrap = ',im3_wrap.array) + # print('diff = ',im3_wrap.array-im_test[b3].array) + np.testing.assert_allclose( + im3_wrap.array, + im_test[b3].array, + err_msg="image.wrap(%s, hermitian='x') did not match expectation" % b3, + ) + np.testing.assert_array_equal( + im3_wrap.array, + im3[b3].array, + "image.wrap(%s, hermitian='x') did not match expectation" % b3, + ) + np.testing.assert_equal( + im3_wrap.bounds, + b3, + "image.wrap(%s, hermitian='x') did not match expectation" % b3, + ) b = galsim.BoundsI(-K + 1, K, -L + 1, L) b2 = galsim.BoundsI(-K + 1, K, 0, L) @@ -211,11 +276,9 @@ def test_wrap_jax_complex(): assert_raises(ValueError, im.wrap, b, hermitian="invalid") assert_raises(ValueError, im.wrap, b3, hermitian="x") - # FIXME: turn on when wrapping works for hermitian images - if False: - assert_raises(ValueError, im3.wrap, b, hermitian="x") - assert_raises(ValueError, im3.wrap, b2, hermitian="x") - assert_raises(ValueError, im2.wrap, b, hermitian="y") - assert_raises(ValueError, im2.wrap, b3, hermitian="y") - assert_raises(ValueError, im2.wrap, b2, hermitian="invalid") - assert_raises(ValueError, im3.wrap, b3, hermitian="invalid") + assert_raises(ValueError, im2.wrap, b, hermitian="y") + assert_raises(ValueError, im2.wrap, b3, hermitian="y") + assert_raises(ValueError, im2.wrap, b2, hermitian="invalid") + assert_raises(ValueError, im3.wrap, b3, hermitian="invalid") + assert_raises(ValueError, im3.wrap, b, hermitian="x") + assert_raises(ValueError, im3.wrap, b2, hermitian="x") From c1e653b113a91bee789d8eb156a52295ebae4fd8 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Fri, 22 Sep 2023 08:17:58 -0500 Subject: [PATCH 03/67] Update tests/jax/galsim/test_shear_position_jax.py --- tests/jax/galsim/test_shear_position_jax.py | 26 --------------------- 1 file changed, 26 deletions(-) diff --git a/tests/jax/galsim/test_shear_position_jax.py b/tests/jax/galsim/test_shear_position_jax.py index 93d077cd..a60eb94d 100644 --- a/tests/jax/galsim/test_shear_position_jax.py +++ b/tests/jax/galsim/test_shear_position_jax.py @@ -110,32 +110,6 @@ def test_wrap_jax_weird_real(): ) -def test_wrap_jax_inds(): - Nx = 2 - Ny = 2 - im = np.ones((2 * Ny + 1, Nx + 1), dtype=np.complex128) - ymin = -Ny - xmin = 0 - - Nxs = 1 - Nys = 1 - ymins = -Nys - xmins = -Nxs - py = 2 * Nys + 1 - px = 2 * Nxs + 1 - - print(" ") - for i in range(im.shape[0]): - for j in range(im.shape[1]): - im_y = i + ymin - im_x = j + xmin - - wrap_y = (im_y - ymins) % py + ymins - wrap_x = (im_x - xmins) % px + xmins - - print("% d % d % d % d" % (im_x, im_y, wrap_x, wrap_y)) - - @timer def test_wrap_jax_complex(): # For complex images (in particular k-space images), we often want the image to be implicitly From 82255007f9cb3ed05104aa7d761b0f6300b2c19f Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Fri, 22 Sep 2023 08:18:29 -0500 Subject: [PATCH 04/67] Update tests/jax/galsim/test_shear_position_jax.py --- tests/jax/galsim/test_shear_position_jax.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/jax/galsim/test_shear_position_jax.py b/tests/jax/galsim/test_shear_position_jax.py index a60eb94d..405dfded 100644 --- a/tests/jax/galsim/test_shear_position_jax.py +++ b/tests/jax/galsim/test_shear_position_jax.py @@ -131,9 +131,6 @@ def test_wrap_jax_complex(): b = galsim.BoundsI(-K + 1, K, -L + 1, L) b2 = galsim.BoundsI(-K + 1, K, 0, L) b3 = galsim.BoundsI(0, K, -L + 1, L) - # print('b = ',b) - # print('b2 = ',b2) - # print('b3 = ',b3) im_test = galsim.ImageCD(b, init_value=0) for i in range(-M, M + 1): for j in range(-N, N + 1): From 1c8e3e23cdd83c4dba0e7c82322b7b165dcf43e6 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 22 Sep 2023 08:27:53 -0500 Subject: [PATCH 05/67] TST clean up the test suite a bit --- tests/jax/galsim/test_image_jax.py | 102 ++++---- tests/jax/galsim/test_shear_position_jax.py | 255 -------------------- tests/jax/test_image_wrapping.py | 82 +++++++ 3 files changed, 132 insertions(+), 307 deletions(-) delete mode 100644 tests/jax/galsim/test_shear_position_jax.py create mode 100644 tests/jax/test_image_wrapping.py diff --git a/tests/jax/galsim/test_image_jax.py b/tests/jax/galsim/test_image_jax.py index 0615c748..fccca4cd 100644 --- a/tests/jax/galsim/test_image_jax.py +++ b/tests/jax/galsim/test_image_jax.py @@ -4502,59 +4502,57 @@ def test_wrap(): im_wrap.bounds, b, "image.wrap(%s) does not have the correct bounds" ) - # FIXME: turn on when hermitian wrapping is implemented - if False: - im2_wrap = im2.wrap(b2, hermitian="y") - # print('im_test = ',im_test[b2].array) - # print('im2_wrap = ',im2_wrap.array) - # print('diff = ',im2_wrap.array-im_test[b2].array) - np.testing.assert_array_almost_equal( - im2_wrap.array, - im_test[b2].array, - 12, - "image.wrap(%s) did not match expectation" % b, - ) - np.testing.assert_array_equal( - im2_wrap.array, - im2[b2].array, - "image.wrap(%s) did not return the right subimage", - ) - np.testing.assert_equal( - im2_wrap.bounds, b2, "image.wrap(%s) does not have the correct bounds" - ) + im2_wrap = im2.wrap(b2, hermitian="y") + # print('im_test = ',im_test[b2].array) + # print('im2_wrap = ',im2_wrap.array) + # print('diff = ',im2_wrap.array-im_test[b2].array) + np.testing.assert_array_almost_equal( + im2_wrap.array, + im_test[b2].array, + 12, + "image.wrap(%s) did not match expectation" % b, + ) + np.testing.assert_array_equal( + im2_wrap.array, + im2[b2].array, + "image.wrap(%s) did not return the right subimage", + ) + np.testing.assert_equal( + im2_wrap.bounds, b2, "image.wrap(%s) does not have the correct bounds" + ) - im3_wrap = im3.wrap(b3, hermitian="x") - # print('im_test = ',im_test[b3].array) - # print('im3_wrap = ',im3_wrap.array) - # print('diff = ',im3_wrap.array-im_test[b3].array) - np.testing.assert_array_almost_equal( - im3_wrap.array, - im_test[b3].array, - 12, - "image.wrap(%s) did not match expectation" % b, - ) - np.testing.assert_array_equal( - im3_wrap.array, - im3[b3].array, - "image.wrap(%s) did not return the right subimage", - ) - np.testing.assert_equal( - im3_wrap.bounds, b3, "image.wrap(%s) does not have the correct bounds" - ) - - b = galsim.BoundsI(-K + 1, K, -L + 1, L) - b2 = galsim.BoundsI(-K + 1, K, 0, L) - b3 = galsim.BoundsI(0, K, -L + 1, L) - assert_raises(TypeError, im.wrap, bounds=None) - assert_raises(ValueError, im3.wrap, b, hermitian="x") - assert_raises(ValueError, im3.wrap, b2, hermitian="x") - assert_raises(ValueError, im.wrap, b3, hermitian="x") - assert_raises(ValueError, im2.wrap, b, hermitian="y") - assert_raises(ValueError, im2.wrap, b3, hermitian="y") - assert_raises(ValueError, im.wrap, b2, hermitian="y") - assert_raises(ValueError, im.wrap, b, hermitian="invalid") - assert_raises(ValueError, im2.wrap, b2, hermitian="invalid") - assert_raises(ValueError, im3.wrap, b3, hermitian="invalid") + im3_wrap = im3.wrap(b3, hermitian="x") + # print('im_test = ',im_test[b3].array) + # print('im3_wrap = ',im3_wrap.array) + # print('diff = ',im3_wrap.array-im_test[b3].array) + np.testing.assert_array_almost_equal( + im3_wrap.array, + im_test[b3].array, + 12, + "image.wrap(%s) did not match expectation" % b, + ) + np.testing.assert_array_equal( + im3_wrap.array, + im3[b3].array, + "image.wrap(%s) did not return the right subimage", + ) + np.testing.assert_equal( + im3_wrap.bounds, b3, "image.wrap(%s) does not have the correct bounds" + ) + + b = galsim.BoundsI(-K + 1, K, -L + 1, L) + b2 = galsim.BoundsI(-K + 1, K, 0, L) + b3 = galsim.BoundsI(0, K, -L + 1, L) + assert_raises(TypeError, im.wrap, bounds=None) + assert_raises(ValueError, im3.wrap, b, hermitian="x") + assert_raises(ValueError, im3.wrap, b2, hermitian="x") + assert_raises(ValueError, im.wrap, b3, hermitian="x") + assert_raises(ValueError, im2.wrap, b, hermitian="y") + assert_raises(ValueError, im2.wrap, b3, hermitian="y") + assert_raises(ValueError, im.wrap, b2, hermitian="y") + assert_raises(ValueError, im.wrap, b, hermitian="invalid") + assert_raises(ValueError, im2.wrap, b2, hermitian="invalid") + assert_raises(ValueError, im3.wrap, b3, hermitian="invalid") @timer diff --git a/tests/jax/galsim/test_shear_position_jax.py b/tests/jax/galsim/test_shear_position_jax.py deleted file mode 100644 index 405dfded..00000000 --- a/tests/jax/galsim/test_shear_position_jax.py +++ /dev/null @@ -1,255 +0,0 @@ -import numpy as np - -import galsim -from galsim_test_helpers import timer, assert_raises -from jax_galsim.core.wrap_image import ( - expand_hermitian_x, - contract_hermitian_x, - expand_hermitian_y, - contract_hermitian_y, -) - - -@timer -def test_shear_position_image_integration_pixelwcs_jax(): - wcs = galsim.PixelScale(0.3) - obj1 = galsim.Gaussian(sigma=3) - obj2 = galsim.Gaussian(sigma=2) - pos2 = galsim.PositionD(3, 5) - sum = obj1 + obj2.shift(pos2) - shear = galsim.Shear(g1=0.1, g2=0.18) - im1 = galsim.Image(50, 50, wcs=wcs) - sum.shear(shear).drawImage(im1, center=im1.center) - - # Equivalent to shear each object separately and drawing at the sheared position. - im2 = galsim.Image(50, 50, wcs=wcs) - obj1.shear(shear).drawImage(im2, center=im2.center) - obj2.shear(shear).drawImage( - im2, - add_to_image=True, - center=im2.center + wcs.toImage(pos2.shear(shear)), - ) - - print("err:", np.max(np.abs(im1.array - im2.array))) - np.testing.assert_allclose(im1.array, im2.array, rtol=0, atol=5e-8) - - -@timer -def test_wrap_jax_simple_real(): - """Test the image.wrap() function.""" - # Start with a fairly simple test where the image is 4 copies of the same data: - im_orig = galsim.Image( - [ - [11.0, 12.0, 13.0, 14.0, 11.0, 12.0, 13.0, 14.0], - [21.0, 22.0, 23.0, 24.0, 21.0, 22.0, 23.0, 24.0], - [31.0, 32.0, 33.0, 34.0, 31.0, 32.0, 33.0, 34.0], - [41.0, 42.0, 43.0, 44.0, 41.0, 42.0, 43.0, 44.0], - [11.0, 12.0, 13.0, 14.0, 11.0, 12.0, 13.0, 14.0], - [21.0, 22.0, 23.0, 24.0, 21.0, 22.0, 23.0, 24.0], - [31.0, 32.0, 33.0, 34.0, 31.0, 32.0, 33.0, 34.0], - [41.0, 42.0, 43.0, 44.0, 41.0, 42.0, 43.0, 44.0], - ] - ) - im = im_orig.copy() - b = galsim.BoundsI(1, 4, 1, 4) - im_quad = im_orig[b] - im_wrap = im.wrap(b) - np.testing.assert_allclose(im_wrap.array, 4.0 * im_quad.array) - - # The same thing should work no matter where the lower left corner is: - for xmin, ymin in ((1, 5), (5, 1), (5, 5), (2, 3), (4, 1)): - b = galsim.BoundsI(xmin, xmin + 3, ymin, ymin + 3) - im_quad = im_orig[b] - im = im_orig.copy() - im_wrap = im.wrap(b) - np.testing.assert_allclose( - im_wrap.array, - 4.0 * im_quad.array, - err_msg="image.wrap(%s) did not match expectation" % b, - ) - np.testing.assert_allclose( - im_wrap.array, - im[b].array, - err_msg="image.wrap(%s) did not return the right subimage" % b, - ) - # this test passes even though we do not get a view - im[b].fill(0) - np.testing.assert_allclose( - im_wrap.array, - im[b].array, - err_msg="image.wrap(%s) did not return a view of the original" % b, - ) - - -@timer -def test_wrap_jax_weird_real(): - # Now test where the subimage is not a simple fraction of the original, and all the - # sizes are different. - im = galsim.ImageD(17, 23, xmin=0, ymin=0) - b = galsim.BoundsI(7, 9, 11, 18) - im_test = galsim.ImageD(b, init_value=0) - for i in range(17): - for j in range(23): - val = np.exp(i / 7.3) + (j / 12.9) ** 3 # Something randomly complicated... - im[i, j] = val - # Find the location in the sub-image for this point. - ii = (i - b.xmin) % (b.xmax - b.xmin + 1) + b.xmin - jj = (j - b.ymin) % (b.ymax - b.ymin + 1) + b.ymin - im_test.addValue(ii, jj, val) - im_wrap = im.wrap(b) - np.testing.assert_allclose( - im_wrap.array, - im_test.array, - err_msg="image.wrap(%s) did not match expectation" % b, - ) - np.testing.assert_array_equal( - im_wrap.array, im[b].array, "image.wrap(%s) did not return the right subimage" - ) - np.testing.assert_equal( - im_wrap.bounds, b, "image.wrap(%s) does not have the correct bounds" - ) - - -@timer -def test_wrap_jax_complex(): - # For complex images (in particular k-space images), we often want the image to be implicitly - # Hermitian, so we only need to keep around half of it. - M = 38 - N = 25 - K = 8 - L = 5 - im = galsim.ImageCD(2 * M + 1, 2 * N + 1, xmin=-M, ymin=-N) # Explicitly Hermitian - im2 = galsim.ImageCD( - 2 * M + 1, N + 1, xmin=-M, ymin=0 - ) # Implicitly Hermitian across y axis - im3 = galsim.ImageCD( - M + 1, 2 * N + 1, xmin=0, ymin=-N - ) # Implicitly Hermitian across x axis - # print('im = ',im) - # print('im2 = ',im2) - # print('im3 = ',im3) - b = galsim.BoundsI(-K + 1, K, -L + 1, L) - b2 = galsim.BoundsI(-K + 1, K, 0, L) - b3 = galsim.BoundsI(0, K, -L + 1, L) - im_test = galsim.ImageCD(b, init_value=0) - for i in range(-M, M + 1): - for j in range(-N, N + 1): - # An arbitrary, complicated Hermitian function. - val = ( - np.exp((i / (2.3 * M)) ** 2 + 1j * (2.8 * i - 1.3 * j)) - + ((2 + 3j * j) / (1.9 * N)) ** 3 - ) - # val = 2*(i-j)**2 + 3j*(i+j) - - im[i, j] = val - if j >= 0: - im2[i, j] = val - if i >= 0: - im3[i, j] = val - - ii = (i - b.xmin) % (b.xmax - b.xmin + 1) + b.xmin - jj = (j - b.ymin) % (b.ymax - b.ymin + 1) + b.ymin - im_test.addValue(ii, jj, val) - # print("im = ",im.array) - - # Confirm that the image is Hermitian. - for i in range(-M, M + 1): - for j in range(-N, N + 1): - assert im(i, j) == im(-i, -j).conjugate() - - im_exp = expand_hermitian_x(im3.array) - np.testing.assert_allclose( - im_exp, - im.array, - err_msg="expand_hermitian_x() did not match expectation", - ) - - im_cnt = contract_hermitian_x(im.array) - np.testing.assert_allclose( - im_cnt, - im3.array, - err_msg="contract_hermitian_x() did not match expectation", - ) - - im_exp = expand_hermitian_y(im2.array) - np.testing.assert_allclose( - im_exp, - im.array, - err_msg="expand_hermitian_x() did not match expectation", - ) - - im_cnt = contract_hermitian_y(im.array) - np.testing.assert_allclose( - im_cnt, - im2.array, - err_msg="contract_hermitian_x() did not match expectation", - ) - - im_wrap = im.wrap(b) - # print("im_wrap = ",im_wrap.array) - np.testing.assert_allclose( - im_wrap.array, - im_test.array, - err_msg="image.wrap(%s) did not match expectation" % b, - ) - np.testing.assert_array_equal( - im_wrap.array, - im[b].array, - "image.wrap(%s) did not return the right subimage" % b, - ) - np.testing.assert_equal( - im_wrap.bounds, b, "image.wrap(%s) does not have the correct bounds" % b - ) - - im2_wrap = im2.wrap(b2, hermitian="y") - # print('im_test = ',im_test[b2].array) - # print('im2_wrap = ',im2_wrap.array) - # print('diff = ',im2_wrap.array-im_test[b2].array) - np.testing.assert_allclose( - im2_wrap.array, - im_test[b2].array, - err_msg="image.wrap(%s) did not match expectation" % b, - ) - np.testing.assert_array_equal( - im2_wrap.array, - im2[b2].array, - "image.wrap(%s) did not return the right subimage", - ) - np.testing.assert_equal( - im2_wrap.bounds, b2, "image.wrap(%s) does not have the correct bounds" - ) - - im3_wrap = im3.wrap(b3, hermitian="x") - # print('im_test = ',im_test[b3].array) - # print('im3_wrap = ',im3_wrap.array) - # print('diff = ',im3_wrap.array-im_test[b3].array) - np.testing.assert_allclose( - im3_wrap.array, - im_test[b3].array, - err_msg="image.wrap(%s, hermitian='x') did not match expectation" % b3, - ) - np.testing.assert_array_equal( - im3_wrap.array, - im3[b3].array, - "image.wrap(%s, hermitian='x') did not match expectation" % b3, - ) - np.testing.assert_equal( - im3_wrap.bounds, - b3, - "image.wrap(%s, hermitian='x') did not match expectation" % b3, - ) - - b = galsim.BoundsI(-K + 1, K, -L + 1, L) - b2 = galsim.BoundsI(-K + 1, K, 0, L) - b3 = galsim.BoundsI(0, K, -L + 1, L) - assert_raises(TypeError, im.wrap, bounds=None) - assert_raises(ValueError, im.wrap, b2, hermitian="y") - assert_raises(ValueError, im.wrap, b, hermitian="invalid") - assert_raises(ValueError, im.wrap, b3, hermitian="x") - - assert_raises(ValueError, im2.wrap, b, hermitian="y") - assert_raises(ValueError, im2.wrap, b3, hermitian="y") - assert_raises(ValueError, im2.wrap, b2, hermitian="invalid") - assert_raises(ValueError, im3.wrap, b3, hermitian="invalid") - assert_raises(ValueError, im3.wrap, b, hermitian="x") - assert_raises(ValueError, im3.wrap, b2, hermitian="x") diff --git a/tests/jax/test_image_wrapping.py b/tests/jax/test_image_wrapping.py new file mode 100644 index 00000000..99fcbfb3 --- /dev/null +++ b/tests/jax/test_image_wrapping.py @@ -0,0 +1,82 @@ +import jax_galsim as galsim +import numpy as np +from galsim_test_helpers import timer + +from jax_galsim.core.wrap_image import ( + expand_hermitian_x, expand_hermitian_y, contract_hermitian_x, + contract_hermitian_y, +) + + +@timer +def test_image_wrapping(): + # For complex images (in particular k-space images), we often want the image to be implicitly + # Hermitian, so we only need to keep around half of it. + M = 38 + N = 25 + K = 8 + L = 5 + im = galsim.ImageCD(2 * M + 1, 2 * N + 1, xmin=-M, ymin=-N) # Explicitly Hermitian + im2 = galsim.ImageCD( + 2 * M + 1, N + 1, xmin=-M, ymin=0 + ) # Implicitly Hermitian across y axis + im3 = galsim.ImageCD( + M + 1, 2 * N + 1, xmin=0, ymin=-N + ) # Implicitly Hermitian across x axis + # print('im = ',im) + # print('im2 = ',im2) + # print('im3 = ',im3) + b = galsim.BoundsI(-K + 1, K, -L + 1, L) + im_test = galsim.ImageCD(b, init_value=0) + for i in range(-M, M + 1): + for j in range(-N, N + 1): + # An arbitrary, complicated Hermitian function. + val = ( + np.exp((i / (2.3 * M)) ** 2 + 1j * (2.8 * i - 1.3 * j)) + + ((2 + 3j * j) / (1.9 * N)) ** 3 + ) + # val = 2*(i-j)**2 + 3j*(i+j) + + im[i, j] = val + if j >= 0: + im2[i, j] = val + if i >= 0: + im3[i, j] = val + + ii = (i - b.xmin) % (b.xmax - b.xmin + 1) + b.xmin + jj = (j - b.ymin) % (b.ymax - b.ymin + 1) + b.ymin + im_test.addValue(ii, jj, val) + # print("im = ",im.array) + + # Confirm that the image is Hermitian. + for i in range(-M, M + 1): + for j in range(-N, N + 1): + assert im(i, j) == im(-i, -j).conjugate() + + im_exp = expand_hermitian_x(im3.array) + np.testing.assert_allclose( + im_exp, + im.array, + err_msg="expand_hermitian_x() did not match expectation", + ) + + im_cnt = contract_hermitian_x(im.array) + np.testing.assert_allclose( + im_cnt, + im3.array, + err_msg="contract_hermitian_x() did not match expectation", + ) + + im_exp = expand_hermitian_y(im2.array) + np.testing.assert_allclose( + im_exp, + im.array, + err_msg="expand_hermitian_x() did not match expectation", + ) + + im_cnt = contract_hermitian_y(im.array) + np.testing.assert_allclose( + im_cnt, + im2.array, + err_msg="contract_hermitian_x() did not match expectation", + ) From 35b3164814a392b25cd0b47e2ef3334fae25029f Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Fri, 22 Sep 2023 08:29:34 -0500 Subject: [PATCH 06/67] Update tests/jax/test_image_wrapping.py --- tests/jax/test_image_wrapping.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/jax/test_image_wrapping.py b/tests/jax/test_image_wrapping.py index 99fcbfb3..1272c325 100644 --- a/tests/jax/test_image_wrapping.py +++ b/tests/jax/test_image_wrapping.py @@ -9,7 +9,7 @@ @timer -def test_image_wrapping(): +def test_image_wrapping_expand_contract(): # For complex images (in particular k-space images), we often want the image to be implicitly # Hermitian, so we only need to keep around half of it. M = 38 From 991db7815c2263c54e943673061bd046c2b36316 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 22 Sep 2023 08:30:56 -0500 Subject: [PATCH 07/67] STY blacken --- tests/jax/test_image_wrapping.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/jax/test_image_wrapping.py b/tests/jax/test_image_wrapping.py index 1272c325..c420d92e 100644 --- a/tests/jax/test_image_wrapping.py +++ b/tests/jax/test_image_wrapping.py @@ -3,7 +3,9 @@ from galsim_test_helpers import timer from jax_galsim.core.wrap_image import ( - expand_hermitian_x, expand_hermitian_y, contract_hermitian_x, + expand_hermitian_x, + expand_hermitian_y, + contract_hermitian_x, contract_hermitian_y, ) From ca388523bbce2718b678c25e146f65d3f7493d76 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 22 Sep 2023 08:57:50 -0500 Subject: [PATCH 08/67] TST add test of rev-mode autodiff --- tests/jax/test_image_wrapping.py | 66 ++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/tests/jax/test_image_wrapping.py b/tests/jax/test_image_wrapping.py index c420d92e..268e10b4 100644 --- a/tests/jax/test_image_wrapping.py +++ b/tests/jax/test_image_wrapping.py @@ -1,3 +1,4 @@ +import jax import jax_galsim as galsim import numpy as np from galsim_test_helpers import timer @@ -82,3 +83,68 @@ def test_image_wrapping_expand_contract(): im2.array, err_msg="contract_hermitian_x() did not match expectation", ) + + +@timer +def test_image_wrapping_autodiff(): + # For complex images (in particular k-space images), we often want the image to be implicitly + # Hermitian, so we only need to keep around half of it. + M = 38 + N = 25 + K = 8 + L = 5 + im = galsim.ImageCD(2 * M + 1, 2 * N + 1, xmin=-M, ymin=-N) # Explicitly Hermitian + im2 = galsim.ImageCD( + 2 * M + 1, N + 1, xmin=-M, ymin=0 + ) # Implicitly Hermitian across y axis + im3 = galsim.ImageCD( + M + 1, 2 * N + 1, xmin=0, ymin=-N + ) # Implicitly Hermitian across x axis + # print('im = ',im) + # print('im2 = ',im2) + # print('im3 = ',im3) + b = galsim.BoundsI(-K + 1, K, -L + 1, L) + im_test = galsim.ImageCD(b, init_value=0) + for i in range(-M, M + 1): + for j in range(-N, N + 1): + # An arbitrary, complicated Hermitian function. + val = ( + np.exp((i / (2.3 * M)) ** 2 + 1j * (2.8 * i - 1.3 * j)) + + ((2 + 3j * j) / (1.9 * N)) ** 3 + ) + # val = 2*(i-j)**2 + 3j*(i+j) + + im[i, j] = val + if j >= 0: + im2[i, j] = val + if i >= 0: + im3[i, j] = val + + ii = (i - b.xmin) % (b.xmax - b.xmin + 1) + b.xmin + jj = (j - b.ymin) % (b.ymax - b.ymin + 1) + b.ymin + im_test.addValue(ii, jj, val) + # print("im = ",im.array) + + # Confirm that the image is Hermitian. + for i in range(-M, M + 1): + for j in range(-N, N + 1): + assert im(i, j) == im(-i, -j).conjugate() + + @jax.jit + def _wrapit(im): + b3 = galsim.BoundsI(0, K, -L + 1, L) + return im.wrap(b3) + + # make sure this runs + p, grad = jax.vjp(_wrapit, im3) + + grad = jax.jit(grad) + grad(p) + + def _wrapit(im): + b3 = galsim.BoundsI(0, K, -L + 1, L) + return im.wrap(b3) + + # make sure this runs + p, grad = jax.vjp(_wrapit, im3) + grad(p) From 579c15115085eb6aeed3c0f82385db953afb8e53 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 25 Sep 2023 13:25:27 -0500 Subject: [PATCH 09/67] WIP getting closer --- jax_galsim/__init__.py | 1 + jax_galsim/image.py | 16 +- jax_galsim/interpolatedimage.py | 632 ++++++++++++++++------ jax_galsim/utilities.py | 1 + jax_galsim/wcs.py | 6 + tests/GalSim | 2 +- tests/galsim_tests_config.yaml | 7 + tests/jax/test_interpolatedimage_utils.py | 64 +++ 8 files changed, 551 insertions(+), 178 deletions(-) create mode 100644 tests/jax/test_interpolatedimage_utils.py diff --git a/jax_galsim/__init__.py b/jax_galsim/__init__.py index 893eb46f..159ad36b 100644 --- a/jax_galsim/__init__.py +++ b/jax_galsim/__init__.py @@ -73,6 +73,7 @@ Quintic, Lanczos, ) +from .interpolatedimage import InterpolatedImage # packages kept separate from . import bessel diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 53b9130d..23cf6ae4 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -656,10 +656,14 @@ def calculate_fft(self): ) No2 = jnp.maximum( - -self.bounds.xmin, - self.bounds.xmax + 1, - -self.bounds.ymin, - self.bounds.ymax + 1, + jnp.maximum( + -self.bounds.xmin, + self.bounds.xmax + 1, + ), + jnp.maximum( + -self.bounds.ymin, + self.bounds.ymax + 1, + ), ) full_bounds = BoundsI(-No2, No2 - 1, -No2, No2 - 1) @@ -676,7 +680,7 @@ def calculate_fft(self): dk = jnp.pi / (No2 * dx) out = Image(BoundsI(0, No2, -No2, No2 - 1), dtype=np.complex128, scale=dk) - out._image = jnp.fft.rfft2(ximage._image) + out._array = jnp.fft.rfft2(ximage.array) out *= dx * dx out.setOrigin(0, -No2) @@ -725,7 +729,7 @@ def calculate_inverse_fft(self): # For the inverse, we need a bit of extra space for the fft. out_extra = Image(BoundsI(-No2, No2 + 1, -No2, No2 - 1), dtype=float, scale=dx) - out_extra._image = jnp.fft.irfft2(kimage._image) + out_extra._array = jnp.fft.irfft2(kimage.array) # Now cut off the bit we don't need. out = out_extra.subImage(BoundsI(-No2, No2 - 1, -No2, No2 - 1)) out *= (dk * No2 / jnp.pi) ** 2 diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index bb6bd02b..9f2e3fc5 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -1,4 +1,6 @@ import textwrap +import jax +from functools import partial import jax.numpy as jnp from jax._src.numpy.util import _wraps from jax.tree_util import register_pytree_node_class @@ -16,6 +18,10 @@ from jax_galsim.interpolant import Quintic from jax_galsim.utilities import convert_interpolant from jax_galsim.bounds import BoundsI +from jax_galsim import fits +from jax_galsim.core.draw import draw_by_xValue +from jax_galsim.transform import Transformation +from jax_galsim.wcs import PixelScale @_wraps( @@ -25,32 +31,109 @@ - noise padding - depixelize - - reading images from FITS files + Further, it always computes the FFT of the image as opposed to galsim + where this is done as needed. One almost always needs the FFT and JAX + generally works best with pure functions that do not modify state. """ ), ) @register_pytree_node_class -class InterpolatedImage(GSObject): - _req_params = {'image': str} +class InterpolatedImage(Transformation): + _req_params = {"image": str} _opt_params = { - 'x_interpolant': str, - 'k_interpolant': str, - 'normalization': str, - 'scale': float, - 'flux': float, - 'pad_factor': float, - 'noise_pad_size': float, - 'noise_pad': str, - 'pad_image': str, - 'calculate_stepk': bool, - 'calculate_maxk': bool, - 'use_true_center': bool, - 'depixelize': bool, - 'offset': PositionD, - 'hdu': int + "x_interpolant": str, + "k_interpolant": str, + "normalization": str, + "scale": float, + "flux": float, + "pad_factor": float, + "noise_pad_size": float, + "noise_pad": str, + "pad_image": str, + "calculate_stepk": bool, + "calculate_maxk": bool, + "use_true_center": bool, + "depixelize": bool, + "offset": PositionD, + "hdu": int, } _takes_rng = True + + def __init__( + self, + image, + x_interpolant=None, + k_interpolant=None, + normalization="flux", + scale=None, + wcs=None, + flux=None, + pad_factor=4.0, + noise_pad_size=0, + noise_pad=0.0, + rng=None, + pad_image=None, + calculate_stepk=True, + calculate_maxk=True, + use_cache=True, + use_true_center=True, + depixelize=False, + offset=None, + gsparams=None, + _force_stepk=0.0, + _force_maxk=0.0, + hdu=None, + ): + obj = InterpolatedImageImpl( + image, + x_interpolant=x_interpolant, + k_interpolant=k_interpolant, + normalization=normalization, + scale=scale, + wcs=wcs, + flux=flux, + pad_factor=pad_factor, + noise_pad_size=noise_pad_size, + noise_pad=noise_pad, + rng=rng, + pad_image=pad_image, + calculate_stepk=calculate_stepk, + calculate_maxk=calculate_maxk, + use_cache=use_cache, + use_true_center=use_true_center, + depixelize=depixelize, + offset=offset, + gsparams=gsparams, + _force_stepk=_force_stepk, + _force_maxk=_force_maxk, + hdu=hdu, + ) + super().__init__( + obj, + jac=obj._jac_arr, + flux_ratio=obj._flux_ratio / obj._wcs.pixelArea(), + offset=PositionD(0.0, 0.0), + ) + + @property + def x_interpolant(self): + """The real-space `Interpolant` for this profile.""" + return self._original._x_interpolant + + @property + def k_interpolant(self): + """The Fourier-space `Interpolant` for this profile.""" + return self._original._k_interpolant + + @property + def image(self): + """The underlying `Image` being interpolated.""" + return self._original._image + + +@register_pytree_node_class +class InterpolatedImageImpl(GSObject): _cache_noise_pad = {} _has_hard_edges = False @@ -58,35 +141,89 @@ class InterpolatedImage(GSObject): _is_analytic_x = True _is_analytic_k = True - def __init__(self, image, x_interpolant=None, k_interpolant=None, normalization='flux', - scale=None, wcs=None, flux=None, pad_factor=4., noise_pad_size=0, noise_pad=0., - rng=None, pad_image=None, calculate_stepk=True, calculate_maxk=True, - use_cache=True, use_true_center=True, depixelize=False, offset=None, - gsparams=None, _force_stepk=0., _force_maxk=0., hdu=None): + def __init__( + self, + image, + x_interpolant=None, + k_interpolant=None, + normalization="flux", + scale=None, + wcs=None, + flux=None, + pad_factor=4.0, + noise_pad_size=0, + noise_pad=0.0, + rng=None, + pad_image=None, + calculate_stepk=True, + calculate_maxk=True, + use_cache=True, + use_true_center=True, + depixelize=False, + offset=None, + gsparams=None, + _force_stepk=0.0, + _force_maxk=0.0, + hdu=None, + ): + # this class does a ton of munging of the inputs that I don't want to reconstruct when + # flattening and unflattening the class. + # thus I am going to make some refs here so we have it when we need it + self._jax_children = ( + image, + dict( + scale=scale, + wcs=wcs, + flux=flux, + pad_image=pad_image, + offset=offset, + ), + ) + self._jax_aux_data = dict( + x_interpolant=x_interpolant, + k_interpolant=k_interpolant, + normalization=normalization, + pad_factor=pad_factor, + noise_pad_size=noise_pad_size, + noise_pad=noise_pad, + rng=rng, + calculate_stepk=calculate_stepk, + calculate_maxk=calculate_maxk, + use_cache=use_cache, + use_true_center=use_true_center, + depixelize=depixelize, + gsparams=gsparams, + _force_stepk=_force_stepk, + _force_maxk=_force_maxk, + hdu=hdu, + ) + self._params = {} from .wcs import BaseWCS, PixelScale + # FIXME: no BaseDeviate in jax_galsim # from .random import BaseDeviate # If the "image" is not actually an image, try to read the image as a file. if isinstance(image, str): - # FIXME: no FITSIO in jax_galsim - # image = fits.read(image, hdu=hdu) - raise NotImplementedError( - "Reading InterpolatedImages from FITS files is not implemented in jax_galsim." - ) + image = fits.read(image, hdu=hdu) elif not isinstance(image, Image): raise TypeError("Supplied image must be an Image or file name") # it must have well-defined bounds, otherwise seg fault in SBInterpolatedImage constructor if not image.bounds.isDefined(): - raise GalSimUndefinedBoundsError("Supplied image does not have bounds defined.") + raise GalSimUndefinedBoundsError( + "Supplied image does not have bounds defined." + ) # check what normalization was specified for the image: is it an image of surface # brightness, or flux? if normalization.lower() not in ("flux", "f", "surface brightness", "sb"): - raise GalSimValueError("Invalid normalization requested.", normalization, - ('flux', 'f', 'surface brightness', 'sb')) + raise GalSimValueError( + "Invalid normalization requested.", + normalization, + ("flux", "f", "surface brightness", "sb"), + ) # Set up the interpolants if none was provided by user, or check that the user-provided ones # are of a valid type @@ -94,18 +231,24 @@ def __init__(self, image, x_interpolant=None, k_interpolant=None, normalization= if x_interpolant is None: self._x_interpolant = Quintic(gsparams=self._gsparams) else: - self._x_interpolant = convert_interpolant(x_interpolant).withGSParams(self._gsparams) + self._x_interpolant = convert_interpolant(x_interpolant).withGSParams( + self._gsparams + ) if k_interpolant is None: self._k_interpolant = Quintic(gsparams=self._gsparams) else: - self._k_interpolant = convert_interpolant(k_interpolant).withGSParams(self._gsparams) + self._k_interpolant = convert_interpolant(k_interpolant).withGSParams( + self._gsparams + ) # Store the image as an attribute and make sure we don't change the original image # in anything we do here. (e.g. set scale, etc.) if depixelize: # FIXME: no depixelize in jax_galsim # self._image = image.view(dtype=np.float64).depixelize(self._x_interpolant) - raise NotImplementedError("InterpolatedImages do not support 'depixelize' in jax_galsim.") + raise NotImplementedError( + "InterpolatedImages do not support 'depixelize' in jax_galsim." + ) else: self._image = image.view(dtype=jnp.float64, contiguous=True) self._image.setCenter(0, 0) @@ -114,7 +257,10 @@ def __init__(self, image, x_interpolant=None, k_interpolant=None, normalization= if scale is not None: if wcs is not None: raise GalSimIncompatibleValuesError( - "Cannot provide both scale and wcs to InterpolatedImage", scale=scale, wcs=wcs) + "Cannot provide both scale and wcs to InterpolatedImage", + scale=scale, + wcs=wcs, + ) self._image.wcs = PixelScale(scale) elif wcs is not None: if not isinstance(wcs, BaseWCS): @@ -123,28 +269,46 @@ def __init__(self, image, x_interpolant=None, k_interpolant=None, normalization= elif self._image.wcs is None: raise GalSimIncompatibleValuesError( "No information given with Image or keywords about pixel scale!", - scale=scale, wcs=wcs, image=image) + scale=scale, + wcs=wcs, + image=image, + ) # Figure out the offset to apply based on the original image (not the padded one). # We will apply this below in _sbp. offset = self._parse_offset(offset) - self._offset = self._adjust_offset(self._image.bounds, offset, None, use_true_center) + self._offset = self._adjust_offset( + self._image.bounds, offset, None, use_true_center + ) im_cen = image.true_center if use_true_center else image.center + self._jac_arr = self._image.wcs.jacobian(image_pos=im_cen).getMatrix().ravel() self._wcs = self._image.wcs.local(image_pos=im_cen) # Build the fully padded real-space image according to the various pad options. - self._buildRealImage(pad_factor, pad_image, noise_pad_size, noise_pad, rng, use_cache) - self._image_flux = jnp.sum(self._image.array, dtype=jnp.float64) + self._buildImages( + pad_factor, + pad_image, + noise_pad_size, + noise_pad, + rng, + use_cache, + flux, + normalization, + ) # I think the only things that will mess up if flux == 0 are the # calculateStepK and calculateMaxK functions, and rescaling the flux to some value. - if (calculate_stepk or calculate_maxk or flux is not None) and self._image_flux == 0.: - raise GalSimValueError("This input image has zero total flux. It does not define a " - "valid surface brightness profile.", image) + if ( + calculate_stepk or calculate_maxk or flux is not None + ) and self._image_flux == 0.0: + raise GalSimValueError( + "This input image has zero total flux. It does not define a " + "valid surface brightness profile.", + image, + ) # Process the different options for flux, stepk, maxk - self._flux = self._getFlux(flux, normalization) self._calculate_stepk = calculate_stepk self._calculate_maxk = calculate_maxk self._stepk = self._getStepK(calculate_stepk, _force_stepk) @@ -164,86 +328,64 @@ def withGSParams(self, gsparams=None, **kwargs): ret._x_interpolant = self._x_interpolant.withGSParams(ret._gsparams, **kwargs) ret._k_interpolant = self._k_interpolant.withGSParams(ret._gsparams, **kwargs) if ret._gsparams.folding_threshold != self._gsparams.folding_threshold: - ret._stepk = ret._getStepK(self._calculate_stepk, 0.) + ret._stepk = ret._getStepK(self._calculate_stepk, 0.0) if ret._gsparams.maxk_threshold != self._gsparams.maxk_threshold: - ret._maxk = ret._getMaxK(self._calculate_maxk, 0.) + ret._maxk = ret._getMaxK(self._calculate_maxk, 0.0) return ret def tree_flatten(self): """This function flattens the InterpolatedImage into a list of children nodes that will be traced by JAX and auxiliary static data.""" - # Define the children nodes of the PyTree that need tracing - children = (self.params,) - # Define auxiliary static data that doesn’t need to be traced - aux_data = {"gsparams": self.gsparams} - return (children, aux_data) + return (self._jax_children, self._jax_aux_data) @classmethod def tree_unflatten(cls, aux_data, children): """Recreates an instance of the class from flatten representation""" - return cls(**(children[0]), **aux_data) - - def __eq__(self, other): - return (self is other or - (isinstance(other, InterpolatedImage) and - self._xim == other._xim and - self.x_interpolant == other.x_interpolant and - self.k_interpolant == other.k_interpolant and - self.flux == other.flux and - self._offset == other._offset and - self.gsparams == other.gsparams and - self._stepk == other._stepk and - self._maxk == other._maxk)) - - # TODO: do this in JAX OFC - # @lazy_property - # def _sbp(self): - # min_scale = self._wcs._minScale() - # max_scale = self._wcs._maxScale() - # self._sbii = _galsim.SBInterpolatedImage( - # self._xim._image, self._image.bounds._b, self._pad_image.bounds._b, - # self._x_interpolant._i, self._k_interpolant._i, - # self._stepk*min_scale, - # self._maxk*max_scale, - # self.gsparams._gsp) - - # self._sbp = self._sbii # Temporary. Will overwrite this with the return value. - - # # Apply the offset - # prof = self - # if self._offset != _PositionD(0,0): - # # Opposite direction of what drawImage does. - # prof = prof._shift(-self._offset.x, -self._offset.y) - - # # If the user specified a flux, then set to that flux value. - # if self._flux != self._image_flux: - # flux_ratio = self._flux / self._image_flux - # else: - # flux_ratio = 1. - - # # Bring the profile from image coordinates into world coordinates - # # Note: offset needs to happen first before the transformation, so can't bundle it here. - # prof = self._wcs._profileToWorld(prof, flux_ratio, _PositionD(0,0)) + val = {} + val.update(aux_data) + val.update(children[1]) + return cls(children[0], **val) + + def _buildImages( + self, + pad_factor, + pad_image, + noise_pad_size, + noise_pad, + rng, + use_cache, + flux, + normalization, + ): + # If the user specified a surface brightness normalization for the input Image, then + # need to rescale flux by the pixel area to get proper normalization. + self._image_flux = jnp.sum(self._image.array, dtype=float) + if flux is None: + flux = self._image_flux + if normalization.lower() in ("surface brightness", "sb"): + flux *= self._wcs.pixelArea() + self._flux = flux - # return prof._sbp + # If the user specified a flux, then set the flux ratio for the transform that wraps + # this class + if self._flux != self._image_flux: + self._flux_ratio = self._flux / self._image_flux + else: + self._flux_ratio = 1.0 - def _buildRealImage(self, pad_factor, pad_image, noise_pad_size, noise_pad, rng, use_cache): # Check that given pad_image is valid: if pad_image is not None: if isinstance(pad_image, str): - # FIXME: no FITSIO in jax_galsim - # pad_image = fits.read(pad_image).view(dtype=np.float64) - raise NotImplementedError( - "Reading padding images for InterpolatedImages from FITS files " - "is not implemented in jax_galsim." - ) + pad_image = fits.read(pad_image).view(dtype=jnp.float64) elif isinstance(pad_image, Image): pad_image = pad_image.view(dtype=jnp.float64, contiguous=True) else: raise TypeError("Supplied pad_image must be an Image.", pad_image) - if pad_factor <= 0.: - raise GalSimRangeError("Invalid pad_factor <= 0 in InterpolatedImage", pad_factor, 0.) + if pad_factor <= 0.0: + raise GalSimRangeError( + "Invalid pad_factor <= 0 in InterpolatedImage", pad_factor, 0.0 + ) # Convert noise_pad_size from arcsec to pixels according to the local wcs. # Use the minimum scale, since we want to make sure noise_pad_size is @@ -258,18 +400,22 @@ def _buildRealImage(self, pad_factor, pad_image, noise_pad_size, noise_pad, rng, # noise_pad=noise_pad, noise_pad_size=noise_pad_size) # noise_pad_size = int(math.ceil(noise_pad_size / self._wcs._minScale())) # noise_pad_size = Image.good_fft_size(noise_pad_size) - raise NotImplementedError("InterpolatedImages do not support noise padding in jax_galsim.") + raise NotImplementedError( + "InterpolatedImages do not support noise padding in jax_galsim." + ) else: if noise_pad: # FIXME: no BaseDeviate in jax_galsim so no noise padding # raise GalSimIncompatibleValuesError( # "Must provide noise_pad_size if noise_pad != 0", # noise_pad=noise_pad, noise_pad_size=noise_pad_size) - raise NotImplementedError("InterpolatedImages do not support noise padding in jax_galsim.") + raise NotImplementedError( + "InterpolatedImages do not support noise padding in jax_galsim." + ) # The size of the final padded image is the largest of the various size specifications pad_size = max(self._image.array.shape) - if pad_factor > 1.: + if pad_factor > 1.0: pad_size = int(math.ceil(pad_factor * pad_size)) if noise_pad_size: pad_size = max(pad_size, noise_pad_size) @@ -279,8 +425,9 @@ def _buildRealImage(self, pad_factor, pad_image, noise_pad_size, noise_pad, rng, # And round up to a good fft size pad_size = Image.good_fft_size(pad_size) - self._xim = Image(pad_size, pad_size, dtype=jnp.float64, wcs=self._wcs) + self._xim = Image(pad_size, pad_size, dtype=jnp.float64, wcs=PixelScale(1.0)) self._xim.setCenter(0, 0) + self._image.wcs = PixelScale(1.0) # If requested, fill (some of) this image with noise padding. nz_bounds = self._image.bounds @@ -299,7 +446,6 @@ def _buildRealImage(self, pad_factor, pad_image, noise_pad_size, noise_pad, rng, # Now place the given image in the center of the padding image: # assert self._xim.bounds.includes(self._image.bounds) self._xim[self._image.bounds] = self._image - self._xim.wcs = self._wcs # And update the _image to be that portion of the full real image rather than the # input image. @@ -311,6 +457,9 @@ def _buildRealImage(self, pad_factor, pad_image, noise_pad_size, noise_pad, rng, # self._pad_factor = (max(self._xim.array.shape)-1.e-6) / max(self._image.array.shape) self._pad_factor = pad_factor + # we always make this + self._kim = self._xim.calculate_fft() + # FIXME: no BaseDeviate in jax_galsim so no noise padding # def _buildNoisePadImage(self, noise_pad_size, noise_pad, rng, use_cache): # """A helper function that builds the ``pad_image`` from the given ``noise_pad`` @@ -362,15 +511,6 @@ def _buildRealImage(self, pad_factor, pad_image, noise_pad_size, noise_pad, rng, # noise_image.addNoise(noise) # return b - def _getFlux(self, flux, normalization): - # If the user specified a surface brightness normalization for the input Image, then - # need to rescale flux by the pixel area to get proper normalization. - if flux is None: - flux = self._image_flux - if normalization.lower() in ('surface brightness', 'sb'): - flux *= self._wcs.pixelArea() - return flux - def _getStepK(self, calculate_stepk, _force_stepk): # GalSim cannot automatically know what stepK and maxK are appropriate for the # input image. So it is usually worth it to do a manual calculation (below). @@ -382,7 +522,7 @@ def _getStepK(self, calculate_stepk, _force_stepk): # units required by the C++ layer below. Also note that profile recentering for even-sized # images (see the ._adjust_offset step below) leads to automatic reduction of stepK slightly # below what is provided here, while maxK is preserved. - if _force_stepk > 0.: + if _force_stepk > 0.0: return _force_stepk elif calculate_stepk: if calculate_stepk is True: @@ -395,13 +535,13 @@ def _getStepK(self, calculate_stepk, _force_stepk): im = self._image[b] thresh = (1.0 - self.gsparams.folding_threshold) * self._image_flux # this line appears buggy in galsim - I expect they meant to use im - R = _galsim.CalculateSizeContainingFlux(im._image, thresh) + R = _calculate_size_containing_flux(im, thresh) else: - R = jnp.max(self._image.array.shape) / 2. - 0.5 + R = max(*self._image.array.shape) / 2.0 - 0.5 return self._getSimpleStepK(R) def _getSimpleStepK(self, R): - min_scale = self._wcs._minScale() + min_scale = 1.0 # Add xInterp range in quadrature just like convolution: R2 = self._x_interpolant.xrange R = jnp.hypot(R, R2) @@ -409,28 +549,36 @@ def _getSimpleStepK(self, R): return stepk def _getMaxK(self, calculate_maxk, _force_maxk): - max_scale = self._wcs._maxScale() - if _force_maxk > 0.: + max_scale = 1.0 + if _force_maxk > 0.0: return _force_maxk elif calculate_maxk: - self._maxk = 0. - self._sbp + _uscale = 1 / (2 * jnp.pi) + self._maxk = self._x_interpolant.urange() / _uscale / max_scale + if calculate_maxk is True: - self._sbii.calculateMaxK(0.) + maxk = _find_maxk( + self._kim, self._maxk, self._gsparams.maxk_threshold * self.flux + ) else: - # If not a bool, then value is max_maxk - self._sbii.calculateMaxK(float(calculate_maxk)) - self.__dict__.pop('_sbp') # Need to remake it. - return self._sbii.maxK() / max_scale + maxk = _find_maxk( + self._kim, calculate_maxk, self._gsparams.maxk_threshold * self.flux + ) + + return maxk / max_scale else: return self._x_interpolant.krange / max_scale def __hash__(self): # Definitely want to cache this, since the size of the image could be large. - if not hasattr(self, '_hash'): - self._hash = hash(("galsim.InterpolatedImage", self.x_interpolant, self.k_interpolant)) + if not hasattr(self, "_hash"): + self._hash = hash( + ("galsim.InterpolatedImage", self.x_interpolant, self.k_interpolant) + ) self._hash ^= hash((self.flux, self._stepk, self._maxk, self._pad_factor)) - self._hash ^= hash((self._xim.bounds, self._image.bounds, self._pad_image.bounds)) + self._hash ^= hash( + (self._xim.bounds, self._image.bounds, self._pad_image.bounds) + ) # A common offset is 0.5,0.5, and *sometimes* this produces the same hash as 0,0 # (which is also common). I guess because they are only different in 2 bits. # This mucking of the numbers seems to help make the hash more reliably different for @@ -438,81 +586,96 @@ def __hash__(self): # https://stackoverflow.com/questions/27522626/hash-function-in-python-3-3-returns-different-results-between-sessions self._hash ^= hash((self._offset.x * 1.234, self._offset.y * 0.23424)) self._hash ^= hash(self._gsparams) - self._hash ^= hash(self._xim.wcs) + self._hash ^= hash(self._wcs) # Just hash the diagonal. Much faster, and usually is unique enough. # (Let python handle collisions as needed if multiple similar IIs are used as keys.) self._hash ^= hash(tuple(jnp.diag(self._pad_image.array))) return self._hash def __repr__(self): - s = 'galsim.InterpolatedImage(%r, %r, %r' % ( - self._image, self.x_interpolant, self.k_interpolant + s = "galsim.InterpolatedImage(%r, %r, %r" % ( + self._image, + self.x_interpolant, + self.k_interpolant, ) # Most things we keep even if not required, but the pad_image is large, so skip it # if it's really just the same as the main image. if self._pad_image.bounds != self._image.bounds: - s += ', pad_image=%r' % (self._pad_image) - s += ', pad_factor=%f, flux=%r, offset=%r' % (self._pad_factor, self.flux, self._offset) - s += ', use_true_center=False, gsparams=%r, _force_stepk=%r, _force_maxk=%r)' % ( - self.gsparams, self._stepk, self._maxk + s += ", pad_image=%r" % (self._pad_image) + s += ", wcs=%s" % self._wcs + s += ", pad_factor=%f, flux=%r, offset=%r" % ( + self._pad_factor, + self.flux, + self._offset, + ) + s += ( + ", use_true_center=False, gsparams=%r, _force_stepk=%r, _force_maxk=%r)" + % (self.gsparams, self._stepk, self._maxk) ) return s def __str__(self): - return 'galsim.InterpolatedImage(image=%s, flux=%s)' % (self.image, self.flux) + return "galsim.InterpolatedImage(image=%s, flux=%s)" % (self.image, self.flux) def __getstate__(self): d = self.__dict__.copy() - # TODO - probably remove these pops for things we don't have - d.pop('_sbii', None) - d.pop('_sbp', None) # Only pickle _pad_image. Not _xim or _image - d['_xim_bounds'] = self._xim.bounds - d['_image_bounds'] = self._image.bounds - d.pop('_xim', None) - d.pop('_image', None) + d["_xim_bounds"] = self._xim.bounds + d["_image_bounds"] = self._image.bounds + d.pop("_xim", None) + d.pop("_image", None) return d def __setstate__(self, d): - xim_bounds = d.pop('_xim_bounds') - image_bounds = d.pop('_image_bounds') + xim_bounds = d.pop("_xim_bounds") + image_bounds = d.pop("_image_bounds") self.__dict__ = d if self._pad_image.bounds == xim_bounds: self._xim = self._pad_image else: - self._xim = Image(xim_bounds, wcs=self._wcs, dtype=jnp.float64) + self._xim = Image(xim_bounds, wcs=PixelScale(1.0), dtype=jnp.float64) self._xim[self._pad_image.bounds] = self._pad_image self._image = self._xim[image_bounds] @property def x_interpolant(self): - """The real-space `Interpolant` for this profile. - """ + """The real-space `Interpolant` for this profile.""" return self._x_interpolant @property def k_interpolant(self): - """The Fourier-space `Interpolant` for this profile. - """ + """The Fourier-space `Interpolant` for this profile.""" return self._k_interpolant @property def image(self): - """The underlying `Image` being interpolated. - """ + """The underlying `Image` being interpolated.""" return self._image + @property + def _flux(self): + """By default, the flux is contained in the parameters dictionay.""" + return self._params["flux"] + + @_flux.setter + def _flux(self, value): + self._params["flux"] = value + @property def _centroid(self): - return PositionD(self._sbp.centroid()) + raise NotImplementedError("WIP interp - centroid") @property def _positive_flux(self): - return self._sbp.getPositiveFlux() + raise NotImplementedError("WIP interp - positive_flux") @property def _negative_flux(self): - return self._sbp.getNegativeFlux() + raise NotImplementedError("WIP interp - negative_flux") + + @property + def _max_sb(self): + return jnp.max(jnp.abs(self._pad_image.array)) # @lazy_property def _flux_per_photon(self): @@ -520,24 +683,151 @@ def _flux_per_photon(self): # return self._calculate_flux_per_photon() raise NotImplementedError("Photon shooting not implemented.") - @property - def _max_sb(self): - return self._sbp.maxSB() - def _xValue(self, pos): - return self._sbp.xValue(pos._p) + pos += self._offset + vals = _draw_with_interpolant_xval( + jnp.array([pos.x], dtype=float), + jnp.array([pos.y], dtype=float), + self._pad_image.bounds.xmin, + self._pad_image.bounds.ymin, + self._pad_image.array, + self._x_interpolant, + ) + return vals[0] def _kValue(self, kpos): - return self._sbp.kValue(kpos._p) + raise NotImplementedError("WIP interp - kValue") def _shoot(self, photons, rng): raise NotImplementedError("Photon shooting not implemented.") - def _drawReal(self, image, jac=None, offset=(0., 0.), flux_scaling=1.): - dx, dy = offset - _jac = 0 if jac is None else jac.__array_interface__['data'][0] - self._sbp.draw(image._image, image.scale, _jac, dx, dy, flux_scaling) + def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): + _jac = jnp.eye(2) if jac is None else jac + return draw_by_xValue(self, image, _jac, jnp.asarray(offset), flux_scaling) def _drawKImage(self, image, jac=None): - _jac = 0 if jac is None else jac.__array_interface__['data'][0] + raise NotImplementedError("WIP interp - drawKImage") + _jac = 0 if jac is None else jac.__array_interface__["data"][0] self._sbp.drawK(image._image, image.scale, _jac) + + +@partial(jax.jit, static_argnums=(5,)) +def _draw_with_interpolant_xval(x, y, xmin, ymin, zp, interp): + orig_shape = x.shape + x = x.ravel() + xi = jnp.floor(x - xmin).astype(jnp.int32) + xp = xi + xmin + nx = zp.shape[1] + + y = y.ravel() + yi = jnp.floor(y - ymin).astype(jnp.int32) + yp = yi + ymin + ny = zp.shape[0] + + def _body_1d(i, args): + z, wy, msky, yind, xi, xp, zp = args + + xind = xi + i + mskx = (xind >= 0) & (xind < nx) + _x = x - (xp + i) + wx = interp.xval(_x) + + w = wx * wy + msk = msky & mskx + z += jnp.where(msk, zp[yind, xind] * w, 0) + + return [z, wy, msky, yind, xi, xp, zp] + + def _body(i, args): + z, xi, yi, xp, yp, zp = args + yind = yi + i + msk = (yind >= 0) & (yind < ny) + _y = y - (yp + i) + wy = interp.xval(_y) + z = jax.lax.fori_loop( + -interp.xrange, interp.xrange + 1, _body_1d, [z, wy, msk, yind, xi, xp, zp] + )[0] + return [z, xi, yi, xp, yp, zp] + + z = jax.lax.fori_loop( + -interp.xrange, + interp.xrange + 1, + _body, + [jnp.zeros(x.shape, dtype=zp.dtype), xi, yi, xp, yp, zp], + )[0] + return z.reshape(orig_shape) + + +@jax.jit +def _flux_frac(a, x, y, cenx, ceny): + def _body(d, args): + res, a, dx, dy, cenx, ceny = args + msk = (jnp.abs(dx) <= d) & (jnp.abs(dx) <= d) + + res = res.at[d].set( + jnp.sum( + jnp.where( + msk, + a, + 0.0, + ) + ) + ) + + return [res, a, dx, dy, cenx, ceny] + + res = jnp.zeros(a.shape[0], dtype=float) - jnp.inf + return jax.lax.fori_loop( + 0, a.shape[0], _body, [res, a, x - cenx, y - ceny, cenx, ceny] + )[0] + + +def _calculate_size_containing_flux(image, thresh): + cenx, ceny = image.center.x, image.center.y + x, y = image.get_pixel_centers() + fluxes = _flux_frac(image.array, x, y, cenx, ceny) + msk = fluxes >= -jnp.inf + fluxes = jnp.where(msk, fluxes, jnp.max(fluxes)) + d = jnp.arange(image.array.shape[0]) + 1.0 + expfac = 4.0 + dint = jnp.arange(image.array.shape[0] * expfac) / expfac + 1.0 + fluxes = jnp.interp(dint, d, fluxes) + msk = fluxes <= thresh + return ( + jnp.argmax( + jnp.where( + msk, + dint, + -jnp.inf, + ) + ) + / expfac + + 1.0 + ) + + +@jax.jit +def _inner_comp_find_maxk(arr, thresh, kx, ky): + msk = arr * arr.conjugate() > thresh * thresh + max_kx = jnp.max( + jnp.where( + msk, + jnp.abs(kx), + -jnp.inf, + ) + ) + max_ky = jnp.max( + jnp.where( + msk, + jnp.abs(ky), + -jnp.inf, + ) + ) + return jnp.maximum(max_kx, max_ky) + + +def _find_maxk(kim, max_maxk, thresh): + kx, ky = kim.get_pixel_centers() + kx *= kim.scale + ky *= kim.scale + return _inner_comp_find_maxk(kim.array, thresh, kx, ky) * 1.15 diff --git a/jax_galsim/utilities.py b/jax_galsim/utilities.py index 53085203..2c1fa96b 100644 --- a/jax_galsim/utilities.py +++ b/jax_galsim/utilities.py @@ -93,6 +93,7 @@ def g1g2_to_e1e2(g1, g2): @_wraps(_galsim.utilities.convert_interpolant) def convert_interpolant(interpolant): from jax_galsim.interpolant import Interpolant + if isinstance(interpolant, Interpolant): return interpolant else: diff --git a/jax_galsim/wcs.py b/jax_galsim/wcs.py index 52a58221..e0806298 100644 --- a/jax_galsim/wcs.py +++ b/jax_galsim/wcs.py @@ -109,6 +109,12 @@ def local(self, image_pos=None, world_pos=None, color=None): raise TypeError("image_pos must be a PositionD or PositionI argument") return self._local(image_pos, color) + @_wraps(_galsim.BaseWCS.jacobian) + def jacobian(self, image_pos=None, world_pos=None, color=None): + if color is None: + color = self._color + return self.local(image_pos, world_pos, color=color)._toJacobian() + @_wraps(_galsim.BaseWCS.affine) def affine(self, image_pos=None, world_pos=None, color=None): if color is None: diff --git a/tests/GalSim b/tests/GalSim index 2cd1f41d..b7ce7573 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 2cd1f41d58de35d2f5488d0dce736d5f3ea3401e +Subproject commit b7ce7573f42b7d21120365b0373cb83a88953b45 diff --git a/tests/galsim_tests_config.yaml b/tests/galsim_tests_config.yaml index 7dfaec1b..91365258 100644 --- a/tests/galsim_tests_config.yaml +++ b/tests/galsim_tests_config.yaml @@ -9,6 +9,7 @@ enabled_tests: - test_utilities.py - test_shear.py - test_shear_position.py + - test_interpolatedimage.py # This documents which error messages will be allowed # without being reported as an error. These typically @@ -48,3 +49,9 @@ allowed_failures: - "'Image' object has no attribute 'bin'" - "has no attribute 'shoot'" - "module 'jax_galsim' has no attribute 'integ'" + - "module 'jax_galsim' has no attribute 'InterpolatedKImage'" + - "module 'jax_galsim' has no attribute 'GaussianDeviate'" + - "module 'jax_galsim' has no attribute 'CorrelatedNoise'" + - "'Image' object has no attribute 'FindAdaptiveMom'" + - " module 'jax_galsim' has no attribute '_InterpolatedImage'" + - "JAX arrays are immutable." diff --git a/tests/jax/test_interpolatedimage_utils.py b/tests/jax/test_interpolatedimage_utils.py new file mode 100644 index 00000000..a584732f --- /dev/null +++ b/tests/jax/test_interpolatedimage_utils.py @@ -0,0 +1,64 @@ +import galsim as _galsim +import jax_galsim +import numpy as np +import jax.numpy as jnp +from jax_galsim.interpolatedimage import _draw_with_interpolant_xval +from jax_galsim.interpolant import ( + Nearest, + # SincInterpolant, + Linear, + Cubic, + Quintic, + Lanczos, +) + +import pytest + + +@pytest.mark.parametrize("interp", [ + Nearest(), + Linear(), + # this is really slow right now and I am not sure why will fix later + # SincInterpolant(), + Linear(), + Cubic(), + Quintic(), + Lanczos(3, conserve_dc=False), + Lanczos(5, conserve_dc=True), +]) +def test_interpolatedimage_utils_draw_with_interpolant_xval(interp): + zp = jnp.array([ + [0.01, 0.08, 0.07, 0.02], + [0.13, 0.38, 0.52, 0.06], + [0.09, 0.41, 0.44, 0.09], + [0.04, 0.11, 0.10, 0.01] + ]) + for xmin in [-3, 0, 2]: + for ymin in [-5, 0, 1]: + for x in range(4): + for y in range(4): + np.testing.assert_allclose( + _draw_with_interpolant_xval( + jnp.array([x + xmin], dtype=float), + jnp.array([y + ymin], dtype=float), + xmin, ymin, zp, interp, + ), + zp[y, x], + ) + + +def test_interpolatedimage_utils_stepk_maxk(): + ref_array = np.array([ + [0.01, 0.08, 0.07, 0.02], + [0.13, 0.38, 0.52, 0.06], + [0.09, 0.41, 0.44, 0.09], + [0.04, 0.11, 0.10, 0.01] + ]) + test_scale = 2.0 + gimage_in = _galsim.Image(ref_array) + jgimage_in = jax_galsim.Image(ref_array) + gii = _galsim.InterpolatedImage(gimage_in, scale=test_scale) + jgii = jax_galsim.InterpolatedImage(jgimage_in, scale=test_scale) + + np.testing.assert_allclose(gii.stepk, jgii.stepk, rtol=0.2, atol=0) + np.testing.assert_allclose(gii.maxk, jgii.maxk, rtol=0.2, atol=0) From f537ace5cc9339a9d99fe26e546568fe8624075c Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 25 Sep 2023 13:27:14 -0500 Subject: [PATCH 10/67] STY black, sort, and flake --- jax_galsim/bessel.py | 3 +- jax_galsim/interpolant.py | 5 +- jax_galsim/interpolatedimage.py | 29 ++++---- tests/conftest.py | 2 +- tests/jax/galsim/test_image_jax.py | 6 +- tests/jax/galsim/test_interpolant_jax.py | 12 ++-- tests/jax/galsim/test_shear_position_jax.py | 5 +- tests/jax/test_interpolatedimage_utils.py | 77 ++++++++++++--------- tests/jax/test_temporary_tests.py | 24 +++---- 9 files changed, 87 insertions(+), 76 deletions(-) diff --git a/jax_galsim/bessel.py b/jax_galsim/bessel.py index bc809c79..3b5f4950 100644 --- a/jax_galsim/bessel.py +++ b/jax_galsim/bessel.py @@ -1,9 +1,8 @@ +import galsim as _galsim import jax import jax.numpy as jnp from jax._src.numpy.util import _wraps -import galsim as _galsim - # the code here for Si, f, g and _si_small_pade is taken from galsim/src/math/Sinc.cpp @jax.jit diff --git a/jax_galsim/interpolant.py b/jax_galsim/interpolant.py index dec125dd..7affe1fe 100644 --- a/jax_galsim/interpolant.py +++ b/jax_galsim/interpolant.py @@ -5,15 +5,14 @@ shapes, the integrals of the kernels, etc.) are constants. """ import galsim as _galsim -from galsim.errors import GalSimValueError - import jax import jax.numpy as jnp +from galsim.errors import GalSimValueError from jax._src.numpy.util import _wraps from jax.tree_util import register_pytree_node_class -from jax_galsim.gsparams import GSParams from jax_galsim.bessel import si +from jax_galsim.gsparams import GSParams @_wraps(_galsim.interpolant.Interpolant) diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index 9f2e3fc5..3f09db5e 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -1,26 +1,30 @@ +import math import textwrap -import jax from functools import partial -import jax.numpy as jnp -from jax._src.numpy.util import _wraps -from jax.tree_util import register_pytree_node_class -import math import galsim as _galsim +import jax +import jax.numpy as jnp +from galsim.errors import ( + GalSimIncompatibleValuesError, + GalSimRangeError, + GalSimUndefinedBoundsError, + GalSimValueError, +) from galsim.utilities import doc_inherit -from galsim.errors import GalSimRangeError, GalSimValueError, GalSimUndefinedBoundsError -from galsim.errors import GalSimIncompatibleValuesError +from jax._src.numpy.util import _wraps +from jax.tree_util import register_pytree_node_class +from jax_galsim import fits +from jax_galsim.bounds import BoundsI +from jax_galsim.core.draw import draw_by_xValue from jax_galsim.gsobject import GSObject from jax_galsim.gsparams import GSParams from jax_galsim.image import Image -from jax_galsim.position import PositionD from jax_galsim.interpolant import Quintic -from jax_galsim.utilities import convert_interpolant -from jax_galsim.bounds import BoundsI -from jax_galsim import fits -from jax_galsim.core.draw import draw_by_xValue +from jax_galsim.position import PositionD from jax_galsim.transform import Transformation +from jax_galsim.utilities import convert_interpolant from jax_galsim.wcs import PixelScale @@ -203,7 +207,6 @@ def __init__( # FIXME: no BaseDeviate in jax_galsim # from .random import BaseDeviate - # If the "image" is not actually an image, try to read the image as a file. if isinstance(image, str): image = fits.read(image, hdu=hdu) diff --git a/tests/conftest.py b/tests/conftest.py index 3e2ae681..7cf4fd75 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,6 @@ +import inspect import os from functools import lru_cache -import inspect import pytest import yaml diff --git a/tests/jax/galsim/test_image_jax.py b/tests/jax/galsim/test_image_jax.py index 0615c748..945d2da2 100644 --- a/tests/jax/galsim/test_image_jax.py +++ b/tests/jax/galsim/test_image_jax.py @@ -44,6 +44,7 @@ """ from __future__ import print_function + import os import sys from unicodedata import decimal @@ -51,11 +52,10 @@ sys.path.insert( 1, os.path.abspath(os.path.join(os.path.dirname(__file__), "../GalSim/tests")) ) -import numpy as np - import galsim -from galsim_test_helpers import * +import numpy as np from galsim._pyfits import pyfits +from galsim_test_helpers import * # Setup info for tests, not likely to change ntypes = 8 # Note: Most tests below only run through the first 8 types. diff --git a/tests/jax/galsim/test_interpolant_jax.py b/tests/jax/galsim/test_interpolant_jax.py index 09f88794..07cbe207 100644 --- a/tests/jax/galsim/test_interpolant_jax.py +++ b/tests/jax/galsim/test_interpolant_jax.py @@ -3,15 +3,17 @@ Much of the code is copied out of the galsim test suite. """ -import jax +import pickle import time + import galsim as ref_galsim +import jax import numpy as np -import jax_galsim as galsim -from galsim_test_helpers import timer, assert_raises -from scipy.special import sici -import pickle import pytest +from galsim_test_helpers import assert_raises, timer +from scipy.special import sici + +import jax_galsim as galsim def do_pickle(obj1): diff --git a/tests/jax/galsim/test_shear_position_jax.py b/tests/jax/galsim/test_shear_position_jax.py index d8265406..1ba07dab 100644 --- a/tests/jax/galsim/test_shear_position_jax.py +++ b/tests/jax/galsim/test_shear_position_jax.py @@ -1,7 +1,6 @@ -import numpy as np - import galsim -from galsim_test_helpers import timer, assert_raises +import numpy as np +from galsim_test_helpers import assert_raises, timer @timer diff --git a/tests/jax/test_interpolatedimage_utils.py b/tests/jax/test_interpolatedimage_utils.py index a584732f..f91eac19 100644 --- a/tests/jax/test_interpolatedimage_utils.py +++ b/tests/jax/test_interpolatedimage_utils.py @@ -1,38 +1,42 @@ import galsim as _galsim -import jax_galsim -import numpy as np import jax.numpy as jnp -from jax_galsim.interpolatedimage import _draw_with_interpolant_xval -from jax_galsim.interpolant import ( - Nearest, - # SincInterpolant, - Linear, +import numpy as np +import pytest + +import jax_galsim +from jax_galsim.interpolant import ( # SincInterpolant, Cubic, - Quintic, Lanczos, + Linear, + Nearest, + Quintic, ) - -import pytest +from jax_galsim.interpolatedimage import _draw_with_interpolant_xval -@pytest.mark.parametrize("interp", [ - Nearest(), - Linear(), - # this is really slow right now and I am not sure why will fix later - # SincInterpolant(), - Linear(), - Cubic(), - Quintic(), - Lanczos(3, conserve_dc=False), - Lanczos(5, conserve_dc=True), -]) +@pytest.mark.parametrize( + "interp", + [ + Nearest(), + Linear(), + # this is really slow right now and I am not sure why will fix later + # SincInterpolant(), + Linear(), + Cubic(), + Quintic(), + Lanczos(3, conserve_dc=False), + Lanczos(5, conserve_dc=True), + ], +) def test_interpolatedimage_utils_draw_with_interpolant_xval(interp): - zp = jnp.array([ - [0.01, 0.08, 0.07, 0.02], - [0.13, 0.38, 0.52, 0.06], - [0.09, 0.41, 0.44, 0.09], - [0.04, 0.11, 0.10, 0.01] - ]) + zp = jnp.array( + [ + [0.01, 0.08, 0.07, 0.02], + [0.13, 0.38, 0.52, 0.06], + [0.09, 0.41, 0.44, 0.09], + [0.04, 0.11, 0.10, 0.01], + ] + ) for xmin in [-3, 0, 2]: for ymin in [-5, 0, 1]: for x in range(4): @@ -41,19 +45,24 @@ def test_interpolatedimage_utils_draw_with_interpolant_xval(interp): _draw_with_interpolant_xval( jnp.array([x + xmin], dtype=float), jnp.array([y + ymin], dtype=float), - xmin, ymin, zp, interp, + xmin, + ymin, + zp, + interp, ), zp[y, x], ) def test_interpolatedimage_utils_stepk_maxk(): - ref_array = np.array([ - [0.01, 0.08, 0.07, 0.02], - [0.13, 0.38, 0.52, 0.06], - [0.09, 0.41, 0.44, 0.09], - [0.04, 0.11, 0.10, 0.01] - ]) + ref_array = np.array( + [ + [0.01, 0.08, 0.07, 0.02], + [0.13, 0.38, 0.52, 0.06], + [0.09, 0.41, 0.44, 0.09], + [0.04, 0.11, 0.10, 0.01], + ] + ) test_scale = 2.0 gimage_in = _galsim.Image(ref_array) jgimage_in = jax_galsim.Image(ref_array) diff --git a/tests/jax/test_temporary_tests.py b/tests/jax/test_temporary_tests.py index 3c55af1f..79c24f14 100644 --- a/tests/jax/test_temporary_tests.py +++ b/tests/jax/test_temporary_tests.py @@ -1,9 +1,9 @@ -import numpy as np import galsim as ref_galsim -import jax_galsim - +import numpy as np import pytest +import jax_galsim + def test_convolve_temp(): """Validates convolutions against reference GalSim @@ -100,24 +100,24 @@ def func(galsim): def test_pickling_eval_repr(obj1): """This test is here until we run all of the galsim tests which cover this one.""" # test copied from galsim - from numbers import Integral, Real, Complex # noqa: F401 - import pickle import copy + import pickle + from collections.abc import Hashable + from numbers import Complex, Integral, Real # noqa: F401 # In case the repr uses these: from numpy import ( # noqa: F401 array, - uint16, - uint32, - int16, - int32, - float32, - float64, complex64, complex128, + float32, + float64, + int16, + int32, ndarray, + uint16, + uint32, ) - from collections.abc import Hashable def func(x): return x From cad9dc3df0d0482b9e9f7ad97d681bd7cf732802 Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 27 Sep 2023 05:49:47 -0500 Subject: [PATCH 11/67] WIP what exists so far --- jax_galsim/__init__.py | 14 +- jax_galsim/core/utils.py | 11 ++ jax_galsim/exponential.py | 24 +-- jax_galsim/image.py | 9 +- jax_galsim/interpolant.py | 53 ++---- jax_galsim/interpolatedimage.py | 112 +++++++++++-- jax_galsim/moffat.py | 64 ++++---- jax_galsim/transform.py | 2 +- tests/galsim_tests_config.yaml | 1 - tests/jax/test_interpolatedimage_utils.py | 187 +++++++++++++++++++++- tests/jax/test_temporary_tests.py | 3 + 11 files changed, 369 insertions(+), 111 deletions(-) diff --git a/jax_galsim/__init__.py b/jax_galsim/__init__.py index 1f7f9fc0..c361637c 100644 --- a/jax_galsim/__init__.py +++ b/jax_galsim/__init__.py @@ -21,9 +21,6 @@ GalSimWarning, ) -# Bessel -from .core.bessel import j0 - # Basic building blocks from .bounds import Bounds, BoundsD, BoundsI from .gsparams import GSParams @@ -47,20 +44,14 @@ from .gaussian import Gaussian from .box import Box, Pixel from .gsobject import GSObject - - -# Integration -from .core.integrate import ClenshawCurtisQuad, quad_integral - -# Interpolation from .moffat import Moffat - from .sum import Add, Sum from .transform import Transform, Transformation from .convolve import Convolve, Convolution # WCS from .wcs import ( + BaseWCS, AffineTransform, JacobianWCS, OffsetWCS, @@ -89,3 +80,6 @@ # packages kept separate from . import bessel from . import fits + +# this one is specific to jax_galsim +from . import core diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index 6fd35532..0ac0270d 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -25,3 +25,14 @@ def cast_scalar_to_int(x): return int(x) except TypeError: return x + + +def ensure_hashable(v): + """Ensure that the input is hashable. If it is a jax array, convert it to a tuple or float.""" + if isinstance(v, jax.Array): + if len(v.shape) > 0: + return tuple(v.tolist()) + else: + return v.item() + else: + return v diff --git a/jax_galsim/exponential.py b/jax_galsim/exponential.py index 99ebb714..f4fe054a 100644 --- a/jax_galsim/exponential.py +++ b/jax_galsim/exponential.py @@ -4,6 +4,7 @@ from jax.tree_util import register_pytree_node_class from jax_galsim.core.draw import draw_by_kValue, draw_by_xValue +from jax_galsim.core.utils import ensure_hashable from jax_galsim.gsobject import GSObject from jax_galsim.gsparams import GSParams @@ -38,7 +39,9 @@ def __init__( ) else: super().__init__( - half_light_radius=half_light_radius, flux=flux, gsparams=gsparams + scale_radius=half_light_radius / Exponential._hlr_factor, + flux=flux, + gsparams=gsparams, ) elif scale_radius is None: @@ -53,10 +56,7 @@ def __init__( @property def scale_radius(self): """The scale radius of the profile.""" - if "half_light_radius" in self.params: - return self.params["half_light_radius"] / Exponential._hlr_factor - else: - return self.params["scale_radius"] + return self.params["scale_radius"] @property def _r0(self): @@ -73,13 +73,17 @@ def _norm(self): @property def half_light_radius(self): """The half-light radius of the profile.""" - if "half_light_radius" in self.params: - return self.params["half_light_radius"] - else: - return self.params["scale_radius"] * Exponential._hlr_factor + return self.params["scale_radius"] * Exponential._hlr_factor def __hash__(self): - return hash(("galsim.Exponential", self.scale_radius, self.flux, self.gsparams)) + return hash( + ( + "galsim.Exponential", + ensure_hashable(self.scale_radius), + ensure_hashable(self.flux), + self.gsparams, + ) + ) def __repr__(self): return "galsim.Exponential(scale_radius=%r, flux=%r, gsparams=%r)" % ( diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 083dc422..f27a5752 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -679,7 +679,7 @@ def calculate_fft(self): dk = jnp.pi / (No2 * dx) out = Image(BoundsI(0, No2, -No2, No2 - 1), dtype=np.complex128, scale=dk) - out._array = jnp.fft.rfft2(ximage.array) + out._array = jnp.fft.fftshift(jnp.fft.rfft2(ximage.array), axes=0) out *= dx * dx out.setOrigin(0, -No2) @@ -706,7 +706,10 @@ def calculate_inverse_fft(self): self.bounds, ) - No2 = jnp.maximum(self.bounds.xmax, -self.bounds.ymin, self.bounds.ymax) + No2 = jnp.maximum( + jnp.maximum(self.bounds.xmax, -self.bounds.ymin), + self.bounds.ymax, + ) target_bounds = BoundsI(0, No2, -No2, No2 - 1) if self.bounds == target_bounds: @@ -728,7 +731,7 @@ def calculate_inverse_fft(self): # For the inverse, we need a bit of extra space for the fft. out_extra = Image(BoundsI(-No2, No2 + 1, -No2, No2 - 1), dtype=float, scale=dx) - out_extra._array = jnp.fft.irfft2(kimage.array) + out_extra._array = jnp.fft.irfft2(jnp.fft.ifftshift(kimage.array, axes=0)) # Now cut off the bit we don't need. out = out_extra.subImage(BoundsI(-No2, No2 - 1, -No2, No2 - 1)) out *= (dk * No2 / jnp.pi) ** 2 diff --git a/jax_galsim/interpolant.py b/jax_galsim/interpolant.py index 7affe1fe..a0a8b4df 100644 --- a/jax_galsim/interpolant.py +++ b/jax_galsim/interpolant.py @@ -157,8 +157,11 @@ def xval(self, x): an array. """ if jnp.ndim(x) > 1: - raise GalSimValueError("kval only takes scalar or 1D array values", x) + raise GalSimValueError("xval only takes scalar or 1D array values", x) + return self._xval_noraise(x) + + def _xval_noraise(self, x): return self.__class__._xval(x) def kval(self, k): @@ -175,6 +178,9 @@ def kval(self, k): if jnp.ndim(k) > 1: raise GalSimValueError("kval only takes scalar or 1D array values", k) + return self._kval_noraise(k) + + def _kval_noraise(self, k): return self.__class__._uval(k / 2.0 / jnp.pi) def unit_integrals(self, max_len=None): @@ -266,20 +272,7 @@ def __init__(self, tol=None, gsparams=None): gsparams = GSParams(kvalue_accuracy=tol) self._gsparams = GSParams.check(gsparams) - def xval(self, x): - """Calculate the value of the interpolant kernel at one or more x values - - Parameters: - x: The value (as a float) or values (as a np.array) at which to compute the - amplitude of the Interpolant kernel. - - Returns: - xval: The value(s) at the x location(s). If x was an array, then this is also - an array. - """ - if jnp.ndim(x) > 1: - raise GalSimValueError("kval only takes scalar or 1D array values", x) - + def _xval_noraise(self, x): return Delta._xval(x, self._gsparams.kvalue_accuracy) @jax.jit @@ -1504,20 +1497,7 @@ def _no_dcval(val, x, n, _K): _K, ) - def xval(self, x): - """Calculate the value of the interpolant kernel at one or more x values - - Parameters: - x: The value (as a float) or values (as a np.array) at which to compute the - amplitude of the Interpolant kernel. - - Returns: - xval: The value(s) at the x location(s). If x was an array, then this is also - an array. - """ - if jnp.ndim(x) > 1: - raise GalSimValueError("kval only takes scalar or 1D array values", x) - + def _xval_noraise(self, x): return Lanczos._xval(x, self._n, self._conserve_dc, self._K_arr) def _raw_uval(u, n): @@ -1574,20 +1554,7 @@ def _no_dcval(retval, u, n, _C): _C, ) - def kval(self, k): - """Calculate the value of the interpolant kernel in Fourier space at one or more k values. - - Parameters: - k: The value (as a float) or values (as a np.array) at which to compute the - amplitude of the Interpolant kernel in Fourier space. - - Returns: - kval: The k-value(s) at the k location(s). If k was an array, then this is also - an array. - """ - if jnp.ndim(k) > 1: - raise GalSimValueError("kval only takes scalar or 1D array values", k) - + def _kval_noraise(self, k): return Lanczos._uval(k / 2.0 / jnp.pi, self._n, self._conserve_dc, self._C_arr) def urange(self): diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index 3f09db5e..2c4e16cb 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -17,7 +17,7 @@ from jax_galsim import fits from jax_galsim.bounds import BoundsI -from jax_galsim.core.draw import draw_by_xValue +from jax_galsim.core.draw import draw_by_kValue, draw_by_xValue from jax_galsim.gsobject import GSObject from jax_galsim.gsparams import GSParams from jax_galsim.image import Image @@ -135,6 +135,15 @@ def image(self): """The underlying `Image` being interpolated.""" return self._original._image + def __hash__(self): + return hash(self._original) + + def __repr__(self): + return repr(self._original) + + def __str__(self): + return str(self._original) + @register_pytree_node_class class InterpolatedImageImpl(GSObject): @@ -578,7 +587,14 @@ def __hash__(self): self._hash = hash( ("galsim.InterpolatedImage", self.x_interpolant, self.k_interpolant) ) - self._hash ^= hash((self.flux, self._stepk, self._maxk, self._pad_factor)) + self._hash ^= hash( + ( + self.flux.item(), + self._stepk.item(), + self._maxk.item(), + self._pad_factor, + ) + ) self._hash ^= hash( (self._xim.bounds, self._image.bounds, self._pad_image.bounds) ) @@ -592,7 +608,7 @@ def __hash__(self): self._hash ^= hash(self._wcs) # Just hash the diagonal. Much faster, and usually is unique enough. # (Let python handle collisions as needed if multiple similar IIs are used as keys.) - self._hash ^= hash(tuple(jnp.diag(self._pad_image.array))) + self._hash ^= hash(tuple(jnp.diag(self._pad_image.array).tolist())) return self._hash def __repr__(self): @@ -699,7 +715,34 @@ def _xValue(self, pos): return vals[0] def _kValue(self, kpos): - raise NotImplementedError("WIP interp - kValue") + # phase factor due to offset + # not we shift by -offset which explains the signs + # in pkx, pky + pkx = kpos.x * 1j * self._offset.x + pky = kpos.y * 1j * self._offset.y + pkx += pky + pfac = jnp.exp(pkx) + + kx = jnp.array([kpos.x / self._kim.scale], dtype=float) + ky = jnp.array([kpos.y / self._kim.scale], dtype=float) + + _uscale = 1.0 / (2.0 * jnp.pi) + _maxk_xint = self._x_interpolant.urange() / _uscale / self._kim.scale + + val = _draw_with_interpolant_kval( + kx, + ky, + self._kim.bounds.ymin, + self._kim.bounds.ymin, + self._kim.array, + self._k_interpolant, + ) + + msk = (jnp.abs(kx) <= _maxk_xint) & (jnp.abs(ky) <= _maxk_xint) + xint_val = self._x_interpolant._kval_noraise( + kx * self._kim.scale + ) * self._x_interpolant._kval_noraise(ky * self._kim.scale) + return jnp.where(msk, val * xint_val * pfac, 0.0)[0] def _shoot(self, photons, rng): raise NotImplementedError("Photon shooting not implemented.") @@ -709,9 +752,8 @@ def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): return draw_by_xValue(self, image, _jac, jnp.asarray(offset), flux_scaling) def _drawKImage(self, image, jac=None): - raise NotImplementedError("WIP interp - drawKImage") - _jac = 0 if jac is None else jac.__array_interface__["data"][0] - self._sbp.drawK(image._image, image.scale, _jac) + _jac = jnp.eye(2) if jac is None else jac + return draw_by_kValue(self, image, _jac) @partial(jax.jit, static_argnums=(5,)) @@ -733,7 +775,7 @@ def _body_1d(i, args): xind = xi + i mskx = (xind >= 0) & (xind < nx) _x = x - (xp + i) - wx = interp.xval(_x) + wx = interp._xval_noraise(_x) w = wx * wy msk = msky & mskx @@ -746,7 +788,7 @@ def _body(i, args): yind = yi + i msk = (yind >= 0) & (yind < ny) _y = y - (yp + i) - wy = interp.xval(_y) + wy = interp._xval_noraise(_y) z = jax.lax.fori_loop( -interp.xrange, interp.xrange + 1, _body_1d, [z, wy, msk, yind, xi, xp, zp] )[0] @@ -761,6 +803,58 @@ def _body(i, args): return z.reshape(orig_shape) +@partial(jax.jit, static_argnums=(5,)) +def _draw_with_interpolant_kval(kx, ky, kxmin, kymin, zp, interp): + orig_shape = kx.shape + kx = kx.ravel() + kxi = jnp.floor(kx - kxmin).astype(jnp.int32) + kxp = kxi + kxmin + nkx_2 = zp.shape[1] - 1 + nkx = nkx_2 * 2 + 1 + + ky = ky.ravel() + kyi = jnp.floor(ky - kymin).astype(jnp.int32) + kyp = kyi + kymin + nky = zp.shape[0] + + def _body_1d(i, args): + z, wky, kyind, kxi, nkx, nkx_2, kxp, zp = args + + kxind = (kxi + i) % nkx + _kx = kx - (kxp + i) + wkx = interp._xval_noraise(_kx) + + val = jnp.where( + kxind < nkx_2, + zp[nky - 1 - kyind, nkx - 1 - kxind + nkx_2].conjugate(), + zp[kyind, kxind - nkx_2], + ) + z += val * wkx * wky + + return [z, wky, kyind, kxi, nkx, nkx_2, kxp, zp] + + def _body(i, args): + z, kxi, kyi, nky, nkx, nkx_2, kxp, kyp, zp = args + kyind = (kyi + i) % nky + _ky = ky - (kyp + i) + wky = interp._xval_noraise(_ky) + z = jax.lax.fori_loop( + -interp.xrange, + interp.xrange + 1, + _body_1d, + [z, wky, kyind, kxi, nkx, nkx_2, kxp, zp], + )[0] + return [z, kxi, kyi, nky, nkx, nkx_2, kxp, kyp, zp] + + z = jax.lax.fori_loop( + -interp.xrange, + interp.xrange + 1, + _body, + [jnp.zeros(kx.shape, dtype=zp.dtype), kxi, kyi, nky, nkx, nkx_2, kxp, kyp, zp], + )[0] + return z.reshape(orig_shape) + + @jax.jit def _flux_frac(a, x, y, cenx, ceny): def _body(d, args): diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index a3a1ae88..27afbf29 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -11,6 +11,7 @@ from jax_galsim.core.bessel import j0 from jax_galsim.core.draw import draw_by_kValue, draw_by_xValue from jax_galsim.core.integrate import ClenshawCurtisQuad, quad_integral +from jax_galsim.core.utils import ensure_hashable from jax_galsim.gsobject import GSObject from jax_galsim.gsparams import GSParams @@ -101,9 +102,15 @@ def __init__( fwhm=fwhm, ) else: + r0 = jax.lax.select( + trunc > 0, + MoffatCalculateSRFromHLR(half_light_radius, trunc, beta), + half_light_radius + / jnp.sqrt(jnp.power(0.5, 1.0 / (1.0 - beta)) - 1.0), + ) super().__init__( beta=beta, - half_light_radius=half_light_radius, + scale_radius=r0, trunc=trunc, flux=flux, gsparams=gsparams, @@ -118,7 +125,11 @@ def __init__( ) else: super().__init__( - beta=beta, fwhm=fwhm, trunc=trunc, flux=flux, gsparams=gsparams + beta=beta, + scale_radius=fwhm / (2.0 * jnp.sqrt(2.0 ** (1.0 / beta) - 1.0)), + trunc=trunc, + flux=flux, + gsparams=gsparams, ) elif scale_radius is None: raise _galsim.GalSimIncompatibleValuesError( @@ -139,29 +150,16 @@ def __init__( @property def beta(self): """The beta parameter of this `Moffat` profile.""" - return self.params["beta"] + return self._params["beta"] @property def trunc(self): """The truncation radius (if any) of this `Moffat` profile.""" - return self.params["trunc"] + return self._params["trunc"] @property def scale_radius(self): - """The scale radius of this `Moffat` profile.""" - if "half_light_radius" in self.params: - hlr = self.params["half_light_radius"] - return jax.lax.select( - self.trunc > 0, - MoffatCalculateSRFromHLR(hlr, self.trunc, self.beta), - hlr / jnp.sqrt(jnp.power(0.5, 1.0 / (1.0 - self.beta)) - 1.0), - ) - elif "fwhm" in self.params: - return self.params["fwhm"] / ( - 2.0 * jnp.sqrt(2.0 ** (1.0 / self.beta) - 1.0) - ) - else: - return self.params["scale_radius"] + return self._params["scale_radius"] @property def _r0(self): @@ -210,20 +208,14 @@ def _fluxFactor(self): @property def half_light_radius(self): """The half-light radius of this `Moffat` profile.""" - if "half_light_radius" in self.params: - return self.params["half_light_radius"] - else: - return self._r0 * jnp.sqrt( - jnp.power(1.0 - 0.5 * self._fluxFactor, 1.0 / (1.0 - self.beta)) - 1.0 - ) + return self._r0 * jnp.sqrt( + jnp.power(1.0 - 0.5 * self._fluxFactor, 1.0 / (1.0 - self.beta)) - 1.0 + ) @property def fwhm(self): """The FWHM of this `Moffat` profle.""" - if "fwhm" in self.params: - return self.params["fwhm"] - else: - return self._r0 * (2.0 * jnp.sqrt(2.0 ** (1.0 / self.beta) - 1.0)) + return self._r0 * (2.0 * jnp.sqrt(2.0 ** (1.0 / self.beta) - 1.0)) @property def _norm(self): @@ -247,10 +239,10 @@ def __hash__(self): return hash( ( "galsim.Moffat", - self.beta, - self.scale_radius, - self.trunc, - self.flux, + ensure_hashable(self.beta), + ensure_hashable(self.scale_radius), + ensure_hashable(self.trunc), + ensure_hashable(self.flux), self.gsparams, ) ) @@ -258,7 +250,13 @@ def __hash__(self): def __repr__(self): return ( "galsim.Moffat(beta=%r, scale_radius=%r, trunc=%r, flux=%r, gsparams=%r)" - % (self.beta, self.scale_radius, self.trunc, self.flux, self.gsparams) + % ( + self.beta, + ensure_hashable(self.scale_radius), + self.trunc, + self.flux, + self.gsparams, + ) ) def __str__(self): diff --git a/jax_galsim/transform.py b/jax_galsim/transform.py index 53055ee8..58b00ab4 100644 --- a/jax_galsim/transform.py +++ b/jax_galsim/transform.py @@ -133,7 +133,7 @@ def __hash__(self): ( "galsim.Transformation", self.original, - tuple(self._jac.ravel()), + tuple(self._jac.ravel().tolist()), self.offset.x, self.offset.y, self.flux_ratio, diff --git a/tests/galsim_tests_config.yaml b/tests/galsim_tests_config.yaml index 5dee45cb..faae387c 100644 --- a/tests/galsim_tests_config.yaml +++ b/tests/galsim_tests_config.yaml @@ -46,7 +46,6 @@ allowed_failures: - "module 'jax_galsim.utilities' has no attribute 'rand_with_replacement'" - "module 'jax_galsim.utilities' has no attribute 'dol_to_lod'" - "module 'jax_galsim.utilities' has no attribute 'nCr'" - - "unhashable type: 'ArrayImpl'" - "'Image' object has no attribute 'bin'" - "has no attribute 'shoot'" - "module 'jax_galsim' has no attribute 'integ'" diff --git a/tests/jax/test_interpolatedimage_utils.py b/tests/jax/test_interpolatedimage_utils.py index f91eac19..537047c2 100644 --- a/tests/jax/test_interpolatedimage_utils.py +++ b/tests/jax/test_interpolatedimage_utils.py @@ -11,7 +11,10 @@ Nearest, Quintic, ) -from jax_galsim.interpolatedimage import _draw_with_interpolant_xval +from jax_galsim.interpolatedimage import ( + _draw_with_interpolant_kval, + _draw_with_interpolant_xval, +) @pytest.mark.parametrize( @@ -54,6 +57,51 @@ def test_interpolatedimage_utils_draw_with_interpolant_xval(interp): ) +@pytest.mark.parametrize( + "interp", + [ + Nearest(), + Linear(), + # this is really slow right now and I am not sure why will fix later + # SincInterpolant(), + Linear(), + Cubic(), + Quintic(), + Lanczos(3, conserve_dc=False), + Lanczos(5, conserve_dc=True), + ], +) +def test_interpolatedimage_utils_draw_with_interpolant_kval(interp): + zp = jnp.array( + [ + [0.01, 0.08, 0.07, 0.02], + [0.13, 0.38, 0.52, 0.06], + [0.09, 0.41, 0.44, 0.09], + [0.04, 0.11, 0.10, 0.01], + ] + ) + kim = jax_galsim.Image(zp, scale=1.0).calculate_fft() + zpherm = jnp.concatenate( + [kim.array[:, 1:][::-1, ::-1].conjugate(), kim.array], axis=1 + ) + nherm = kim.array.shape[0] + minherm = kim.bounds.ymin + + for x in range(nherm): + for y in range(nherm): + np.testing.assert_allclose( + _draw_with_interpolant_kval( + jnp.array([x + minherm], dtype=float), + jnp.array([y + minherm], dtype=float), + minherm, + minherm, + kim.array, + interp, + ), + zpherm[y, x], + ) + + def test_interpolatedimage_utils_stepk_maxk(): ref_array = np.array( [ @@ -71,3 +119,140 @@ def test_interpolatedimage_utils_stepk_maxk(): np.testing.assert_allclose(gii.stepk, jgii.stepk, rtol=0.2, atol=0) np.testing.assert_allclose(gii.maxk, jgii.maxk, rtol=0.2, atol=0) + + +@pytest.mark.parametrize("method", ["xValue", "kValue"]) +def test_interpolated_image_utils_comp_to_galsim(method): + ref_array = np.array( + [ + [0.01, 0.08, 0.07, 0.02, 0.0, 0.0], + [0.13, 0.38, 0.52, 0.06, 0.0, 0.05], + [0.09, 0.41, 0.44, 0.09, 0.0, 0.2], + [0.04, 0.11, 0.10, 0.01, 0.0, 0.5], + [0.04, 0.11, 0.10, 0.01, 0.0, 0.3], + [0.04, 0.11, 0.10, 0.01, 0.0, 0.1], + ] + ) + gimage_in = _galsim.Image(ref_array, scale=1) + jgimage_in = jax_galsim.Image(ref_array, scale=1) + + for wcs in [ + _galsim.PixelScale(1.0), + _galsim.JacobianWCS(2.1, 0.3, -0.4, 2.3), + _galsim.AffineTransform(-0.3, 2.1, 1.8, 0.1, _galsim.PositionD(0.3, -0.4)), + ]: + gii = _galsim.InterpolatedImage(gimage_in, wcs=wcs) + jgii = jax_galsim.InterpolatedImage( + jgimage_in, wcs=jax_galsim.BaseWCS.from_galsim(wcs) + ) + + np.testing.assert_allclose(gii.stepk, jgii.stepk, rtol=0.5, atol=0) + np.testing.assert_allclose(gii.maxk, jgii.maxk, rtol=0.5, atol=0) + kxvals = [(0, 0), (1, 1), (1, -2), (-3, 4)] + for x, y in kxvals: + if method == "kValue": + dk = jgii._original._kim.scale + np.testing.assert_allclose( + gii.kValue(x * dk, y * dk), + jgii.kValue(x * dk, y * dk), + err_msg=f"kValue mismatch: wcs={wcs}, x={x}, y={y} kim={jgii._original._kim(x, y)}", + ) + else: + dx = jnp.sqrt(jgii._original._wcs.pixelArea()) + np.testing.assert_allclose( + gii.xValue(x * dx, y * dx), + jgii.xValue(x * dx, y * dx), + err_msg=f"xValue mismatch: wcs={wcs}, x={x}, y={y}", + ) + + +def _compute_fft_with_numpy_jax_galsim(im): + import numpy as np + + from jax_galsim import BoundsI, Image + + No2 = max(-im.bounds.xmin, im.bounds.xmax + 1, -im.bounds.ymin, im.bounds.ymax + 1) + + full_bounds = BoundsI(-No2, No2 - 1, -No2, No2 - 1) + if im.bounds == full_bounds: + # Then the image is already in the shape we need. + ximage = im + else: + # Then we pad out with zeros + ximage = Image(full_bounds, dtype=im.dtype, init_value=0) + ximage[im.bounds] = im[im.bounds] + + dx = im.scale + # dk = 2pi / (N dk) + dk = np.pi / (No2 * dx) + + out = Image(BoundsI(0, No2, -No2, No2 - 1), dtype=np.complex128, scale=dk) + out._array = np.fft.fftshift(np.fft.rfft2(ximage.array), axes=0) + out *= dx * dx + out.setOrigin(0, -No2) + return out + + +@pytest.mark.parametrize("n", [5, 4]) +def test_jax_galsim_fft_vs_numpy(n): + import numpy as np + + import jax_galsim as galsim + + rng = np.random.RandomState(42) + arr = rng.normal(size=(n, n)) + im = galsim.Image(arr, scale=1) + kim = im.calculate_fft() + xkim = kim.calculate_inverse_fft() + + np.testing.assert_allclose(im.array, xkim[im.bounds].array) + + np_kim = _compute_fft_with_numpy_jax_galsim(im) + print("ratio real:\n", np_kim.array.real / kim.array.real) + print("ratio imag:\n", np_kim.array.imag / kim.array.imag) + np.testing.assert_allclose(kim.array, np_kim.array) + + +def _compute_fft_with_numpy_galsim(im): + import numpy as np + from galsim import BoundsI, Image + + No2 = max(-im.bounds.xmin, im.bounds.xmax + 1, -im.bounds.ymin, im.bounds.ymax + 1) + + full_bounds = BoundsI(-No2, No2 - 1, -No2, No2 - 1) + if im.bounds == full_bounds: + # Then the image is already in the shape we need. + ximage = im + else: + # Then we pad out with zeros + ximage = Image(full_bounds, dtype=im.dtype, init_value=0) + ximage[im.bounds] = im[im.bounds] + + dx = im.scale + # dk = 2pi / (N dk) + dk = np.pi / (No2 * dx) + + out = Image(BoundsI(0, No2, -No2, No2 - 1), dtype=np.complex128, scale=dk) + out._array = np.fft.fftshift(np.fft.rfft2(ximage.array), axes=0) + out *= dx * dx + out.setOrigin(0, -No2) + return out + + +@pytest.mark.parametrize("n", [5, 4]) +def test_galsim_fft_vs_numpy(n): + import galsim + import numpy as np + + rng = np.random.RandomState(42) + arr = rng.normal(size=(n, n)) + im = galsim.Image(arr, scale=1) + kim = im.calculate_fft() + xkim = kim.calculate_inverse_fft() + + np.testing.assert_allclose(im.array, xkim[im.bounds].array) + + np_kim = _compute_fft_with_numpy_galsim(im) + print("ratio real:\n", np_kim.array.real / kim.array.real) + print("ratio imag:\n", np_kim.array.imag / kim.array.imag) + np.testing.assert_allclose(kim.array, np_kim.array) diff --git a/tests/jax/test_temporary_tests.py b/tests/jax/test_temporary_tests.py index 79c24f14..252b4803 100644 --- a/tests/jax/test_temporary_tests.py +++ b/tests/jax/test_temporary_tests.py @@ -82,6 +82,9 @@ def func(galsim): jax_galsim.Gaussian(fwhm=1.0), jax_galsim.Pixel(scale=1.0), jax_galsim.Exponential(scale_radius=1.0), + jax_galsim.Exponential(half_light_radius=1.0), + jax_galsim.Moffat(fwhm=1.0, beta=3), + jax_galsim.Moffat(scale_radius=1.0, beta=3), jax_galsim.Shear(g1=0.1, g2=0.2), jax_galsim.PositionD(x=0.1, y=0.2), jax_galsim.BoundsI(xmin=0, xmax=1, ymin=0, ymax=1), From 37c8f257169164c69be2c9c8cff58bddbb734c0d Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 27 Sep 2023 12:19:20 -0500 Subject: [PATCH 12/67] WIP it works --- jax_galsim/image.py | 8 +- jax_galsim/interpolatedimage.py | 4 +- tests/jax/test_interpolatedimage_utils.py | 137 +++++++++++----------- 3 files changed, 74 insertions(+), 75 deletions(-) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index f27a5752..6b5f926e 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -679,7 +679,9 @@ def calculate_fft(self): dk = jnp.pi / (No2 * dx) out = Image(BoundsI(0, No2, -No2, No2 - 1), dtype=np.complex128, scale=dk) - out._array = jnp.fft.fftshift(jnp.fft.rfft2(ximage.array), axes=0) + out._array = jnp.fft.fftshift( + jnp.fft.rfft2(jnp.fft.fftshift(ximage.array)), axes=0 + ) out *= dx * dx out.setOrigin(0, -No2) @@ -731,7 +733,9 @@ def calculate_inverse_fft(self): # For the inverse, we need a bit of extra space for the fft. out_extra = Image(BoundsI(-No2, No2 + 1, -No2, No2 - 1), dtype=float, scale=dx) - out_extra._array = jnp.fft.irfft2(jnp.fft.ifftshift(kimage.array, axes=0)) + out_extra._array = jnp.fft.fftshift( + jnp.fft.irfft2(jnp.fft.fftshift(kimage.array, axes=0)) + ) # Now cut off the bit we don't need. out = out_extra.subImage(BoundsI(-No2, No2 - 1, -No2, No2 - 1)) out *= (dk * No2 / jnp.pi) ** 2 diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index 2c4e16cb..449a0289 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -810,7 +810,7 @@ def _draw_with_interpolant_kval(kx, ky, kxmin, kymin, zp, interp): kxi = jnp.floor(kx - kxmin).astype(jnp.int32) kxp = kxi + kxmin nkx_2 = zp.shape[1] - 1 - nkx = nkx_2 * 2 + 1 + nkx = nkx_2 * 2 ky = ky.ravel() kyi = jnp.floor(ky - kymin).astype(jnp.int32) @@ -826,7 +826,7 @@ def _body_1d(i, args): val = jnp.where( kxind < nkx_2, - zp[nky - 1 - kyind, nkx - 1 - kxind + nkx_2].conjugate(), + zp[(nky - kyind) % nky, nkx - kxind - nkx_2].conjugate(), zp[kyind, kxind - nkx_2], ) z += val * wkx * wky diff --git a/tests/jax/test_interpolatedimage_utils.py b/tests/jax/test_interpolatedimage_utils.py index 537047c2..b0547bf2 100644 --- a/tests/jax/test_interpolatedimage_utils.py +++ b/tests/jax/test_interpolatedimage_utils.py @@ -81,12 +81,22 @@ def test_interpolatedimage_utils_draw_with_interpolant_kval(interp): ] ) kim = jax_galsim.Image(zp, scale=1.0).calculate_fft() - zpherm = jnp.concatenate( - [kim.array[:, 1:][::-1, ::-1].conjugate(), kim.array], axis=1 - ) nherm = kim.array.shape[0] minherm = kim.bounds.ymin - + kimherm = jax_galsim.Image( + jnp.zeros((kim.array.shape[0], kim.array.shape[0]), dtype=complex), + xmin=minherm, + ymin=minherm, + ) + for y in range(kimherm.bounds.ymin, kimherm.bounds.ymax + 1): + for x in range(kimherm.bounds.xmin, kimherm.bounds.xmax + 1): + if x >= 0: + kimherm[x, y] = kim[x, y] + else: + if y == minherm: + kimherm[x, y] = kim[-x, y].conj() + else: + kimherm[x, y] = kim[-x, -y].conj() for x in range(nherm): for y in range(nherm): np.testing.assert_allclose( @@ -98,7 +108,7 @@ def test_interpolatedimage_utils_draw_with_interpolant_kval(interp): kim.array, interp, ), - zpherm[y, x], + kimherm(x + minherm, y + minherm), ) @@ -121,23 +131,37 @@ def test_interpolatedimage_utils_stepk_maxk(): np.testing.assert_allclose(gii.maxk, jgii.maxk, rtol=0.2, atol=0) +@pytest.mark.parametrize( + "ref_array", + [ + np.array( + [ + [0.01, 0.08, 0.07, 0.02, 0.0, 0.0], + [0.13, 0.38, 0.52, 0.06, 0.0, 0.05], + [0.09, 0.41, 0.44, 0.09, 0.0, 0.2], + [0.04, 0.11, 0.10, 0.01, 0.0, 0.5], + [0.04, 0.11, 0.10, 0.01, 0.0, 0.3], + [0.04, 0.11, 0.10, 0.01, 0.0, 0.1], + ] + ), + np.array( + [ + [0.01, 0.08, 0.07, 0.02, 0.0], + [0.13, 0.38, 0.52, 0.06, 0.0], + [0.09, 0.41, 0.44, 0.09, 0.0], + [0.04, 0.11, 0.10, 0.01, 0.0], + [0.04, 0.11, 0.10, 0.01, 0.0], + ] + ), + ], +) @pytest.mark.parametrize("method", ["xValue", "kValue"]) -def test_interpolated_image_utils_comp_to_galsim(method): - ref_array = np.array( - [ - [0.01, 0.08, 0.07, 0.02, 0.0, 0.0], - [0.13, 0.38, 0.52, 0.06, 0.0, 0.05], - [0.09, 0.41, 0.44, 0.09, 0.0, 0.2], - [0.04, 0.11, 0.10, 0.01, 0.0, 0.5], - [0.04, 0.11, 0.10, 0.01, 0.0, 0.3], - [0.04, 0.11, 0.10, 0.01, 0.0, 0.1], - ] - ) +def test_interpolatedimage_utils_comp_to_galsim(method, ref_array): gimage_in = _galsim.Image(ref_array, scale=1) jgimage_in = jax_galsim.Image(ref_array, scale=1) for wcs in [ - _galsim.PixelScale(1.0), + _galsim.PixelScale(2.0), _galsim.JacobianWCS(2.1, 0.3, -0.4, 2.3), _galsim.AffineTransform(-0.3, 2.1, 1.8, 0.1, _galsim.PositionD(0.3, -0.4)), ]: @@ -148,14 +172,27 @@ def test_interpolated_image_utils_comp_to_galsim(method): np.testing.assert_allclose(gii.stepk, jgii.stepk, rtol=0.5, atol=0) np.testing.assert_allclose(gii.maxk, jgii.maxk, rtol=0.5, atol=0) - kxvals = [(0, 0), (1, 1), (1, -2), (-3, 4)] + kxvals = [ + (0, 0), + (-5, -5), + (-10, 10), + (1, 1), + (1, -2), + (-1, 0), + (0, -1), + (-1, -1), + (-2, 2), + (-5, 0), + (3, -4), + (-3, 4), + ] for x, y in kxvals: if method == "kValue": dk = jgii._original._kim.scale np.testing.assert_allclose( gii.kValue(x * dk, y * dk), jgii.kValue(x * dk, y * dk), - err_msg=f"kValue mismatch: wcs={wcs}, x={x}, y={y} kim={jgii._original._kim(x, y)}", + err_msg=f"kValue mismatch: wcs={wcs}, x={x}, y={y}", ) else: dx = jnp.sqrt(jgii._original._wcs.pixelArea()) @@ -187,72 +224,30 @@ def _compute_fft_with_numpy_jax_galsim(im): dk = np.pi / (No2 * dx) out = Image(BoundsI(0, No2, -No2, No2 - 1), dtype=np.complex128, scale=dk) - out._array = np.fft.fftshift(np.fft.rfft2(ximage.array), axes=0) + out._array = np.fft.fftshift(np.fft.rfft2(np.fft.fftshift(ximage.array)), axes=0) out *= dx * dx out.setOrigin(0, -No2) return out @pytest.mark.parametrize("n", [5, 4]) -def test_jax_galsim_fft_vs_numpy(n): - import numpy as np - - import jax_galsim as galsim - +def test_interpolatedimage_utils_jax_galsim_fft_vs_galsim_fft(n): rng = np.random.RandomState(42) arr = rng.normal(size=(n, n)) - im = galsim.Image(arr, scale=1) + im = jax_galsim.Image(arr, scale=1) kim = im.calculate_fft() xkim = kim.calculate_inverse_fft() - np.testing.assert_allclose(im.array, xkim[im.bounds].array) np_kim = _compute_fft_with_numpy_jax_galsim(im) - print("ratio real:\n", np_kim.array.real / kim.array.real) - print("ratio imag:\n", np_kim.array.imag / kim.array.imag) np.testing.assert_allclose(kim.array, np_kim.array) - -def _compute_fft_with_numpy_galsim(im): - import numpy as np - from galsim import BoundsI, Image - - No2 = max(-im.bounds.xmin, im.bounds.xmax + 1, -im.bounds.ymin, im.bounds.ymax + 1) - - full_bounds = BoundsI(-No2, No2 - 1, -No2, No2 - 1) - if im.bounds == full_bounds: - # Then the image is already in the shape we need. - ximage = im - else: - # Then we pad out with zeros - ximage = Image(full_bounds, dtype=im.dtype, init_value=0) - ximage[im.bounds] = im[im.bounds] - - dx = im.scale - # dk = 2pi / (N dk) - dk = np.pi / (No2 * dx) - - out = Image(BoundsI(0, No2, -No2, No2 - 1), dtype=np.complex128, scale=dk) - out._array = np.fft.fftshift(np.fft.rfft2(ximage.array), axes=0) - out *= dx * dx - out.setOrigin(0, -No2) - return out - - -@pytest.mark.parametrize("n", [5, 4]) -def test_galsim_fft_vs_numpy(n): - import galsim - import numpy as np - rng = np.random.RandomState(42) arr = rng.normal(size=(n, n)) - im = galsim.Image(arr, scale=1) - kim = im.calculate_fft() - xkim = kim.calculate_inverse_fft() - - np.testing.assert_allclose(im.array, xkim[im.bounds].array) - - np_kim = _compute_fft_with_numpy_galsim(im) - print("ratio real:\n", np_kim.array.real / kim.array.real) - print("ratio imag:\n", np_kim.array.imag / kim.array.imag) - np.testing.assert_allclose(kim.array, np_kim.array) + gim = jax_galsim.Image(arr, scale=1) + gkim = gim.calculate_fft() + gxkim = gkim.calculate_inverse_fft() + np.testing.assert_allclose(gim.array, gxkim[gim.bounds].array) + np.testing.assert_allclose(gim.array, im.array) + np.testing.assert_allclose(gkim.array, kim.array) + np.testing.assert_allclose(gxkim.array, xkim.array) From 84384339e6e8adb7bfa0b5ac48d5203d77d4529e Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 29 Sep 2023 16:21:57 -0500 Subject: [PATCH 13/67] WIP it works maybe --- jax_galsim/__init__.py | 2 +- jax_galsim/core/utils.py | 53 +++++++++++ jax_galsim/gsobject.py | 5 +- jax_galsim/interpolant.py | 3 +- jax_galsim/interpolatedimage.py | 153 ++++++++++++++++++++++++++++---- jax_galsim/transform.py | 57 ++++++------ 6 files changed, 225 insertions(+), 48 deletions(-) diff --git a/jax_galsim/__init__.py b/jax_galsim/__init__.py index c361637c..dba77258 100644 --- a/jax_galsim/__init__.py +++ b/jax_galsim/__init__.py @@ -75,7 +75,7 @@ Quintic, Lanczos, ) -from .interpolatedimage import InterpolatedImage +from .interpolatedimage import InterpolatedImage, _InterpolatedImage # packages kept separate from . import bessel diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index 0ac0270d..7c3752f2 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -1,4 +1,5 @@ import jax +import jax.numpy as jnp def cast_scalar_to_float(x): @@ -36,3 +37,55 @@ def ensure_hashable(v): return v.item() else: return v + + +def is_equal_with_arrays(x, y): + """Return True if the data is equal, False otherwise. Handles jax.Array types.""" + if isinstance(x, list): + if isinstance(y, list) and len(x) == len(y): + for vx, vy in zip(x, y): + if not is_equal_with_arrays(vx, vy): + print(vx, vy) + return False + return True + else: + return False + elif isinstance(x, tuple): + if isinstance(y, tuple) and len(x) == len(y): + for vx, vy in zip(x, y): + if not is_equal_with_arrays(vx, vy): + print(vx, vy) + return False + return True + else: + return False + elif isinstance(x, set): + if isinstance(y, set) and len(x) == len(y): + for vx, vy in zip(x, y): + if not is_equal_with_arrays(vx, vy): + print(vx, vy) + return False + return True + else: + return False + elif isinstance(x, dict): + if isinstance(y, dict) and len(x) == len(y): + for kx, vx in x.items(): + if kx not in y or (not is_equal_with_arrays(vx, y[kx])): + print(kx, vx, y[kx]) + return False + return True + else: + return False + elif isinstance(x, jax.Array) and jnp.ndim(x) > 0: + if isinstance(y, jax.Array) and y.shape == x.shape: + return jnp.array_equal(x, y) + else: + print(x, y) + return False + elif (isinstance(x, jax.Array) and jnp.ndim(x) == 0) or ( + isinstance(y, jax.Array) and jnp.ndim(y) == 0 + ): + return jnp.array_equal(x, y) + else: + return x == y diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index a7b63531..9d93ab25 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -4,6 +4,7 @@ import numpy as np from jax._src.numpy.util import _wraps +from jax_galsim.core.utils import is_equal_with_arrays from jax_galsim.gsparams import GSParams from jax_galsim.position import Position, PositionD, PositionI from jax_galsim.utilities import parse_pos_args @@ -13,7 +14,7 @@ class GSObject: def __init__(self, *, gsparams=None, **params): self._params = params # Dictionary containing all traced parameters - self._gsparams = gsparams # Non-traced static parameters + self._gsparams = GSParams.check(gsparams) # Non-traced static parameters @property def flux(self): @@ -178,7 +179,7 @@ def __neg__(self): def __eq__(self, other): return (self is other) or ( (type(other) is self.__class__) - and (self.tree_flatten() == other.tree_flatten()) + and is_equal_with_arrays(self.tree_flatten(), other.tree_flatten()) ) @_wraps(_galsim.GSObject.xValue) diff --git a/jax_galsim/interpolant.py b/jax_galsim/interpolant.py index a0a8b4df..de892073 100644 --- a/jax_galsim/interpolant.py +++ b/jax_galsim/interpolant.py @@ -12,6 +12,7 @@ from jax.tree_util import register_pytree_node_class from jax_galsim.bessel import si +from jax_galsim.core.utils import is_equal_with_arrays from jax_galsim.gsparams import GSParams @@ -136,7 +137,7 @@ def _i(self): def __eq__(self, other): return (self is other) or ( type(other) is self.__class__ - and self.tree_flatten() == other.tree_flatten() + and is_equal_with_arrays(self.tree_flatten(), other.tree_flatten()) ) def __ne__(self, other): diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index 449a0289..e38c57db 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -1,3 +1,4 @@ +import copy import math import textwrap from functools import partial @@ -87,8 +88,39 @@ def __init__( gsparams=None, _force_stepk=0.0, _force_maxk=0.0, + _recenter_image=True, hdu=None, ): + self._jax_children = ( + image, + dict( + scale=scale, + wcs=wcs, + flux=flux, + pad_image=pad_image, + offset=offset, + ), + ) + self._jax_aux_data = dict( + x_interpolant=x_interpolant, + k_interpolant=k_interpolant, + normalization=normalization, + pad_factor=pad_factor, + noise_pad_size=noise_pad_size, + noise_pad=noise_pad, + rng=rng, + calculate_stepk=calculate_stepk, + calculate_maxk=calculate_maxk, + use_cache=use_cache, + use_true_center=use_true_center, + depixelize=depixelize, + gsparams=GSParams.check(gsparams), + _force_stepk=_force_stepk, + _force_maxk=_force_maxk, + _recenter_image=_recenter_image, + hdu=hdu, + ) + obj = InterpolatedImageImpl( image, x_interpolant=x_interpolant, @@ -108,18 +140,41 @@ def __init__( use_true_center=use_true_center, depixelize=depixelize, offset=offset, - gsparams=gsparams, + gsparams=GSParams.check(gsparams), _force_stepk=_force_stepk, _force_maxk=_force_maxk, hdu=hdu, + _recenter_image=_recenter_image, ) super().__init__( obj, jac=obj._jac_arr, flux_ratio=obj._flux_ratio / obj._wcs.pixelArea(), offset=PositionD(0.0, 0.0), + gsparams=GSParams.check(gsparams), + propagate_gsparams=True, ) + # the galsim tests use this internal attribute + # so we add it here + @property + def _xim(self): + return self._original._xim + + @property + def _maxk(self): + if self._jax_aux_data["_force_maxk"] > 0: + return self._jax_aux_data["_force_maxk"] + else: + return super()._maxk + + @property + def _stepk(self): + if self._jax_aux_data["_force_stepk"] > 0: + return self._jax_aux_data["_force_stepk"] + else: + return super()._stepk + @property def x_interpolant(self): """The real-space `Interpolant` for this profile.""" @@ -144,6 +199,32 @@ def __repr__(self): def __str__(self): return str(self._original) + def tree_flatten(self): + """This function flattens the InterpolatedImage into a list of children + nodes that will be traced by JAX and auxiliary static data.""" + return (self._jax_children, copy.copy(self._jax_aux_data)) + + @classmethod + def tree_unflatten(cls, aux_data, children): + """Recreates an instance of the class from flatten representation""" + val = {} + val.update(aux_data) + val.update(children[1]) + return cls(children[0], **val) + + @doc_inherit + def withGSParams(self, gsparams=None, **kwargs): + if gsparams == self.gsparams: + return self + # Checking gsparams + gsparams = GSParams.check(gsparams, self.gsparams, **kwargs) + # Flattening the representation to instantiate a clean new object + children, aux_data = self.tree_flatten() + aux_data["gsparams"] = gsparams + ret = self.tree_unflatten(aux_data, children) + + return ret + @register_pytree_node_class class InterpolatedImageImpl(GSObject): @@ -178,6 +259,7 @@ def __init__( _force_stepk=0.0, _force_maxk=0.0, hdu=None, + _recenter_image=True, ): # this class does a ton of munging of the inputs that I don't want to reconstruct when # flattening and unflattening the class. @@ -208,6 +290,7 @@ def __init__( gsparams=gsparams, _force_stepk=_force_stepk, _force_maxk=_force_maxk, + _recenter_image=_recenter_image, hdu=hdu, ) self._params = {} @@ -263,7 +346,8 @@ def __init__( ) else: self._image = image.view(dtype=jnp.float64, contiguous=True) - self._image.setCenter(0, 0) + if _recenter_image: + self._image.setCenter(0, 0) # Set the wcs if necessary if scale is not None: @@ -337,18 +421,12 @@ def withGSParams(self, gsparams=None, **kwargs): aux_data["gsparams"] = gsparams ret = self.tree_unflatten(aux_data, children) - ret._x_interpolant = self._x_interpolant.withGSParams(ret._gsparams, **kwargs) - ret._k_interpolant = self._k_interpolant.withGSParams(ret._gsparams, **kwargs) - if ret._gsparams.folding_threshold != self._gsparams.folding_threshold: - ret._stepk = ret._getStepK(self._calculate_stepk, 0.0) - if ret._gsparams.maxk_threshold != self._gsparams.maxk_threshold: - ret._maxk = ret._getMaxK(self._calculate_maxk, 0.0) return ret def tree_flatten(self): """This function flattens the InterpolatedImage into a list of children nodes that will be traced by JAX and auxiliary static data.""" - return (self._jax_children, self._jax_aux_data) + return (self._jax_children, copy.copy(self._jax_aux_data)) @classmethod def tree_unflatten(cls, aux_data, children): @@ -472,6 +550,18 @@ def _buildImages( # we always make this self._kim = self._xim.calculate_fft() + # record pos and neg fluxes now too + pflux = jnp.sum(jnp.where(self._pad_image.array > 0, self._pad_image.array, 0)) + nflux = jnp.sum(jnp.where(self._pad_image.array < 0, self._pad_image.array, 0)) + pint = self._x_interpolant.positive_flux + nint = self._x_interpolant.negative_flux + pint2d = pint * pint + nint * nint + nint2d = 2 * pint * nint + self._pos_neg_fluxes = [ + pint2d * pflux + nint2d * nflux, + pint2d * nflux + nint2d * pflux, + ] + # FIXME: no BaseDeviate in jax_galsim so no noise padding # def _buildNoisePadImage(self, noise_pad_size, noise_pad, rng, use_cache): # """A helper function that builds the ``pad_image`` from the given ``noise_pad`` @@ -535,6 +625,7 @@ def _getStepK(self, calculate_stepk, _force_stepk): # images (see the ._adjust_offset step below) leads to automatic reduction of stepK slightly # below what is provided here, while maxK is preserved. if _force_stepk > 0.0: + print("Forcing stepk to be %f" % _force_stepk) return _force_stepk elif calculate_stepk: if calculate_stepk is True: @@ -563,6 +654,7 @@ def _getSimpleStepK(self, R): def _getMaxK(self, calculate_maxk, _force_maxk): max_scale = 1.0 if _force_maxk > 0.0: + print("Forcing maxk to be %f" % _force_maxk) return _force_maxk elif calculate_maxk: _uscale = 1 / (2 * jnp.pi) @@ -682,25 +774,26 @@ def _flux(self, value): @property def _centroid(self): - raise NotImplementedError("WIP interp - centroid") + x, y = self._pad_image.get_pixel_centers() + tot = jnp.sum(self._pad_image.array) + xpos = jnp.sum(x * self._pad_image.array) / tot + ypos = jnp.sum(y * self._pad_image.array) / tot + return PositionD(xpos, ypos) @property def _positive_flux(self): - raise NotImplementedError("WIP interp - positive_flux") + return self._pos_neg_fluxes[0] @property def _negative_flux(self): - raise NotImplementedError("WIP interp - negative_flux") + return self._pos_neg_fluxes[1] @property def _max_sb(self): return jnp.max(jnp.abs(self._pad_image.array)) - # @lazy_property def _flux_per_photon(self): - # FIXME: jax_galsim does not photon shoot - # return self._calculate_flux_per_photon() - raise NotImplementedError("Photon shooting not implemented.") + return self._calculate_flux_per_photon() def _xValue(self, pos): pos += self._offset @@ -756,6 +849,34 @@ def _drawKImage(self, image, jac=None): return draw_by_kValue(self, image, _jac) +@_wraps(_galsim._InterpolatedImage) +def _InterpolatedImage( + image, + x_interpolant=Quintic(), + k_interpolant=Quintic(), + use_true_center=True, + offset=None, + gsparams=None, + force_stepk=0.0, + force_maxk=0.0, +): + return InterpolatedImage( + image, + x_interpolant=x_interpolant, + k_interpolant=k_interpolant, + use_true_center=use_true_center, + offset=offset, + gsparams=gsparams, + calculate_maxk=False, + calculate_stepk=False, + pad_factor=1.0, + flux=jnp.sum(image.array), + _force_stepk=force_stepk, + _force_maxk=force_maxk, + _recenter_image=False, + ) + + @partial(jax.jit, static_argnums=(5,)) def _draw_with_interpolant_xval(x, y, xmin, ymin, zp, interp): orig_shape = x.shape diff --git a/jax_galsim/transform.py b/jax_galsim/transform.py index 58b00ab4..b4e264f4 100644 --- a/jax_galsim/transform.py +++ b/jax_galsim/transform.py @@ -54,7 +54,6 @@ def __init__( obj = obj.withGSParams(self._gsparams) self._params = { - "obj": obj, "jac": jac, "offset": self._offset, "flux_ratio": self._flux_ratio, @@ -109,24 +108,34 @@ def withGSParams(self, gsparams=None, **kwargs): """ if gsparams == self.gsparams: return self - from copy import copy - ret = copy(self) - ret._gsparams = GSParams.check(gsparams, self.gsparams, **kwargs) + chld, aux = self.tree_flatten() + aux["gsparams"] = GSParams.check(gsparams, self.gsparams, **kwargs) if self._propagate_gsparams: - ret._original = self._original.withGSParams(ret._gsparams) - return ret - - def __eq__(self, other): - return self is other or ( - isinstance(other, Transformation) - and self.original == other.original - and jnp.array_equal(self.jac, other.jac) - and self.offset == other.offset - and self.flux_ratio == other.flux_ratio - and self.gsparams == other.gsparams - and self._propagate_gsparams == other._propagate_gsparams + new_obj = chld[0].withGSParams(aux["gsparams"]) + chld = (new_obj,) + chld[1:] + + return self.tree_unflatten(aux, chld) + + def tree_flatten(self): + """This function flattens the GSObject into a list of children + nodes that will be traced by JAX and auxiliary static data.""" + # Define the children nodes of the PyTree that need tracing + children = ( + self._original, + self.params, ) + # Define auxiliary static data that doesn’t need to be traced + aux_data = { + "gsparams": self.gsparams, + "propagate_gsparams": self._propagate_gsparams, + } + return (children, aux_data) + + @classmethod + def tree_unflatten(cls, aux_data, children): + """Recreates an instance of the class from flatten representation""" + return cls(children[0], **(children[1]), **aux_data) def __hash__(self): return hash( @@ -305,6 +314,10 @@ def _positive_flux(self): def _negative_flux(self): return self._flux_scaling * self._original.negative_flux + @property + def _flux_per_photon(self): + return self._calculate_flux_per_photon() + @property def _max_sb(self): return self._amp_scaling * self._original.max_sb @@ -355,18 +368,6 @@ def _drawKImage(self, image, jac=None): image = image * self._flux_scaling return image - def tree_flatten(self): - """This function flattens the GSObject into a list of children - nodes that will be traced by JAX and auxiliary static data.""" - # Define the children nodes of the PyTree that need tracing - children = (self.params,) - # Define auxiliary static data that doesn’t need to be traced - aux_data = { - "gsparams": self.gsparams, - "propagate_gsparams": self._propagate_gsparams, - } - return (children, aux_data) - def _Transform( obj, From 47c06a756b35fb183e5bf032854372540317a653 Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 3 Oct 2023 07:53:08 -0500 Subject: [PATCH 14/67] new submodule --- tests/GalSim | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/GalSim b/tests/GalSim index b7ce7573..8902c72c 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit b7ce7573f42b7d21120365b0373cb83a88953b45 +Subproject commit 8902c72ceb196216a7161f8858ed9c63c4002264 From a2c54dbdea68ec167782ea26e99e7b5455297c32 Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 3 Oct 2023 08:29:31 -0500 Subject: [PATCH 15/67] TST fix the tests --- jax_galsim/core/utils.py | 5 ----- jax_galsim/transform.py | 13 ++++++++++++- tests/jax/test_api.py | 15 ++++++++------- 3 files changed, 20 insertions(+), 13 deletions(-) diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index d1ae1f71..df03be6c 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -34,7 +34,6 @@ def is_equal_with_arrays(x, y): if isinstance(y, list) and len(x) == len(y): for vx, vy in zip(x, y): if not is_equal_with_arrays(vx, vy): - print(vx, vy) return False return True else: @@ -43,7 +42,6 @@ def is_equal_with_arrays(x, y): if isinstance(y, tuple) and len(x) == len(y): for vx, vy in zip(x, y): if not is_equal_with_arrays(vx, vy): - print(vx, vy) return False return True else: @@ -52,7 +50,6 @@ def is_equal_with_arrays(x, y): if isinstance(y, set) and len(x) == len(y): for vx, vy in zip(x, y): if not is_equal_with_arrays(vx, vy): - print(vx, vy) return False return True else: @@ -61,7 +58,6 @@ def is_equal_with_arrays(x, y): if isinstance(y, dict) and len(x) == len(y): for kx, vx in x.items(): if kx not in y or (not is_equal_with_arrays(vx, y[kx])): - print(kx, vx, y[kx]) return False return True else: @@ -70,7 +66,6 @@ def is_equal_with_arrays(x, y): if isinstance(y, jax.Array) and y.shape == x.shape: return jnp.array_equal(x, y) else: - print(x, y) return False elif (isinstance(x, jax.Array) and jnp.ndim(x) == 0) or ( isinstance(y, jax.Array) and jnp.ndim(y) == 0 diff --git a/jax_galsim/transform.py b/jax_galsim/transform.py index ff2022e2..53ac8efe 100644 --- a/jax_galsim/transform.py +++ b/jax_galsim/transform.py @@ -118,6 +118,17 @@ def withGSParams(self, gsparams=None, **kwargs): return self.tree_unflatten(aux, chld) + def __eq__(self, other): + return (self is other) or ( + (type(other) is self.__class__) + and self._original == other._original + and jnp.array_equal(self._jac, other._jac) + and self._offset == other._offset + and self._flux_ratio == other._flux_ratio + and self.gsparams == other.gsparams + and self._propagate_gsparams == other._propagate_gsparams + ) + def tree_flatten(self): """This function flattens the GSObject into a list of children nodes that will be traced by JAX and auxiliary static data.""" @@ -143,7 +154,7 @@ def __hash__(self): ( "galsim.Transformation", self.original, - ensure_hashable(self._jac.ravel()), + ensure_hashable(self._jac), ensure_hashable(self.offset.x), ensure_hashable(self.offset.y), ensure_hashable(self.flux_ratio), diff --git a/tests/jax/test_api.py b/tests/jax/test_api.py index 9a3c5dc5..19f9de44 100644 --- a/tests/jax/test_api.py +++ b/tests/jax/test_api.py @@ -37,6 +37,7 @@ def test_api_same(): "One of scale_radius, half_light_radius, or fwhm must be specified", "Arguments to Sum must be GSObjects", "'ArrayImpl' object has no attribute 'gsparams'", + "Supplied image must be an Image or file name", ] @@ -274,16 +275,16 @@ def test_api_gsobject(kind): cls_tested.add(cls.__name__) print(obj) - _run_object_checks(obj, cls, kind) + # _run_object_checks(obj, cls, kind) if cls.__name__ == "Gaussian": - _obj = obj + obj - print(_obj) - _run_object_checks(_obj, _obj.__class__, kind) + # _obj = obj + obj + # print(_obj) + # _run_object_checks(_obj, _obj.__class__, kind) - _obj = 2.0 * obj - print(_obj) - _run_object_checks(_obj, _obj.__class__, kind) + # _obj = 2.0 * obj + # print(_obj) + # _run_object_checks(_obj, _obj.__class__, kind) _obj = obj.shear(g1=0.1, g2=0.2) print(_obj) From cbb8659016d522150a290f73ecf803e0914b4331 Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 3 Oct 2023 09:03:29 -0500 Subject: [PATCH 16/67] ENH enable all tests --- tests/jax/test_api.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/jax/test_api.py b/tests/jax/test_api.py index 19f9de44..187d5bd2 100644 --- a/tests/jax/test_api.py +++ b/tests/jax/test_api.py @@ -275,16 +275,16 @@ def test_api_gsobject(kind): cls_tested.add(cls.__name__) print(obj) - # _run_object_checks(obj, cls, kind) + _run_object_checks(obj, cls, kind) if cls.__name__ == "Gaussian": - # _obj = obj + obj - # print(_obj) - # _run_object_checks(_obj, _obj.__class__, kind) + _obj = obj + obj + print(_obj) + _run_object_checks(_obj, _obj.__class__, kind) - # _obj = 2.0 * obj - # print(_obj) - # _run_object_checks(_obj, _obj.__class__, kind) + _obj = 2.0 * obj + print(_obj) + _run_object_checks(_obj, _obj.__class__, kind) _obj = obj.shear(g1=0.1, g2=0.2) print(_obj) From 58a041675681f710fb3408e5cff02e18062df086 Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 14 Oct 2023 15:57:04 -0500 Subject: [PATCH 17/67] ENH update submodules --- tests/GalSim | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/GalSim b/tests/GalSim index 66092bdf..0281f764 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 66092bdf7215983bab4d2d953a700eb8a0ddcbe4 +Subproject commit 0281f764f2f8ad3af45bcee9d171d2b48fd79a20 From 92cb2c4ef7ffb20bc0a6c58484ce52b77c5ddf0d Mon Sep 17 00:00:00 2001 From: beckermr Date: Sun, 15 Oct 2023 22:10:02 -0500 Subject: [PATCH 18/67] ENH update submodule --- tests/GalSim | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/GalSim b/tests/GalSim index 0281f764..d63cbdef 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 0281f764f2f8ad3af45bcee9d171d2b48fd79a20 +Subproject commit d63cbdef5e81528ebdca2a18ae100530c8aa23ea From 395708e4d4d638441408c99d88be84e3fc31f0b5 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 16 Oct 2023 06:47:51 -0500 Subject: [PATCH 19/67] try this for ci --- .github/workflows/python_package.yaml | 36 ++++++++++++++++++++++----- pyproject.toml | 12 +++++++++ pytest.ini | 10 -------- 3 files changed, 42 insertions(+), 16 deletions(-) delete mode 100644 pytest.ini diff --git a/.github/workflows/python_package.yaml b/.github/workflows/python_package.yaml index 1e0cacae..f3b41095 100644 --- a/.github/workflows/python_package.yaml +++ b/.github/workflows/python_package.yaml @@ -16,30 +16,54 @@ jobs: steps: - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v3 + - uses: mamba-org/setup-micromamba@v1 with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies + create-args: >- + python=${{ matrix.python-version }} + isort + flake8 + pytest + black==23.3.0 + flake8-pyproject + compilers + eigen + fftw3 + pybind11 + astropy + scipy + pyyaml + numpy + jax + galsim + pip + cache-environment: true + + - name: Install package + shell: bash -el {0} run: | - python -m pip install --upgrade pip - python -m pip install isort flake8 pytest black==23.3.0 flake8-pyproject python -m pip install . + - name: Ensure black formatting + shell: bash -el {0} run: | black --check jax_galsim/ tests/ --exclude tests/GalSim/ + - name: Lint with flake8 + shell: bash -el {0} run: | # stop the build if there are Python syntax errors or undefined names flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide flake8 jax_galsim/ --count --exit-zero --statistics flake8 tests/jax/ --count --exit-zero --statistics + - name: Ensure isort + shell: bash -el {0} run: | isort --check jax_galsim - name: Test with pytest + shell: bash -el {0} run: | git submodule update --init --recursive pytest diff --git a/pyproject.toml b/pyproject.toml index c6fe79ab..3f456117 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,3 +21,15 @@ skip = [ "tests/Galsim/", "tests/Coord/", ] + +[tool.pytest.ini_options] +minversion = "6.0" +addopts = "-ra -q" +testpaths = [ + "tests/GalSim/tests/", + "tests/jax", + "tests/Coord/tests/", +] +filterwarnings = [ + "ignore::DeprecationWarning", +] diff --git a/pytest.ini b/pytest.ini deleted file mode 100644 index d85bef3a..00000000 --- a/pytest.ini +++ /dev/null @@ -1,10 +0,0 @@ -# pytest.ini -[pytest] -minversion = 6.0 -addopts = -ra -q -testpaths = - tests/GalSim/tests/ - tests/jax - tests/Coord/tests/ -filterwarnings = - ignore::DeprecationWarning From eae7ca57ae05e8c8fd7b00b568f5853da0a8b755 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 16 Oct 2023 06:49:20 -0500 Subject: [PATCH 20/67] try this --- .github/workflows/python_package.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/python_package.yaml b/.github/workflows/python_package.yaml index f3b41095..54c309da 100644 --- a/.github/workflows/python_package.yaml +++ b/.github/workflows/python_package.yaml @@ -36,7 +36,6 @@ jobs: jax galsim pip - cache-environment: true - name: Install package shell: bash -el {0} From b17e85ed5d8ef88486c74416d0a67bfcfaa713b3 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 16 Oct 2023 06:50:18 -0500 Subject: [PATCH 21/67] try this --- .github/workflows/python_package.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/python_package.yaml b/.github/workflows/python_package.yaml index 54c309da..2725d043 100644 --- a/.github/workflows/python_package.yaml +++ b/.github/workflows/python_package.yaml @@ -18,6 +18,7 @@ jobs: - uses: actions/checkout@v3 - uses: mamba-org/setup-micromamba@v1 with: + environment-name: test create-args: >- python=${{ matrix.python-version }} isort From 48c29fa0178b3a4b88db4aa9e41df73774b6aa09 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 16 Oct 2023 06:55:45 -0500 Subject: [PATCH 22/67] try this --- .github/workflows/python_package.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python_package.yaml b/.github/workflows/python_package.yaml index 2725d043..6106b8b5 100644 --- a/.github/workflows/python_package.yaml +++ b/.github/workflows/python_package.yaml @@ -28,7 +28,7 @@ jobs: flake8-pyproject compilers eigen - fftw3 + fftw pybind11 astropy scipy From 13585ec4017a27f0a8a46e881dd8799de18133fc Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 16 Oct 2023 06:57:45 -0500 Subject: [PATCH 23/67] updte tests --- tests/GalSim | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/GalSim b/tests/GalSim index d63cbdef..57dc77f9 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit d63cbdef5e81528ebdca2a18ae100530c8aa23ea +Subproject commit 57dc77f99f3d2ce221d57b25716f1b5953b50590 From 15965183f2e1a21dc7fa2839631ae1234f8afc6d Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 16 Oct 2023 07:12:37 -0500 Subject: [PATCH 24/67] try thids --- .github/workflows/python_package.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/python_package.yaml b/.github/workflows/python_package.yaml index 6106b8b5..07d2d6fb 100644 --- a/.github/workflows/python_package.yaml +++ b/.github/workflows/python_package.yaml @@ -16,6 +16,7 @@ jobs: steps: - uses: actions/checkout@v3 + - uses: mamba-org/setup-micromamba@v1 with: environment-name: test @@ -37,6 +38,7 @@ jobs: jax galsim pip + cache-environment: true - name: Install package shell: bash -el {0} From ec81ca68895242c7f74399a1bc99babad74c7db8 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 16 Oct 2023 07:13:11 -0500 Subject: [PATCH 25/67] update submodule --- tests/GalSim | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/GalSim b/tests/GalSim index 57dc77f9..5140919d 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 57dc77f99f3d2ce221d57b25716f1b5953b50590 +Subproject commit 5140919da7b5a71ff5461149117b083d788924d0 From 8a33a709ddd16eca2bf6fa9d2ae01de13adbdb2d Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 17 Oct 2023 07:46:13 -0500 Subject: [PATCH 26/67] WIP tryign to get stuff to pass --- .github/workflows/python_package.yaml | 10 ++- jax_galsim/interpolatedimage.py | 65 +++++++++---------- jax_galsim/moffat.py | 90 +++++++-------------------- tests/GalSim | 2 +- tests/conftest.py | 5 +- tests/jax/galsim/test_wcs_jax.py | 6 +- tests/jax/test_moffat_comp_galsim.py | 40 ++++++++++++ 7 files changed, 108 insertions(+), 110 deletions(-) create mode 100644 tests/jax/test_moffat_comp_galsim.py diff --git a/.github/workflows/python_package.yaml b/.github/workflows/python_package.yaml index 07d2d6fb..79c0f942 100644 --- a/.github/workflows/python_package.yaml +++ b/.github/workflows/python_package.yaml @@ -36,7 +36,6 @@ jobs: pyyaml numpy jax - galsim pip cache-environment: true @@ -64,8 +63,15 @@ jobs: run: | isort --check jax_galsim - - name: Test with pytest + - name: Build galsim shell: bash -el {0} run: | git submodule update --init --recursive + pushd tests/GalSim + pip install --no-deps --no-build-isolation -e . + popd + + - name: Test with pytest + shell: bash -el {0} + run: | pytest diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index e38c57db..57f40892 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -19,6 +19,7 @@ from jax_galsim import fits from jax_galsim.bounds import BoundsI from jax_galsim.core.draw import draw_by_kValue, draw_by_xValue +from jax_galsim.core.utils import ensure_hashable from jax_galsim.gsobject import GSObject from jax_galsim.gsparams import GSParams from jax_galsim.image import Image @@ -454,14 +455,11 @@ def _buildImages( flux = self._image_flux if normalization.lower() in ("surface brightness", "sb"): flux *= self._wcs.pixelArea() - self._flux = flux + _flux = flux # If the user specified a flux, then set the flux ratio for the transform that wraps # this class - if self._flux != self._image_flux: - self._flux_ratio = self._flux / self._image_flux - else: - self._flux_ratio = 1.0 + self._flux_ratio = _flux / self._image_flux # Check that given pad_image is valid: if pad_image is not None: @@ -681,10 +679,10 @@ def __hash__(self): ) self._hash ^= hash( ( - self.flux.item(), - self._stepk.item(), - self._maxk.item(), - self._pad_factor, + ensure_hashable(self.flux), + ensure_hashable(self._stepk), + ensure_hashable(self._maxk), + ensure_hashable(self._pad_factor), ) ) self._hash ^= hash( @@ -695,34 +693,35 @@ def __hash__(self): # This mucking of the numbers seems to help make the hash more reliably different for # these two cases. Note: "sometiems" because of this: # https://stackoverflow.com/questions/27522626/hash-function-in-python-3-3-returns-different-results-between-sessions - self._hash ^= hash((self._offset.x * 1.234, self._offset.y * 0.23424)) + self._hash ^= hash(( + ensure_hashable(self._offset.x * 1.234), + ensure_hashable(self._offset.y * 0.23424)) + ) self._hash ^= hash(self._gsparams) self._hash ^= hash(self._wcs) # Just hash the diagonal. Much faster, and usually is unique enough. # (Let python handle collisions as needed if multiple similar IIs are used as keys.) - self._hash ^= hash(tuple(jnp.diag(self._pad_image.array).tolist())) + self._hash ^= hash(ensure_hashable(self._pad_image.array)) return self._hash def __repr__(self): - s = "galsim.InterpolatedImage(%r, %r, %r" % ( - self._image, - self.x_interpolant, - self.k_interpolant, - ) - # Most things we keep even if not required, but the pad_image is large, so skip it - # if it's really just the same as the main image. - if self._pad_image.bounds != self._image.bounds: - s += ", pad_image=%r" % (self._pad_image) - s += ", wcs=%s" % self._wcs - s += ", pad_factor=%f, flux=%r, offset=%r" % ( - self._pad_factor, - self.flux, - self._offset, - ) - s += ( - ", use_true_center=False, gsparams=%r, _force_stepk=%r, _force_maxk=%r)" - % (self.gsparams, self._stepk, self._maxk) - ) + s = "galsim.InterpolatedImage(%r" % self._jax_children[0] + + for k, v in self._jax_children[1].items(): + if v is not None: + _v = ensure_hashable(v) + s += ", %s=%r" % (k, _v) + + for k, v in self._jax_aux_data.items(): + if v is not None: # and k not in ["gsparams", "_force_stepk", "_force_maxk"]: + _v = ensure_hashable(v) + s += ", %s=%r" % (k, _v) + + s += ")" + # s += ( + # ", gsparams=%r, _force_stepk=%r, _force_maxk=%r)" + # % (self.gsparams, ensure_hashable(self._stepk), ensure_hashable(self._maxk)) + # ) return s def __str__(self): @@ -766,11 +765,7 @@ def image(self): @property def _flux(self): """By default, the flux is contained in the parameters dictionay.""" - return self._params["flux"] - - @_flux.setter - def _flux(self, value): - self._params["flux"] = value + return self._image_flux @property def _centroid(self): diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index d390790c..787b3105 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -1,7 +1,6 @@ import galsim as _galsim import jax import jax.numpy as jnp -import jax.scipy as jsc import tensorflow_probability as tfp from jax._src.numpy.util import _wraps from jax.tree_util import Partial as partial @@ -12,6 +11,7 @@ from jax_galsim.core.integrate import ClenshawCurtisQuad, quad_integral from jax_galsim.core.utils import ensure_hashable from jax_galsim.gsobject import GSObject +from jax_galsim.position import PositionD @jax.jit @@ -269,75 +269,31 @@ def __str__(self): s += ")" return s - @property - def _maxk_untrunc(self): - """untruncated Moffat maxK - - The 2D Fourier Transform of f(r)=C (1+(r/rd)^2)^(-beta) leads - C rd^2 = Flux (beta-1)/pi (no truc) - and - f(k) = C rd^2 int_0^infty (1+x^2)^(-beta) J_0(krd x) x dx - = 2 F (k rd /2)^(\beta-1) K[beta-1, k rd]/Gamma[beta-1] - with k->infty asymptotic behavior - f(k)/F \approx sqrt(pi)/Gamma(beta-1) e^(-k') (k'/2)^(beta -3/2) with k' = k rd - So we solve f(maxk)/F = thr (aka maxk_threshold in gsparams.py) - leading to the iterative search of - let alpha = -log(thr Gamma(beta-1)/sqrt(pi)) - k = (\beta -3/2)log(k/2) + alpha - starting with k = alpha - - note : in the code "alternative code" is related to issue #1208 in GalSim github - """ - - def body(i, val): - kcur, alpha = val - knew = (self.beta - 0.5) * jnp.log(kcur) + alpha - # knew = (self.beta -1.5)* jnp.log(kcur/2) + alpha # alternative code - return knew, alpha - - # alpha = -jnp.log(self.gsparams.maxk_threshold - # * jnp.exp(jsc.special.gammaln(self._beta-1))/jnp.sqrt(jnp.pi) ) # alternative code - - alpha = -jnp.log( - self.gsparams.maxk_threshold - * jnp.power(2.0, self.beta - 0.5) - * jnp.exp(jsc.special.gammaln(self.beta - 1)) - / (2 * jnp.sqrt(jnp.pi)) - ) - - val_init = ( - alpha, - alpha, - ) - val = jax.lax.fori_loop(0, 5, body, val_init) - maxk, alpha = val - return maxk / self._r0 - @property def _prefactor(self): return 2.0 * (self.beta - 1.0) / (self._fluxFactor) @property - def _maxk_trunc(self): - """truncated Moffat maxK""" - # a for gaussian profile... this is f(k_max)/Flux = maxk_threshold - maxk_val = self.gsparams.maxk_threshold - dk = self.gsparams.table_spacing * jnp.sqrt( - jnp.sqrt(self.gsparams.kvalue_accuracy / 10.0) - ) - # 50 is a max (GalSim) but it may be lowered if necessary - ki = jnp.arange(0.0, 50.0, dk) - quad = ClenshawCurtisQuad.init(150) - g = partial(_xMoffatIntegrant, beta=self.beta, rmax=self._maxRrD, quad=quad) - fki_1 = jax.jit(jax.vmap(g))(ki) - fki = fki_1 * self._prefactor - cond = jnp.abs(fki) > maxk_val - maxk = ki[cond][-1] - return maxk / self._r0 - - @property + @jax.jit def _maxk(self): - return jax.lax.select(self.trunc > 0, self._maxk_trunc, self._maxk_untrunc) + def _func(i, args): + low, flow, high, fhigh = args + mid = (low + high) / 2.0 + fmid = jnp.abs(self._kValue(PositionD(x=mid, y=0)).real / self.flux) - self.gsparams.maxk_threshold + return jax.lax.cond( + fmid * flow > 0, + lambda x: (mid, fmid, high, fhigh), + lambda x: (low, flow, mid, fmid), + fmid, + ) + + low = 0.0 + high = 1e5 + flow = jnp.abs(self._kValue(PositionD(x=low, y=0)).real / self.flux) - self.gsparams.maxk_threshold + fhigh = jnp.abs(self._kValue(PositionD(x=high, y=0)).real / self.flux) - self.gsparams.maxk_threshold + args = (low, flow, high, fhigh) + res = jax.lax.fori_loop(0, 50, _func, args) + return res[2] @property def _stepk_lowbeta(self): @@ -353,12 +309,10 @@ def _stepk_highbeta(self): jnp.power(self.gsparams.folding_threshold, 0.5 / (1.0 - self.beta)) * self._r0 ) - if R > self._maxR: - R = self._maxR + R = jnp.minimum(R, self._maxR) # at least R should be 5 HLR R5hlr = self.gsparams.stepk_minimum_hlr * self.half_light_radius - if R < R5hlr: - R = R5hlr + R = jnp.maximum(R, R5hlr) return jnp.pi / R @property diff --git a/tests/GalSim b/tests/GalSim index 5140919d..ca90d938 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 5140919da7b5a71ff5461149117b083d788924d0 +Subproject commit ca90d938a3b16450b84452720068e0b558842bbb diff --git a/tests/conftest.py b/tests/conftest.py index 8095105e..7c8fd8c2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -91,7 +91,10 @@ def pytest_pycollect_makemodule(module_path, path, parent): if ( callable(v) and hasattr(v, "__globals__") - and inspect.getsourcefile(v).endswith("galsim_test_helpers.py") + and ( + inspect.getsourcefile(v).endswith("galsim_test_helpers.py") + or inspect.getsourcefile(v).endswith("galsim/utilities.py") + ) and _infile("def " + k, inspect.getsourcefile(v)) and "galsim" in v.__globals__ ): diff --git a/tests/jax/galsim/test_wcs_jax.py b/tests/jax/galsim/test_wcs_jax.py index 49f6d8d4..0247845f 100644 --- a/tests/jax/galsim/test_wcs_jax.py +++ b/tests/jax/galsim/test_wcs_jax.py @@ -801,7 +801,7 @@ def do_local_wcs(wcs, ufunc, vfunc, name): do_wcs_pos(wcs, ufunc, vfunc, name) # Check picklability - do_pickle(wcs) + check_pickle(wcs) # Test the transformation of a GSObject # These only work for local WCS projections! @@ -1014,7 +1014,7 @@ def do_nonlocal_wcs(wcs, ufunc, vfunc, name, test_pickle=True, color=None): # Check picklability if test_pickle: - do_pickle(wcs) + check_pickle(wcs) # The GSObject transformation tests are only valid for a local WCS. # But it should work for wcs.local() @@ -1223,7 +1223,7 @@ def do_celestial_wcs(wcs, name, test_pickle=True, approx=False): # Check picklability if test_pickle: - do_pickle(wcs) + check_pickle(wcs) near_ra_list = [] near_dec_list = [] diff --git a/tests/jax/test_moffat_comp_galsim.py b/tests/jax/test_moffat_comp_galsim.py new file mode 100644 index 00000000..b8a7271d --- /dev/null +++ b/tests/jax/test_moffat_comp_galsim.py @@ -0,0 +1,40 @@ +import galsim as _galsim +import jax_galsim as galsim +import numpy as np + + +def test_moffat_comp_galsim_maxk(): + psfs = [ + # Make sure to include all the specialized betas we have in C++ layer. + # The scale_radius and flux don't matter, but vary themm too. + # Note: We also specialize beta=1, but that seems to be impossible to realize, + # even when it is trunctatd. + galsim.Moffat(beta=1.5, scale_radius=1, flux=1), + galsim.Moffat(beta=1.5001, scale_radius=1, flux=1), + galsim.Moffat(beta=2, scale_radius=0.8, flux=23), + galsim.Moffat(beta=2.5, scale_radius=1.8e-3, flux=2), + galsim.Moffat(beta=3, scale_radius=1.8e3, flux=35), + galsim.Moffat(beta=3.5, scale_radius=1.3, flux=123), + galsim.Moffat(beta=4, scale_radius=4.9, flux=23), + galsim.Moffat(beta=1.22, scale_radius=23, flux=23), + galsim.Moffat(beta=3.6, scale_radius=2, flux=23), + galsim.Moffat(beta=12.9, scale_radius=5, flux=23), + galsim.Moffat(beta=1.22, scale_radius=7, flux=23, trunc=30), + galsim.Moffat(beta=3.6, scale_radius=9, flux=23, trunc=50), + galsim.Moffat(beta=12.9, scale_radius=11, flux=23, trunc=1000), + ] + threshs = [1.e-3, 1.e-4, 0.03] + print('\nbeta \t trunc \t thresh \t kValue(maxk) \t jgs-maxk \t gs-maxk') + for psf in psfs: + for thresh in threshs: + psf = psf.withGSParams(maxk_threshold=thresh) + gpsf = _galsim.Moffat(beta=psf.beta, scale_radius=psf.scale_radius, flux=psf.flux, trunc=psf.trunc) + gpsf = gpsf.withGSParams(maxk_threshold=thresh) + fk = psf.kValue(psf.maxk, 0).real / psf.flux + + print(f'{psf.beta} \t {int(psf.trunc)} \t {thresh:.1e} \t {fk:.3e} \t {psf.maxk:.3e} \t {gpsf.maxk:.3e}') + np.testing.assert_allclose(psf.kValue(0.0, 0.0), gpsf.kValue(0.0, 0.0), rtol=1e-5) + np.testing.assert_allclose(psf.kValue(0.0, 0.1), gpsf.kValue(0.0, 0.1), rtol=1e-5) + np.testing.assert_allclose(psf.kValue(-1.0, 0.0), gpsf.kValue(-1.0, 0.0), rtol=1e-5) + np.testing.assert_allclose(psf.kValue(1.0, 0.0), gpsf.kValue(1.0, 0.0), rtol=1e-5) + np.testing.assert_allclose(gpsf.maxk, psf.maxk, rtol=0.5, atol=0) From 48a409df86b77af1038a7f12208487c7a99eae79 Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 19 Oct 2023 23:50:04 -0500 Subject: [PATCH 27/67] update submodule --- tests/GalSim | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/GalSim b/tests/GalSim index ca90d938..b018d57f 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit ca90d938a3b16450b84452720068e0b558842bbb +Subproject commit b018d57fba88eabbaacf40d34d3029a77e7071f2 From 21db9954b20da2025d32a32a7c80e3a263c704fd Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 19 Oct 2023 23:52:46 -0500 Subject: [PATCH 28/67] reset transofrm to current main --- jax_galsim/transform.py | 55 ++++++++++++++++++----------------------- 1 file changed, 24 insertions(+), 31 deletions(-) diff --git a/jax_galsim/transform.py b/jax_galsim/transform.py index 7b1a5428..c64863f8 100644 --- a/jax_galsim/transform.py +++ b/jax_galsim/transform.py @@ -125,42 +125,22 @@ def withGSParams(self, gsparams=None, **kwargs): return self.tree_unflatten(aux, chld) def __eq__(self, other): - return (self is other) or ( - (type(other) is self.__class__) - and self._original == other._original - and jnp.array_equal(self._jac, other._jac) - and self._offset == other._offset - and self._flux_ratio == other._flux_ratio + return self is other or ( + isinstance(other, Transformation) + and self.original == other.original + and jnp.array_equal(self.jac, other.jac) + and self.offset == other.offset + and self.flux_ratio == other.flux_ratio and self.gsparams == other.gsparams and self._propagate_gsparams == other._propagate_gsparams ) - def tree_flatten(self): - """This function flattens the GSObject into a list of children - nodes that will be traced by JAX and auxiliary static data.""" - # Define the children nodes of the PyTree that need tracing - children = ( - self._original, - self.params, - ) - # Define auxiliary static data that doesn’t need to be traced - aux_data = { - "gsparams": self.gsparams, - "propagate_gsparams": self._propagate_gsparams, - } - return (children, aux_data) - - @classmethod - def tree_unflatten(cls, aux_data, children): - """Recreates an instance of the class from flatten representation""" - return cls(children[0], **(children[1]), **aux_data) - def __hash__(self): return hash( ( "galsim.Transformation", self.original, - ensure_hashable(self._jac), + ensure_hashable(self._jac.ravel()), ensure_hashable(self.offset.x), ensure_hashable(self.offset.y), ensure_hashable(self.flux_ratio), @@ -335,10 +315,6 @@ def _positive_flux(self): def _negative_flux(self): return self._flux_scaling * self._original.negative_flux - @property - def _flux_per_photon(self): - return self._calculate_flux_per_photon() - @property def _max_sb(self): return self._amp_scaling * self._original.max_sb @@ -389,6 +365,23 @@ def _drawKImage(self, image, jac=None): image = image * self._flux_scaling return image + def tree_flatten(self): + """This function flattens the Transform into a list of children + nodes that will be traced by JAX and auxiliary static data.""" + # Define the children nodes of the PyTree that need tracing + children = (self._original, self._params) + # Define auxiliary static data that doesn’t need to be traced + aux_data = { + "gsparams": self.gsparams, + "propagate_gsparams": self._propagate_gsparams, + } + return (children, aux_data) + + @classmethod + def tree_unflatten(cls, aux_data, children): + """Recreates an instance of the class from flatten representation""" + return cls(children[0], **(children[1]), **aux_data) + def _Transform( obj, From 6a000e7c07fdc92092bc9cbef7934216a76b3c91 Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 19 Oct 2023 23:54:46 -0500 Subject: [PATCH 29/67] TST do not ignore these errors --- tests/galsim_tests_config.yaml | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/galsim_tests_config.yaml b/tests/galsim_tests_config.yaml index 0bf3312c..4403b1e7 100644 --- a/tests/galsim_tests_config.yaml +++ b/tests/galsim_tests_config.yaml @@ -53,10 +53,8 @@ allowed_failures: - "has no attribute 'shoot'" - "module 'jax_galsim' has no attribute 'integ'" - "module 'jax_galsim' has no attribute 'InterpolatedKImage'" - - "module 'jax_galsim' has no attribute 'GaussianDeviate'" - "module 'jax_galsim' has no attribute 'CorrelatedNoise'" - "'Image' object has no attribute 'FindAdaptiveMom'" - - " module 'jax_galsim' has no attribute '_InterpolatedImage'" - "JAX arrays are immutable." - "CelestialCoord.precess is too slow" # cannot get jax to warmup but once it does it passes - "ValueError not raised by from_xyz" From 8bac999c41d2540a7222c7fdc56a9c3891dec2c8 Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 19 Oct 2023 23:55:56 -0500 Subject: [PATCH 30/67] back out changes to test file --- tests/jax/test_moffat_comp_galsim.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/jax/test_moffat_comp_galsim.py b/tests/jax/test_moffat_comp_galsim.py index 69efb132..d4549420 100644 --- a/tests/jax/test_moffat_comp_galsim.py +++ b/tests/jax/test_moffat_comp_galsim.py @@ -1,7 +1,8 @@ import galsim as _galsim -import jax_galsim as galsim import numpy as np +import jax_galsim as galsim + def test_moffat_comp_galsim_maxk(): psfs = [ From 7b97e9c5c3e18ac6714ebb0d4cc1c4676edc0eca Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Thu, 19 Oct 2023 23:56:44 -0500 Subject: [PATCH 31/67] Update tests/galsim_tests_config.yaml --- tests/galsim_tests_config.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/galsim_tests_config.yaml b/tests/galsim_tests_config.yaml index 4403b1e7..e53266f3 100644 --- a/tests/galsim_tests_config.yaml +++ b/tests/galsim_tests_config.yaml @@ -55,7 +55,6 @@ allowed_failures: - "module 'jax_galsim' has no attribute 'InterpolatedKImage'" - "module 'jax_galsim' has no attribute 'CorrelatedNoise'" - "'Image' object has no attribute 'FindAdaptiveMom'" - - "JAX arrays are immutable." - "CelestialCoord.precess is too slow" # cannot get jax to warmup but once it does it passes - "ValueError not raised by from_xyz" - "ValueError not raised by greatCirclePoint" From c2cff10e59e33629e0e4d648e3d961df184d63e6 Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 19 Oct 2023 23:58:12 -0500 Subject: [PATCH 32/67] STY blacken --- jax_galsim/interpolatedimage.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index 57f40892..e5fa59da 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -693,9 +693,11 @@ def __hash__(self): # This mucking of the numbers seems to help make the hash more reliably different for # these two cases. Note: "sometiems" because of this: # https://stackoverflow.com/questions/27522626/hash-function-in-python-3-3-returns-different-results-between-sessions - self._hash ^= hash(( - ensure_hashable(self._offset.x * 1.234), - ensure_hashable(self._offset.y * 0.23424)) + self._hash ^= hash( + ( + ensure_hashable(self._offset.x * 1.234), + ensure_hashable(self._offset.y * 0.23424), + ) ) self._hash ^= hash(self._gsparams) self._hash ^= hash(self._wcs) @@ -713,7 +715,9 @@ def __repr__(self): s += ", %s=%r" % (k, _v) for k, v in self._jax_aux_data.items(): - if v is not None: # and k not in ["gsparams", "_force_stepk", "_force_maxk"]: + if ( + v is not None + ): # and k not in ["gsparams", "_force_stepk", "_force_maxk"]: _v = ensure_hashable(v) s += ", %s=%r" % (k, _v) From 0bb6e7034d12c2534cbb1bad18eca60b384259f4 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 20 Oct 2023 00:02:35 -0500 Subject: [PATCH 33/67] TST use more specific error --- jax_galsim/convolve.py | 8 ++++---- tests/galsim_tests_config.yaml | 3 ++- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/jax_galsim/convolve.py b/jax_galsim/convolve.py index 0a349039..3f34d1b1 100644 --- a/jax_galsim/convolve.py +++ b/jax_galsim/convolve.py @@ -118,7 +118,7 @@ def __init__(self, *args, **kwargs): # Save the construction parameters (as they are at this point) as attributes so they # can be inspected later if necessary. if bool(real_space): - raise NotImplementedError("Real space convolutions are not implemented") + raise NotImplementedError("Real-space convolutions are not implemented") self._real_space = bool(real_space) # Figure out what gsparams to use @@ -296,7 +296,7 @@ def _max_sb(self): return self.flux / jnp.sum(jnp.array(area_list)) def _xValue(self, pos): - raise NotImplementedError("Not implemented") + raise NotImplementedError("Real-space convolutions are not implemented") def _kValue(self, kpos): kv_list = [ @@ -305,10 +305,10 @@ def _kValue(self, kpos): return jnp.prod(jnp.array(kv_list)) def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): - raise NotImplementedError("Not implemented") + raise NotImplementedError("Real-space convolutions are not implemented") def _shoot(self, photons, rng): - raise NotImplementedError("Not implemented") + raise NotImplementedError("Photon shooting convolutions are not implemented") def _drawKImage(self, image, jac=None): image = self.obj_list[0]._drawKImage(image, jac) diff --git a/tests/galsim_tests_config.yaml b/tests/galsim_tests_config.yaml index e53266f3..5873e932 100644 --- a/tests/galsim_tests_config.yaml +++ b/tests/galsim_tests_config.yaml @@ -24,7 +24,8 @@ enabled_tests: # correspond to features that are not implemented yet # in jax_galsim allowed_failures: - - "NotImplementedError" + - "Real-space convolutions are not implemented" + - "Photon shooting convolutions are not implemented" - "module 'jax_galsim' has no attribute 'Airy'" - "module 'jax_galsim' has no attribute 'Kolmogorov'" - "module 'jax_galsim' has no attribute 'Sersic'" From f9cb0e7ed544ce7758d552dc1d1cfa22aabda343 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 20 Oct 2023 00:14:40 -0500 Subject: [PATCH 34/67] TST use more specific error --- jax_galsim/transform.py | 4 ++++ tests/galsim_tests_config.yaml | 1 + 2 files changed, 5 insertions(+) diff --git a/jax_galsim/transform.py b/jax_galsim/transform.py index c64863f8..e761785f 100644 --- a/jax_galsim/transform.py +++ b/jax_galsim/transform.py @@ -315,6 +315,10 @@ def _positive_flux(self): def _negative_flux(self): return self._flux_scaling * self._original.negative_flux + @property + def _flux_per_photon(self): + return self._calculate_flux_per_photon() + @property def _max_sb(self): return self._amp_scaling * self._original.max_sb diff --git a/tests/galsim_tests_config.yaml b/tests/galsim_tests_config.yaml index 5873e932..6cd33a44 100644 --- a/tests/galsim_tests_config.yaml +++ b/tests/galsim_tests_config.yaml @@ -24,6 +24,7 @@ enabled_tests: # correspond to features that are not implemented yet # in jax_galsim allowed_failures: + - "Phot shooting not yet implemented in drawImage" - "Real-space convolutions are not implemented" - "Photon shooting convolutions are not implemented" - "module 'jax_galsim' has no attribute 'Airy'" From 42344f9094a8e6a8800a73afb55a9b981204a098 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 20 Oct 2023 00:35:57 -0500 Subject: [PATCH 35/67] TST make tests pass --- tests/GalSim | 2 +- tests/galsim_tests_config.yaml | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/GalSim b/tests/GalSim index b018d57f..5b3d4910 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit b018d57fba88eabbaacf40d34d3029a77e7071f2 +Subproject commit 5b3d4910aec5b58a41faadbc6b267b188a93e521 diff --git a/tests/galsim_tests_config.yaml b/tests/galsim_tests_config.yaml index 6cd33a44..a4715dbb 100644 --- a/tests/galsim_tests_config.yaml +++ b/tests/galsim_tests_config.yaml @@ -68,3 +68,7 @@ allowed_failures: - "module 'jax_galsim.utilities' has no attribute 'horner2d'" - "'Image' object has no attribute 'FindAdaptiveMom'" - "module 'jax_galsim' has no attribute 'BaseCorrelatedNoise'" + - "'Image' object has no attribute 'addNoise'" + - "Transform does not support callable arguments." + - "The JAX galsim.BaseDeviate does not support being used as a numpy PRNG." + - "jax_galsim does not support the galsim WCS class GSFitsWCS" From d52c79c6b2def353bbf2bb291e097c6656c68786 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 20 Oct 2023 16:55:11 -0500 Subject: [PATCH 36/67] TST make tests pass and refactor a bit --- jax_galsim/core/utils.py | 8 ++ jax_galsim/interpolatedimage.py | 183 +++++++++++++++++--------------- jax_galsim/transform.py | 11 +- tests/GalSim | 2 +- 4 files changed, 109 insertions(+), 95 deletions(-) diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index dfc29e45..6c42227c 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -4,6 +4,14 @@ import jax.numpy as jnp +def compute_major_minor_from_jacobian(jac): + h1 = jnp.hypot(jac[0, 0] + jac[1, 1], jac[0, 1] - jac[1, 0]) + h2 = jnp.hypot(jac[0, 0] - jac[1, 1], jac[0, 1] + jac[1, 0]) + major = 0.5 * abs(h1 + h2) + minor = 0.5 * abs(h1 - h2) + return major, minor + + def convert_to_float(x): if isinstance(x, jax.Array): if x.shape == (): diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index e5fa59da..15992d3d 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -19,7 +19,7 @@ from jax_galsim import fits from jax_galsim.bounds import BoundsI from jax_galsim.core.draw import draw_by_kValue, draw_by_xValue -from jax_galsim.core.utils import ensure_hashable +from jax_galsim.core.utils import compute_major_minor_from_jacobian, ensure_hashable from jax_galsim.gsobject import GSObject from jax_galsim.gsparams import GSParams from jax_galsim.image import Image @@ -27,7 +27,7 @@ from jax_galsim.position import PositionD from jax_galsim.transform import Transformation from jax_galsim.utilities import convert_interpolant -from jax_galsim.wcs import PixelScale +from jax_galsim.wcs import BaseWCS, PixelScale @_wraps( @@ -89,7 +89,7 @@ def __init__( gsparams=None, _force_stepk=0.0, _force_maxk=0.0, - _recenter_image=True, + _recenter_image=True, # this option is used by _InterpolatedImage below hdu=None, ): self._jax_children = ( @@ -192,13 +192,85 @@ def image(self): return self._original._image def __hash__(self): - return hash(self._original) + # Definitely want to cache this, since the size of the image could be large. + if not hasattr(self, "_hash"): + self._hash = hash( + ("galsim.InterpolatedImage", self.x_interpolant, self.k_interpolant) + ) + self._hash ^= hash( + ( + ensure_hashable(self.flux), + ensure_hashable(self._stepk), + ensure_hashable(self._maxk), + ensure_hashable(self._original._pad_factor), + ) + ) + self._hash ^= hash( + ( + self._original._xim.bounds, + self._original._image.bounds, + self._original._pad_image.bounds, + ) + ) + # A common offset is 0.5,0.5, and *sometimes* this produces the same hash as 0,0 + # (which is also common). I guess because they are only different in 2 bits. + # This mucking of the numbers seems to help make the hash more reliably different for + # these two cases. Note: "sometiems" because of this: + # https://stackoverflow.com/questions/27522626/hash-function-in-python-3-3-returns-different-results-between-sessions + self._hash ^= hash( + ( + ensure_hashable(self._original._offset.x * 1.234), + ensure_hashable(self._original._offset.y * 0.23424), + ) + ) + self._hash ^= hash(self.gsparams) + self._hash ^= hash(self._original._wcs) + # Just hash the diagonal. Much faster, and usually is unique enough. + # (Let python handle collisions as needed if multiple similar IIs are used as keys.) + self._hash ^= hash(ensure_hashable(self._original._pad_image.array)) + return self._hash def __repr__(self): - return repr(self._original) + s = "galsim.InterpolatedImage(%r, %r, %r, wcs=%r" % ( + self._original.image, + self.x_interpolant, + self.k_interpolant, + self._original._wcs, + ) + # Most things we keep even if not required, but the pad_image is large, so skip it + # if it's really just the same as the main image. + if self._original._pad_image.bounds != self._original.image.bounds: + s += ", pad_image=%r" % (self._pad_image) + s += ", pad_factor=%f, flux=%r, offset=%r" % ( + ensure_hashable(self._original._pad_factor), + ensure_hashable(self.flux), + self._original._offset, + ) + s += ( + ", use_true_center=False, gsparams=%r, _force_stepk=%r, _force_maxk=%r)" + % ( + self.gsparams, + ensure_hashable(self._stepk), + ensure_hashable(self._maxk), + ) + ) + return s def __str__(self): - return str(self._original) + return "galsim.InterpolatedImage(image=%s, flux=%s)" % (self.image, self.flux) + + def __eq__(self, other): + return self is other or ( + isinstance(other, InterpolatedImage) + and self._xim == other._xim + and self.x_interpolant == other.x_interpolant + and self.k_interpolant == other.k_interpolant + and self.flux == other.flux + and self._original._offset == other._original._offset + and self.gsparams == other.gsparams + and self._stepk == other._stepk + and self._maxk == other._maxk + ) def tree_flatten(self): """This function flattens the InterpolatedImage into a list of children @@ -294,12 +366,12 @@ def __init__( _recenter_image=_recenter_image, hdu=hdu, ) - self._params = {} - - from .wcs import BaseWCS, PixelScale + self._params = { + "image": image, + } + self._params.update(self._jax_aux_data) + self._params.update(self._jax_children[1]) - # FIXME: no BaseDeviate in jax_galsim - # from .random import BaseDeviate # If the "image" is not actually an image, try to read the image as a file. if isinstance(image, str): image = fits.read(image, hdu=hdu) @@ -378,9 +450,7 @@ def __init__( self._image.bounds, offset, None, use_true_center ) - im_cen = image.true_center if use_true_center else image.center - self._jac_arr = self._image.wcs.jacobian(image_pos=im_cen).getMatrix().ravel() - self._wcs = self._image.wcs.local(image_pos=im_cen) + self._jac_arr, self._wcs = _get_image_jac_arr_wcs(self._image, use_true_center) # Build the fully padded real-space image according to the various pad options. self._buildImages( @@ -405,11 +475,10 @@ def __init__( image, ) - # Process the different options for flux, stepk, maxk - self._calculate_stepk = calculate_stepk - self._calculate_maxk = calculate_maxk - self._stepk = self._getStepK(calculate_stepk, _force_stepk) - self._maxk = self._getMaxK(calculate_maxk, _force_maxk) + major, minor = compute_major_minor_from_jacobian(self._jac_arr.reshape((2, 2))) + + self._maxk = self._getMaxK(calculate_maxk, _force_maxk * minor) + self._stepk = self._getStepK(calculate_stepk, _force_stepk * major) @doc_inherit def withGSParams(self, gsparams=None, **kwargs): @@ -534,6 +603,7 @@ def _buildImages( # Now place the given image in the center of the padding image: # assert self._xim.bounds.includes(self._image.bounds) self._xim[self._image.bounds] = self._image + self._xim.wcs = self._image.wcs # And update the _image to be that portion of the full real image rather than the # input image. @@ -542,7 +612,6 @@ def _buildImages( # These next two allow for easy pickling/repring. We don't need to serialize all the # zeros around the edge. But we do need to keep any non-zero padding as a pad_image. self._pad_image = self._xim[nz_bounds] - # self._pad_factor = (max(self._xim.array.shape)-1.e-6) / max(self._image.array.shape) self._pad_factor = pad_factor # we always make this @@ -623,7 +692,6 @@ def _getStepK(self, calculate_stepk, _force_stepk): # images (see the ._adjust_offset step below) leads to automatic reduction of stepK slightly # below what is provided here, while maxK is preserved. if _force_stepk > 0.0: - print("Forcing stepk to be %f" % _force_stepk) return _force_stepk elif calculate_stepk: if calculate_stepk is True: @@ -652,15 +720,14 @@ def _getSimpleStepK(self, R): def _getMaxK(self, calculate_maxk, _force_maxk): max_scale = 1.0 if _force_maxk > 0.0: - print("Forcing maxk to be %f" % _force_maxk) return _force_maxk elif calculate_maxk: _uscale = 1 / (2 * jnp.pi) - self._maxk = self._x_interpolant.urange() / _uscale / max_scale + _maxk = self._x_interpolant.urange() / _uscale / max_scale if calculate_maxk is True: maxk = _find_maxk( - self._kim, self._maxk, self._gsparams.maxk_threshold * self.flux + self._kim, _maxk, self._gsparams.maxk_threshold * self.flux ) else: maxk = _find_maxk( @@ -671,66 +738,6 @@ def _getMaxK(self, calculate_maxk, _force_maxk): else: return self._x_interpolant.krange / max_scale - def __hash__(self): - # Definitely want to cache this, since the size of the image could be large. - if not hasattr(self, "_hash"): - self._hash = hash( - ("galsim.InterpolatedImage", self.x_interpolant, self.k_interpolant) - ) - self._hash ^= hash( - ( - ensure_hashable(self.flux), - ensure_hashable(self._stepk), - ensure_hashable(self._maxk), - ensure_hashable(self._pad_factor), - ) - ) - self._hash ^= hash( - (self._xim.bounds, self._image.bounds, self._pad_image.bounds) - ) - # A common offset is 0.5,0.5, and *sometimes* this produces the same hash as 0,0 - # (which is also common). I guess because they are only different in 2 bits. - # This mucking of the numbers seems to help make the hash more reliably different for - # these two cases. Note: "sometiems" because of this: - # https://stackoverflow.com/questions/27522626/hash-function-in-python-3-3-returns-different-results-between-sessions - self._hash ^= hash( - ( - ensure_hashable(self._offset.x * 1.234), - ensure_hashable(self._offset.y * 0.23424), - ) - ) - self._hash ^= hash(self._gsparams) - self._hash ^= hash(self._wcs) - # Just hash the diagonal. Much faster, and usually is unique enough. - # (Let python handle collisions as needed if multiple similar IIs are used as keys.) - self._hash ^= hash(ensure_hashable(self._pad_image.array)) - return self._hash - - def __repr__(self): - s = "galsim.InterpolatedImage(%r" % self._jax_children[0] - - for k, v in self._jax_children[1].items(): - if v is not None: - _v = ensure_hashable(v) - s += ", %s=%r" % (k, _v) - - for k, v in self._jax_aux_data.items(): - if ( - v is not None - ): # and k not in ["gsparams", "_force_stepk", "_force_maxk"]: - _v = ensure_hashable(v) - s += ", %s=%r" % (k, _v) - - s += ")" - # s += ( - # ", gsparams=%r, _force_stepk=%r, _force_maxk=%r)" - # % (self.gsparams, ensure_hashable(self._stepk), ensure_hashable(self._maxk)) - # ) - return s - - def __str__(self): - return "galsim.InterpolatedImage(image=%s, flux=%s)" % (self.image, self.flux) - def __getstate__(self): d = self.__dict__.copy() # Only pickle _pad_image. Not _xim or _image @@ -768,7 +775,6 @@ def image(self): @property def _flux(self): - """By default, the flux is contained in the parameters dictionay.""" return self._image_flux @property @@ -876,6 +882,13 @@ def _InterpolatedImage( ) +def _get_image_jac_arr_wcs(image, use_true_center): + im_cen = image.true_center if use_true_center else image.center + _jac_arr = image.wcs.jacobian(image_pos=im_cen).getMatrix().ravel() + _wcs = image.wcs.local(image_pos=im_cen) + return _jac_arr, _wcs + + @partial(jax.jit, static_argnums=(5,)) def _draw_with_interpolant_xval(x, y, xmin, ymin, zp, interp): orig_shape = x.shape diff --git a/jax_galsim/transform.py b/jax_galsim/transform.py index e761785f..03bef064 100644 --- a/jax_galsim/transform.py +++ b/jax_galsim/transform.py @@ -4,7 +4,7 @@ from jax._src.numpy.util import _wraps from jax.tree_util import register_pytree_node_class -from jax_galsim.core.utils import ensure_hashable +from jax_galsim.core.utils import compute_major_minor_from_jacobian, ensure_hashable from jax_galsim.gsobject import GSObject from jax_galsim.gsparams import GSParams from jax_galsim.position import PositionD @@ -253,14 +253,7 @@ def _kfactor(self, kx, ky): def _major_minor(self): if not hasattr(self, "_major"): - h1 = jnp.hypot( - self._jac[0, 0] + self._jac[1, 1], self._jac[0, 1] - self._jac[1, 0] - ) - h2 = jnp.hypot( - self._jac[0, 0] - self._jac[1, 1], self._jac[0, 1] + self._jac[1, 0] - ) - self._major = 0.5 * abs(h1 + h2) - self._minor = 0.5 * abs(h1 - h2) + self._major, self._minor = compute_major_minor_from_jacobian(self._jac) @property def _maxk(self): diff --git a/tests/GalSim b/tests/GalSim index 5b3d4910..380bab83 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 5b3d4910aec5b58a41faadbc6b267b188a93e521 +Subproject commit 380bab83e57042fe7c2530636453908ee2482add From 186f620358c34739432a551dd75285407775eaf4 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Sat, 21 Oct 2023 07:47:19 -0500 Subject: [PATCH 37/67] Update jax_galsim/core/utils.py --- jax_galsim/core/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index 6c42227c..f6eb786c 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -72,7 +72,7 @@ def is_equal_with_arrays(x, y): return False elif isinstance(x, set): if isinstance(y, set) and len(x) == len(y): - for vx, vy in zip(x, y): + for vx, vy in zip(sorted(x), sorted(y)): if not is_equal_with_arrays(vx, vy): return False return True From e44a658728c26c7cedd19b25a71e1843f59fb378 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 23 Oct 2023 15:55:36 -0500 Subject: [PATCH 38/67] ENH it works omg --- jax_galsim/image.py | 22 +- jax_galsim/interpolatedimage.py | 579 +++++++++++----------- jax_galsim/transform.py | 112 +++-- tests/GalSim | 2 +- tests/jax/test_api.py | 16 + tests/jax/test_interpolatedimage_utils.py | 121 +++-- 6 files changed, 454 insertions(+), 398 deletions(-) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 85c41da2..caedcb84 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -659,12 +659,12 @@ def calculate_fft(self): "calculate_fft requires that the image has a PixelScale wcs." ) - No2 = jnp.maximum( - jnp.maximum( + No2 = max( + max( -self.bounds.xmin, self.bounds.xmax + 1, ), - jnp.maximum( + max( -self.bounds.ymin, self.bounds.ymax + 1, ), @@ -713,8 +713,8 @@ def calculate_inverse_fft(self): self.bounds, ) - No2 = jnp.maximum( - jnp.maximum(self.bounds.xmax, -self.bounds.ymin), + No2 = max( + max(self.bounds.xmax, -self.bounds.ymin), self.bounds.ymax, ) @@ -755,6 +755,8 @@ def good_fft_size(cls, input_size): going to be performing FFTs on an image, these will tend to be faster at performing the FFT. """ + import math + # Reference from GalSim C++ # https://github.com/GalSim-developers/GalSim/blob/ece3bd32c1ae6ed771f2b489c5ab1b25729e0ea4/src/Image.cpp#L1009 input_size = int(input_size) @@ -762,12 +764,12 @@ def good_fft_size(cls, input_size): return 2 # Reduce slightly to eliminate potential rounding errors: insize = (1.0 - 1.0e-5) * input_size - log2n = jnp.log(2.0) * jnp.ceil(jnp.log(insize) / jnp.log(2.0)) - log2n3 = jnp.log(3.0) + jnp.log(2.0) * jnp.ceil( - (jnp.log(insize) - jnp.log(3.0)) / jnp.log(2.0) + log2n = math.log(2.0) * math.ceil(math.log(insize) / math.log(2.0)) + log2n3 = math.log(3.0) + math.log(2.0) * math.ceil( + (math.log(insize) - math.log(3.0)) / math.log(2.0) ) - log2n3 = max(log2n3, jnp.log(6.0)) # must be even number - Nk = int(jnp.ceil(jnp.exp(min(log2n, log2n3)) - 1.0e-5)) + log2n3 = max(log2n3, math.log(6.0)) # must be even number + Nk = int(math.ceil(math.exp(min(log2n, log2n3)) - 1.0e-5)) return Nk def copyFrom(self, rhs): diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index 15992d3d..dfc0605f 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -29,6 +29,21 @@ from jax_galsim.utilities import convert_interpolant from jax_galsim.wcs import BaseWCS, PixelScale +_KEYS_TO_REMOVE = [ + "flux_ratio", + "jac", + "offset", + "original", +] + + +# magic from https://stackoverflow.com/questions/46120462/how-to-override-the-dir-method-for-a-class +class DirMeta(type): + def __dir__(cls): + keys = set(list(cls.__dict__.keys()) + dir(cls.__base__)) + keys -= set(_KEYS_TO_REMOVE) + return list(keys) + @_wraps( _galsim.InterpolatedImage, @@ -37,6 +52,7 @@ - noise padding - depixelize + - most of the type checks and dtype casts done by galsim Further, it always computes the FFT of the image as opposed to galsim where this is done as needed. One almost always needs the FFT and JAX @@ -45,7 +61,7 @@ ), ) @register_pytree_node_class -class InterpolatedImage(Transformation): +class InterpolatedImage(Transformation, metaclass=DirMeta): _req_params = {"image": str} _opt_params = { "x_interpolant": str, @@ -91,7 +107,14 @@ def __init__( _force_maxk=0.0, _recenter_image=True, # this option is used by _InterpolatedImage below hdu=None, + _obj=None, ): + # If the "image" is not actually an image, try to read the image as a file. + if isinstance(image, str): + image = fits.read(image, hdu=hdu) + elif not isinstance(image, Image): + raise TypeError("Supplied image must be an Image or file name") + self._jax_children = ( image, dict( @@ -122,39 +145,64 @@ def __init__( hdu=hdu, ) - obj = InterpolatedImageImpl( - image, - x_interpolant=x_interpolant, - k_interpolant=k_interpolant, - normalization=normalization, - scale=scale, - wcs=wcs, - flux=flux, - pad_factor=pad_factor, - noise_pad_size=noise_pad_size, - noise_pad=noise_pad, - rng=rng, - pad_image=pad_image, - calculate_stepk=calculate_stepk, - calculate_maxk=calculate_maxk, - use_cache=use_cache, - use_true_center=use_true_center, - depixelize=depixelize, - offset=offset, - gsparams=GSParams.check(gsparams), - _force_stepk=_force_stepk, - _force_maxk=_force_maxk, - hdu=hdu, - _recenter_image=_recenter_image, - ) - super().__init__( - obj, - jac=obj._jac_arr, - flux_ratio=obj._flux_ratio / obj._wcs.pixelArea(), - offset=PositionD(0.0, 0.0), - gsparams=GSParams.check(gsparams), - propagate_gsparams=True, - ) + if _obj is not None: + obj = _obj + else: + obj = _InterpolatedImageImpl( + image, + x_interpolant=x_interpolant, + k_interpolant=k_interpolant, + normalization=normalization, + scale=scale, + wcs=wcs, + flux=flux, + pad_factor=pad_factor, + noise_pad_size=noise_pad_size, + noise_pad=noise_pad, + rng=rng, + pad_image=pad_image, + calculate_stepk=calculate_stepk, + calculate_maxk=calculate_maxk, + use_cache=use_cache, + use_true_center=use_true_center, + depixelize=depixelize, + offset=offset, + gsparams=GSParams.check(gsparams), + _force_stepk=_force_stepk, + _force_maxk=_force_maxk, + hdu=hdu, + _recenter_image=_recenter_image, + ) + + # we don't use the parent init but instead set things by hand to + # avoid computations upon init + self._gsparams = GSParams.check(gsparams, obj.gsparams) + self._propagate_gsparams = True + if self._propagate_gsparams: + obj = obj.withGSParams(self._gsparams) + self._original = obj + self._params = { + "offset": PositionD(0.0, 0.0), + } + self._jax_children[1]["_obj"] = obj + + @property + def _flux_ratio(self): + return self._original._flux_ratio / self._original._wcs.pixelArea() + + @property + def _jac(self): + return self._original._jac_arr.reshape((2, 2)) + + def __getattribute__(self, name): + if name in _KEYS_TO_REMOVE: + raise AttributeError(f"{self.__class__} has no attribute '{name}'") + return super().__getattribute__(name) + + def __dir__(self): + allattrs = set(self.__dict__.keys() + dir(self.__class__)) + allattrs -= set(_KEYS_TO_REMOVE) + return list(allattrs) # the galsim tests use this internal attribute # so we add it here @@ -285,7 +333,7 @@ def tree_unflatten(cls, aux_data, children): val.update(children[1]) return cls(children[0], **val) - @doc_inherit + @_wraps(_galsim.InterpolatedImage.withGSParams) def withGSParams(self, gsparams=None, **kwargs): if gsparams == self.gsparams: return self @@ -299,8 +347,13 @@ def withGSParams(self, gsparams=None, **kwargs): return ret +@partial(jax.jit, static_argnums=(1,)) +def _zeropad_image(arr, npad): + return jnp.pad(arr, npad, mode="constant", constant_values=0.0) + + @register_pytree_node_class -class InterpolatedImageImpl(GSObject): +class _InterpolatedImageImpl(GSObject): _cache_noise_pad = {} _has_hard_edges = False @@ -337,6 +390,7 @@ def __init__( # this class does a ton of munging of the inputs that I don't want to reconstruct when # flattening and unflattening the class. # thus I am going to make some refs here so we have it when we need it + self._cached_comps = {} self._jax_children = ( image, dict( @@ -346,6 +400,7 @@ def __init__( pad_image=pad_image, offset=offset, ), + self._cached_comps, ) self._jax_aux_data = dict( x_interpolant=x_interpolant, @@ -366,17 +421,6 @@ def __init__( _recenter_image=_recenter_image, hdu=hdu, ) - self._params = { - "image": image, - } - self._params.update(self._jax_aux_data) - self._params.update(self._jax_children[1]) - - # If the "image" is not actually an image, try to read the image as a file. - if isinstance(image, str): - image = fits.read(image, hdu=hdu) - elif not isinstance(image, Image): - raise TypeError("Supplied image must be an Image or file name") # it must have well-defined bounds, otherwise seg fault in SBInterpolatedImage constructor if not image.bounds.isDefined(): @@ -409,76 +453,143 @@ def __init__( self._gsparams ) - # Store the image as an attribute and make sure we don't change the original image - # in anything we do here. (e.g. set scale, etc.) - if depixelize: - # FIXME: no depixelize in jax_galsim - # self._image = image.view(dtype=np.float64).depixelize(self._x_interpolant) + if pad_image is not None: + raise NotImplementedError("pad_image not implemented in jax_galsim.") + + if pad_factor <= 0.0: + raise GalSimRangeError( + "Invalid pad_factor <= 0 in InterpolatedImage", pad_factor, 0.0 + ) + + if noise_pad_size: raise NotImplementedError( - "InterpolatedImages do not support 'depixelize' in jax_galsim." + "InterpolatedImages do not support noise padding in jax_galsim." ) else: - self._image = image.view(dtype=jnp.float64, contiguous=True) - if _recenter_image: - self._image.setCenter(0, 0) + if noise_pad: + raise NotImplementedError( + "InterpolatedImages do not support noise padding in jax_galsim." + ) - # Set the wcs if necessary if scale is not None: if wcs is not None: raise GalSimIncompatibleValuesError( "Cannot provide both scale and wcs to InterpolatedImage", - scale=scale, - wcs=wcs, + scale=self._jax_children[1]["scale"], + wcs=self._jax_children[1]["wcs"], ) - self._image.wcs = PixelScale(scale) elif wcs is not None: if not isinstance(wcs, BaseWCS): raise TypeError("wcs parameter is not a galsim.BaseWCS instance") - self._image.wcs = wcs - elif self._image.wcs is None: - raise GalSimIncompatibleValuesError( - "No information given with Image or keywords about pixel scale!", - scale=scale, - wcs=wcs, - image=image, - ) + else: + if self._jax_children[0].wcs is None: + raise GalSimIncompatibleValuesError( + "No information given with Image or keywords about pixel scale!", + scale=self._jax_children[1]["scale"], + wcs=self._jax_children[1]["wcs"], + image=self._jax_children[0], + ) + + @property + def _flux_ratio(self): + if self._jax_children[1]["flux"] is None: + flux = self._image_flux + if self._jax_aux_data["normalization"].lower() in ( + "surface brightness", + "sb", + ): + flux *= self._wcs.pixelArea() + else: + flux = self._jax_children[1]["flux"] + + # If the user specified a flux, then set the flux ratio for the transform that wraps + # this class + return flux / self._image_flux + @property + def _image_flux(self): + return jnp.sum(self._image.array, dtype=float) + + @property + def _offset(self): # Figure out the offset to apply based on the original image (not the padded one). # We will apply this below in _sbp. - offset = self._parse_offset(offset) - self._offset = self._adjust_offset( - self._image.bounds, offset, None, use_true_center + offset = self._parse_offset(self._jax_children[1]["offset"]) + return self._adjust_offset( + self._image.bounds, offset, None, self._jax_aux_data["use_true_center"] ) - self._jac_arr, self._wcs = _get_image_jac_arr_wcs(self._image, use_true_center) - - # Build the fully padded real-space image according to the various pad options. - self._buildImages( - pad_factor, - pad_image, - noise_pad_size, - noise_pad, - rng, - use_cache, - flux, - normalization, + @property + def _image(self): + # Store the image as an attribute and make sure we don't change the original image + # in anything we do here. (e.g. set scale, etc.) + if self._jax_aux_data["depixelize"]: + # FIXME: no depixelize in jax_galsim + # self._image = image.view(dtype=np.float64).depixelize(self._x_interpolant) + raise NotImplementedError( + "InterpolatedImages do not support 'depixelize' in jax_galsim." + ) + else: + image = self._jax_children[0].view(dtype=float) + + if self._jax_aux_data["_recenter_image"]: + image.setCenter(0, 0) + + return image + + @property + def _wcs(self): + im_cen = ( + self._jax_children[0].true_center + if self._jax_aux_data["use_true_center"] + else self._jax_children[0].center ) - # I think the only things that will mess up if flux == 0 are the - # calculateStepK and calculateMaxK functions, and rescaling the flux to some value. - if ( - calculate_stepk or calculate_maxk or flux is not None - ) and self._image_flux == 0.0: - raise GalSimValueError( - "This input image has zero total flux. It does not define a " - "valid surface brightness profile.", - image, - ) + # error checking was done on init + if self._jax_children[1]["scale"] is not None: + wcs = PixelScale(self._jax_children[1]["scale"]) + elif self._jax_children[1]["wcs"] is not None: + wcs = self._jax_children[1]["wcs"] + else: + wcs = self._jax_children[0].wcs + + return wcs.local(image_pos=im_cen) - major, minor = compute_major_minor_from_jacobian(self._jac_arr.reshape((2, 2))) + @property + def _jac_arr(self): + image = self._jax_children[0] + im_cen = ( + image.true_center if self._jax_aux_data["use_true_center"] else image.center + ) + return self._wcs.jacobian(image_pos=im_cen).getMatrix().ravel() + + @property + def _maxk(self): + if self._jax_aux_data["_force_maxk"]: + major, minor = compute_major_minor_from_jacobian( + self._jac_arr.reshape((2, 2)) + ) + return self._jax_aux_data["_force_maxk"] * minor + else: + if "_maxk" not in self._cached_comps: + self._cached_comps["_maxk"] = self._getMaxK( + self._jax_aux_data["calculate_maxk"] + ) + return self._cached_comps["_maxk"] - self._maxk = self._getMaxK(calculate_maxk, _force_maxk * minor) - self._stepk = self._getStepK(calculate_stepk, _force_stepk * major) + @property + def _stepk(self): + if self._jax_aux_data["_force_stepk"]: + major, minor = compute_major_minor_from_jacobian( + self._jac_arr.reshape((2, 2)) + ) + return self._jax_aux_data["_force_stepk"] * minor + else: + if "_stepk" not in self._cached_comps: + self._cached_comps["_stepk"] = self._getStepK( + self._jax_aux_data["calculate_stepk"] + ) + return self._cached_comps["_stepk"] @doc_inherit def withGSParams(self, gsparams=None, **kwargs): @@ -504,183 +615,76 @@ def tree_unflatten(cls, aux_data, children): val = {} val.update(aux_data) val.update(children[1]) - return cls(children[0], **val) + ret = cls(children[0], **val) + ret._cached_comps.update(children[2]) - def _buildImages( - self, - pad_factor, - pad_image, - noise_pad_size, - noise_pad, - rng, - use_cache, - flux, - normalization, - ): - # If the user specified a surface brightness normalization for the input Image, then - # need to rescale flux by the pixel area to get proper normalization. - self._image_flux = jnp.sum(self._image.array, dtype=float) - if flux is None: - flux = self._image_flux - if normalization.lower() in ("surface brightness", "sb"): - flux *= self._wcs.pixelArea() - _flux = flux + @property + def _xim(self): + if "_xim" not in self._cached_comps: + pad_factor = self._jax_aux_data["pad_factor"] + + # The size of the final padded image is the largest of the various size specifications + pad_size = max(self._image.array.shape) + if pad_factor > 1.0: + pad_size = int(math.ceil(pad_factor * pad_size)) + # And round up to a good fft size + pad_size = Image.good_fft_size(pad_size) + + xim = Image( + _zeropad_image( + self._image.array, (pad_size - max(self._image.array.shape)) // 2 + ), + wcs=PixelScale(1.0), + ) + xim.setCenter(0, 0) + xim.wcs = PixelScale(1.0) - # If the user specified a flux, then set the flux ratio for the transform that wraps - # this class - self._flux_ratio = _flux / self._image_flux + nz_bounds = self._image.bounds - # Check that given pad_image is valid: - if pad_image is not None: - if isinstance(pad_image, str): - pad_image = fits.read(pad_image).view(dtype=jnp.float64) - elif isinstance(pad_image, Image): - pad_image = pad_image.view(dtype=jnp.float64, contiguous=True) - else: - raise TypeError("Supplied pad_image must be an Image.", pad_image) + # Now place the given image in the center of the padding image: + # assert self._xim.bounds.includes(self._image.bounds) + xim[self._image.bounds] = self._image - if pad_factor <= 0.0: - raise GalSimRangeError( - "Invalid pad_factor <= 0 in InterpolatedImage", pad_factor, 0.0 - ) + self._pad_factor = pad_factor + self._nz_bounds = nz_bounds - # Convert noise_pad_size from arcsec to pixels according to the local wcs. - # Use the minimum scale, since we want to make sure noise_pad_size is - # as large as we need in any direction. - if noise_pad_size: - # FIXME: no BaseDeviate in jax_galsim so no noise padding - # if noise_pad_size < 0: - # raise GalSimValueError("noise_pad_size may not be negative", noise_pad_size) - # if not noise_pad: - # raise GalSimIncompatibleValuesError( - # "Must provide noise_pad if noise_pad_size > 0", - # noise_pad=noise_pad, noise_pad_size=noise_pad_size) - # noise_pad_size = int(math.ceil(noise_pad_size / self._wcs._minScale())) - # noise_pad_size = Image.good_fft_size(noise_pad_size) - raise NotImplementedError( - "InterpolatedImages do not support noise padding in jax_galsim." - ) - else: - if noise_pad: - # FIXME: no BaseDeviate in jax_galsim so no noise padding - # raise GalSimIncompatibleValuesError( - # "Must provide noise_pad_size if noise_pad != 0", - # noise_pad=noise_pad, noise_pad_size=noise_pad_size) - raise NotImplementedError( - "InterpolatedImages do not support noise padding in jax_galsim." - ) + self._cached_comps["_xim"] = xim - # The size of the final padded image is the largest of the various size specifications - pad_size = max(self._image.array.shape) - if pad_factor > 1.0: - pad_size = int(math.ceil(pad_factor * pad_size)) - if noise_pad_size: - pad_size = max(pad_size, noise_pad_size) - if pad_image: - pad_image.setCenter(0, 0) - pad_size = max(pad_size, *pad_image.array.shape) - # And round up to a good fft size - pad_size = Image.good_fft_size(pad_size) - - self._xim = Image(pad_size, pad_size, dtype=jnp.float64, wcs=PixelScale(1.0)) - self._xim.setCenter(0, 0) - self._image.wcs = PixelScale(1.0) - - # If requested, fill (some of) this image with noise padding. - nz_bounds = self._image.bounds - # FIXME: no BaseDeviate in jax_galsim so no noise padding - # if noise_pad: - # # This is a bit involved, so pass this off to another helper function. - # b = self._buildNoisePadImage(noise_pad_size, noise_pad, rng, use_cache) - # nz_bounds += b - - # The the user gives us a pad image to use, fill the relevant portion with that. - if pad_image: - # assert self._xim.bounds.includes(pad_image.bounds) - self._xim[pad_image.bounds] = pad_image - nz_bounds += pad_image.bounds - - # Now place the given image in the center of the padding image: - # assert self._xim.bounds.includes(self._image.bounds) - self._xim[self._image.bounds] = self._image - self._xim.wcs = self._image.wcs - - # And update the _image to be that portion of the full real image rather than the - # input image. - self._image = self._xim[self._image.bounds] + return self._cached_comps["_xim"] + @property + def _pad_image(self): # These next two allow for easy pickling/repring. We don't need to serialize all the # zeros around the edge. But we do need to keep any non-zero padding as a pad_image. - self._pad_image = self._xim[nz_bounds] - self._pad_factor = pad_factor + xim = self._xim + return xim[self._nz_bounds] - # we always make this - self._kim = self._xim.calculate_fft() + @property + def _kim(self): + if "_kim" in self._cached_comps: + return self._cached_comps["_kim"] + else: + kim = self._xim.calculate_fft() + self._cached_comps["_kim"] = kim + return kim + @property + def _pos_neg_fluxes(self): # record pos and neg fluxes now too pflux = jnp.sum(jnp.where(self._pad_image.array > 0, self._pad_image.array, 0)) - nflux = jnp.sum(jnp.where(self._pad_image.array < 0, self._pad_image.array, 0)) + nflux = jnp.abs( + jnp.sum(jnp.where(self._pad_image.array < 0, self._pad_image.array, 0)) + ) pint = self._x_interpolant.positive_flux nint = self._x_interpolant.negative_flux pint2d = pint * pint + nint * nint nint2d = 2 * pint * nint - self._pos_neg_fluxes = [ + return [ pint2d * pflux + nint2d * nflux, pint2d * nflux + nint2d * pflux, ] - # FIXME: no BaseDeviate in jax_galsim so no noise padding - # def _buildNoisePadImage(self, noise_pad_size, noise_pad, rng, use_cache): - # """A helper function that builds the ``pad_image`` from the given ``noise_pad`` - # specification. - # """ - # from .random import BaseDeviate - # from .noise import GaussianNoise - # from .correlatednoise import BaseCorrelatedNoise, CorrelatedNoise - - # # Make sure we make rng a BaseDeviate if rng is None - # rng1 = BaseDeviate(rng) - - # # Figure out what kind of noise to apply to the image - # try: - # noise_pad = float(noise_pad) - # except (TypeError, ValueError): - # if isinstance(noise_pad, BaseCorrelatedNoise): - # noise = noise_pad.copy(rng=rng1) - # elif isinstance(noise_pad, Image): - # noise = CorrelatedNoise(noise_pad, rng1) - # elif use_cache and noise_pad in InterpolatedImage._cache_noise_pad: - # noise = InterpolatedImage._cache_noise_pad[noise_pad] - # if rng: - # # Make sure that we are using a specified RNG by resetting that in this cached - # # CorrelatedNoise instance, otherwise preserve the cached RNG - # noise = noise.copy(rng=rng1) - # elif isinstance(noise_pad, basestring): - # noise = CorrelatedNoise(fits.read(noise_pad), rng1) - # if use_cache: - # InterpolatedImage._cache_noise_pad[noise_pad] = noise - # else: - # raise GalSimValueError( - # "Input noise_pad must be a float/int, a CorrelatedNoise, Image, or filename " - # "containing an image to use to make a CorrelatedNoise.", noise_pad) - - # else: - # if noise_pad < 0.: - # raise GalSimRangeError("Noise variance may not be negative.", noise_pad, 0.) - # noise = GaussianNoise(rng1, sigma = np.sqrt(noise_pad)) - - # # Find the portion of xim to fill with noise. - # # It's allowed for the noise padding to not cover the whole pad image - # half_size = noise_pad_size // 2 - # b = _BoundsI(-half_size, -half_size + noise_pad_size-1, - # -half_size, -half_size + noise_pad_size-1) - # #assert self._xim.bounds.includes(b) - # noise_image = self._xim[b] - # # Add the noise - # noise_image.addNoise(noise) - # return b - - def _getStepK(self, calculate_stepk, _force_stepk): + def _getStepK(self, calculate_stepk): # GalSim cannot automatically know what stepK and maxK are appropriate for the # input image. So it is usually worth it to do a manual calculation (below). # @@ -691,39 +695,33 @@ def _getStepK(self, calculate_stepk, _force_stepk): # units required by the C++ layer below. Also note that profile recentering for even-sized # images (see the ._adjust_offset step below) leads to automatic reduction of stepK slightly # below what is provided here, while maxK is preserved. - if _force_stepk > 0.0: - return _force_stepk - elif calculate_stepk: + if calculate_stepk: if calculate_stepk is True: - im = self._image + im = self.image else: # If not a bool, then value is max_stepk R = (jnp.ceil(jnp.pi / calculate_stepk)).astype(int) b = BoundsI(-R, R, -R, R) - b = self._image.bounds & b - im = self._image[b] + b = self.image.bounds & b + im = self.image[b] thresh = (1.0 - self.gsparams.folding_threshold) * self._image_flux # this line appears buggy in galsim - I expect they meant to use im R = _calculate_size_containing_flux(im, thresh) else: - R = max(*self._image.array.shape) / 2.0 - 0.5 + R = max(*self.image.array.shape) / 2.0 - 0.5 return self._getSimpleStepK(R) def _getSimpleStepK(self, R): - min_scale = 1.0 # Add xInterp range in quadrature just like convolution: R2 = self._x_interpolant.xrange R = jnp.hypot(R, R2) - stepk = jnp.pi / (R * min_scale) + stepk = jnp.pi / R return stepk - def _getMaxK(self, calculate_maxk, _force_maxk): - max_scale = 1.0 - if _force_maxk > 0.0: - return _force_maxk - elif calculate_maxk: + def _getMaxK(self, calculate_maxk): + if calculate_maxk: _uscale = 1 / (2 * jnp.pi) - _maxk = self._x_interpolant.urange() / _uscale / max_scale + _maxk = self._x_interpolant.urange() / _uscale if calculate_maxk is True: maxk = _find_maxk( @@ -734,29 +732,9 @@ def _getMaxK(self, calculate_maxk, _force_maxk): self._kim, calculate_maxk, self._gsparams.maxk_threshold * self.flux ) - return maxk / max_scale - else: - return self._x_interpolant.krange / max_scale - - def __getstate__(self): - d = self.__dict__.copy() - # Only pickle _pad_image. Not _xim or _image - d["_xim_bounds"] = self._xim.bounds - d["_image_bounds"] = self._image.bounds - d.pop("_xim", None) - d.pop("_image", None) - return d - - def __setstate__(self, d): - xim_bounds = d.pop("_xim_bounds") - image_bounds = d.pop("_image_bounds") - self.__dict__ = d - if self._pad_image.bounds == xim_bounds: - self._xim = self._pad_image + return maxk else: - self._xim = Image(xim_bounds, wcs=PixelScale(1.0), dtype=jnp.float64) - self._xim[self._pad_image.bounds] = self._pad_image - self._image = self._xim[image_bounds] + return self._x_interpolant.krange @property def x_interpolant(self): @@ -771,7 +749,7 @@ def k_interpolant(self): @property def image(self): """The underlying `Image` being interpolated.""" - return self._image + return self._xim[self._image.bounds] @property def _flux(self): @@ -821,25 +799,26 @@ def _kValue(self, kpos): pkx += pky pfac = jnp.exp(pkx) - kx = jnp.array([kpos.x / self._kim.scale], dtype=float) - ky = jnp.array([kpos.y / self._kim.scale], dtype=float) + _kim = self._kim + kx = jnp.array([kpos.x / _kim.scale], dtype=float) + ky = jnp.array([kpos.y / _kim.scale], dtype=float) _uscale = 1.0 / (2.0 * jnp.pi) - _maxk_xint = self._x_interpolant.urange() / _uscale / self._kim.scale + _maxk_xint = self._x_interpolant.urange() / _uscale / _kim.scale val = _draw_with_interpolant_kval( kx, ky, - self._kim.bounds.ymin, - self._kim.bounds.ymin, - self._kim.array, + _kim.bounds.ymin, + _kim.bounds.ymin, + _kim.array, self._k_interpolant, ) msk = (jnp.abs(kx) <= _maxk_xint) & (jnp.abs(ky) <= _maxk_xint) xint_val = self._x_interpolant._kval_noraise( - kx * self._kim.scale - ) * self._x_interpolant._kval_noraise(ky * self._kim.scale) + kx * _kim.scale + ) * self._x_interpolant._kval_noraise(ky * _kim.scale) return jnp.where(msk, val * xint_val * pfac, 0.0)[0] def _shoot(self, photons, rng): @@ -931,7 +910,7 @@ def _body(i, args): -interp.xrange, interp.xrange + 1, _body, - [jnp.zeros(x.shape, dtype=zp.dtype), xi, yi, xp, yp, zp], + [jnp.zeros(x.shape, dtype=float), xi, yi, xp, yp, zp], )[0] return z.reshape(orig_shape) @@ -983,7 +962,7 @@ def _body(i, args): -interp.xrange, interp.xrange + 1, _body, - [jnp.zeros(kx.shape, dtype=zp.dtype), kxi, kyi, nky, nkx, nkx_2, kxp, kyp, zp], + [jnp.zeros(kx.shape, dtype=complex), kxi, kyi, nky, nkx, nkx_2, kxp, kyp, zp], )[0] return z.reshape(orig_shape) diff --git a/jax_galsim/transform.py b/jax_galsim/transform.py index 03bef064..61de3981 100644 --- a/jax_galsim/transform.py +++ b/jax_galsim/transform.py @@ -59,17 +59,38 @@ def __init__( "flux_ratio": flux_ratio, } - if isinstance(obj, Transformation): + # this import is here to avoid circular imports + # we do not want to mess with the transform properties of the interpolated image + from .interpolatedimage import InterpolatedImage + + if isinstance(obj, Transformation) and not isinstance(obj, InterpolatedImage): # Combine the two affine transformations into one. - dx, dy = self._fwd(obj.offset.x, obj.offset.y) - self._params["offset"].x += dx - self._params["offset"].y += dy - self._params["jac"] = self._jac.dot(obj.jac) - self._params["flux_ratio"] *= obj._params["flux_ratio"] - self._original = obj.original + dx, dy = self._fwd(obj._params["offset"].x, obj._params["offset"].y) + self._offset.x += dx + self._offset.y += dy + self._params["jac"] = self._jac.dot(obj._jac) + self._params["flux_ratio"] *= obj._flux_ratio + self._original = obj._original else: self._original = obj + ############################################################## + # The internal code of the methods of the Transform class + # should only aceess _offset, _flux_ratio, and _jac. It + # should pull these direct from _params + # Things are structured this way since the interpolated image + # class inherits and overrides these methods. + + @property + def _offset(self): + return self._params["offset"] + + # we use this property so that the interpolated image can override + # how flux ratio is computer / stored + @property + def _flux_ratio(self): + return self._params["flux_ratio"] + @property def _jac(self): jac = self._params["jac"] @@ -79,7 +100,7 @@ def _jac(self): lambda jax: jnp.array([1.0, 0.0, 0.0, 1.0]), jac, ) - return jnp.asarray(jac, dtype=float).reshape(2, 2) + return jnp.asarray(jac, dtype=float).reshape((2, 2)) @property def original(self): @@ -94,17 +115,18 @@ def jac(self): @property def offset(self): """The offset of the transformation.""" - return self._params["offset"] + return self._offset @property def flux_ratio(self): """The flux ratio of the transformation.""" - return self._params["flux_ratio"] + return self._flux_ratio @property def _flux(self): return self._flux_scaling * self._original.flux + @_wraps(_galsim.Transformation.withGSParams) def withGSParams(self, gsparams=None, **kwargs): """Create a version of the current object with the given gsparams @@ -113,11 +135,11 @@ def withGSParams(self, gsparams=None, **kwargs): Unless you set ``propagate_gsparams=False``, this method will also update the gsparams of the wrapped component object. """ - if gsparams == self.gsparams: + if gsparams == self._gsparams: return self chld, aux = self.tree_flatten() - aux["gsparams"] = GSParams.check(gsparams, self.gsparams, **kwargs) + aux["gsparams"] = GSParams.check(gsparams, self._gsparams, **kwargs) if self._propagate_gsparams: new_obj = chld[0].withGSParams(aux["gsparams"]) chld = (new_obj,) + chld[1:] @@ -127,11 +149,11 @@ def withGSParams(self, gsparams=None, **kwargs): def __eq__(self, other): return self is other or ( isinstance(other, Transformation) - and self.original == other.original - and jnp.array_equal(self.jac, other.jac) - and self.offset == other.offset - and self.flux_ratio == other.flux_ratio - and self.gsparams == other.gsparams + and self._original == other._original + and jnp.array_equal(self._jac, other._jac) + and self._offset == other._params["offset"] + and self._flux_ratio == other._flux_ratio + and self._gsparams == other._gsparams and self._propagate_gsparams == other._propagate_gsparams ) @@ -139,12 +161,12 @@ def __hash__(self): return hash( ( "galsim.Transformation", - self.original, + self._original, ensure_hashable(self._jac.ravel()), - ensure_hashable(self.offset.x), - ensure_hashable(self.offset.y), - ensure_hashable(self.flux_ratio), - self.gsparams, + ensure_hashable(self._offset.x), + ensure_hashable(self._offset.y), + ensure_hashable(self._flux_ratio), + self._gsparams, self._propagate_gsparams, ) ) @@ -154,11 +176,11 @@ def __repr__(self): "galsim.Transformation(%r, jac=%r, offset=%r, flux_ratio=%r, gsparams=%r, " "propagate_gsparams=%r)" ) % ( - self.original, + self._original, ensure_hashable(self._jac.ravel()), - self.offset, - ensure_hashable(self.flux_ratio), - self.gsparams, + self._offset, + ensure_hashable(self._flux_ratio), + self._gsparams, self._propagate_gsparams, ) @@ -200,15 +222,15 @@ def _str_from_jac(cls, jac): return "" def __str__(self): - s = str(self.original) + s = str(self._original) s += self._str_from_jac(self._jac) - if self.offset.x != 0 or self.offset.y != 0: + if self._offset.x != 0 or self._offset.y != 0: s += ".shift(%s,%s)" % ( - ensure_hashable(self.offset.x), - ensure_hashable(self.offset.y), + ensure_hashable(self._offset.x), + ensure_hashable(self._offset.y), ) - if self.flux_ratio != 1.0: - s += " * %s" % ensure_hashable(self.flux_ratio) + if self._flux_ratio != 1.0: + s += " * %s" % ensure_hashable(self._flux_ratio) return s @property @@ -227,11 +249,11 @@ def _invjac(self): # than flux_ratio, which is really an amplitude scaling. @property def _amp_scaling(self): - return self._params["flux_ratio"] + return self._flux_ratio @property def _flux_scaling(self): - return jnp.abs(self._det) * self._params["flux_ratio"] + return jnp.abs(self._det) * self._flux_ratio def _fwd(self, x, y): res = jnp.dot(self._jac, jnp.array([x, y])) @@ -246,8 +268,8 @@ def _inv(self, x, y): return res[0], res[1] def _kfactor(self, kx, ky): - kx *= -1j * self.offset.x - ky *= -1j * self.offset.y + kx *= -1j * self._offset.x + ky *= -1j * self._offset.y kx += ky return self._flux_scaling * jnp.exp(kx) @@ -268,7 +290,7 @@ def _stepk(self): # stepk = Pi/R # R <- R + |shift| # stepk <- Pi/(Pi/stepk + |shift|) - dr = jnp.hypot(self.offset.x, self.offset.y) + dr = jnp.hypot(self._offset.x, self._offset.y) stepk = jnp.pi / (jnp.pi / stepk + dr) return stepk @@ -282,7 +304,7 @@ def _is_axisymmetric(self): self._original.is_axisymmetric and self._jac[0, 0] == self._jac[1, 1] and self._jac[0, 1] == -self._jac[1, 0] - and self.offset == PositionD(0.0, 0.0) + and self._offset == PositionD(0.0, 0.0) ) @property @@ -297,7 +319,7 @@ def _is_analytic_k(self): def _centroid(self): cen = self._original.centroid cen = PositionD(self._fwd(cen.x, cen.y)) - cen += self.offset + cen += self._offset return cen @property @@ -317,7 +339,7 @@ def _max_sb(self): return self._amp_scaling * self._original.max_sb def _xValue(self, pos): - pos -= self.offset + pos -= self._offset inv_pos = PositionD(self._inv(pos.x, pos.y)) return self._original._xValue(inv_pos) * self._amp_scaling @@ -328,12 +350,12 @@ def _kValue(self, kpos): def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): dx, dy = offset if jac is not None: - x1 = jac.dot(self.offset._array) + x1 = jac.dot(self._offset._array) dx += x1[0] dy += x1[1] else: - dx += self.offset.x - dy += self.offset.y + dx += self._offset.x + dy += self._offset.y flux_scaling *= self._flux_scaling jac = ( self._jac @@ -357,7 +379,7 @@ def _drawKImage(self, image, jac=None): image = self._original._drawKImage(image, jac1) _jac = jnp.eye(2) if jac is None else jac - image = apply_kImage_phases(self.offset, image, _jac) + image = apply_kImage_phases(self._offset, image, _jac) image = image * self._flux_scaling return image @@ -369,7 +391,7 @@ def tree_flatten(self): children = (self._original, self._params) # Define auxiliary static data that doesn’t need to be traced aux_data = { - "gsparams": self.gsparams, + "gsparams": self._gsparams, "propagate_gsparams": self._propagate_gsparams, } return (children, aux_data) diff --git a/tests/GalSim b/tests/GalSim index 380bab83..bf287a91 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 380bab83e57042fe7c2530636453908ee2482add +Subproject commit bf287a91314b56db67308e3946878ae2ab52a8c4 diff --git a/tests/jax/test_api.py b/tests/jax/test_api.py index a2ba8dd9..6a21de43 100644 --- a/tests/jax/test_api.py +++ b/tests/jax/test_api.py @@ -76,6 +76,19 @@ def _attempt_init(cls, kwargs): else: raise e + if cls in [jax_galsim.InterpolatedImage]: + try: + return cls( + jax_galsim.ImageD(jnp.arange(100).reshape((10, 10))), + scale=1.3, + **kwargs + ) + except Exception as e: + if any(estr in repr(e) for estr in OK_ERRORS): + pass + else: + raise e + return None @@ -101,6 +114,8 @@ def _kfun(x, prof): def _run_object_checks(obj, cls, kind): if kind == "pickle-eval-repr": + from numpy import array # noqa: F401 + # eval repr is identity mapping assert eval(repr(obj)) == obj @@ -352,6 +367,7 @@ def test_api_gsobject(kind): assert "Moffat" in cls_tested assert "Box" in cls_tested assert "Pixel" in cls_tested + assert "InterpolatedImage" in cls_tested @pytest.mark.parametrize( diff --git a/tests/jax/test_interpolatedimage_utils.py b/tests/jax/test_interpolatedimage_utils.py index b0547bf2..35e17485 100644 --- a/tests/jax/test_interpolatedimage_utils.py +++ b/tests/jax/test_interpolatedimage_utils.py @@ -131,6 +131,36 @@ def test_interpolatedimage_utils_stepk_maxk(): np.testing.assert_allclose(gii.maxk, jgii.maxk, rtol=0.2, atol=0) +@pytest.mark.parametrize("normalization", ["sb", "flux"]) +@pytest.mark.parametrize("use_true_center", [True, False]) +@pytest.mark.parametrize( + "wcs", + [ + _galsim.PixelScale(2.0), + _galsim.JacobianWCS(2.1, 0.3, -0.4, 2.3), + _galsim.AffineTransform(-0.3, 2.1, 1.8, 0.1, _galsim.PositionD(0.3, -0.4)), + ], +) +@pytest.mark.parametrize( + "offset_x", + [ + -4.35, + -0.45, + 0.0, + 0.67, + 3.78, + ], +) +@pytest.mark.parametrize( + "offset_y", + [ + -2.12, + -0.33, + 0.0, + 0.12, + 1.45, + ], +) @pytest.mark.parametrize( "ref_array", [ @@ -156,51 +186,58 @@ def test_interpolatedimage_utils_stepk_maxk(): ], ) @pytest.mark.parametrize("method", ["xValue", "kValue"]) -def test_interpolatedimage_utils_comp_to_galsim(method, ref_array): +def test_interpolatedimage_utils_comp_to_galsim( + method, ref_array, offset_x, offset_y, wcs, use_true_center, normalization +): gimage_in = _galsim.Image(ref_array, scale=1) jgimage_in = jax_galsim.Image(ref_array, scale=1) - for wcs in [ - _galsim.PixelScale(2.0), - _galsim.JacobianWCS(2.1, 0.3, -0.4, 2.3), - _galsim.AffineTransform(-0.3, 2.1, 1.8, 0.1, _galsim.PositionD(0.3, -0.4)), - ]: - gii = _galsim.InterpolatedImage(gimage_in, wcs=wcs) - jgii = jax_galsim.InterpolatedImage( - jgimage_in, wcs=jax_galsim.BaseWCS.from_galsim(wcs) - ) - - np.testing.assert_allclose(gii.stepk, jgii.stepk, rtol=0.5, atol=0) - np.testing.assert_allclose(gii.maxk, jgii.maxk, rtol=0.5, atol=0) - kxvals = [ - (0, 0), - (-5, -5), - (-10, 10), - (1, 1), - (1, -2), - (-1, 0), - (0, -1), - (-1, -1), - (-2, 2), - (-5, 0), - (3, -4), - (-3, 4), - ] - for x, y in kxvals: - if method == "kValue": - dk = jgii._original._kim.scale - np.testing.assert_allclose( - gii.kValue(x * dk, y * dk), - jgii.kValue(x * dk, y * dk), - err_msg=f"kValue mismatch: wcs={wcs}, x={x}, y={y}", - ) - else: - dx = jnp.sqrt(jgii._original._wcs.pixelArea()) - np.testing.assert_allclose( - gii.xValue(x * dx, y * dx), - jgii.xValue(x * dx, y * dx), - err_msg=f"xValue mismatch: wcs={wcs}, x={x}, y={y}", - ) + gii = _galsim.InterpolatedImage( + gimage_in, + wcs=wcs, + offset=_galsim.PositionD(offset_x, offset_y), + use_true_center=use_true_center, + normalization=normalization, + ) + jgii = jax_galsim.InterpolatedImage( + jgimage_in, + wcs=jax_galsim.BaseWCS.from_galsim(wcs), + offset=jax_galsim.PositionD(offset_x, offset_y), + use_true_center=use_true_center, + normalization=normalization, + ) + + np.testing.assert_allclose(gii.stepk, jgii.stepk, rtol=0.5, atol=0) + np.testing.assert_allclose(gii.maxk, jgii.maxk, rtol=0.5, atol=0) + kxvals = [ + (0, 0), + (-5, -5), + (-10, 10), + (1, 1), + (1, -2), + (-1, 0), + (0, -1), + (-1, -1), + (-2, 2), + (-5, 0), + (3, -4), + (-3, 4), + ] + for x, y in kxvals: + if method == "kValue": + dk = jgii._original._kim.scale + np.testing.assert_allclose( + gii.kValue(x * dk, y * dk), + jgii.kValue(x * dk, y * dk), + err_msg=f"kValue mismatch: wcs={wcs}, x={x}, y={y}", + ) + else: + dx = jnp.sqrt(jgii._original._wcs.pixelArea()) + np.testing.assert_allclose( + gii.xValue(x * dx, y * dx), + jgii.xValue(x * dx, y * dx), + err_msg=f"xValue mismatch: wcs={wcs}, x={x}, y={y}", + ) def _compute_fft_with_numpy_jax_galsim(im): From 619ea10f45fab3942304c0fbdf1da153d2b851d2 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Mon, 23 Oct 2023 16:07:02 -0500 Subject: [PATCH 39/67] Update jax_galsim/transform.py --- jax_galsim/transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_galsim/transform.py b/jax_galsim/transform.py index 61de3981..598ec7bb 100644 --- a/jax_galsim/transform.py +++ b/jax_galsim/transform.py @@ -77,7 +77,7 @@ def __init__( ############################################################## # The internal code of the methods of the Transform class # should only aceess _offset, _flux_ratio, and _jac. It - # should pull these direct from _params + # should not pull these directly from _params. # Things are structured this way since the interpolated image # class inherits and overrides these methods. From 4014d5fe48a5933d9272806bc08de37ee6606960 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Mon, 23 Oct 2023 16:11:35 -0500 Subject: [PATCH 40/67] Update jax_galsim/interpolatedimage.py --- jax_galsim/interpolatedimage.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index dfc0605f..5801fb93 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -29,6 +29,11 @@ from jax_galsim.utilities import convert_interpolant from jax_galsim.wcs import BaseWCS, PixelScale +# These keys are removed from the public API of +# InterpolatedImage so that it matches the galsim +# one. +# The DirMeta class does this along with the changes to +# __getattribute__ and __dir__ below. _KEYS_TO_REMOVE = [ "flux_ratio", "jac", From 47b02808129ddd1be18f8369d5535566ac1b3780 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 23 Oct 2023 16:44:43 -0500 Subject: [PATCH 41/67] BUG cache kids --- jax_galsim/interpolatedimage.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index dfc0605f..a0552b6a 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -617,6 +617,7 @@ def tree_unflatten(cls, aux_data, children): val.update(children[1]) ret = cls(children[0], **val) ret._cached_comps.update(children[2]) + return ret @property def _xim(self): From 898ca57d0bbc96cf8ee91c37a56c79966bcd598c Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 23 Oct 2023 16:50:44 -0500 Subject: [PATCH 42/67] remove extra function --- jax_galsim/interpolatedimage.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index 968309fb..3aede15f 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -867,13 +867,6 @@ def _InterpolatedImage( ) -def _get_image_jac_arr_wcs(image, use_true_center): - im_cen = image.true_center if use_true_center else image.center - _jac_arr = image.wcs.jacobian(image_pos=im_cen).getMatrix().ravel() - _wcs = image.wcs.local(image_pos=im_cen) - return _jac_arr, _wcs - - @partial(jax.jit, static_argnums=(5,)) def _draw_with_interpolant_xval(x, y, xmin, ymin, zp, interp): orig_shape = x.shape From 7ca8f1e5ad5eb6d1884ed3989c12d51dd9eeb99b Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 23 Oct 2023 18:04:39 -0500 Subject: [PATCH 43/67] BUG make sure to have cached nz_bounds --- jax_galsim/interpolatedimage.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index 3aede15f..18fc31cd 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -645,15 +645,10 @@ def _xim(self): xim.setCenter(0, 0) xim.wcs = PixelScale(1.0) - nz_bounds = self._image.bounds - # Now place the given image in the center of the padding image: # assert self._xim.bounds.includes(self._image.bounds) xim[self._image.bounds] = self._image - self._pad_factor = pad_factor - self._nz_bounds = nz_bounds - self._cached_comps["_xim"] = xim return self._cached_comps["_xim"] @@ -663,7 +658,8 @@ def _pad_image(self): # These next two allow for easy pickling/repring. We don't need to serialize all the # zeros around the edge. But we do need to keep any non-zero padding as a pad_image. xim = self._xim - return xim[self._nz_bounds] + nz_bounds = self._image.bounds + return xim[nz_bounds] @property def _kim(self): From 6b359f75dba25844d6b4ada09e81b3d793aedcbc Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 23 Oct 2023 22:31:59 -0500 Subject: [PATCH 44/67] fix another bug --- jax_galsim/interpolatedimage.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index 18fc31cd..9d9043ab 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -255,7 +255,7 @@ def __hash__(self): ensure_hashable(self.flux), ensure_hashable(self._stepk), ensure_hashable(self._maxk), - ensure_hashable(self._original._pad_factor), + ensure_hashable(self._original._jax_aux_data["pad_factor"]), ) ) self._hash ^= hash( @@ -295,7 +295,7 @@ def __repr__(self): if self._original._pad_image.bounds != self._original.image.bounds: s += ", pad_image=%r" % (self._pad_image) s += ", pad_factor=%f, flux=%r, offset=%r" % ( - ensure_hashable(self._original._pad_factor), + ensure_hashable(self._original._jax_aux_data["pad_factor"]), ensure_hashable(self.flux), self._original._offset, ) From d5ec4efec1cdf86caf35fc7c62376544f406612f Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 24 Oct 2023 07:56:32 -0500 Subject: [PATCH 45/67] TST add test of metacal --- jax_galsim/interpolatedimage.py | 4 +- tests/conftest.py | 24 ++-- tests/jax/test_interpolatedimage_utils.py | 33 ++++-- tests/jax/test_metacal.py | 128 ++++++++++++++++++++++ 4 files changed, 165 insertions(+), 24 deletions(-) create mode 100644 tests/jax/test_metacal.py diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index 9d9043ab..7d62a5f6 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -986,6 +986,7 @@ def _body(d, args): )[0] +@jax.jit def _calculate_size_containing_flux(image, thresh): cenx, ceny = image.center.x, image.center.y x, y = image.get_pixel_centers() @@ -1030,8 +1031,9 @@ def _inner_comp_find_maxk(arr, thresh, kx, ky): return jnp.maximum(max_kx, max_ky) +@jax.jit def _find_maxk(kim, max_maxk, thresh): kx, ky = kim.get_pixel_centers() kx *= kim.scale ky *= kim.scale - return _inner_comp_find_maxk(kim.array, thresh, kx, ky) * 1.15 + return _inner_comp_find_maxk(kim.array, thresh, kx, ky) diff --git a/tests/conftest.py b/tests/conftest.py index 5db13ef4..6780f5fc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,17 +1,19 @@ -import inspect -import os -import sys -from functools import lru_cache -from unittest.mock import patch - -import galsim -import pytest -import yaml - # Define the accuracy for running the tests from jax.config import config -import jax_galsim +config.update("jax_enable_x64", True) + +import inspect # noqa: E402 +import os # noqa: E402 +import sys # noqa: E402 +from functools import lru_cache # noqa: E402 +from unittest.mock import patch # noqa: E402 + +import galsim # noqa: E402 +import pytest # noqa: E402 +import yaml # noqa: E402 + +import jax_galsim # noqa: E402 config.update("jax_enable_x64", True) diff --git a/tests/jax/test_interpolatedimage_utils.py b/tests/jax/test_interpolatedimage_utils.py index 35e17485..30473a1e 100644 --- a/tests/jax/test_interpolatedimage_utils.py +++ b/tests/jax/test_interpolatedimage_utils.py @@ -113,22 +113,31 @@ def test_interpolatedimage_utils_draw_with_interpolant_kval(interp): def test_interpolatedimage_utils_stepk_maxk(): - ref_array = np.array( - [ - [0.01, 0.08, 0.07, 0.02], - [0.13, 0.38, 0.52, 0.06], - [0.09, 0.41, 0.44, 0.09], - [0.04, 0.11, 0.10, 0.01], - ] + hlr = 0.5 + fwhm = 0.9 + scale = 0.2 + + ref_array = ( + _galsim.Convolve( + _galsim.Exponential(half_light_radius=hlr), + _galsim.Gaussian(fwhm=fwhm), + ) + .drawImage( + nx=53, + ny=53, + scale=scale, + ) + .array.astype(np.float64) ) - test_scale = 2.0 + gimage_in = _galsim.Image(ref_array) jgimage_in = jax_galsim.Image(ref_array) - gii = _galsim.InterpolatedImage(gimage_in, scale=test_scale) - jgii = jax_galsim.InterpolatedImage(jgimage_in, scale=test_scale) + gii = _galsim.InterpolatedImage(gimage_in, scale=scale) + jgii = jax_galsim.InterpolatedImage(jgimage_in, scale=scale) - np.testing.assert_allclose(gii.stepk, jgii.stepk, rtol=0.2, atol=0) - np.testing.assert_allclose(gii.maxk, jgii.maxk, rtol=0.2, atol=0) + rtol = 1e-1 + np.testing.assert_allclose(gii.maxk, jgii.maxk, rtol=rtol, atol=0) + np.testing.assert_allclose(gii.stepk, jgii.stepk, rtol=rtol, atol=0) @pytest.mark.parametrize("normalization", ["sb", "flux"]) diff --git a/tests/jax/test_metacal.py b/tests/jax/test_metacal.py new file mode 100644 index 00000000..363c9454 --- /dev/null +++ b/tests/jax/test_metacal.py @@ -0,0 +1,128 @@ +import galsim as _galsim +import numpy as np + +import jax_galsim + + +def _metacal_galsim(gs, im, psf, nse_im, scale, target_fwhm, g1): + res = {} + iim = gs.InterpolatedImage(gs.ImageD(im), scale=scale, x_interpolant="lanczos15") + ipsf = gs.InterpolatedImage(gs.ImageD(psf), scale=scale, x_interpolant="lanczos15") + inse = gs.InterpolatedImage( + gs.ImageD(np.rot90(nse_im)), scale=scale, x_interpolant="lanczos15" + ) + + res["maxk"] = (iim.maxk, ipsf.maxk, inse.maxk) + res["stepk"] = (iim.stepk, ipsf.stepk, inse.stepk) + + res["iim"] = iim.drawImage( + nx=33, ny=33, scale=scale, method="no_pixel" + ).array.astype(np.float64) + res["ipsf"] = ipsf.drawImage( + nx=33, ny=33, scale=scale, method="no_pixel" + ).array.astype(np.float64) + + ppsf_iim = gs.Convolve(iim, gs.Deconvolve(ipsf)) + ppsf_iim = ppsf_iim.shear(g1=g1, g2=0.0) + + sim = ( + gs.Convolve(ppsf_iim, gs.Gaussian(fwhm=target_fwhm)) + .drawImage( + nx=33, + ny=33, + scale=scale, + method="no_pixel", + ) + .array.astype(np.float64) + ) + res["im"] = sim + + ppsf_inse = gs.Convolve(inse, gs.Deconvolve(ipsf)) + ppsf_inse = ppsf_iim.shear(g1=g1, g2=0.0) + snse = ( + gs.Convolve(ppsf_inse, gs.Gaussian(fwhm=target_fwhm)) + .drawImage( + nx=33, + ny=33, + scale=scale, + method="no_pixel", + ) + .array.astype(np.float64) + ) + res["nse"] = snse + res["tot"] = sim + np.rot90(snse, 3) + return res + + +def test_metacal_comp_to_galsim(): + seed = 42 + hlr = 0.5 + fwhm = 0.9 + scale = 0.2 + nse = 1e-3 + g1 = 0.01 + target_fwhm = 1.0 + + rng = np.random.RandomState(seed) + + im = ( + _galsim.Convolve( + _galsim.Exponential(half_light_radius=hlr), + _galsim.Gaussian(fwhm=fwhm), + ) + .drawImage( + nx=33, + ny=33, + scale=scale, + ) + .array.astype(np.float64) + ) + + psf = ( + _galsim.Gaussian(fwhm=fwhm) + .drawImage( + nx=33, + ny=33, + scale=scale, + ) + .array.astype(np.float64) + ) + + nse_im = rng.normal(size=im.shape) * nse + im += rng.normal(size=im.shape) * nse + + gres = _metacal_galsim( + _galsim, im.copy(), psf.copy(), nse_im.copy(), scale, target_fwhm, g1 + ) + jgres = _metacal_galsim( + jax_galsim, im.copy(), psf.copy(), nse_im.copy(), scale, target_fwhm, g1 + ) + + np.testing.assert_allclose(gres["maxk"], jgres["maxk"], rtol=2e-2) + np.testing.assert_allclose(gres["stepk"], jgres["stepk"], rtol=2e-2) + + for k in ["iim", "ipsf", "im", "nse", "tot"]: + gim = gres[k] + jgim = jgres[k] + + if k in ["iim", "ipsf"]: + atol = 1e-7 + else: + atol = 5e-5 + + if not np.allclose(gim, jgim, rtol=0, atol=atol): + import proplot as pplt + + fig, axs = pplt.subplots(ncols=3, nrows=5, figsize=(4.5, 7.5)) + print(axs.shape) + for row, _k in enumerate(["iim", "ipsf", "im", "nse", "tot"]): + _gim = gres[_k] + _jgim = jgres[_k] + axs[row, 0].imshow(np.arcsinh(_gim / nse)) + axs[row, 1].imshow(np.arcsinh(_jgim / nse)) + axs[row, 2].imshow(_jgim - _gim) + fig.show() + + np.testing.assert_allclose( + gim, jgim, err_msg=f"Failed for {k}", rtol=0, atol=atol + ) From c4b1b14b2a06e253a74028c0efe7c566db3ec8a0 Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 24 Oct 2023 12:54:48 -0500 Subject: [PATCH 46/67] TST add metacal tests --- jax_galsim/image.py | 5 +- jax_galsim/interpolant.py | 110 +++++++++++++++++++------------------- tests/jax/test_metacal.py | 87 ++++++++++++++++++++++++++++-- 3 files changed, 140 insertions(+), 62 deletions(-) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index caedcb84..ad6d19cc 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -759,9 +759,6 @@ def good_fft_size(cls, input_size): # Reference from GalSim C++ # https://github.com/GalSim-developers/GalSim/blob/ece3bd32c1ae6ed771f2b489c5ab1b25729e0ea4/src/Image.cpp#L1009 - input_size = int(input_size) - if input_size <= 2: - return 2 # Reduce slightly to eliminate potential rounding errors: insize = (1.0 - 1.0e-5) * input_size log2n = math.log(2.0) * math.ceil(math.log(insize) / math.log(2.0)) @@ -769,7 +766,7 @@ def good_fft_size(cls, input_size): (math.log(insize) - math.log(3.0)) / math.log(2.0) ) log2n3 = max(log2n3, math.log(6.0)) # must be even number - Nk = int(math.ceil(math.exp(min(log2n, log2n3)) - 1.0e-5)) + Nk = max(int(math.ceil(math.exp(min(log2n, log2n3)) - 1.0e-5)), 2) return Nk def copyFrom(self, rhs): diff --git a/jax_galsim/interpolant.py b/jax_galsim/interpolant.py index ef3c1db5..01880cc3 100644 --- a/jax_galsim/interpolant.py +++ b/jax_galsim/interpolant.py @@ -1330,66 +1330,52 @@ def __init__( conserve_dc=True, tol=None, gsparams=None, - _K=None, - _C=None, - _umax=None, - _du=None, ): if tol is not None: from galsim.deprecated import depr depr("tol", 2.2, "gsparams=GSParams(kvalue_accuracy=tol)") gsparams = GSParams(kvalue_accuracy=tol) - self._n = int(n) - self._conserve_dc = bool(conserve_dc) + self._n = n + self._conserve_dc = conserve_dc self._gsparams = GSParams.check(gsparams) + self._workspace = {} - if _C is None or _K is None: - _K = [0.0] + [Lanczos._raw_uval(i + 1.0, n).item() for i in range(5)] - _C = [0.0] * 6 - _C[0] = 1.0 + 2.0 * ( - _K[1] * (1.0 + 3.0 * _K[1] + _K[2] + _K[3]) - + _K[2] - + _K[3] - + _K[4] - + _K[5] - ) - _C[1] = -_K[1] * (1.0 + 4.0 * _K[1] + _K[2] + 2.0 * _K[3]) - _C[2] = _K[1] * (_K[1] - 2.0 * _K[2] + _K[3]) - _K[2] - _C[3] = _K[1] * (_K[1] - 2.0 * _K[3]) - _K[3] - _C[4] = _K[1] * _K[3] - _K[4] - _C[5] = -_K[5] - _K = tuple(_K) - _C = tuple(_C) - self._K = _K - self._C = _C - else: - self._K = _K - self._C = _C - - self._K_arr = jnp.array(self._K, dtype=float) - self._C_arr = jnp.array(self._C, dtype=float) - - if _du is None: - _du = ( - self._gsparams.table_spacing - * jnp.power(self._gsparams.kvalue_accuracy / 200.0, 0.25) - / self._n - ).item() - self._du = _du - else: - self._du = _du + @property + def _K_arr(self): + if "_K_arr" not in self._workspace: + _C_arr, _K_arr = _compute_C_K_lanczos(self._n) + self._workspace["_K_arr"] = _K_arr + self._workspace["_C_arr"] = _C_arr + return self._workspace["_K_arr"] + + @property + def _C_arr(self): + if "_C_arr" not in self._workspace: + _C_arr, _K_arr = _compute_C_K_lanczos(self._n) + self._workspace["_K_arr"] = _K_arr + self._workspace["_C_arr"] = _C_arr + return self._workspace["_C_arr"] + + @property + def _du(self): + return ( + self._gsparams.table_spacing + * jnp.power(self._gsparams.kvalue_accuracy / 200.0, 0.25) + / self._n + ) - if _umax is None: - self._umax = _find_umax_lanczos( + @property + def _umax(self): + if "_umax" not in self._workspace: + self._workspace["_umax"] = _find_umax_lanczos( self._du, self._n, self._conserve_dc, - self._C, + self._C_arr, self._gsparams.kvalue_accuracy, - ).item() - else: - self._umax = _umax + ) + return self._workspace["_umax"] def tree_flatten(self): """This function flattens the Interpolant into a list of children @@ -1402,20 +1388,14 @@ def tree_flatten(self): "n": self._n, "conserve_dc": self._conserve_dc, } - if hasattr(self, "_du"): - aux_data["_du"] = self._du - if hasattr(self, "_umax"): - aux_data["_umax"] = self._umax - if hasattr(self, "_K"): - aux_data["_K"] = self._K - aux_data["_C"] = self._C return (children, aux_data) @classmethod def tree_unflatten(cls, aux_data, children): """Recreates an instance of the class from flattened representation""" n = aux_data.pop("n") - return cls(n, **aux_data) + ret = cls(n, **aux_data) + return ret def __repr__(self): return "galsim.Lanczos(%r, %r, gsparams=%r)" % ( @@ -1647,3 +1627,23 @@ def _body(vals): _body, [0.0, 0.0], )[0] + + +@jax.jit +def _compute_C_K_lanczos(n): + _K = jnp.concatenate( + (jnp.zeros(1), Lanczos._raw_uval(jnp.arange(5) + 1.0, n)), axis=0 + ) + _C = jnp.zeros(6) + _C = _C.at[0].set( + 1.0 + + 2.0 + * (_K[1] * (1.0 + 3.0 * _K[1] + _K[2] + _K[3]) + _K[2] + _K[3] + _K[4] + _K[5]) + ) + _C = _C.at[1].set(-_K[1] * (1.0 + 4.0 * _K[1] + _K[2] + 2.0 * _K[3])) + _C = _C.at[2].set(_K[1] * (_K[1] - 2.0 * _K[2] + _K[3]) - _K[2]) + _C = _C.at[3].set(_K[1] * (_K[1] - 2.0 * _K[3]) - _K[3]) + _C = _C.at[4].set(_K[1] * _K[3] - _K[4]) + _C = _C.at[5].set(-_K[5]) + + return _C, _K diff --git a/tests/jax/test_metacal.py b/tests/jax/test_metacal.py index 363c9454..f910d594 100644 --- a/tests/jax/test_metacal.py +++ b/tests/jax/test_metacal.py @@ -1,4 +1,8 @@ +import time + import galsim as _galsim +import jax +import jax.numpy as jnp import numpy as np import jax_galsim @@ -54,6 +58,66 @@ def _metacal_galsim(gs, im, psf, nse_im, scale, target_fwhm, g1): return res +def _metacal_jax_galsim(im, psf, nse_im, scale, target_fwhm, g1, nk): + gs = jax_galsim + iim = gs.InterpolatedImage(gs.ImageD(im), scale=scale, x_interpolant="lanczos15") + ipsf = gs.InterpolatedImage(gs.ImageD(psf), scale=scale, x_interpolant="lanczos15") + inse = gs.InterpolatedImage( + gs.ImageD(jnp.rot90(nse_im)), scale=scale, x_interpolant="lanczos15" + ) + + res = {} + + res["maxk"] = (iim.maxk, ipsf.maxk, inse.maxk) + res["stepk"] = (iim.stepk, ipsf.stepk, inse.stepk) + + res["iim"] = iim.drawImage( + nx=33, ny=33, scale=scale, method="no_pixel" + ).array.astype(np.float64) + res["ipsf"] = ipsf.drawImage( + nx=33, ny=33, scale=scale, method="no_pixel" + ).array.astype(np.float64) + + ppsf_iim = gs.Convolve(iim, gs.Deconvolve(ipsf)) + ppsf_iim = ppsf_iim.shear(g1=g1, g2=0.0) + + sim = ( + gs.Convolve( + ppsf_iim, + gs.Gaussian(fwhm=target_fwhm), + gsparams=gs.GSParams(minimum_fft_size=nk, maximum_fft_size=nk), + ) + .drawImage( + nx=33, + ny=33, + scale=scale, + method="no_pixel", + ) + .array.astype(np.float64) + ) + res["im"] = sim + + ppsf_inse = gs.Convolve(inse, gs.Deconvolve(ipsf)) + ppsf_inse = ppsf_iim.shear(g1=g1, g2=0.0) + snse = ( + gs.Convolve( + ppsf_inse, + gs.Gaussian(fwhm=target_fwhm), + gsparams=gs.GSParams(minimum_fft_size=nk, maximum_fft_size=nk), + ) + .drawImage( + nx=33, + ny=33, + scale=scale, + method="no_pixel", + ) + .array.astype(np.float64) + ) + res["nse"] = snse + res["tot"] = sim + jnp.rot90(snse, 3) + return res + + def test_metacal_comp_to_galsim(): seed = 42 hlr = 0.5 @@ -91,12 +155,29 @@ def test_metacal_comp_to_galsim(): nse_im = rng.normal(size=im.shape) * nse im += rng.normal(size=im.shape) * nse + gt0 = time.time() gres = _metacal_galsim( _galsim, im.copy(), psf.copy(), nse_im.copy(), scale, target_fwhm, g1 ) - jgres = _metacal_galsim( - jax_galsim, im.copy(), psf.copy(), nse_im.copy(), scale, target_fwhm, g1 - ) + gt0 = time.time() - gt0 + + jit_mcal = jax.jit(_metacal_jax_galsim, static_argnums=6) + + for _ in range(2): + jgt0 = time.time() + jgres = jit_mcal( + im, + psf, + nse_im, + scale, + target_fwhm, + g1, + 128, + ) + jgt0 = time.time() - jgt0 + print("Jax-Galsim time: ", jgt0 * 1e3, " [ms]") + + print("Galsim time: ", gt0 * 1e3, " [ms]") np.testing.assert_allclose(gres["maxk"], jgres["maxk"], rtol=2e-2) np.testing.assert_allclose(gres["stepk"], jgres["stepk"], rtol=2e-2) From 15815236c501dd6d73136f5ad42c694ead0539fd Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 24 Oct 2023 22:30:12 -0500 Subject: [PATCH 47/67] remove cache --- jax_galsim/interpolant.py | 42 ++++++++--------- jax_galsim/interpolatedimage.py | 67 ++++++++++----------------- tests/jax/test_metacal.py | 81 +++++++++++++++++++++++++++++++++ 3 files changed, 124 insertions(+), 66 deletions(-) diff --git a/jax_galsim/interpolant.py b/jax_galsim/interpolant.py index 01880cc3..2682c2df 100644 --- a/jax_galsim/interpolant.py +++ b/jax_galsim/interpolant.py @@ -138,7 +138,7 @@ def _i(self): def __eq__(self, other): return (self is other) or ( type(other) is self.__class__ - and is_equal_with_arrays(self.tree_flatten(), other.tree_flatten()) + and is_equal_with_arrays(self.tree_flatten()[1], other.tree_flatten()[1]) ) def __ne__(self, other): @@ -1339,23 +1339,14 @@ def __init__( self._n = n self._conserve_dc = conserve_dc self._gsparams = GSParams.check(gsparams) - self._workspace = {} @property def _K_arr(self): - if "_K_arr" not in self._workspace: - _C_arr, _K_arr = _compute_C_K_lanczos(self._n) - self._workspace["_K_arr"] = _K_arr - self._workspace["_C_arr"] = _C_arr - return self._workspace["_K_arr"] + return self._K_arrs[self._n] @property def _C_arr(self): - if "_C_arr" not in self._workspace: - _C_arr, _K_arr = _compute_C_K_lanczos(self._n) - self._workspace["_K_arr"] = _K_arr - self._workspace["_C_arr"] = _C_arr - return self._workspace["_C_arr"] + return self._C_arrs[self._n] @property def _du(self): @@ -1367,15 +1358,13 @@ def _du(self): @property def _umax(self): - if "_umax" not in self._workspace: - self._workspace["_umax"] = _find_umax_lanczos( - self._du, - self._n, - self._conserve_dc, - self._C_arr, - self._gsparams.kvalue_accuracy, - ) - return self._workspace["_umax"] + return _find_umax_lanczos( + self._du, + self._n, + self._conserve_dc, + self._C_arr, + self._gsparams.kvalue_accuracy, + ) def tree_flatten(self): """This function flattens the Interpolant into a list of children @@ -1394,8 +1383,7 @@ def tree_flatten(self): def tree_unflatten(cls, aux_data, children): """Recreates an instance of the class from flattened representation""" n = aux_data.pop("n") - ret = cls(n, **aux_data) - return ret + return cls(n, **aux_data) def __repr__(self): return "galsim.Lanczos(%r, %r, gsparams=%r)" % ( @@ -1647,3 +1635,11 @@ def _compute_C_K_lanczos(n): _C = _C.at[5].set(-_K[5]) return _C, _K + + +Lanczos._K_arrs = {} +Lanczos._C_arrs = {} +for n in range(1, 31): + _C_arr, _K_arr = _compute_C_K_lanczos(n) + Lanczos._K_arrs[n] = _K_arr + Lanczos._C_arrs[n] = _C_arr diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index 7d62a5f6..a957ca6c 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -395,7 +395,6 @@ def __init__( # this class does a ton of munging of the inputs that I don't want to reconstruct when # flattening and unflattening the class. # thus I am going to make some refs here so we have it when we need it - self._cached_comps = {} self._jax_children = ( image, dict( @@ -405,7 +404,6 @@ def __init__( pad_image=pad_image, offset=offset, ), - self._cached_comps, ) self._jax_aux_data = dict( x_interpolant=x_interpolant, @@ -576,11 +574,7 @@ def _maxk(self): ) return self._jax_aux_data["_force_maxk"] * minor else: - if "_maxk" not in self._cached_comps: - self._cached_comps["_maxk"] = self._getMaxK( - self._jax_aux_data["calculate_maxk"] - ) - return self._cached_comps["_maxk"] + return self._getMaxK(self._jax_aux_data["calculate_maxk"]) @property def _stepk(self): @@ -590,11 +584,7 @@ def _stepk(self): ) return self._jax_aux_data["_force_stepk"] * minor else: - if "_stepk" not in self._cached_comps: - self._cached_comps["_stepk"] = self._getStepK( - self._jax_aux_data["calculate_stepk"] - ) - return self._cached_comps["_stepk"] + return self._getStepK(self._jax_aux_data["calculate_stepk"]) @doc_inherit def withGSParams(self, gsparams=None, **kwargs): @@ -621,37 +611,33 @@ def tree_unflatten(cls, aux_data, children): val.update(aux_data) val.update(children[1]) ret = cls(children[0], **val) - ret._cached_comps.update(children[2]) return ret @property def _xim(self): - if "_xim" not in self._cached_comps: - pad_factor = self._jax_aux_data["pad_factor"] - - # The size of the final padded image is the largest of the various size specifications - pad_size = max(self._image.array.shape) - if pad_factor > 1.0: - pad_size = int(math.ceil(pad_factor * pad_size)) - # And round up to a good fft size - pad_size = Image.good_fft_size(pad_size) - - xim = Image( - _zeropad_image( - self._image.array, (pad_size - max(self._image.array.shape)) // 2 - ), - wcs=PixelScale(1.0), - ) - xim.setCenter(0, 0) - xim.wcs = PixelScale(1.0) - - # Now place the given image in the center of the padding image: - # assert self._xim.bounds.includes(self._image.bounds) - xim[self._image.bounds] = self._image + pad_factor = self._jax_aux_data["pad_factor"] + + # The size of the final padded image is the largest of the various size specifications + pad_size = max(self._image.array.shape) + if pad_factor > 1.0: + pad_size = int(math.ceil(pad_factor * pad_size)) + # And round up to a good fft size + pad_size = Image.good_fft_size(pad_size) + + xim = Image( + _zeropad_image( + self._image.array, (pad_size - max(self._image.array.shape)) // 2 + ), + wcs=PixelScale(1.0), + ) + xim.setCenter(0, 0) + xim.wcs = PixelScale(1.0) - self._cached_comps["_xim"] = xim + # Now place the given image in the center of the padding image: + # assert self._xim.bounds.includes(self._image.bounds) + xim[self._image.bounds] = self._image - return self._cached_comps["_xim"] + return xim @property def _pad_image(self): @@ -663,12 +649,7 @@ def _pad_image(self): @property def _kim(self): - if "_kim" in self._cached_comps: - return self._cached_comps["_kim"] - else: - kim = self._xim.calculate_fft() - self._cached_comps["_kim"] = kim - return kim + return self._xim.calculate_fft() @property def _pos_neg_fluxes(self): diff --git a/tests/jax/test_metacal.py b/tests/jax/test_metacal.py index f910d594..08b13126 100644 --- a/tests/jax/test_metacal.py +++ b/tests/jax/test_metacal.py @@ -207,3 +207,84 @@ def test_metacal_comp_to_galsim(): np.testing.assert_allclose( gim, jgim, err_msg=f"Failed for {k}", rtol=0, atol=atol ) + + +def test_metacal_vmap(): + start_seed = 42 + hlr = 0.5 + fwhm = 0.9 + scale = 0.2 + nse = 1e-3 + g1 = 0.01 + target_fwhm = 1.0 + + ims = [] + nse_ims = [] + psfs = [] + for _seed in range(1000): + seed = _seed + start_seed + rng = np.random.RandomState(seed) + + im = ( + _galsim.Convolve( + _galsim.Exponential(half_light_radius=hlr), + _galsim.Gaussian(fwhm=fwhm), + ) + .drawImage( + nx=33, + ny=33, + scale=scale, + ) + .array.astype(np.float64) + ) + + psf = ( + _galsim.Gaussian(fwhm=fwhm) + .drawImage( + nx=33, + ny=33, + scale=scale, + ) + .array.astype(np.float64) + ) + + nse_im = rng.normal(size=im.shape) * nse + im += rng.normal(size=im.shape) * nse + + ims.append(im) + psfs.append(psf) + nse_ims.append(nse_im) + + ims = np.stack(ims) + psfs = np.stack(psfs) + nse_ims = np.stack(nse_ims) + + gt0 = time.time() + for im, psf, nse_im in zip(ims, psfs, nse_ims): + _metacal_galsim( + _galsim, im.copy(), psf.copy(), nse_im.copy(), scale, target_fwhm, g1 + ) + gt0 = time.time() - gt0 + print("Galsim time: ", gt0 * 1e3, " [ms]") + + jit_mcal = jax.jit( + jax.vmap( + _metacal_jax_galsim, + in_axes=(0, 0, 0, None, None, None, None), + ), + static_argnums=6, + ) + + for _ in range(2): + jgt0 = time.time() + jit_mcal( + ims, + psfs, + nse_ims, + scale, + target_fwhm, + g1, + 128, + ) + jgt0 = time.time() - jgt0 + print("Jax-Galsim time: ", jgt0 * 1e3, " [ms]") From 20bc7063dfa78b6f41a9b214358c9afdebdc813c Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 25 Oct 2023 05:53:26 -0500 Subject: [PATCH 48/67] TST add metacal tests --- tests/jax/test_metacal.py | 188 +++++++++++++++----------------------- 1 file changed, 76 insertions(+), 112 deletions(-) diff --git a/tests/jax/test_metacal.py b/tests/jax/test_metacal.py index 08b13126..258e4c89 100644 --- a/tests/jax/test_metacal.py +++ b/tests/jax/test_metacal.py @@ -1,4 +1,5 @@ import time +from functools import partial import galsim as _galsim import jax @@ -8,29 +9,22 @@ import jax_galsim -def _metacal_galsim(gs, im, psf, nse_im, scale, target_fwhm, g1): - res = {} - iim = gs.InterpolatedImage(gs.ImageD(im), scale=scale, x_interpolant="lanczos15") - ipsf = gs.InterpolatedImage(gs.ImageD(psf), scale=scale, x_interpolant="lanczos15") - inse = gs.InterpolatedImage( - gs.ImageD(np.rot90(nse_im)), scale=scale, x_interpolant="lanczos15" +def _metacal_galsim(im, psf, nse_im, scale, target_fwhm, g1): + iim = _galsim.InterpolatedImage( + _galsim.ImageD(im), scale=scale, x_interpolant="lanczos15" + ) + ipsf = _galsim.InterpolatedImage( + _galsim.ImageD(psf), scale=scale, x_interpolant="lanczos15" + ) + inse = _galsim.InterpolatedImage( + _galsim.ImageD(np.rot90(nse_im)), scale=scale, x_interpolant="lanczos15" ) - res["maxk"] = (iim.maxk, ipsf.maxk, inse.maxk) - res["stepk"] = (iim.stepk, ipsf.stepk, inse.stepk) - - res["iim"] = iim.drawImage( - nx=33, ny=33, scale=scale, method="no_pixel" - ).array.astype(np.float64) - res["ipsf"] = ipsf.drawImage( - nx=33, ny=33, scale=scale, method="no_pixel" - ).array.astype(np.float64) - - ppsf_iim = gs.Convolve(iim, gs.Deconvolve(ipsf)) + ppsf_iim = _galsim.Convolve(iim, _galsim.Deconvolve(ipsf)) ppsf_iim = ppsf_iim.shear(g1=g1, g2=0.0) sim = ( - gs.Convolve(ppsf_iim, gs.Gaussian(fwhm=target_fwhm)) + _galsim.Convolve(ppsf_iim, _galsim.Gaussian(fwhm=target_fwhm)) .drawImage( nx=33, ny=33, @@ -39,12 +33,11 @@ def _metacal_galsim(gs, im, psf, nse_im, scale, target_fwhm, g1): ) .array.astype(np.float64) ) - res["im"] = sim - ppsf_inse = gs.Convolve(inse, gs.Deconvolve(ipsf)) - ppsf_inse = ppsf_iim.shear(g1=g1, g2=0.0) + ppsf_inse = _galsim.Convolve(inse, _galsim.Deconvolve(ipsf)) + ppsf_inse = ppsf_inse.shear(g1=g1, g2=0.0) snse = ( - gs.Convolve(ppsf_inse, gs.Gaussian(fwhm=target_fwhm)) + _galsim.Convolve(ppsf_inse, _galsim.Gaussian(fwhm=target_fwhm)) .drawImage( nx=33, ny=33, @@ -53,39 +46,19 @@ def _metacal_galsim(gs, im, psf, nse_im, scale, target_fwhm, g1): ) .array.astype(np.float64) ) - res["nse"] = snse - res["tot"] = sim + np.rot90(snse, 3) - return res - - -def _metacal_jax_galsim(im, psf, nse_im, scale, target_fwhm, g1, nk): - gs = jax_galsim - iim = gs.InterpolatedImage(gs.ImageD(im), scale=scale, x_interpolant="lanczos15") - ipsf = gs.InterpolatedImage(gs.ImageD(psf), scale=scale, x_interpolant="lanczos15") - inse = gs.InterpolatedImage( - gs.ImageD(jnp.rot90(nse_im)), scale=scale, x_interpolant="lanczos15" - ) + return sim + np.rot90(snse, 3) - res = {} - res["maxk"] = (iim.maxk, ipsf.maxk, inse.maxk) - res["stepk"] = (iim.stepk, ipsf.stepk, inse.stepk) +@partial(jax.jit, static_argnames=("nk",)) +def _metacal_jax_galsim_render(im, psf, g1, target_psf, scale, nk): + prepsf_im = jax_galsim.Convolve(im, jax_galsim.Deconvolve(psf)) + prepsf_im = prepsf_im.shear(g1=g1, g2=0.0) - res["iim"] = iim.drawImage( - nx=33, ny=33, scale=scale, method="no_pixel" - ).array.astype(np.float64) - res["ipsf"] = ipsf.drawImage( - nx=33, ny=33, scale=scale, method="no_pixel" - ).array.astype(np.float64) - - ppsf_iim = gs.Convolve(iim, gs.Deconvolve(ipsf)) - ppsf_iim = ppsf_iim.shear(g1=g1, g2=0.0) - - sim = ( - gs.Convolve( - ppsf_iim, - gs.Gaussian(fwhm=target_fwhm), - gsparams=gs.GSParams(minimum_fft_size=nk, maximum_fft_size=nk), + return ( + jax_galsim.Convolve( + prepsf_im, + target_psf, + gsparams=jax_galsim.GSParams(minimum_fft_size=nk, maximum_fft_size=nk), ) .drawImage( nx=33, @@ -95,27 +68,25 @@ def _metacal_jax_galsim(im, psf, nse_im, scale, target_fwhm, g1, nk): ) .array.astype(np.float64) ) - res["im"] = sim - ppsf_inse = gs.Convolve(inse, gs.Deconvolve(ipsf)) - ppsf_inse = ppsf_iim.shear(g1=g1, g2=0.0) - snse = ( - gs.Convolve( - ppsf_inse, - gs.Gaussian(fwhm=target_fwhm), - gsparams=gs.GSParams(minimum_fft_size=nk, maximum_fft_size=nk), - ) - .drawImage( - nx=33, - ny=33, - scale=scale, - method="no_pixel", - ) - .array.astype(np.float64) + +def _metacal_jax_galsim(im, psf, nse_im, scale, target_fwhm, g1, nk): + iim = jax_galsim.InterpolatedImage( + jax_galsim.ImageD(im), scale=scale, x_interpolant="lanczos15" + ) + ipsf = jax_galsim.InterpolatedImage( + jax_galsim.ImageD(psf), scale=scale, x_interpolant="lanczos15" ) - res["nse"] = snse - res["tot"] = sim + jnp.rot90(snse, 3) - return res + inse = jax_galsim.InterpolatedImage( + jax_galsim.ImageD(jnp.rot90(nse_im)), scale=scale, x_interpolant="lanczos15" + ) + target_psf = jax_galsim.Gaussian(fwhm=target_fwhm) + + sim = _metacal_jax_galsim_render(iim, ipsf, g1, target_psf, scale, nk) + + snse = _metacal_jax_galsim_render(inse, ipsf, g1, target_psf, scale, nk) + + return sim + jnp.rot90(snse, 3) def test_metacal_comp_to_galsim(): @@ -156,16 +127,18 @@ def test_metacal_comp_to_galsim(): im += rng.normal(size=im.shape) * nse gt0 = time.time() - gres = _metacal_galsim( - _galsim, im.copy(), psf.copy(), nse_im.copy(), scale, target_fwhm, g1 - ) + gres = _metacal_galsim(im.copy(), psf.copy(), nse_im.copy(), scale, target_fwhm, g1) gt0 = time.time() - gt0 - jit_mcal = jax.jit(_metacal_jax_galsim, static_argnums=6) + print("Galsim time: ", gt0 * 1e3, " [ms]") - for _ in range(2): + for i in range(2): + if i == 0: + msg = "jit warmup" + elif i == 1: + msg = "jit" jgt0 = time.time() - jgres = jit_mcal( + jgres = _metacal_jax_galsim( im, psf, nse_im, @@ -174,39 +147,27 @@ def test_metacal_comp_to_galsim(): g1, 128, ) + jgres = jax.block_until_ready(jgres) jgt0 = time.time() - jgt0 - print("Jax-Galsim time: ", jgt0 * 1e3, " [ms]") + print("Jax-Galsim time (%s): " % msg, jgt0 * 1e3, " [ms]") - print("Galsim time: ", gt0 * 1e3, " [ms]") + gim = gres + jgim = jgres - np.testing.assert_allclose(gres["maxk"], jgres["maxk"], rtol=2e-2) - np.testing.assert_allclose(gres["stepk"], jgres["stepk"], rtol=2e-2) - - for k in ["iim", "ipsf", "im", "nse", "tot"]: - gim = gres[k] - jgim = jgres[k] - - if k in ["iim", "ipsf"]: - atol = 1e-7 - else: - atol = 5e-5 - - if not np.allclose(gim, jgim, rtol=0, atol=atol): - import proplot as pplt - - fig, axs = pplt.subplots(ncols=3, nrows=5, figsize=(4.5, 7.5)) - print(axs.shape) - for row, _k in enumerate(["iim", "ipsf", "im", "nse", "tot"]): - _gim = gres[_k] - _jgim = jgres[_k] - axs[row, 0].imshow(np.arcsinh(_gim / nse)) - axs[row, 1].imshow(np.arcsinh(_jgim / nse)) - axs[row, 2].imshow(_jgim - _gim) - fig.show() - - np.testing.assert_allclose( - gim, jgim, err_msg=f"Failed for {k}", rtol=0, atol=atol - ) + atol = 7e-5 + + if not np.allclose(gim, jgim, rtol=0, atol=atol): + import proplot as pplt + + fig, axs = pplt.subplots(ncols=3, nrows=1, figsize=(4.5, 7.5)) + _gim = gres + _jgim = jgres + axs[0].imshow(np.arcsinh(_gim / nse)) + axs[1].imshow(np.arcsinh(_jgim / nse)) + axs[2].imshow(_jgim - _gim) + fig.show() + + np.testing.assert_allclose(gim, jgim, rtol=0, atol=atol) def test_metacal_vmap(): @@ -261,9 +222,7 @@ def test_metacal_vmap(): gt0 = time.time() for im, psf, nse_im in zip(ims, psfs, nse_ims): - _metacal_galsim( - _galsim, im.copy(), psf.copy(), nse_im.copy(), scale, target_fwhm, g1 - ) + _metacal_galsim(im.copy(), psf.copy(), nse_im.copy(), scale, target_fwhm, g1) gt0 = time.time() - gt0 print("Galsim time: ", gt0 * 1e3, " [ms]") @@ -275,7 +234,12 @@ def test_metacal_vmap(): static_argnums=6, ) - for _ in range(2): + for i in range(2): + if i == 0: + msg = "jit warmup" + elif i == 1: + msg = "jit" + jgt0 = time.time() jit_mcal( ims, @@ -287,4 +251,4 @@ def test_metacal_vmap(): 128, ) jgt0 = time.time() - jgt0 - print("Jax-Galsim time: ", jgt0 * 1e3, " [ms]") + print("Jax-Galsim time (%s): " % msg, jgt0 * 1e3, " [ms]") From 5037a50986bfaa837be7f289f3c587241bf90289 Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 25 Oct 2023 07:10:17 -0500 Subject: [PATCH 49/67] REF use vectorized drawing directly without vmap --- jax_galsim/interpolatedimage.py | 149 +++++++++++++++++----- tests/jax/test_interpolatedimage_utils.py | 2 +- 2 files changed, 118 insertions(+), 33 deletions(-) diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index a957ca6c..d6ef34c2 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -18,7 +18,6 @@ from jax_galsim import fits from jax_galsim.bounds import BoundsI -from jax_galsim.core.draw import draw_by_kValue, draw_by_xValue from jax_galsim.core.utils import compute_major_minor_from_jacobian, ensure_hashable from jax_galsim.gsobject import GSObject from jax_galsim.gsparams import GSParams @@ -761,59 +760,145 @@ def _max_sb(self): def _flux_per_photon(self): return self._calculate_flux_per_photon() - def _xValue(self, pos): - pos += self._offset + @jax.jit + def _xValue_arr(x, y, x_offset, y_offset, xmin, ymin, arr, x_interpolant): vals = _draw_with_interpolant_xval( - jnp.array([pos.x], dtype=float), - jnp.array([pos.y], dtype=float), + x + x_offset, + y + y_offset, + xmin, + ymin, + arr, + x_interpolant, + ) + return vals + + def _xValue(self, pos): + x = jnp.array([pos.x], dtype=float) + y = jnp.array([pos.y], dtype=float) + return _InterpolatedImageImpl._xValue_arr( + x, + y, + self._offset.x, + self._offset.y, self._pad_image.bounds.xmin, self._pad_image.bounds.ymin, self._pad_image.array, self._x_interpolant, - ) - return vals[0] + )[0] - def _kValue(self, kpos): + @jax.jit + def _kValue_arr( + kx, + ky, + x_offset, + y_offset, + kxmin, + kymin, + arr, + scale, + x_interpolant, + k_interpolant, + ): # phase factor due to offset - # not we shift by -offset which explains the signs - # in pkx, pky - pkx = kpos.x * 1j * self._offset.x - pky = kpos.y * 1j * self._offset.y - pkx += pky - pfac = jnp.exp(pkx) + # not we shift by -offset which explains the sign + # in the exponent + pfac = jnp.exp(1j * (kx * x_offset + ky * y_offset)) - _kim = self._kim - kx = jnp.array([kpos.x / _kim.scale], dtype=float) - ky = jnp.array([kpos.y / _kim.scale], dtype=float) + kxi = kx / scale + kyi = ky / scale _uscale = 1.0 / (2.0 * jnp.pi) - _maxk_xint = self._x_interpolant.urange() / _uscale / _kim.scale + _maxk_xint = x_interpolant.urange() / _uscale / scale val = _draw_with_interpolant_kval( + kxi, + kyi, + kymin, # this is not a bug! we need the minimum for the full periodic space + kymin, + arr, + k_interpolant, + ) + + msk = (jnp.abs(kxi) <= _maxk_xint) & (jnp.abs(kyi) <= _maxk_xint) + xint_val = x_interpolant._kval_noraise(kx) * x_interpolant._kval_noraise(ky) + return jnp.where(msk, val * xint_val * pfac, 0.0) + + def _kValue(self, kpos): + kx = jnp.array([kpos.x], dtype=float) + ky = jnp.array([kpos.y], dtype=float) + return _InterpolatedImageImpl._kValue_arr( kx, ky, - _kim.bounds.ymin, - _kim.bounds.ymin, - _kim.array, + self._offset.x, + self._offset.y, + self._kim.bounds.xmin, + self._kim.bounds.ymin, + self._kim.array, + self._kim.scale, + self._x_interpolant, self._k_interpolant, - ) - - msk = (jnp.abs(kx) <= _maxk_xint) & (jnp.abs(ky) <= _maxk_xint) - xint_val = self._x_interpolant._kval_noraise( - kx * _kim.scale - ) * self._x_interpolant._kval_noraise(ky * _kim.scale) - return jnp.where(msk, val * xint_val * pfac, 0.0)[0] + )[0] def _shoot(self, photons, rng): raise NotImplementedError("Photon shooting not implemented.") def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): - _jac = jnp.eye(2) if jac is None else jac - return draw_by_xValue(self, image, _jac, jnp.asarray(offset), flux_scaling) + jacobian = jnp.eye(2) if jac is None else jac + + flux_scaling *= image.scale**2 + + # Create an array of coordinates + coords = jnp.stack(image.get_pixel_centers(), axis=-1) + coords = coords * image.scale # Scale by the image pixel scale + coords = coords - jnp.asarray(offset) # Add the offset + + # Apply the jacobian transformation + inv_jacobian = jnp.linalg.inv(jacobian) + _, logdet = jnp.linalg.slogdet(inv_jacobian) + coords = jnp.dot(coords, inv_jacobian.T) + flux_scaling *= jnp.exp(logdet) + + im = _InterpolatedImageImpl._xValue_arr( + coords[..., 0], + coords[..., 1], + self._offset.x, + self._offset.y, + self._pad_image.bounds.xmin, + self._pad_image.bounds.ymin, + self._pad_image.array, + self._x_interpolant, + ) + + # Apply the flux scaling + im = (im * flux_scaling).astype(image.dtype) + + # Return an image + return Image(array=im, bounds=image.bounds, wcs=image.wcs, check_bounds=False) def _drawKImage(self, image, jac=None): - _jac = jnp.eye(2) if jac is None else jac - return draw_by_kValue(self, image, _jac) + jacobian = jnp.eye(2) if jac is None else jac + + # Create an array of coordinates + coords = jnp.stack(image.get_pixel_centers(), axis=-1) + coords = coords * image.scale # Scale by the image pixel scale + coords = jnp.dot(coords, jacobian) + + im = _InterpolatedImageImpl._kValue_arr( + coords[..., 0], + coords[..., 1], + self._offset.x, + self._offset.y, + self._kim.bounds.xmin, + self._kim.bounds.ymin, + self._kim.array, + self._kim.scale, + self._x_interpolant, + self._k_interpolant, + ) + im = (im).astype(image.dtype) + + # Return an image + return Image(array=im, bounds=image.bounds, wcs=image.wcs, check_bounds=False) @_wraps(_galsim._InterpolatedImage) diff --git a/tests/jax/test_interpolatedimage_utils.py b/tests/jax/test_interpolatedimage_utils.py index 30473a1e..a25dd8b1 100644 --- a/tests/jax/test_interpolatedimage_utils.py +++ b/tests/jax/test_interpolatedimage_utils.py @@ -194,7 +194,7 @@ def test_interpolatedimage_utils_stepk_maxk(): ), ], ) -@pytest.mark.parametrize("method", ["xValue", "kValue"]) +@pytest.mark.parametrize("method", ["kValue", "xValue"]) def test_interpolatedimage_utils_comp_to_galsim( method, ref_array, offset_x, offset_y, wcs, use_true_center, normalization ): From 8e0669f4f7deeee0e5c3609dd74c6ff30967ded3 Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 26 Oct 2023 05:15:09 -0500 Subject: [PATCH 50/67] REF use lazy property --- jax_galsim/interpolant.py | 21 +++++--------- jax_galsim/interpolatedimage.py | 51 +++++++++++++++++---------------- 2 files changed, 34 insertions(+), 38 deletions(-) diff --git a/jax_galsim/interpolant.py b/jax_galsim/interpolant.py index 2682c2df..2b5ab344 100644 --- a/jax_galsim/interpolant.py +++ b/jax_galsim/interpolant.py @@ -8,6 +8,7 @@ import jax import jax.numpy as jnp from galsim.errors import GalSimValueError +from galsim.utilities import lazy_property from jax._src.numpy.util import _wraps from jax.tree_util import register_pytree_node_class @@ -1340,15 +1341,15 @@ def __init__( self._conserve_dc = conserve_dc self._gsparams = GSParams.check(gsparams) - @property + @lazy_property def _K_arr(self): - return self._K_arrs[self._n] + return _compute_C_K_lanczos(self._n)[1] - @property + @lazy_property def _C_arr(self): - return self._C_arrs[self._n] + return _compute_C_K_lanczos(self._n)[0] - @property + @lazy_property def _du(self): return ( self._gsparams.table_spacing @@ -1356,7 +1357,7 @@ def _du(self): / self._n ) - @property + @lazy_property def _umax(self): return _find_umax_lanczos( self._du, @@ -1635,11 +1636,3 @@ def _compute_C_K_lanczos(n): _C = _C.at[5].set(-_K[5]) return _C, _K - - -Lanczos._K_arrs = {} -Lanczos._C_arrs = {} -for n in range(1, 31): - _C_arr, _K_arr = _compute_C_K_lanczos(n) - Lanczos._K_arrs[n] = _K_arr - Lanczos._C_arrs[n] = _C_arr diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index d6ef34c2..9c815662 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -12,7 +12,7 @@ GalSimUndefinedBoundsError, GalSimValueError, ) -from galsim.utilities import doc_inherit +from galsim.utilities import doc_inherit, lazy_property from jax._src.numpy.util import _wraps from jax.tree_util import register_pytree_node_class @@ -55,12 +55,9 @@ def __dir__(cls): """The JAX equivalent of galsim.InterpolatedImage does not support - noise padding + - the pad_image options - depixelize - most of the type checks and dtype casts done by galsim - - Further, it always computes the FFT of the image as opposed to galsim - where this is done as needed. One almost always needs the FFT and JAX - generally works best with pure functions that do not modify state. """ ), ) @@ -492,7 +489,7 @@ def __init__( image=self._jax_children[0], ) - @property + @lazy_property def _flux_ratio(self): if self._jax_children[1]["flux"] is None: flux = self._image_flux @@ -508,11 +505,11 @@ def _flux_ratio(self): # this class return flux / self._image_flux - @property + @lazy_property def _image_flux(self): return jnp.sum(self._image.array, dtype=float) - @property + @lazy_property def _offset(self): # Figure out the offset to apply based on the original image (not the padded one). # We will apply this below in _sbp. @@ -521,7 +518,7 @@ def _offset(self): self._image.bounds, offset, None, self._jax_aux_data["use_true_center"] ) - @property + @lazy_property def _image(self): # Store the image as an attribute and make sure we don't change the original image # in anything we do here. (e.g. set scale, etc.) @@ -539,7 +536,7 @@ def _image(self): return image - @property + @lazy_property def _wcs(self): im_cen = ( self._jax_children[0].true_center @@ -557,7 +554,7 @@ def _wcs(self): return wcs.local(image_pos=im_cen) - @property + @lazy_property def _jac_arr(self): image = self._jax_children[0] im_cen = ( @@ -565,7 +562,7 @@ def _jac_arr(self): ) return self._wcs.jacobian(image_pos=im_cen).getMatrix().ravel() - @property + @lazy_property def _maxk(self): if self._jax_aux_data["_force_maxk"]: major, minor = compute_major_minor_from_jacobian( @@ -575,7 +572,7 @@ def _maxk(self): else: return self._getMaxK(self._jax_aux_data["calculate_maxk"]) - @property + @lazy_property def _stepk(self): if self._jax_aux_data["_force_stepk"]: major, minor = compute_major_minor_from_jacobian( @@ -612,7 +609,7 @@ def tree_unflatten(cls, aux_data, children): ret = cls(children[0], **val) return ret - @property + @lazy_property def _xim(self): pad_factor = self._jax_aux_data["pad_factor"] @@ -638,7 +635,7 @@ def _xim(self): return xim - @property + @lazy_property def _pad_image(self): # These next two allow for easy pickling/repring. We don't need to serialize all the # zeros around the edge. But we do need to keep any non-zero padding as a pad_image. @@ -646,11 +643,11 @@ def _pad_image(self): nz_bounds = self._image.bounds return xim[nz_bounds] - @property + @lazy_property def _kim(self): return self._xim.calculate_fft() - @property + @lazy_property def _pos_neg_fluxes(self): # record pos and neg fluxes now too pflux = jnp.sum(jnp.where(self._pad_image.array > 0, self._pad_image.array, 0)) @@ -731,7 +728,9 @@ def k_interpolant(self): @property def image(self): """The underlying `Image` being interpolated.""" - return self._xim[self._image.bounds] + if not hasattr(self, "_image_val"): + self._image_val = self._xim[self._image.bounds] + return self._image_val @property def _flux(self): @@ -739,11 +738,13 @@ def _flux(self): @property def _centroid(self): - x, y = self._pad_image.get_pixel_centers() - tot = jnp.sum(self._pad_image.array) - xpos = jnp.sum(x * self._pad_image.array) / tot - ypos = jnp.sum(y * self._pad_image.array) / tot - return PositionD(xpos, ypos) + if not hasattr(self, "_centroid_val"): + x, y = self._pad_image.get_pixel_centers() + tot = jnp.sum(self._pad_image.array) + xpos = jnp.sum(x * self._pad_image.array) / tot + ypos = jnp.sum(y * self._pad_image.array) / tot + self._centroid_val = PositionD(xpos, ypos) + return self._centroid_val @property def _positive_flux(self): @@ -755,7 +756,9 @@ def _negative_flux(self): @property def _max_sb(self): - return jnp.max(jnp.abs(self._pad_image.array)) + if not hasattr(self, "_max_sb_val"): + self._max_sb_val = jnp.max(jnp.abs(self._pad_image.array)) + return self._max_sb_val def _flux_per_photon(self): return self._calculate_flux_per_photon() From b9c9f08ed7be4565b0d755c6463c3f44b5fa4c2c Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 26 Oct 2023 05:52:22 -0500 Subject: [PATCH 51/67] TST add extra jit for testing --- tests/jax/test_metacal.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/jax/test_metacal.py b/tests/jax/test_metacal.py index 258e4c89..304de2a4 100644 --- a/tests/jax/test_metacal.py +++ b/tests/jax/test_metacal.py @@ -132,13 +132,15 @@ def test_metacal_comp_to_galsim(): print("Galsim time: ", gt0 * 1e3, " [ms]") + _func = jax.jit(_metacal_jax_galsim, static_argnames=("nk",)) + for i in range(2): if i == 0: msg = "jit warmup" elif i == 1: msg = "jit" jgt0 = time.time() - jgres = _metacal_jax_galsim( + jgres = _func( im, psf, nse_im, From 75c332b142dbc4b84021069795871ac61f29486a Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 26 Oct 2023 06:17:22 -0500 Subject: [PATCH 52/67] REF do not use lazy property here --- jax_galsim/interpolatedimage.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index 9c815662..65035876 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -12,7 +12,7 @@ GalSimUndefinedBoundsError, GalSimValueError, ) -from galsim.utilities import doc_inherit, lazy_property +from galsim.utilities import doc_inherit from jax._src.numpy.util import _wraps from jax.tree_util import register_pytree_node_class @@ -489,7 +489,7 @@ def __init__( image=self._jax_children[0], ) - @lazy_property + @property def _flux_ratio(self): if self._jax_children[1]["flux"] is None: flux = self._image_flux @@ -505,11 +505,11 @@ def _flux_ratio(self): # this class return flux / self._image_flux - @lazy_property + @property def _image_flux(self): return jnp.sum(self._image.array, dtype=float) - @lazy_property + @property def _offset(self): # Figure out the offset to apply based on the original image (not the padded one). # We will apply this below in _sbp. @@ -518,7 +518,7 @@ def _offset(self): self._image.bounds, offset, None, self._jax_aux_data["use_true_center"] ) - @lazy_property + @property def _image(self): # Store the image as an attribute and make sure we don't change the original image # in anything we do here. (e.g. set scale, etc.) @@ -536,7 +536,7 @@ def _image(self): return image - @lazy_property + @property def _wcs(self): im_cen = ( self._jax_children[0].true_center @@ -554,7 +554,7 @@ def _wcs(self): return wcs.local(image_pos=im_cen) - @lazy_property + @property def _jac_arr(self): image = self._jax_children[0] im_cen = ( @@ -562,7 +562,7 @@ def _jac_arr(self): ) return self._wcs.jacobian(image_pos=im_cen).getMatrix().ravel() - @lazy_property + @property def _maxk(self): if self._jax_aux_data["_force_maxk"]: major, minor = compute_major_minor_from_jacobian( @@ -572,7 +572,7 @@ def _maxk(self): else: return self._getMaxK(self._jax_aux_data["calculate_maxk"]) - @lazy_property + @property def _stepk(self): if self._jax_aux_data["_force_stepk"]: major, minor = compute_major_minor_from_jacobian( @@ -609,7 +609,7 @@ def tree_unflatten(cls, aux_data, children): ret = cls(children[0], **val) return ret - @lazy_property + @property def _xim(self): pad_factor = self._jax_aux_data["pad_factor"] @@ -635,7 +635,7 @@ def _xim(self): return xim - @lazy_property + @property def _pad_image(self): # These next two allow for easy pickling/repring. We don't need to serialize all the # zeros around the edge. But we do need to keep any non-zero padding as a pad_image. @@ -643,11 +643,11 @@ def _pad_image(self): nz_bounds = self._image.bounds return xim[nz_bounds] - @lazy_property + @property def _kim(self): return self._xim.calculate_fft() - @lazy_property + @property def _pos_neg_fluxes(self): # record pos and neg fluxes now too pflux = jnp.sum(jnp.where(self._pad_image.array > 0, self._pad_image.array, 0)) From 69be6570683f4c09eaf2be5684882e186d96e4d0 Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 26 Oct 2023 06:26:18 -0500 Subject: [PATCH 53/67] REF do not cache items which have gradients --- jax_galsim/interpolatedimage.py | 20 +++++++------------- jax_galsim/transform.py | 12 ++++-------- 2 files changed, 11 insertions(+), 21 deletions(-) diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index 65035876..0e8b9a79 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -728,9 +728,7 @@ def k_interpolant(self): @property def image(self): """The underlying `Image` being interpolated.""" - if not hasattr(self, "_image_val"): - self._image_val = self._xim[self._image.bounds] - return self._image_val + return self._xim[self._image.bounds] @property def _flux(self): @@ -738,13 +736,11 @@ def _flux(self): @property def _centroid(self): - if not hasattr(self, "_centroid_val"): - x, y = self._pad_image.get_pixel_centers() - tot = jnp.sum(self._pad_image.array) - xpos = jnp.sum(x * self._pad_image.array) / tot - ypos = jnp.sum(y * self._pad_image.array) / tot - self._centroid_val = PositionD(xpos, ypos) - return self._centroid_val + x, y = self._pad_image.get_pixel_centers() + tot = jnp.sum(self._pad_image.array) + xpos = jnp.sum(x * self._pad_image.array) / tot + ypos = jnp.sum(y * self._pad_image.array) / tot + return PositionD(xpos, ypos) @property def _positive_flux(self): @@ -756,9 +752,7 @@ def _negative_flux(self): @property def _max_sb(self): - if not hasattr(self, "_max_sb_val"): - self._max_sb_val = jnp.max(jnp.abs(self._pad_image.array)) - return self._max_sb_val + return jnp.max(jnp.abs(self._pad_image.array)) def _flux_per_photon(self): return self._calculate_flux_per_photon() diff --git a/jax_galsim/transform.py b/jax_galsim/transform.py index 598ec7bb..aa34eec5 100644 --- a/jax_galsim/transform.py +++ b/jax_galsim/transform.py @@ -273,19 +273,15 @@ def _kfactor(self, kx, ky): kx += ky return self._flux_scaling * jnp.exp(kx) - def _major_minor(self): - if not hasattr(self, "_major"): - self._major, self._minor = compute_major_minor_from_jacobian(self._jac) - @property def _maxk(self): - self._major_minor() - return self._original.maxk / self._minor + _, minor = compute_major_minor_from_jacobian(self._jac) + return self._original.maxk / minor @property def _stepk(self): - self._major_minor() - stepk = self._original.stepk / self._major + major, _ = compute_major_minor_from_jacobian(self._jac) + stepk = self._original.stepk / major # If we have a shift, we need to further modify stepk # stepk = Pi/R # R <- R + |shift| From d38086d877a8b8542b092220c5f28f4ee9d26dfc Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 27 Oct 2023 06:20:46 -0500 Subject: [PATCH 54/67] ENH add lazy property decorator with explicit workspace --- jax_galsim/interpolant.py | 15 +++- jax_galsim/interpolatedimage.py | 14 +++- jax_galsim/utilities.py | 23 ++++++ tests/jax/test_interpolatedimage_utils.py | 47 ++++++------ tests/jax/test_metacal.py | 91 ++++++++++++++++++++--- 5 files changed, 152 insertions(+), 38 deletions(-) diff --git a/jax_galsim/interpolant.py b/jax_galsim/interpolant.py index 2b5ab344..ffba7613 100644 --- a/jax_galsim/interpolant.py +++ b/jax_galsim/interpolant.py @@ -8,13 +8,13 @@ import jax import jax.numpy as jnp from galsim.errors import GalSimValueError -from galsim.utilities import lazy_property from jax._src.numpy.util import _wraps from jax.tree_util import register_pytree_node_class from jax_galsim.bessel import si from jax_galsim.core.utils import is_equal_with_arrays from jax_galsim.gsparams import GSParams +from jax_galsim.utilities import lazy_property @_wraps(_galsim.interpolant.Interpolant) @@ -1340,6 +1340,7 @@ def __init__( self._n = n self._conserve_dc = conserve_dc self._gsparams = GSParams.check(gsparams) + self._workspace = {} @lazy_property def _K_arr(self): @@ -1384,7 +1385,8 @@ def tree_flatten(self): def tree_unflatten(cls, aux_data, children): """Recreates an instance of the class from flattened representation""" n = aux_data.pop("n") - return cls(n, **aux_data) + ret = cls(n, **aux_data) + return ret def __repr__(self): return "galsim.Lanczos(%r, %r, gsparams=%r)" % ( @@ -1396,6 +1398,15 @@ def __repr__(self): def __str__(self): return "galsim.Lanczos(%s)" % (self._n) + def __getstate__(self): + d = self.__dict__.copy() + d.pop("_workspace") + return d + + def __setstate__(self, d): + self.__dict__ = d + self._workspace = {} + # this is a pure function and we apply JIT ahead of time since this # one is pretty slow @jax.jit diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index 0e8b9a79..4e1049af 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -25,7 +25,7 @@ from jax_galsim.interpolant import Quintic from jax_galsim.position import PositionD from jax_galsim.transform import Transformation -from jax_galsim.utilities import convert_interpolant +from jax_galsim.utilities import convert_interpolant, lazy_property from jax_galsim.wcs import BaseWCS, PixelScale # These keys are removed from the public API of @@ -391,6 +391,7 @@ def __init__( # this class does a ton of munging of the inputs that I don't want to reconstruct when # flattening and unflattening the class. # thus I am going to make some refs here so we have it when we need it + self._workspace = {} self._jax_children = ( image, dict( @@ -609,6 +610,15 @@ def tree_unflatten(cls, aux_data, children): ret = cls(children[0], **val) return ret + def __getstate__(self): + d = self.__dict__.copy() + d.pop("_workspace") + return d + + def __setstate__(self, d): + self.__dict__ = d + self._workspace = {} + @property def _xim(self): pad_factor = self._jax_aux_data["pad_factor"] @@ -643,7 +653,7 @@ def _pad_image(self): nz_bounds = self._image.bounds return xim[nz_bounds] - @property + @lazy_property def _kim(self): return self._xim.calculate_fft() diff --git a/jax_galsim/utilities.py b/jax_galsim/utilities.py index d21886ef..3ef35e11 100644 --- a/jax_galsim/utilities.py +++ b/jax_galsim/utilities.py @@ -1,3 +1,5 @@ +import functools + import galsim as _galsim import jax.numpy as jnp from jax._src.numpy.util import _wraps @@ -7,6 +9,27 @@ printoptions = _galsim.utilities.printoptions +@_wraps( + _galsim.utilities.lazy_property, + lax_description=( + "The LAX version of this decorator uses an `_workspace` attribute " + "attached to the object so that the cache can easily be discarded " + "for certain operations." + ), +) +def lazy_property(func): + attname = func.__name__ + "_cached" + + @property + @functools.wraps(func) + def _func(self): + if attname not in self._workspace: + self._workspace[attname] = func(self) + return self._workspace[attname] + + return _func + + @_wraps(_galsim.utilities.parse_pos_args) def parse_pos_args(args, kwargs, name1, name2, integer=False, others=[]): def canindex(arg): diff --git a/tests/jax/test_interpolatedimage_utils.py b/tests/jax/test_interpolatedimage_utils.py index a25dd8b1..527b0711 100644 --- a/tests/jax/test_interpolatedimage_utils.py +++ b/tests/jax/test_interpolatedimage_utils.py @@ -140,14 +140,15 @@ def test_interpolatedimage_utils_stepk_maxk(): np.testing.assert_allclose(gii.stepk, jgii.stepk, rtol=rtol, atol=0) +@pytest.mark.parametrize("x_interp", ["lanczos15", "quintic"]) @pytest.mark.parametrize("normalization", ["sb", "flux"]) @pytest.mark.parametrize("use_true_center", [True, False]) @pytest.mark.parametrize( "wcs", [ - _galsim.PixelScale(2.0), - _galsim.JacobianWCS(2.1, 0.3, -0.4, 2.3), - _galsim.AffineTransform(-0.3, 2.1, 1.8, 0.1, _galsim.PositionD(0.3, -0.4)), + _galsim.PixelScale(0.2), + _galsim.JacobianWCS(0.21, 0.03, -0.04, 0.23), + _galsim.AffineTransform(-0.03, 0.21, 0.18, 0.01, _galsim.PositionD(0.3, -0.4)), ], ) @pytest.mark.parametrize( @@ -173,30 +174,20 @@ def test_interpolatedimage_utils_stepk_maxk(): @pytest.mark.parametrize( "ref_array", [ - np.array( - [ - [0.01, 0.08, 0.07, 0.02, 0.0, 0.0], - [0.13, 0.38, 0.52, 0.06, 0.0, 0.05], - [0.09, 0.41, 0.44, 0.09, 0.0, 0.2], - [0.04, 0.11, 0.10, 0.01, 0.0, 0.5], - [0.04, 0.11, 0.10, 0.01, 0.0, 0.3], - [0.04, 0.11, 0.10, 0.01, 0.0, 0.1], - ] - ), - np.array( - [ - [0.01, 0.08, 0.07, 0.02, 0.0], - [0.13, 0.38, 0.52, 0.06, 0.0], - [0.09, 0.41, 0.44, 0.09, 0.0], - [0.04, 0.11, 0.10, 0.01, 0.0], - [0.04, 0.11, 0.10, 0.01, 0.0], - ] - ), + _galsim.Gaussian(fwhm=0.9).drawImage(nx=33, ny=33, scale=0.2).array, + _galsim.Gaussian(fwhm=0.9).drawImage(nx=32, ny=32, scale=0.2).array, ], ) @pytest.mark.parametrize("method", ["kValue", "xValue"]) def test_interpolatedimage_utils_comp_to_galsim( - method, ref_array, offset_x, offset_y, wcs, use_true_center, normalization + method, + ref_array, + offset_x, + offset_y, + wcs, + use_true_center, + normalization, + x_interp, ): gimage_in = _galsim.Image(ref_array, scale=1) jgimage_in = jax_galsim.Image(ref_array, scale=1) @@ -207,6 +198,7 @@ def test_interpolatedimage_utils_comp_to_galsim( offset=_galsim.PositionD(offset_x, offset_y), use_true_center=use_true_center, normalization=normalization, + x_interpolant=x_interp, ) jgii = jax_galsim.InterpolatedImage( jgimage_in, @@ -214,8 +206,11 @@ def test_interpolatedimage_utils_comp_to_galsim( offset=jax_galsim.PositionD(offset_x, offset_y), use_true_center=use_true_center, normalization=normalization, + x_interpolant=x_interp, ) + rng = np.random.RandomState(seed=42) + np.testing.assert_allclose(gii.stepk, jgii.stepk, rtol=0.5, atol=0) np.testing.assert_allclose(gii.maxk, jgii.maxk, rtol=0.5, atol=0) kxvals = [ @@ -234,14 +229,16 @@ def test_interpolatedimage_utils_comp_to_galsim( ] for x, y in kxvals: if method == "kValue": - dk = jgii._original._kim.scale + dk = jgii._original._kim.scale * rng.uniform(low=0.5, high=1.5) np.testing.assert_allclose( gii.kValue(x * dk, y * dk), jgii.kValue(x * dk, y * dk), err_msg=f"kValue mismatch: wcs={wcs}, x={x}, y={y}", ) else: - dx = jnp.sqrt(jgii._original._wcs.pixelArea()) + dx = jnp.sqrt(jgii._original._wcs.pixelArea()) * rng.uniform( + low=0.5, high=1.5 + ) np.testing.assert_allclose( gii.xValue(x * dx, y * dx), jgii.xValue(x * dx, y * dx), diff --git a/tests/jax/test_metacal.py b/tests/jax/test_metacal.py index 304de2a4..35f36735 100644 --- a/tests/jax/test_metacal.py +++ b/tests/jax/test_metacal.py @@ -9,7 +9,7 @@ import jax_galsim -def _metacal_galsim(im, psf, nse_im, scale, target_fwhm, g1): +def _metacal_galsim(im, psf, nse_im, scale, target_fwhm, g1, nk): iim = _galsim.InterpolatedImage( _galsim.ImageD(im), scale=scale, x_interpolant="lanczos15" ) @@ -24,7 +24,11 @@ def _metacal_galsim(im, psf, nse_im, scale, target_fwhm, g1): ppsf_iim = ppsf_iim.shear(g1=g1, g2=0.0) sim = ( - _galsim.Convolve(ppsf_iim, _galsim.Gaussian(fwhm=target_fwhm)) + _galsim.Convolve( + ppsf_iim, + _galsim.Gaussian(fwhm=target_fwhm), + gsparams=_galsim.GSParams(minimum_fft_size=nk, maximum_fft_size=nk), + ) .drawImage( nx=33, ny=33, @@ -37,7 +41,11 @@ def _metacal_galsim(im, psf, nse_im, scale, target_fwhm, g1): ppsf_inse = _galsim.Convolve(inse, _galsim.Deconvolve(ipsf)) ppsf_inse = ppsf_inse.shear(g1=g1, g2=0.0) snse = ( - _galsim.Convolve(ppsf_inse, _galsim.Gaussian(fwhm=target_fwhm)) + _galsim.Convolve( + ppsf_inse, + _galsim.Gaussian(fwhm=target_fwhm), + gsparams=_galsim.GSParams(minimum_fft_size=nk, maximum_fft_size=nk), + ) .drawImage( nx=33, ny=33, @@ -127,7 +135,9 @@ def test_metacal_comp_to_galsim(): im += rng.normal(size=im.shape) * nse gt0 = time.time() - gres = _metacal_galsim(im.copy(), psf.copy(), nse_im.copy(), scale, target_fwhm, g1) + gres = _metacal_galsim( + im.copy(), psf.copy(), nse_im.copy(), scale, target_fwhm, g1, 128 + ) gt0 = time.time() - gt0 print("Galsim time: ", gt0 * 1e3, " [ms]") @@ -161,12 +171,73 @@ def test_metacal_comp_to_galsim(): if not np.allclose(gim, jgim, rtol=0, atol=atol): import proplot as pplt - fig, axs = pplt.subplots(ncols=3, nrows=1, figsize=(4.5, 7.5)) + fig, axs = pplt.subplots(ncols=3, nrows=3) _gim = gres _jgim = jgres - axs[0].imshow(np.arcsinh(_gim / nse)) - axs[1].imshow(np.arcsinh(_jgim / nse)) - axs[2].imshow(_jgim - _gim) + + gpsf = ( + _galsim.InterpolatedImage( + _galsim.Image(psf, scale=scale), x_interpolant="lanczos15" + ) + .drawImage( + nx=33, + ny=33, + scale=scale, + ) + .array + ) + jgpsf = ( + jax_galsim.InterpolatedImage( + jax_galsim.Image(psf, scale=scale), x_interpolant="lanczos15" + ) + .drawImage( + nx=33, + ny=33, + scale=scale, + ) + .array + ) + + axs[0, 0].imshow(gpsf) + axs[0, 1].imshow(jgpsf) + m = axs[0, 2].imshow((jgpsf - gpsf) / 1e-5) + axs[0, 2].colorbar(m, loc="r") + + gpsf = ( + _galsim.InterpolatedImage( + _galsim.Image(psf, scale=scale), x_interpolant="lanczos15" + ) + .drawImage( + nx=33, + ny=33, + scale=scale, + method="no_pixel", + ) + .array + ) + jgpsf = ( + jax_galsim.InterpolatedImage( + jax_galsim.Image(psf, scale=scale), x_interpolant="lanczos15" + ) + .drawImage( + nx=33, + ny=33, + scale=scale, + method="no_pixel", + ) + .array + ) + + axs[1, 0].imshow(gpsf) + axs[1, 1].imshow(jgpsf) + m = axs[1, 2].imshow((jgpsf - gpsf) / 1e-5) + axs[1, 2].colorbar(m, loc="r") + + axs[2, 0].imshow(np.arcsinh(_gim / nse)) + axs[2, 1].imshow(np.arcsinh(_jgim / nse)) + m = axs[2, 2].imshow((_jgim - _gim) / 1e-5) + axs[2, 2].colorbar(m, loc="r") + fig.show() np.testing.assert_allclose(gim, jgim, rtol=0, atol=atol) @@ -224,7 +295,9 @@ def test_metacal_vmap(): gt0 = time.time() for im, psf, nse_im in zip(ims, psfs, nse_ims): - _metacal_galsim(im.copy(), psf.copy(), nse_im.copy(), scale, target_fwhm, g1) + _metacal_galsim( + im.copy(), psf.copy(), nse_im.copy(), scale, target_fwhm, g1, 128 + ) gt0 = time.time() - gt0 print("Galsim time: ", gt0 * 1e3, " [ms]") From 326de3762371d65424d009242fe3b972a332a4a3 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 27 Oct 2023 06:21:45 -0500 Subject: [PATCH 55/67] ENH add lazy property decorator with explicit workspace --- jax_galsim/interpolatedimage.py | 212 ++++++++++++++++---------------- jax_galsim/utilities.py | 2 + tests/jax/test_metacal.py | 139 +++++++++++++++++++++ 3 files changed, 247 insertions(+), 106 deletions(-) diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index 4e1049af..15ed6416 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -490,7 +490,74 @@ def __init__( image=self._jax_children[0], ) + @doc_inherit + def withGSParams(self, gsparams=None, **kwargs): + if gsparams == self.gsparams: + return self + # Checking gsparams + gsparams = GSParams.check(gsparams, self.gsparams, **kwargs) + # Flattening the representation to instantiate a clean new object + children, aux_data = self.tree_flatten() + aux_data["gsparams"] = gsparams + ret = self.tree_unflatten(aux_data, children) + + return ret + + def tree_flatten(self): + """This function flattens the InterpolatedImage into a list of children + nodes that will be traced by JAX and auxiliary static data.""" + return (self._jax_children, copy.copy(self._jax_aux_data)) + + @classmethod + def tree_unflatten(cls, aux_data, children): + """Recreates an instance of the class from flatten representation""" + val = {} + val.update(aux_data) + val.update(children[1]) + ret = cls(children[0], **val) + return ret + + def __getstate__(self): + d = self.__dict__.copy() + d.pop("_workspace") + return d + + def __setstate__(self, d): + self.__dict__ = d + self._workspace = {} + + @property + def x_interpolant(self): + """The real-space `Interpolant` for this profile.""" + return self._x_interpolant + @property + def k_interpolant(self): + """The Fourier-space `Interpolant` for this profile.""" + return self._k_interpolant + + @lazy_property + def image(self): + """The underlying `Image` being interpolated.""" + return self._xim[self._image.bounds] + + @property + def _flux(self): + return self._image_flux + + @lazy_property + def _centroid(self): + x, y = self._pad_image.get_pixel_centers() + tot = jnp.sum(self._pad_image.array) + xpos = jnp.sum(x * self._pad_image.array) / tot + ypos = jnp.sum(y * self._pad_image.array) / tot + return PositionD(xpos, ypos) + + @lazy_property + def _max_sb(self): + return jnp.max(jnp.abs(self._pad_image.array)) + + @lazy_property def _flux_ratio(self): if self._jax_children[1]["flux"] is None: flux = self._image_flux @@ -506,11 +573,11 @@ def _flux_ratio(self): # this class return flux / self._image_flux - @property + @lazy_property def _image_flux(self): return jnp.sum(self._image.array, dtype=float) - @property + @lazy_property def _offset(self): # Figure out the offset to apply based on the original image (not the padded one). # We will apply this below in _sbp. @@ -519,7 +586,7 @@ def _offset(self): self._image.bounds, offset, None, self._jax_aux_data["use_true_center"] ) - @property + @lazy_property def _image(self): # Store the image as an attribute and make sure we don't change the original image # in anything we do here. (e.g. set scale, etc.) @@ -537,7 +604,7 @@ def _image(self): return image - @property + @lazy_property def _wcs(self): im_cen = ( self._jax_children[0].true_center @@ -555,7 +622,7 @@ def _wcs(self): return wcs.local(image_pos=im_cen) - @property + @lazy_property def _jac_arr(self): image = self._jax_children[0] im_cen = ( @@ -563,63 +630,7 @@ def _jac_arr(self): ) return self._wcs.jacobian(image_pos=im_cen).getMatrix().ravel() - @property - def _maxk(self): - if self._jax_aux_data["_force_maxk"]: - major, minor = compute_major_minor_from_jacobian( - self._jac_arr.reshape((2, 2)) - ) - return self._jax_aux_data["_force_maxk"] * minor - else: - return self._getMaxK(self._jax_aux_data["calculate_maxk"]) - - @property - def _stepk(self): - if self._jax_aux_data["_force_stepk"]: - major, minor = compute_major_minor_from_jacobian( - self._jac_arr.reshape((2, 2)) - ) - return self._jax_aux_data["_force_stepk"] * minor - else: - return self._getStepK(self._jax_aux_data["calculate_stepk"]) - - @doc_inherit - def withGSParams(self, gsparams=None, **kwargs): - if gsparams == self.gsparams: - return self - # Checking gsparams - gsparams = GSParams.check(gsparams, self.gsparams, **kwargs) - # Flattening the representation to instantiate a clean new object - children, aux_data = self.tree_flatten() - aux_data["gsparams"] = gsparams - ret = self.tree_unflatten(aux_data, children) - - return ret - - def tree_flatten(self): - """This function flattens the InterpolatedImage into a list of children - nodes that will be traced by JAX and auxiliary static data.""" - return (self._jax_children, copy.copy(self._jax_aux_data)) - - @classmethod - def tree_unflatten(cls, aux_data, children): - """Recreates an instance of the class from flatten representation""" - val = {} - val.update(aux_data) - val.update(children[1]) - ret = cls(children[0], **val) - return ret - - def __getstate__(self): - d = self.__dict__.copy() - d.pop("_workspace") - return d - - def __setstate__(self, d): - self.__dict__ = d - self._workspace = {} - - @property + @lazy_property def _xim(self): pad_factor = self._jax_aux_data["pad_factor"] @@ -645,7 +656,7 @@ def _xim(self): return xim - @property + @lazy_property def _pad_image(self): # These next two allow for easy pickling/repring. We don't need to serialize all the # zeros around the edge. But we do need to keep any non-zero padding as a pad_image. @@ -657,7 +668,7 @@ def _pad_image(self): def _kim(self): return self._xim.calculate_fft() - @property + @lazy_property def _pos_neg_fluxes(self): # record pos and neg fluxes now too pflux = jnp.sum(jnp.where(self._pad_image.array > 0, self._pad_image.array, 0)) @@ -673,6 +684,37 @@ def _pos_neg_fluxes(self): pint2d * nflux + nint2d * pflux, ] + @property + def _positive_flux(self): + return self._pos_neg_fluxes[0] + + @property + def _negative_flux(self): + return self._pos_neg_fluxes[1] + + def _flux_per_photon(self): + return self._calculate_flux_per_photon() + + @lazy_property + def _maxk(self): + if self._jax_aux_data["_force_maxk"]: + major, minor = compute_major_minor_from_jacobian( + self._jac_arr.reshape((2, 2)) + ) + return self._jax_aux_data["_force_maxk"] * minor + else: + return self._getMaxK(self._jax_aux_data["calculate_maxk"]) + + @lazy_property + def _stepk(self): + if self._jax_aux_data["_force_stepk"]: + major, minor = compute_major_minor_from_jacobian( + self._jac_arr.reshape((2, 2)) + ) + return self._jax_aux_data["_force_stepk"] * minor + else: + return self._getStepK(self._jax_aux_data["calculate_stepk"]) + def _getStepK(self, calculate_stepk): # GalSim cannot automatically know what stepK and maxK are appropriate for the # input image. So it is usually worth it to do a manual calculation (below). @@ -725,48 +767,6 @@ def _getMaxK(self, calculate_maxk): else: return self._x_interpolant.krange - @property - def x_interpolant(self): - """The real-space `Interpolant` for this profile.""" - return self._x_interpolant - - @property - def k_interpolant(self): - """The Fourier-space `Interpolant` for this profile.""" - return self._k_interpolant - - @property - def image(self): - """The underlying `Image` being interpolated.""" - return self._xim[self._image.bounds] - - @property - def _flux(self): - return self._image_flux - - @property - def _centroid(self): - x, y = self._pad_image.get_pixel_centers() - tot = jnp.sum(self._pad_image.array) - xpos = jnp.sum(x * self._pad_image.array) / tot - ypos = jnp.sum(y * self._pad_image.array) / tot - return PositionD(xpos, ypos) - - @property - def _positive_flux(self): - return self._pos_neg_fluxes[0] - - @property - def _negative_flux(self): - return self._pos_neg_fluxes[1] - - @property - def _max_sb(self): - return jnp.max(jnp.abs(self._pad_image.array)) - - def _flux_per_photon(self): - return self._calculate_flux_per_photon() - @jax.jit def _xValue_arr(x, y, x_offset, y_offset, xmin, ymin, arr, x_interpolant): vals = _draw_with_interpolant_xval( diff --git a/jax_galsim/utilities.py b/jax_galsim/utilities.py index 3ef35e11..d6c74af7 100644 --- a/jax_galsim/utilities.py +++ b/jax_galsim/utilities.py @@ -23,6 +23,8 @@ def lazy_property(func): @property @functools.wraps(func) def _func(self): + if not hasattr(self, "_workspace"): + self._workspace = {} if attname not in self._workspace: self._workspace[attname] = func(self) return self._workspace[attname] diff --git a/tests/jax/test_metacal.py b/tests/jax/test_metacal.py index 35f36735..f28c64ab 100644 --- a/tests/jax/test_metacal.py +++ b/tests/jax/test_metacal.py @@ -97,6 +97,145 @@ def _metacal_jax_galsim(im, psf, nse_im, scale, target_fwhm, g1, nk): return sim + jnp.rot90(snse, 3) +def test_metacal_jit_timing(): + seed = 42 + hlr = 0.5 + fwhm = 0.9 + scale = 0.2 + nse = 1e-3 + g1 = 0.01 + target_fwhm = 1.0 + + rng = np.random.RandomState(seed) + + im = ( + _galsim.Convolve( + _galsim.Exponential(half_light_radius=hlr), + _galsim.Gaussian(fwhm=fwhm), + ) + .drawImage( + nx=33, + ny=33, + scale=scale, + ) + .array.astype(np.float64) + ) + + psf = ( + _galsim.Gaussian(fwhm=fwhm) + .drawImage( + nx=33, + ny=33, + scale=scale, + ) + .array.astype(np.float64) + ) + + # nse_im = rng.normal(size=im.shape) * nse + im += rng.normal(size=im.shape) * nse + + def _f1(im, psf, g1, target_fwhm, scale, nk): + iim = jax_galsim.InterpolatedImage( + jax_galsim.ImageD(im), + scale=scale, + x_interpolant="lanczos15", + gsparams=jax_galsim.GSParams(minimum_fft_size=nk, maximum_fft_size=nk), + ) + return iim.drawImage( + nx=33, + ny=33, + scale=scale, + ).array + + def _f2(im, psf, g1, target_fwhm, scale, nk): + iim = jax_galsim.InterpolatedImage( + jax_galsim.ImageD(im), + scale=scale, + x_interpolant="lanczos15", + gsparams=jax_galsim.GSParams(minimum_fft_size=nk, maximum_fft_size=nk), + ) + return ( + iim.drawImage( + nx=33, + ny=33, + scale=scale, + ).array, + iim.drawImage( + nx=33, + ny=33, + scale=scale * 1.1, + ).array, + ) + + def _f3(im, psf, g1, target_fwhm, scale, nk): + iim = jax_galsim.InterpolatedImage( + jax_galsim.ImageD(im), + scale=scale, + x_interpolant="lanczos15", + gsparams=jax_galsim.GSParams(minimum_fft_size=nk, maximum_fft_size=nk), + ) + ipsf = jax_galsim.InterpolatedImage( + jax_galsim.ImageD(psf), scale=scale, x_interpolant="lanczos15" + ) + prepsf_im = jax_galsim.Convolve(iim, jax_galsim.Deconvolve(ipsf)) + prepsf_im = prepsf_im.shear(g1=g1, g2=0.0) + + return prepsf_im + + def _f4(im, psf, g1, target_fwhm, scale, nk): + iim = jax_galsim.InterpolatedImage( + jax_galsim.ImageD(im), + scale=scale, + x_interpolant="lanczos15", + gsparams=jax_galsim.GSParams(minimum_fft_size=nk, maximum_fft_size=nk), + ) + ipsf = jax_galsim.InterpolatedImage( + jax_galsim.ImageD(psf), scale=scale, x_interpolant="lanczos15" + ) + prepsf_im = jax_galsim.Convolve(iim, jax_galsim.Deconvolve(ipsf)) + prepsf_im = prepsf_im.shear(g1=g1, g2=0.0) + target_psf = jax_galsim.Gaussian(fwhm=target_fwhm) + + return ( + jax_galsim.Convolve( + prepsf_im, + target_psf, + gsparams=jax_galsim.GSParams(minimum_fft_size=nk, maximum_fft_size=nk), + ) + .drawImage( + nx=33, + ny=33, + scale=scale, + method="no_pixel", + ) + .array.astype(np.float64) + ) + + for k, func in enumerate([_f4, _f2, _f1, _f2, _f3, _f4, _f2, _f4]): + print("Timing: ", func.__name__) + for i in range(3): + if i == 0: + msg = "no jit" + _func = func + elif i == 1: + msg = "jit warmup" + _func = jax.jit(func, static_argnames=["nk"]) + elif i == 2: + msg = "jit" + jgt0 = time.time() + jgres = _func(im, psf, g1, target_fwhm, scale, 128) + jgres = jax.block_until_ready(jgres) + jgt0 = time.time() - jgt0 + print("Jax-Galsim time (%s): " % msg, jgt0 * 1e3, " [ms]") + if k == 0 and i == 0: + first_time = jgt0 + if k == 0 and i == 2: + jit_time = jgt0 + + assert first_time > jgt0 + np.testing.assert_allclose(jit_time, jgt0, rtol=0.2) + + def test_metacal_comp_to_galsim(): seed = 42 hlr = 0.5 From 13edd60c3b4e83a87ee0f32bbc2bd0716ee87aef Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 28 Oct 2023 13:56:05 -0500 Subject: [PATCH 56/67] PERF faster tests --- jax_galsim/interpolant.py | 34 +++++----- jax_galsim/interpolatedimage.py | 110 ++++++++++++++++---------------- 2 files changed, 70 insertions(+), 74 deletions(-) diff --git a/jax_galsim/interpolant.py b/jax_galsim/interpolant.py index ffba7613..acce2c21 100644 --- a/jax_galsim/interpolant.py +++ b/jax_galsim/interpolant.py @@ -14,7 +14,6 @@ from jax_galsim.bessel import si from jax_galsim.core.utils import is_equal_with_arrays from jax_galsim.gsparams import GSParams -from jax_galsim.utilities import lazy_property @_wraps(_galsim.interpolant.Interpolant) @@ -1340,17 +1339,16 @@ def __init__( self._n = n self._conserve_dc = conserve_dc self._gsparams = GSParams.check(gsparams) - self._workspace = {} - @lazy_property - def _K_arr(self): - return _compute_C_K_lanczos(self._n)[1] - - @lazy_property + @property def _C_arr(self): - return _compute_C_K_lanczos(self._n)[0] + return self._C_arr_vals[self._n] + + @property + def _K_arr(self): + return self._K_arr_vals[self._n] - @lazy_property + @property def _du(self): return ( self._gsparams.table_spacing @@ -1358,7 +1356,7 @@ def _du(self): / self._n ) - @lazy_property + @property def _umax(self): return _find_umax_lanczos( self._du, @@ -1398,15 +1396,6 @@ def __repr__(self): def __str__(self): return "galsim.Lanczos(%s)" % (self._n) - def __getstate__(self): - d = self.__dict__.copy() - d.pop("_workspace") - return d - - def __setstate__(self, d): - self.__dict__ = d - self._workspace = {} - # this is a pure function and we apply JIT ahead of time since this # one is pretty slow @jax.jit @@ -1647,3 +1636,10 @@ def _compute_C_K_lanczos(n): _C = _C.at[5].set(-_K[5]) return _C, _K + + +Lanczos._C_arr_vals = {} +Lanczos._K_arr_vals = {} +for n in range(1, 31): + Lanczos._C_arr_vals[n] = _compute_C_K_lanczos(n)[0] + Lanczos._K_arr_vals[n] = _compute_C_K_lanczos(n)[1] diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index 15ed6416..3db82e0d 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -767,22 +767,10 @@ def _getMaxK(self, calculate_maxk): else: return self._x_interpolant.krange - @jax.jit - def _xValue_arr(x, y, x_offset, y_offset, xmin, ymin, arr, x_interpolant): - vals = _draw_with_interpolant_xval( - x + x_offset, - y + y_offset, - xmin, - ymin, - arr, - x_interpolant, - ) - return vals - def _xValue(self, pos): x = jnp.array([pos.x], dtype=float) y = jnp.array([pos.y], dtype=float) - return _InterpolatedImageImpl._xValue_arr( + return _xValue_arr( x, y, self._offset.x, @@ -793,47 +781,10 @@ def _xValue(self, pos): self._x_interpolant, )[0] - @jax.jit - def _kValue_arr( - kx, - ky, - x_offset, - y_offset, - kxmin, - kymin, - arr, - scale, - x_interpolant, - k_interpolant, - ): - # phase factor due to offset - # not we shift by -offset which explains the sign - # in the exponent - pfac = jnp.exp(1j * (kx * x_offset + ky * y_offset)) - - kxi = kx / scale - kyi = ky / scale - - _uscale = 1.0 / (2.0 * jnp.pi) - _maxk_xint = x_interpolant.urange() / _uscale / scale - - val = _draw_with_interpolant_kval( - kxi, - kyi, - kymin, # this is not a bug! we need the minimum for the full periodic space - kymin, - arr, - k_interpolant, - ) - - msk = (jnp.abs(kxi) <= _maxk_xint) & (jnp.abs(kyi) <= _maxk_xint) - xint_val = x_interpolant._kval_noraise(kx) * x_interpolant._kval_noraise(ky) - return jnp.where(msk, val * xint_val * pfac, 0.0) - def _kValue(self, kpos): kx = jnp.array([kpos.x], dtype=float) ky = jnp.array([kpos.y], dtype=float) - return _InterpolatedImageImpl._kValue_arr( + return _kValue_arr( kx, ky, self._offset.x, @@ -865,7 +816,7 @@ def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): coords = jnp.dot(coords, inv_jacobian.T) flux_scaling *= jnp.exp(logdet) - im = _InterpolatedImageImpl._xValue_arr( + im = _xValue_arr( coords[..., 0], coords[..., 1], self._offset.x, @@ -890,7 +841,7 @@ def _drawKImage(self, image, jac=None): coords = coords * image.scale # Scale by the image pixel scale coords = jnp.dot(coords, jacobian) - im = _InterpolatedImageImpl._kValue_arr( + im = _kValue_arr( coords[..., 0], coords[..., 1], self._offset.x, @@ -936,7 +887,19 @@ def _InterpolatedImage( ) -@partial(jax.jit, static_argnums=(5,)) +def _xValue_arr(x, y, x_offset, y_offset, xmin, ymin, arr, x_interpolant): + vals = _draw_with_interpolant_xval( + x + x_offset, + y + y_offset, + xmin, + ymin, + arr, + x_interpolant, + ) + return vals + + +@partial(jax.jit, static_argnames=("interp",)) def _draw_with_interpolant_xval(x, y, xmin, ymin, zp, interp): orig_shape = x.shape x = x.ravel() @@ -983,7 +946,44 @@ def _body(i, args): return z.reshape(orig_shape) -@partial(jax.jit, static_argnums=(5,)) +def _kValue_arr( + kx, + ky, + x_offset, + y_offset, + kxmin, + kymin, + arr, + scale, + x_interpolant, + k_interpolant, +): + # phase factor due to offset + # not we shift by -offset which explains the sign + # in the exponent + pfac = jnp.exp(1j * (kx * x_offset + ky * y_offset)) + + kxi = kx / scale + kyi = ky / scale + + _uscale = 1.0 / (2.0 * jnp.pi) + _maxk_xint = x_interpolant.urange() / _uscale / scale + + val = _draw_with_interpolant_kval( + kxi, + kyi, + kymin, # this is not a bug! we need the minimum for the full periodic space + kymin, + arr, + k_interpolant, + ) + + msk = (jnp.abs(kxi) <= _maxk_xint) & (jnp.abs(kyi) <= _maxk_xint) + xint_val = x_interpolant._kval_noraise(kx) * x_interpolant._kval_noraise(ky) + return jnp.where(msk, val * xint_val * pfac, 0.0) + + +@partial(jax.jit, static_argnames=("interp",)) def _draw_with_interpolant_kval(kx, ky, kxmin, kymin, zp, interp): orig_shape = kx.shape kx = kx.ravel() From 4138b03c39cebeeee1e717025e35b288b73d1ac5 Mon Sep 17 00:00:00 2001 From: beckermr Date: Sun, 29 Oct 2023 11:22:09 -0500 Subject: [PATCH 57/67] TST make tests faster by skipping some at random --- tests/jax/test_interpolatedimage_utils.py | 29 +++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/tests/jax/test_interpolatedimage_utils.py b/tests/jax/test_interpolatedimage_utils.py index 527b0711..c977eb46 100644 --- a/tests/jax/test_interpolatedimage_utils.py +++ b/tests/jax/test_interpolatedimage_utils.py @@ -1,3 +1,5 @@ +import hashlib + import galsim as _galsim import jax.numpy as jnp import numpy as np @@ -189,8 +191,29 @@ def test_interpolatedimage_utils_comp_to_galsim( normalization, x_interp, ): - gimage_in = _galsim.Image(ref_array, scale=1) - jgimage_in = jax_galsim.Image(ref_array, scale=1) + seed = max( + abs( + int( + hashlib.sha1( + f"{method}{ref_array}{offset_x}{offset_y}{wcs}{use_true_center}{normalization}{x_interp}".encode( + "utf-8" + ) + ).hexdigest(), + 16, + ) + ) + % (10**7), + 1, + ) + + rng = np.random.RandomState(seed=seed) + if rng.uniform() < 0.75: + pytest.skip( + "Skipping `test_interpolatedimage_utils_comp_to_galsim` case at random to save time." + ) + + gimage_in = _galsim.Image(ref_array, scale=0.2) + jgimage_in = jax_galsim.Image(ref_array, scale=0.2) gii = _galsim.InterpolatedImage( gimage_in, @@ -209,8 +232,6 @@ def test_interpolatedimage_utils_comp_to_galsim( x_interpolant=x_interp, ) - rng = np.random.RandomState(seed=42) - np.testing.assert_allclose(gii.stepk, jgii.stepk, rtol=0.5, atol=0) np.testing.assert_allclose(gii.maxk, jgii.maxk, rtol=0.5, atol=0) kxvals = [ From b478eb5bb4c3731be6640782faf46358198d5f08 Mon Sep 17 00:00:00 2001 From: beckermr Date: Sun, 29 Oct 2023 12:32:26 -0500 Subject: [PATCH 58/67] TST faster tests --- tests/jax/test_metacal.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/jax/test_metacal.py b/tests/jax/test_metacal.py index f28c64ab..90168b3a 100644 --- a/tests/jax/test_metacal.py +++ b/tests/jax/test_metacal.py @@ -211,7 +211,7 @@ def _f4(im, psf, g1, target_fwhm, scale, nk): .array.astype(np.float64) ) - for k, func in enumerate([_f4, _f2, _f1, _f2, _f3, _f4, _f2, _f4]): + for k, func in enumerate([_f4, _f3, _f2, _f1, _f4, _f3, _f2, _f1]): print("Timing: ", func.__name__) for i in range(3): if i == 0: @@ -394,7 +394,7 @@ def test_metacal_vmap(): ims = [] nse_ims = [] psfs = [] - for _seed in range(1000): + for _seed in range(10): seed = _seed + start_seed rng = np.random.RandomState(seed) From 32d49d2cc323ee1536965a658c1c474380dad62b Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Sun, 29 Oct 2023 15:19:05 -0500 Subject: [PATCH 59/67] Update test_metacal.py --- tests/jax/test_metacal.py | 137 -------------------------------------- 1 file changed, 137 deletions(-) diff --git a/tests/jax/test_metacal.py b/tests/jax/test_metacal.py index 90168b3a..98372419 100644 --- a/tests/jax/test_metacal.py +++ b/tests/jax/test_metacal.py @@ -97,143 +97,6 @@ def _metacal_jax_galsim(im, psf, nse_im, scale, target_fwhm, g1, nk): return sim + jnp.rot90(snse, 3) -def test_metacal_jit_timing(): - seed = 42 - hlr = 0.5 - fwhm = 0.9 - scale = 0.2 - nse = 1e-3 - g1 = 0.01 - target_fwhm = 1.0 - - rng = np.random.RandomState(seed) - - im = ( - _galsim.Convolve( - _galsim.Exponential(half_light_radius=hlr), - _galsim.Gaussian(fwhm=fwhm), - ) - .drawImage( - nx=33, - ny=33, - scale=scale, - ) - .array.astype(np.float64) - ) - - psf = ( - _galsim.Gaussian(fwhm=fwhm) - .drawImage( - nx=33, - ny=33, - scale=scale, - ) - .array.astype(np.float64) - ) - - # nse_im = rng.normal(size=im.shape) * nse - im += rng.normal(size=im.shape) * nse - - def _f1(im, psf, g1, target_fwhm, scale, nk): - iim = jax_galsim.InterpolatedImage( - jax_galsim.ImageD(im), - scale=scale, - x_interpolant="lanczos15", - gsparams=jax_galsim.GSParams(minimum_fft_size=nk, maximum_fft_size=nk), - ) - return iim.drawImage( - nx=33, - ny=33, - scale=scale, - ).array - - def _f2(im, psf, g1, target_fwhm, scale, nk): - iim = jax_galsim.InterpolatedImage( - jax_galsim.ImageD(im), - scale=scale, - x_interpolant="lanczos15", - gsparams=jax_galsim.GSParams(minimum_fft_size=nk, maximum_fft_size=nk), - ) - return ( - iim.drawImage( - nx=33, - ny=33, - scale=scale, - ).array, - iim.drawImage( - nx=33, - ny=33, - scale=scale * 1.1, - ).array, - ) - - def _f3(im, psf, g1, target_fwhm, scale, nk): - iim = jax_galsim.InterpolatedImage( - jax_galsim.ImageD(im), - scale=scale, - x_interpolant="lanczos15", - gsparams=jax_galsim.GSParams(minimum_fft_size=nk, maximum_fft_size=nk), - ) - ipsf = jax_galsim.InterpolatedImage( - jax_galsim.ImageD(psf), scale=scale, x_interpolant="lanczos15" - ) - prepsf_im = jax_galsim.Convolve(iim, jax_galsim.Deconvolve(ipsf)) - prepsf_im = prepsf_im.shear(g1=g1, g2=0.0) - - return prepsf_im - - def _f4(im, psf, g1, target_fwhm, scale, nk): - iim = jax_galsim.InterpolatedImage( - jax_galsim.ImageD(im), - scale=scale, - x_interpolant="lanczos15", - gsparams=jax_galsim.GSParams(minimum_fft_size=nk, maximum_fft_size=nk), - ) - ipsf = jax_galsim.InterpolatedImage( - jax_galsim.ImageD(psf), scale=scale, x_interpolant="lanczos15" - ) - prepsf_im = jax_galsim.Convolve(iim, jax_galsim.Deconvolve(ipsf)) - prepsf_im = prepsf_im.shear(g1=g1, g2=0.0) - target_psf = jax_galsim.Gaussian(fwhm=target_fwhm) - - return ( - jax_galsim.Convolve( - prepsf_im, - target_psf, - gsparams=jax_galsim.GSParams(minimum_fft_size=nk, maximum_fft_size=nk), - ) - .drawImage( - nx=33, - ny=33, - scale=scale, - method="no_pixel", - ) - .array.astype(np.float64) - ) - - for k, func in enumerate([_f4, _f3, _f2, _f1, _f4, _f3, _f2, _f1]): - print("Timing: ", func.__name__) - for i in range(3): - if i == 0: - msg = "no jit" - _func = func - elif i == 1: - msg = "jit warmup" - _func = jax.jit(func, static_argnames=["nk"]) - elif i == 2: - msg = "jit" - jgt0 = time.time() - jgres = _func(im, psf, g1, target_fwhm, scale, 128) - jgres = jax.block_until_ready(jgres) - jgt0 = time.time() - jgt0 - print("Jax-Galsim time (%s): " % msg, jgt0 * 1e3, " [ms]") - if k == 0 and i == 0: - first_time = jgt0 - if k == 0 and i == 2: - jit_time = jgt0 - - assert first_time > jgt0 - np.testing.assert_allclose(jit_time, jgt0, rtol=0.2) def test_metacal_comp_to_galsim(): From 0410565b28391461d26d05ee1e02298246f78532 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Sun, 29 Oct 2023 15:20:03 -0500 Subject: [PATCH 60/67] Update test_metacal.py --- tests/jax/test_metacal.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/jax/test_metacal.py b/tests/jax/test_metacal.py index 98372419..0843c77c 100644 --- a/tests/jax/test_metacal.py +++ b/tests/jax/test_metacal.py @@ -97,8 +97,6 @@ def _metacal_jax_galsim(im, psf, nse_im, scale, target_fwhm, g1, nk): return sim + jnp.rot90(snse, 3) - - def test_metacal_comp_to_galsim(): seed = 42 hlr = 0.5 From cfb6770ce022c33bf9802f28845ef4cc3db5bf2a Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 1 Nov 2023 06:29:49 -0500 Subject: [PATCH 61/67] TST metacal passes --- jax_galsim/core/wrap_image.py | 40 +++- jax_galsim/gsobject.py | 2 +- jax_galsim/image.py | 28 ++- jax_galsim/interpolatedimage.py | 7 +- tests/jax/galsim/test_image_jax.py | 102 +++++---- tests/jax/test_metacal.py | 318 ++++++++++++++++++----------- 6 files changed, 319 insertions(+), 178 deletions(-) diff --git a/jax_galsim/core/wrap_image.py b/jax_galsim/core/wrap_image.py index 94ced023..4aa89a79 100644 --- a/jax_galsim/core/wrap_image.py +++ b/jax_galsim/core/wrap_image.py @@ -4,7 +4,7 @@ @jax.jit -def wrap_nonhermition(im, xmin, ymin, nxwrap, nywrap): +def wrap_nonhermitian(im, xmin, ymin, nxwrap, nywrap): def _body_j(j, vals): i, im = vals @@ -33,3 +33,41 @@ def _body_i(i, vals): im = jax.lax.fori_loop(0, im.shape[0], _body_i, im) return im + + +@jax.jit +def expand_hermitian_x(im): + return jnp.concatenate([im[:, 1:][::-1, ::-1].conjugate(), im], axis=1) + + +@jax.jit +def contract_hermitian_x(im): + return im[:, im.shape[1] // 2 :] + + +@jax.jit +def wrap_hermitian_x(im, im_xmin, im_ymin, wrap_xmin, wrap_ymin, wrap_nx, wrap_ny): + im_exp = expand_hermitian_x(im) + im_exp = wrap_nonhermitian( + im_exp, wrap_xmin - im_xmin, wrap_ymin - im_ymin, wrap_nx, wrap_ny + ) + return contract_hermitian_x(im_exp) + + +@jax.jit +def expand_hermitian_y(im): + return jnp.concatenate([im[1:, :][::-1, ::-1].conjugate(), im], axis=0) + + +@jax.jit +def contract_hermitian_y(im): + return im[im.shape[0] // 2 :, :] + + +@jax.jit +def wrap_hermitian_y(im, im_xmin, im_ymin, wrap_xmin, wrap_ymin, wrap_nx, wrap_ny): + im_exp = expand_hermitian_y(im) + im_exp = wrap_nonhermitian( + im_exp, wrap_xmin - im_xmin, wrap_ymin - im_ymin, wrap_nx, wrap_ny + ) + return contract_hermitian_y(im_exp) diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index 75be0b09..49458091 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -772,7 +772,7 @@ def drawFFT_makeKImage(self, image): with jax.ensure_compile_time_eval(): Nk = self.gsparams.maximum_fft_size N = Nk - dk = 2.0 * np.pi / (N * image.scale) + dk = 2.0 * np.pi / (N * image.scale) else: # Start with what this profile thinks a good size would be given the image's pixel scale. N = self.getGoodImageSize(image.scale) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index ad6d19cc..ea38ca85 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -631,9 +631,9 @@ def _wrap(self, bounds, hermx, hermy): Equivalent to ``image.wrap(bounds, hermitian=='x', hermitian=='y')``. """ if not hermx and not hermy: - from jax_galsim.core.wrap_image import wrap_nonhermition + from jax_galsim.core.wrap_image import wrap_nonhermitian - self._array = wrap_nonhermition( + self._array = wrap_nonhermitian( self._array, # zero indexed location of subimage bounds.xmin - self.xmin, @@ -642,6 +642,30 @@ def _wrap(self, bounds, hermx, hermy): bounds.xmax - bounds.xmin + 1, bounds.ymax - bounds.ymin + 1, ) + elif hermx and not hermy: + from jax_galsim.core.wrap_image import wrap_hermitian_x + + self._array = wrap_hermitian_x( + self._array, + -self.xmax, + self.ymin, + -bounds.xmax + 1, + bounds.ymin, + 2 * bounds.xmax, + bounds.ymax - bounds.ymin + 1, + ) + elif not hermx and hermy: + from jax_galsim.core.wrap_image import wrap_hermitian_y + + self._array = wrap_hermitian_y( + self._array, + self.xmin, + -self.ymax, + bounds.xmin, + -bounds.ymax + 1, + bounds.xmax - bounds.xmin + 1, + 2 * bounds.ymax, + ) # FIXME: Wrapping not yet implemented for hermitian images return self.subImage(bounds) diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index 3db82e0d..c37fe2c3 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -1086,7 +1086,7 @@ def _calculate_size_containing_flux(image, thresh): @jax.jit def _inner_comp_find_maxk(arr, thresh, kx, ky): - msk = arr * arr.conjugate() > thresh * thresh + msk = (arr * arr.conjugate()).real > thresh * thresh max_kx = jnp.max( jnp.where( msk, @@ -1109,4 +1109,7 @@ def _find_maxk(kim, max_maxk, thresh): kx, ky = kim.get_pixel_centers() kx *= kim.scale ky *= kim.scale - return _inner_comp_find_maxk(kim.array, thresh, kx, ky) + return jnp.minimum( + _inner_comp_find_maxk(kim.array, thresh, kx, ky), + max_maxk, + ) diff --git a/tests/jax/galsim/test_image_jax.py b/tests/jax/galsim/test_image_jax.py index c14e87ca..1722b9b0 100644 --- a/tests/jax/galsim/test_image_jax.py +++ b/tests/jax/galsim/test_image_jax.py @@ -4502,59 +4502,57 @@ def test_wrap(): im_wrap.bounds, b, "image.wrap(%s) does not have the correct bounds" ) - # FIXME: turn on when hermitian wrapping is implemented - if False: - im2_wrap = im2.wrap(b2, hermitian="y") - # print('im_test = ',im_test[b2].array) - # print('im2_wrap = ',im2_wrap.array) - # print('diff = ',im2_wrap.array-im_test[b2].array) - np.testing.assert_array_almost_equal( - im2_wrap.array, - im_test[b2].array, - 12, - "image.wrap(%s) did not match expectation" % b, - ) - np.testing.assert_array_equal( - im2_wrap.array, - im2[b2].array, - "image.wrap(%s) did not return the right subimage", - ) - np.testing.assert_equal( - im2_wrap.bounds, b2, "image.wrap(%s) does not have the correct bounds" - ) + im2_wrap = im2.wrap(b2, hermitian="y") + # print('im_test = ',im_test[b2].array) + # print('im2_wrap = ',im2_wrap.array) + # print('diff = ',im2_wrap.array-im_test[b2].array) + np.testing.assert_array_almost_equal( + im2_wrap.array, + im_test[b2].array, + 12, + "image.wrap(%s) did not match expectation" % b, + ) + np.testing.assert_array_equal( + im2_wrap.array, + im2[b2].array, + "image.wrap(%s) did not return the right subimage", + ) + np.testing.assert_equal( + im2_wrap.bounds, b2, "image.wrap(%s) does not have the correct bounds" + ) - im3_wrap = im3.wrap(b3, hermitian="x") - # print('im_test = ',im_test[b3].array) - # print('im3_wrap = ',im3_wrap.array) - # print('diff = ',im3_wrap.array-im_test[b3].array) - np.testing.assert_array_almost_equal( - im3_wrap.array, - im_test[b3].array, - 12, - "image.wrap(%s) did not match expectation" % b, - ) - np.testing.assert_array_equal( - im3_wrap.array, - im3[b3].array, - "image.wrap(%s) did not return the right subimage", - ) - np.testing.assert_equal( - im3_wrap.bounds, b3, "image.wrap(%s) does not have the correct bounds" - ) - - b = galsim.BoundsI(-K + 1, K, -L + 1, L) - b2 = galsim.BoundsI(-K + 1, K, 0, L) - b3 = galsim.BoundsI(0, K, -L + 1, L) - assert_raises(TypeError, im.wrap, bounds=None) - assert_raises(ValueError, im3.wrap, b, hermitian="x") - assert_raises(ValueError, im3.wrap, b2, hermitian="x") - assert_raises(ValueError, im.wrap, b3, hermitian="x") - assert_raises(ValueError, im2.wrap, b, hermitian="y") - assert_raises(ValueError, im2.wrap, b3, hermitian="y") - assert_raises(ValueError, im.wrap, b2, hermitian="y") - assert_raises(ValueError, im.wrap, b, hermitian="invalid") - assert_raises(ValueError, im2.wrap, b2, hermitian="invalid") - assert_raises(ValueError, im3.wrap, b3, hermitian="invalid") + im3_wrap = im3.wrap(b3, hermitian="x") + # print('im_test = ',im_test[b3].array) + # print('im3_wrap = ',im3_wrap.array) + # print('diff = ',im3_wrap.array-im_test[b3].array) + np.testing.assert_array_almost_equal( + im3_wrap.array, + im_test[b3].array, + 12, + "image.wrap(%s) did not match expectation" % b, + ) + np.testing.assert_array_equal( + im3_wrap.array, + im3[b3].array, + "image.wrap(%s) did not return the right subimage", + ) + np.testing.assert_equal( + im3_wrap.bounds, b3, "image.wrap(%s) does not have the correct bounds" + ) + + b = galsim.BoundsI(-K + 1, K, -L + 1, L) + b2 = galsim.BoundsI(-K + 1, K, 0, L) + b3 = galsim.BoundsI(0, K, -L + 1, L) + assert_raises(TypeError, im.wrap, bounds=None) + assert_raises(ValueError, im3.wrap, b, hermitian="x") + assert_raises(ValueError, im3.wrap, b2, hermitian="x") + assert_raises(ValueError, im.wrap, b3, hermitian="x") + assert_raises(ValueError, im2.wrap, b, hermitian="y") + assert_raises(ValueError, im2.wrap, b3, hermitian="y") + assert_raises(ValueError, im.wrap, b2, hermitian="y") + assert_raises(ValueError, im.wrap, b, hermitian="invalid") + assert_raises(ValueError, im2.wrap, b2, hermitian="invalid") + assert_raises(ValueError, im3.wrap, b3, hermitian="invalid") @timer diff --git a/tests/jax/test_metacal.py b/tests/jax/test_metacal.py index 0843c77c..0b21a072 100644 --- a/tests/jax/test_metacal.py +++ b/tests/jax/test_metacal.py @@ -5,46 +5,56 @@ import jax import jax.numpy as jnp import numpy as np +import pytest import jax_galsim -def _metacal_galsim(im, psf, nse_im, scale, target_fwhm, g1, nk): +def _metacal_galsim( + im, psf, nse_im, scale, target_fwhm, g1, iim_kwargs, ipsf_kwargs, inse_kwargs, nk +): iim = _galsim.InterpolatedImage( - _galsim.ImageD(im), scale=scale, x_interpolant="lanczos15" + _galsim.ImageD(im), + scale=scale, + x_interpolant="lanczos15", + **iim_kwargs, ) ipsf = _galsim.InterpolatedImage( - _galsim.ImageD(psf), scale=scale, x_interpolant="lanczos15" + _galsim.ImageD(psf), + scale=scale, + x_interpolant="lanczos15", + **ipsf_kwargs, ) inse = _galsim.InterpolatedImage( - _galsim.ImageD(np.rot90(nse_im)), scale=scale, x_interpolant="lanczos15" + _galsim.ImageD(np.rot90(nse_im, 1)), + scale=scale, + x_interpolant="lanczos15", + **inse_kwargs, ) ppsf_iim = _galsim.Convolve(iim, _galsim.Deconvolve(ipsf)) ppsf_iim = ppsf_iim.shear(g1=g1, g2=0.0) - sim = ( - _galsim.Convolve( - ppsf_iim, - _galsim.Gaussian(fwhm=target_fwhm), - gsparams=_galsim.GSParams(minimum_fft_size=nk, maximum_fft_size=nk), - ) - .drawImage( - nx=33, - ny=33, - scale=scale, - method="no_pixel", - ) - .array.astype(np.float64) + prof = _galsim.Convolve( + ppsf_iim, + _galsim.Gaussian(fwhm=target_fwhm), + gsparams=_galsim.GSParams(minimum_fft_size=nk), ) + sim = prof.drawImage( + nx=33, + ny=33, + scale=scale, + method="no_pixel", + ).array.astype(np.float64) + ppsf_inse = _galsim.Convolve(inse, _galsim.Deconvolve(ipsf)) ppsf_inse = ppsf_inse.shear(g1=g1, g2=0.0) snse = ( _galsim.Convolve( ppsf_inse, _galsim.Gaussian(fwhm=target_fwhm), - gsparams=_galsim.GSParams(minimum_fft_size=nk, maximum_fft_size=nk), + gsparams=_galsim.GSParams(minimum_fft_size=nk), ) .drawImage( nx=33, @@ -62,21 +72,19 @@ def _metacal_jax_galsim_render(im, psf, g1, target_psf, scale, nk): prepsf_im = jax_galsim.Convolve(im, jax_galsim.Deconvolve(psf)) prepsf_im = prepsf_im.shear(g1=g1, g2=0.0) - return ( - jax_galsim.Convolve( - prepsf_im, - target_psf, - gsparams=jax_galsim.GSParams(minimum_fft_size=nk, maximum_fft_size=nk), - ) - .drawImage( - nx=33, - ny=33, - scale=scale, - method="no_pixel", - ) - .array.astype(np.float64) + prof = jax_galsim.Convolve( + prepsf_im, + target_psf, + gsparams=jax_galsim.GSParams(minimum_fft_size=nk, maximum_fft_size=nk), ) + return prof.drawImage( + nx=33, + ny=33, + scale=scale, + method="no_pixel", + ).array.astype(np.float64) + def _metacal_jax_galsim(im, psf, nse_im, scale, target_fwhm, g1, nk): iim = jax_galsim.InterpolatedImage( @@ -86,8 +94,9 @@ def _metacal_jax_galsim(im, psf, nse_im, scale, target_fwhm, g1, nk): jax_galsim.ImageD(psf), scale=scale, x_interpolant="lanczos15" ) inse = jax_galsim.InterpolatedImage( - jax_galsim.ImageD(jnp.rot90(nse_im)), scale=scale, x_interpolant="lanczos15" + jax_galsim.ImageD(jnp.rot90(nse_im, 1)), scale=scale, x_interpolant="lanczos15" ) + target_psf = jax_galsim.Gaussian(fwhm=target_fwhm) sim = _metacal_jax_galsim_render(iim, ipsf, g1, target_psf, scale, nk) @@ -97,12 +106,12 @@ def _metacal_jax_galsim(im, psf, nse_im, scale, target_fwhm, g1, nk): return sim + jnp.rot90(snse, 3) -def test_metacal_comp_to_galsim(): +@pytest.mark.parametrize("nse", [1e-3, 1e-10]) +def test_metacal_comp_to_galsim(nse): seed = 42 hlr = 0.5 fwhm = 0.9 scale = 0.2 - nse = 1e-3 g1 = 0.01 target_fwhm = 1.0 @@ -131,18 +140,56 @@ def test_metacal_comp_to_galsim(): .array.astype(np.float64) ) - nse_im = rng.normal(size=im.shape) * nse - im += rng.normal(size=im.shape) * nse + nse_im = rng.normal(size=im.shape, scale=nse) + im += rng.normal(size=im.shape, scale=nse) + + # jax galsim and galsim set stepk and maxk differently due to slight + # algorithmic differences. We force them to be the same here for this + # test so it passes. + iim = jax_galsim.InterpolatedImage( + jax_galsim.ImageD(im), + scale=scale, + x_interpolant="lanczos15", + gsparams=jax_galsim.GSParams(minimum_fft_size=128), + ) + iim_kwargs = { + "_force_stepk": iim.stepk.item(), + "_force_maxk": iim.maxk.item(), + } + inse = jax_galsim.InterpolatedImage( + jax_galsim.ImageD(jnp.rot90(nse_im, 1)), + scale=scale, + x_interpolant="lanczos15", + gsparams=jax_galsim.GSParams(minimum_fft_size=128), + ) + inse_kwargs = { + "_force_stepk": inse.stepk.item(), + "_force_maxk": inse.maxk.item(), + } + ipsf = jax_galsim.InterpolatedImage( + jax_galsim.ImageD(psf), scale=scale, x_interpolant="lanczos15" + ) + ipsf_kwargs = { + "_force_stepk": ipsf.stepk.item(), + "_force_maxk": ipsf.maxk.item(), + } gt0 = time.time() gres = _metacal_galsim( - im.copy(), psf.copy(), nse_im.copy(), scale, target_fwhm, g1, 128 + im.copy(), + psf.copy(), + nse_im.copy(), + scale, + target_fwhm, + g1, + iim_kwargs, + ipsf_kwargs, + inse_kwargs, + 128, ) gt0 = time.time() - gt0 - print("Galsim time: ", gt0 * 1e3, " [ms]") - - _func = jax.jit(_metacal_jax_galsim, static_argnames=("nk",)) + print("galsim time: ", gt0 * 1e3, " [ms]") for i in range(2): if i == 0: @@ -150,10 +197,10 @@ def test_metacal_comp_to_galsim(): elif i == 1: msg = "jit" jgt0 = time.time() - jgres = _func( - im, - psf, - nse_im, + jgres = _metacal_jax_galsim( + im.copy(), + psf.copy(), + nse_im.copy(), scale, target_fwhm, g1, @@ -161,82 +208,21 @@ def test_metacal_comp_to_galsim(): ) jgres = jax.block_until_ready(jgres) jgt0 = time.time() - jgt0 - print("Jax-Galsim time (%s): " % msg, jgt0 * 1e3, " [ms]") + print("jax-galsim time (%s): " % msg, jgt0 * 1e3, " [ms]") gim = gres jgim = jgres - atol = 7e-5 - + atol = 1e-8 if not np.allclose(gim, jgim, rtol=0, atol=atol): import proplot as pplt - fig, axs = pplt.subplots(ncols=3, nrows=3) - _gim = gres - _jgim = jgres + fig, axs = pplt.subplots(ncols=3, nrows=1) - gpsf = ( - _galsim.InterpolatedImage( - _galsim.Image(psf, scale=scale), x_interpolant="lanczos15" - ) - .drawImage( - nx=33, - ny=33, - scale=scale, - ) - .array - ) - jgpsf = ( - jax_galsim.InterpolatedImage( - jax_galsim.Image(psf, scale=scale), x_interpolant="lanczos15" - ) - .drawImage( - nx=33, - ny=33, - scale=scale, - ) - .array - ) - - axs[0, 0].imshow(gpsf) - axs[0, 1].imshow(jgpsf) - m = axs[0, 2].imshow((jgpsf - gpsf) / 1e-5) - axs[0, 2].colorbar(m, loc="r") - - gpsf = ( - _galsim.InterpolatedImage( - _galsim.Image(psf, scale=scale), x_interpolant="lanczos15" - ) - .drawImage( - nx=33, - ny=33, - scale=scale, - method="no_pixel", - ) - .array - ) - jgpsf = ( - jax_galsim.InterpolatedImage( - jax_galsim.Image(psf, scale=scale), x_interpolant="lanczos15" - ) - .drawImage( - nx=33, - ny=33, - scale=scale, - method="no_pixel", - ) - .array - ) - - axs[1, 0].imshow(gpsf) - axs[1, 1].imshow(jgpsf) - m = axs[1, 2].imshow((jgpsf - gpsf) / 1e-5) - axs[1, 2].colorbar(m, loc="r") - - axs[2, 0].imshow(np.arcsinh(_gim / nse)) - axs[2, 1].imshow(np.arcsinh(_jgim / nse)) - m = axs[2, 2].imshow((_jgim - _gim) / 1e-5) - axs[2, 2].colorbar(m, loc="r") + axs[0].imshow(np.arcsinh(gres / nse)) + axs[1].imshow(np.arcsinh(jgres / nse)) + m = axs[2].imshow(jgres - gres) + axs[2].colorbar(m, loc="r") fig.show() @@ -296,17 +282,23 @@ def test_metacal_vmap(): gt0 = time.time() for im, psf, nse_im in zip(ims, psfs, nse_ims): _metacal_galsim( - im.copy(), psf.copy(), nse_im.copy(), scale, target_fwhm, g1, 128 + im.copy(), + psf.copy(), + nse_im.copy(), + scale, + target_fwhm, + g1, + {}, + {}, + {}, + 128, ) gt0 = time.time() - gt0 print("Galsim time: ", gt0 * 1e3, " [ms]") - jit_mcal = jax.jit( - jax.vmap( - _metacal_jax_galsim, - in_axes=(0, 0, 0, None, None, None, None), - ), - static_argnums=6, + vmap_mcal = jax.vmap( + _metacal_jax_galsim, + in_axes=(0, 0, 0, None, None, None, None), ) for i in range(2): @@ -316,7 +308,7 @@ def test_metacal_vmap(): msg = "jit" jgt0 = time.time() - jit_mcal( + vmap_mcal( ims, psfs, nse_ims, @@ -327,3 +319,89 @@ def test_metacal_vmap(): ) jgt0 = time.time() - jgt0 print("Jax-Galsim time (%s): " % msg, jgt0 * 1e3, " [ms]") + + +@pytest.mark.parametrize( + "draw_method", + [ + "no_pixel", + "auto", + ], +) +@pytest.mark.parametrize( + "nse", + [ + 4e-3, + 1e-3, + 1e-10, + ], +) +def test_metacal_iimage_with_noise(nse, draw_method): + hlr = 0.5 + fwhm = 0.9 + scale = 0.2 + nk = 128 + seed = 42 + + rng = np.random.RandomState(seed) + + im = ( + _galsim.Convolve( + _galsim.Exponential(half_light_radius=hlr), + _galsim.Gaussian(fwhm=fwhm), + ) + .drawImage( + nx=33, + ny=33, + scale=scale, + ) + .array.astype(np.float64) + ) + im += rng.normal(size=im.shape) * nse + + jgiim = jax_galsim.InterpolatedImage( + jax_galsim.ImageD(im), + scale=scale, + x_interpolant="lanczos15", + gsparams=jax_galsim.GSParams(minimum_fft_size=nk), + ) + + giim = _galsim.InterpolatedImage( + _galsim.ImageD(im), + scale=scale, + x_interpolant="lanczos15", + gsparams=_galsim.GSParams(minimum_fft_size=nk), + _force_stepk=jgiim.stepk.item(), + _force_maxk=jgiim.maxk.item(), + ) + + def _plot_real(gim, jgim): + import proplot as pplt + + fig, axs = pplt.subplots(ncols=3, nrows=1) + + axs[0].imshow(gim) + axs[1].imshow(jgim) + m = axs[0, 2].imshow((jgim - gim)) + axs[2].colorbar(m, loc="r") + + fig.show() + + atol = 1e-8 + np.testing.assert_allclose(giim.maxk, jgiim.maxk) + np.testing.assert_allclose(giim.maxk, jgiim.maxk) + + if draw_method == "no_pixel": + gim = giim.drawImage(nx=33, ny=33, scale=scale, method="no_pixel").array + jgim = jgiim.drawImage(nx=33, ny=33, scale=scale, method="no_pixel").array + + if not np.allclose(gim, jgim, rtol=0, atol=atol): + _plot_real(gim, jgim) + np.testing.assert_allclose(gim, jgim, rtol=0, atol=atol) + elif draw_method == "auto": + gim = giim.drawImage(nx=33, ny=33, scale=scale).array + jgim = jgiim.drawImage(nx=33, ny=33, scale=scale).array + + if not np.allclose(gim, jgim, rtol=0, atol=atol): + _plot_real(gim, jgim) + np.testing.assert_allclose(gim, jgim, rtol=0, atol=atol) From bf2c17dada1c44b14d87149f4746ba018b1502fe Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 1 Nov 2023 06:33:19 -0500 Subject: [PATCH 62/67] STY isort --- tests/jax/test_image_wrapping.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/jax/test_image_wrapping.py b/tests/jax/test_image_wrapping.py index 268e10b4..bbf9e9f4 100644 --- a/tests/jax/test_image_wrapping.py +++ b/tests/jax/test_image_wrapping.py @@ -1,13 +1,13 @@ import jax -import jax_galsim as galsim import numpy as np from galsim_test_helpers import timer +import jax_galsim as galsim from jax_galsim.core.wrap_image import ( - expand_hermitian_x, - expand_hermitian_y, contract_hermitian_x, contract_hermitian_y, + expand_hermitian_x, + expand_hermitian_y, ) From c48ba92d5683910b63c5d10eedbab49ffc013a82 Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 1 Nov 2023 06:44:31 -0500 Subject: [PATCH 63/67] TST add test of fwd mode too --- tests/jax/test_image_wrapping.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/tests/jax/test_image_wrapping.py b/tests/jax/test_image_wrapping.py index bbf9e9f4..14bfe100 100644 --- a/tests/jax/test_image_wrapping.py +++ b/tests/jax/test_image_wrapping.py @@ -17,8 +17,6 @@ def test_image_wrapping_expand_contract(): # Hermitian, so we only need to keep around half of it. M = 38 N = 25 - K = 8 - L = 5 im = galsim.ImageCD(2 * M + 1, 2 * N + 1, xmin=-M, ymin=-N) # Explicitly Hermitian im2 = galsim.ImageCD( 2 * M + 1, N + 1, xmin=-M, ymin=0 @@ -29,8 +27,6 @@ def test_image_wrapping_expand_contract(): # print('im = ',im) # print('im2 = ',im2) # print('im3 = ',im3) - b = galsim.BoundsI(-K + 1, K, -L + 1, L) - im_test = galsim.ImageCD(b, init_value=0) for i in range(-M, M + 1): for j in range(-N, N + 1): # An arbitrary, complicated Hermitian function. @@ -46,9 +42,6 @@ def test_image_wrapping_expand_contract(): if i >= 0: im3[i, j] = val - ii = (i - b.xmin) % (b.xmax - b.xmin + 1) + b.xmin - jj = (j - b.ymin) % (b.ymax - b.ymin + 1) + b.ymin - im_test.addValue(ii, jj, val) # print("im = ",im.array) # Confirm that the image is Hermitian. @@ -137,9 +130,9 @@ def _wrapit(im): # make sure this runs p, grad = jax.vjp(_wrapit, im3) - grad = jax.jit(grad) grad(p) + jax.jvp(_wrapit, (im3,), (im3 * 2,)) def _wrapit(im): b3 = galsim.BoundsI(0, K, -L + 1, L) @@ -148,3 +141,4 @@ def _wrapit(im): # make sure this runs p, grad = jax.vjp(_wrapit, im3) grad(p) + jax.jvp(_wrapit, (im3,), (im3 * 2,)) From cd444a89fd9a8ce779e8817b4ea3a7b3fc74fa2c Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 9 Nov 2023 06:36:19 -0600 Subject: [PATCH 64/67] ENH respond to code review --- jax_galsim/core/utils.py | 6 ++++-- jax_galsim/image.py | 4 ++++ jax_galsim/interpolant.py | 3 +-- jax_galsim/interpolatedimage.py | 11 +++++------ tests/GalSim | 2 +- tests/conftest.py | 2 -- 6 files changed, 15 insertions(+), 13 deletions(-) diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index f6eb786c..554cc46a 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -4,11 +4,12 @@ import jax.numpy as jnp +@jax.jit def compute_major_minor_from_jacobian(jac): h1 = jnp.hypot(jac[0, 0] + jac[1, 1], jac[0, 1] - jac[1, 0]) h2 = jnp.hypot(jac[0, 0] - jac[1, 1], jac[0, 1] + jac[1, 0]) - major = 0.5 * abs(h1 + h2) - minor = 0.5 * abs(h1 - h2) + major = 0.5 * jnp.abs(h1 + h2) + minor = 0.5 * jnp.abs(h1 - h2) return major, minor @@ -94,6 +95,7 @@ def is_equal_with_arrays(x, y): elif (isinstance(x, jax.Array) and jnp.ndim(x) == 0) or ( isinstance(y, jax.Array) and jnp.ndim(y) == 0 ): + # this case covers comparing an array scalar to a python scalar or vice versa return jnp.array_equal(x, y) else: return x == y diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 293107c3..bfe32b7b 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -707,6 +707,8 @@ def calculate_fft(self): dk = jnp.pi / (No2 * dx) out = Image(BoundsI(0, No2, -No2, No2 - 1), dtype=np.complex128, scale=dk) + # we shift the image before and after the FFT to match the layout of the modes + # used by GalSim out._array = jnp.fft.fftshift( jnp.fft.rfft2(jnp.fft.fftshift(ximage.array)), axes=0 ) @@ -761,6 +763,7 @@ def calculate_inverse_fft(self): # For the inverse, we need a bit of extra space for the fft. out_extra = Image(BoundsI(-No2, No2 + 1, -No2, No2 - 1), dtype=float, scale=dx) + # we shift the image before and after the FFT to match the layout used by galsim out_extra._array = jnp.fft.fftshift( jnp.fft.irfft2(jnp.fft.fftshift(kimage.array, axes=0)) ) @@ -778,6 +781,7 @@ def good_fft_size(cls, input_size): going to be performing FFTs on an image, these will tend to be faster at performing the FFT. """ + # we use the math module here since this function should not be jitted. import math # Reference from GalSim C++ diff --git a/jax_galsim/interpolant.py b/jax_galsim/interpolant.py index acce2c21..9ed587e6 100644 --- a/jax_galsim/interpolant.py +++ b/jax_galsim/interpolant.py @@ -1383,8 +1383,7 @@ def tree_flatten(self): def tree_unflatten(cls, aux_data, children): """Recreates an instance of the class from flattened representation""" n = aux_data.pop("n") - ret = cls(n, **aux_data) - return ret + return cls(n, **aux_data) def __repr__(self): return "galsim.Lanczos(%r, %r, gsparams=%r)" % ( diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index c37fe2c3..0a9df248 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -648,6 +648,8 @@ def _xim(self): wcs=PixelScale(1.0), ) xim.setCenter(0, 0) + # after the call to setCenter you get a WCS with an offset in + # it instead of a pure pixel scale xim.wcs = PixelScale(1.0) # Now place the given image in the center of the padding image: @@ -671,6 +673,7 @@ def _kim(self): @lazy_property def _pos_neg_fluxes(self): # record pos and neg fluxes now too + # see code here: https://github.com/GalSim-developers/GalSim/blob/releases/2.5/src/SBInterpolatedImage.cpp#L1225 pflux = jnp.sum(jnp.where(self._pad_image.array > 0, self._pad_image.array, 0)) nflux = jnp.abs( jnp.sum(jnp.where(self._pad_image.array < 0, self._pad_image.array, 0)) @@ -698,9 +701,7 @@ def _flux_per_photon(self): @lazy_property def _maxk(self): if self._jax_aux_data["_force_maxk"]: - major, minor = compute_major_minor_from_jacobian( - self._jac_arr.reshape((2, 2)) - ) + _, minor = compute_major_minor_from_jacobian(self._jac_arr.reshape((2, 2))) return self._jax_aux_data["_force_maxk"] * minor else: return self._getMaxK(self._jax_aux_data["calculate_maxk"]) @@ -708,9 +709,7 @@ def _maxk(self): @lazy_property def _stepk(self): if self._jax_aux_data["_force_stepk"]: - major, minor = compute_major_minor_from_jacobian( - self._jac_arr.reshape((2, 2)) - ) + _, minor = compute_major_minor_from_jacobian(self._jac_arr.reshape((2, 2))) return self._jax_aux_data["_force_stepk"] * minor else: return self._getStepK(self._jax_aux_data["calculate_stepk"]) diff --git a/tests/GalSim b/tests/GalSim index bf287a91..1ed5131a 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit bf287a91314b56db67308e3946878ae2ab52a8c4 +Subproject commit 1ed5131a54b4dbee384fee6b82b5e2e478ef0492 diff --git a/tests/conftest.py b/tests/conftest.py index 6780f5fc..17175c1b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,8 +15,6 @@ import jax_galsim # noqa: E402 -config.update("jax_enable_x64", True) - # Identify the path to this current file test_directory = os.path.dirname(os.path.abspath(__file__)) From cb94ce62d5bac9dc2aee15b4a66a6a98cf91f33e Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 9 Nov 2023 15:50:29 -0600 Subject: [PATCH 65/67] ENH respond to CR --- CHANGELOG.md | 1 + jax_galsim/interpolatedimage.py | 78 ++++++++++++++++++++++++++++++++- 2 files changed, 78 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7e85e6e8..8b3a6818 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ * `Transformation` * `Shear` * `Convolve` + * `InterpolatedImage` and `Interpolant` * Added implementation of fundamental operations: * `drawImage` * `drawReal` diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index 0a9df248..acc8da4f 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -653,7 +653,6 @@ def _xim(self): xim.wcs = PixelScale(1.0) # Now place the given image in the center of the padding image: - # assert self._xim.bounds.includes(self._image.bounds) xim[self._image.bounds] = self._image return xim @@ -900,7 +899,27 @@ def _xValue_arr(x, y, x_offset, y_offset, xmin, ymin, arr, x_interpolant): @partial(jax.jit, static_argnames=("interp",)) def _draw_with_interpolant_xval(x, y, xmin, ymin, zp, interp): + """This helper function interpolates an image (`zp`) with an interpolant `interp` + at the pixel locations given by `x`, `y`. The lower-left corner of the image is + `xmin` / `ymin`. + + A more standard C/C++ code would have a set of nested for loops that iterates over each + location to interpolate and then over the nterpolation kernel. + + In JAX, we instead write things such that the loop over the points to be interpolated + is vectorized in the code. We represent the loops over the interpolation kernel as explicit + for loops. + """ + # the vectorization over the interpolation points is easier to think about + # if they are in a 1D array. So we use ravel to flatten them and then reshape + # at the end. orig_shape = x.shape + + # the variables here are + # x/y: the x/y coordinates of the points to be interpolated + # xi/yi: the index of the nerest pixel below the point + # xp/yp: the x/y coordinate of the nearest pixel below the point + # nx/ny: the size of the x/y arrays x = x.ravel() xi = jnp.floor(x - xmin).astype(jnp.int32) xp = xi + xmin @@ -911,37 +930,63 @@ def _draw_with_interpolant_xval(x, y, xmin, ymin, zp, interp): yp = yi + ymin ny = zp.shape[0] + # this function is the inner loop over the x direction + # the variables are + # i: the index of the location in the interpolation kernel + # z: the final interpolated values + # wy: the weight of the interpolation kernel in the y direction + # msky: a mask that is true if the y index is in bounds + # yind: the y index of the interpolation point needed by the kernel def _body_1d(i, args): z, wy, msky, yind, xi, xp, zp = args + # this block computes the x weight using the + # offset in the interpolation kernel i xind = xi + i mskx = (xind >= 0) & (xind < nx) _x = x - (xp + i) wx = interp._xval_noraise(_x) + # the actual interpolation is done here. + # we use jnp.where to only do the interpolation + # where the x and y indices are in bounds. + # the total weight is the product of the x and y weights. w = wx * wy msk = msky & mskx z += jnp.where(msk, zp[yind, xind] * w, 0) return [z, wy, msky, yind, xi, xp, zp] + # this function is the outer loop over the y direction + # the variables are + # i: the index of the location in the interpolation kernel + # z: the final interpolated values def _body(i, args): z, xi, yi, xp, yp, zp = args + + # this block computes the x weight using the + # offset in the interpolation kernel i yind = yi + i msk = (yind >= 0) & (yind < ny) _y = y - (yp + i) wy = interp._xval_noraise(_y) + + # this call computes the interpolant for each x locatoon that gets + # paired with this y location z = jax.lax.fori_loop( -interp.xrange, interp.xrange + 1, _body_1d, [z, wy, msk, yind, xi, xp, zp] )[0] return [z, xi, yi, xp, yp, zp] + # the actual loop call for y is here z = jax.lax.fori_loop( -interp.xrange, interp.xrange + 1, _body, [jnp.zeros(x.shape, dtype=float), xi, yi, xp, yp, zp], )[0] + + # we reshape on the way out to match the input shape return z.reshape(orig_shape) @@ -984,10 +1029,23 @@ def _kValue_arr( @partial(jax.jit, static_argnames=("interp",)) def _draw_with_interpolant_kval(kx, ky, kxmin, kymin, zp, interp): + """This function interpolates complex k-space images and follows the + same basic structure as _draw_with_interpolant_xval above. + + The key difference is that the k-space images are Hermitian and so + only half of the data is actually in memory. We account for this by + computing all of the interpolation weights and indicies as if we had + the full image. Then finally, if we need a value that is not in memory, + we get it from the values we have via the Hermitian symmetry. + """ + # all of the code below is almost line-for-line the same as the + # _draw_with_interpolant_xval function above. orig_shape = kx.shape kx = kx.ravel() kxi = jnp.floor(kx - kxmin).astype(jnp.int32) kxp = kxi + kxmin + # this is the number of pixels in the half image and is needed + # for computing values via Hermition symmetry below nkx_2 = zp.shape[1] - 1 nkx = nkx_2 * 2 @@ -1003,6 +1061,16 @@ def _body_1d(i, args): _kx = kx - (kxp + i) wkx = interp._xval_noraise(_kx) + # this is the key difference from the xval function + # we need to use the Hermitian symmetry to get the + # values that are not in memory + # in memory we have the values at nkx_2 to nkx - 1 + # the Hermitian symmetry is that + # f(ky, kx) = conjugate(f(-kx, -ky)) + # In indices this is a symmetric flip about the central + # pixels at kx = ky = 0. + # we do not need to mask any values that run off the edge of the image + # since we rewrap them using the periodicity of the image. val = jnp.where( kxind < nkx_2, zp[(nky - kyind) % nky, nkx - kxind - nkx_2].conjugate(), @@ -1066,6 +1134,11 @@ def _calculate_size_containing_flux(image, thresh): msk = fluxes >= -jnp.inf fluxes = jnp.where(msk, fluxes, jnp.max(fluxes)) d = jnp.arange(image.array.shape[0]) + 1.0 + # below we use a linear interpolation table to find the maximum size + # in pixels that contains a given flux (called thresh here) + # expfac controls how much we oversample the interpolation table + # in order to return a more accurate result + # we have it harded at 4 to compromise between speed and accuracy expfac = 4.0 dint = jnp.arange(image.array.shape[0] * expfac) / expfac + 1.0 fluxes = jnp.interp(dint, d, fluxes) @@ -1108,6 +1181,9 @@ def _find_maxk(kim, max_maxk, thresh): kx, ky = kim.get_pixel_centers() kx *= kim.scale ky *= kim.scale + # this minimum bounds the empirically determined + # maxk from the image (computed by _inner_comp_find_maxk) + # by max_maxk from above return jnp.minimum( _inner_comp_find_maxk(kim.array, thresh, kx, ky), max_maxk, From 0e865714c4846990c04867972a48794d62a9a354 Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 9 Nov 2023 15:51:53 -0600 Subject: [PATCH 66/67] ENH respond to CR --- jax_galsim/interpolatedimage.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index acc8da4f..c5886c40 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -1013,6 +1013,7 @@ def _kValue_arr( _uscale = 1.0 / (2.0 * jnp.pi) _maxk_xint = x_interpolant.urange() / _uscale / scale + # here we do the actual inteprolation in k space val = _draw_with_interpolant_kval( kxi, kyi, @@ -1022,6 +1023,9 @@ def _kValue_arr( k_interpolant, ) + # finally we multiply by the FFT of the real-space interpolation function + # and mask any values that are outside the range of the real-space interpolation + # FFT msk = (jnp.abs(kxi) <= _maxk_xint) & (jnp.abs(kyi) <= _maxk_xint) xint_val = x_interpolant._kval_noraise(kx) * x_interpolant._kval_noraise(ky) return jnp.where(msk, val * xint_val * pfac, 0.0) From e0fd4ce879700311d134897eb4ade1aaa7855709 Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 9 Nov 2023 15:58:22 -0600 Subject: [PATCH 67/67] DOC typo in comment --- jax_galsim/interpolatedimage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index c5886c40..586652a1 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -1142,7 +1142,7 @@ def _calculate_size_containing_flux(image, thresh): # in pixels that contains a given flux (called thresh here) # expfac controls how much we oversample the interpolation table # in order to return a more accurate result - # we have it harded at 4 to compromise between speed and accuracy + # we have it hard coded at 4 to compromise between speed and accuracy expfac = 4.0 dint = jnp.arange(image.array.shape[0] * expfac) / expfac + 1.0 fluxes = jnp.interp(dint, d, fluxes)