From cc4d7158b761756d70d051dfcd23df019fd7bda2 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 13 Nov 2023 07:05:24 -0600 Subject: [PATCH 01/85] ENH add photon shooting --- jax_galsim/__init__.py | 4 + jax_galsim/box.py | 10 + jax_galsim/deprecated.py | 1 + jax_galsim/exponential.py | 30 + jax_galsim/gaussian.py | 10 + jax_galsim/gsobject.py | 371 +++- jax_galsim/interpolatedimage.py | 87 +- jax_galsim/moffat.py | 21 + jax_galsim/noise.py | 3 +- jax_galsim/photon_array.py | 637 +++++++ jax_galsim/random.py | 95 +- jax_galsim/sensor.py | 40 + jax_galsim/sum.py | 54 + jax_galsim/transform.py | 7 + tests/galsim_tests_config.yaml | 11 +- tests/jax/galsim/test_draw_jax.py | 4 +- tests/jax/galsim/test_noise_jax.py | 30 +- tests/jax/galsim/test_photon_array_jax.py | 1873 +++++++++++++++++++++ tests/jax/galsim/test_random_jax.py | 205 ++- 19 files changed, 3292 insertions(+), 201 deletions(-) create mode 100644 jax_galsim/deprecated.py create mode 100644 jax_galsim/photon_array.py create mode 100644 jax_galsim/sensor.py create mode 100644 tests/jax/galsim/test_photon_array_jax.py diff --git a/jax_galsim/__init__.py b/jax_galsim/__init__.py index a2371e52..f15cfcc2 100644 --- a/jax_galsim/__init__.py +++ b/jax_galsim/__init__.py @@ -86,6 +86,10 @@ ) from .interpolatedimage import InterpolatedImage, _InterpolatedImage +# Photon Shooting +from .photon_array import PhotonArray +from .sensor import Sensor + # packages kept separate from . import bessel from . import fits diff --git a/jax_galsim/box.py b/jax_galsim/box.py index 3a8bce6d..5ac8ca00 100644 --- a/jax_galsim/box.py +++ b/jax_galsim/box.py @@ -6,6 +6,7 @@ from jax_galsim.core.draw import draw_by_kValue, draw_by_xValue from jax_galsim.core.utils import ensure_hashable from jax_galsim.gsobject import GSObject +from jax_galsim.random import UniformDeviate @_wraps(_galsim.Box) @@ -115,6 +116,15 @@ def tree_unflatten(cls, aux_data, children): **aux_data ) + @_wraps(_galsim.Box._shoot) + def _shoot(self, photons, rng): + ud = UniformDeviate(rng) + + # this does not fill arrays like in galsim + photons.x = (ud.generate(photons.x) - 0.5) * self.width + photons.y = (ud.generate(photons.y) - 0.5) * self.height + photons.flux = self.flux / photons.size() + @_wraps(_galsim.Pixel) @register_pytree_node_class diff --git a/jax_galsim/deprecated.py b/jax_galsim/deprecated.py new file mode 100644 index 00000000..d537605a --- /dev/null +++ b/jax_galsim/deprecated.py @@ -0,0 +1 @@ +from galsim.deprecated import depr # noqa: F401 diff --git a/jax_galsim/exponential.py b/jax_galsim/exponential.py index f0eb486f..79f1dc44 100644 --- a/jax_galsim/exponential.py +++ b/jax_galsim/exponential.py @@ -6,6 +6,8 @@ from jax_galsim.core.draw import draw_by_kValue, draw_by_xValue from jax_galsim.core.utils import ensure_hashable from jax_galsim.gsobject import GSObject +from jax_galsim.random import UniformDeviate +from jax_galsim.utilities import lazy_property @_wraps(_galsim.Exponential) @@ -145,3 +147,31 @@ def withFlux(self, flux): return Exponential( scale_radius=self.scale_radius, flux=flux, gsparams=self.gsparams ) + + @lazy_property + def _shoot_cdf(self): + # store this for later + _rmax = -jnp.log(self.gsparams.shoot_accuracy) + _umax = 1.0 - jnp.exp(-_rmax) + _u_cdf = jnp.linspace(0, _umax, 10000) + _cdf = _u_cdf - (_u_cdf - 1) * jnp.log(1 - _u_cdf) + _cdf /= _cdf[-1] + return _u_cdf, _cdf + + @_wraps(_galsim.Exponential._shoot) + def _shoot(self, photons, rng): + ud = UniformDeviate(rng) + + u = ud.generate( + photons.x + ) # this does not fill arrays like in galsim so is safe + _u_cdf, _cdf = self._shoot_cdf + u = jnp.interp(u, _cdf, _u_cdf) + r = -jnp.log(1.0 - u) * self._r0 + + ang = ( + ud.generate(photons.x) * 2.0 * jnp.pi + ) # this does not fill arrays like in galsim so is safe + photons.x = r * jnp.cos(ang) + photons.y = r * jnp.sin(ang) + photons.flux = self.flux / photons.size() diff --git a/jax_galsim/gaussian.py b/jax_galsim/gaussian.py index c2f78ba1..eccd0eeb 100644 --- a/jax_galsim/gaussian.py +++ b/jax_galsim/gaussian.py @@ -6,6 +6,7 @@ from jax_galsim.core.draw import draw_by_kValue, draw_by_xValue from jax_galsim.core.utils import ensure_hashable from jax_galsim.gsobject import GSObject +from jax_galsim.random import GaussianDeviate @_wraps(_galsim.Gaussian) @@ -146,3 +147,12 @@ def _drawKImage(self, image, jac=None): @_wraps(_galsim.Gaussian.withFlux) def withFlux(self, flux): return Gaussian(sigma=self.sigma, flux=flux, gsparams=self.gsparams) + + @_wraps(_galsim.Gaussian._shoot) + def _shoot(self, photons, rng): + gd = GaussianDeviate(rng, sigma=self.sigma) + + # this does not fill arrays like in galsim + photons.x = gd.generate(photons.x) + photons.y = gd.generate(photons.y) + photons.flux = self.flux / photons.size() diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index 1e66cbba..fd8285e0 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -4,9 +4,20 @@ import numpy as np from jax._src.numpy.util import _wraps +import jax_galsim.photon_array as pa from jax_galsim.core.utils import is_equal_with_arrays +from jax_galsim.errors import ( + GalSimError, + GalSimIncompatibleValuesError, + GalSimNotImplementedError, + GalSimRangeError, + GalSimValueError, + galsim_warn, +) from jax_galsim.gsparams import GSParams from jax_galsim.position import Position, PositionD, PositionI +from jax_galsim.random import BaseDeviate, PoissonDeviate +from jax_galsim.sensor import Sensor from jax_galsim.utilities import parse_pos_args @@ -15,6 +26,18 @@ class GSObject: def __init__(self, *, gsparams=None, **params): self._params = params # Dictionary containing all traced parameters self._gsparams = GSParams.check(gsparams) # Non-traced static parameters + self._workspace = {} # used by lazy_property + + def __getstate__(self): + d = self.__dict__.copy() + d["had_workspace"] = "_workspace" in d + d.pop("_workspace", None) + return d + + def __setstate__(self, d): + if d.pop("had_workspace", False): + d["_workspace"] = {} + self.__dict__ = d @property def flux(self): @@ -615,7 +638,12 @@ def _determine_wcs(self, scale, wcs, image, default_wcs=None): @_wraps( _galsim.GSObject.drawImage, - lax_description="The JAX-GalSim version of `drawImage` does not do extensive (any?) checking of the input settings.", + lax_description="""\ +The JAX-GalSim version of `drawImage` does not + + - do extensive (any?) checking of the input settings. + - does not support the deprecated `surface_ops` argument +""", ) def drawImage( self, @@ -645,7 +673,6 @@ def drawImage( save_photons=False, bandpass=None, setup_only=False, - surface_ops=None, ): from jax_galsim.box import Pixel from jax_galsim.convolve import Convolve @@ -655,6 +682,11 @@ def drawImage( if image is not None and not isinstance(image, Image): raise TypeError("image is not an Image instance", image) + if method == "phot" and save_photons and maxN is not None: + raise GalSimIncompatibleValuesError( + "Setting maxN is incompatible with save_photons=True" + ) + # Figure out what wcs we are going to use. wcs = self._determine_wcs(scale, wcs, image) @@ -715,26 +747,44 @@ def drawImage( image.added_flux = 0.0 return image - if method == "phot": - raise NotImplementedError("Phot shooting not yet implemented in drawImage") - if sensor is not None: - raise NotImplementedError("Sensor not yet implemented in drawImage") - # Making a view of the image lets us change the center without messing up the original. original_center = image.center wcs = image.wcs image.setCenter(0, 0) image.wcs = PixelScale(1.0) - if prof.is_analytic_x: - added_photons = prof.drawReal(image, add_to_image) + if method == "phot": + added_photons, photons = prof.drawPhot( + image, + gain, + add_to_image, + n_photons, + rng, + max_extra_noise, + poisson_flux, + sensor, + photon_ops, + maxN, + original_center, + local_wcs, + ) else: - added_photons = prof.drawFFT(image, add_to_image) + if sensor is not None: + raise NotImplementedError( + "Sensor not yet implemented in drawImage for method != 'phot'." + ) + + if prof.is_analytic_x: + added_photons = prof.drawReal(image, add_to_image) + else: + added_photons = prof.drawFFT(image, add_to_image) image.added_flux = added_photons / flux_scale # Restore the original center and wcs image.shift(original_center) image.wcs = wcs + if save_photons: + image.photons = photons # Update image_in to satisfy GalSim API image_in._array = image._array @@ -742,6 +792,9 @@ def drawImage( image_in._bounds = image._bounds image_in.wcs = image.wcs image_in._dtype = image._dtype + if save_photons: + image_in.photons = photons + return image @_wraps(_galsim.GSObject.drawReal) @@ -1032,6 +1085,304 @@ def _drawKImage( "%s does not implement drawKImage" % self.__class__.__name__ ) + @_wraps(_galsim.GSObject._calculate_nphotons) + def _calculate_nphotons(self, n_photons, poisson_flux, max_extra_noise, rng): + # For profiles that are positive definite, then N = flux. Easy. + # + # However, some profiles shoot some of their photons with negative flux. This means that + # we need a few more photons to get the right S/N = sqrt(flux). Take eta to be the + # fraction of shot photons that have negative flux. + # + # S^2 = (N+ - N-)^2 = (N+ + N- - 2N-)^2 = (Ntot - 2N-)^2 = Ntot^2(1 - 2 eta)^2 + # N^2 = Var(S) = (N+ + N-) = Ntot + # + # So flux = (S/N)^2 = Ntot (1-2eta)^2 + # Ntot = flux / (1-2eta)^2 + # + # However, if each photon has a flux of 1, then S = (1-2eta) Ntot = flux / (1-2eta). + # So in fact, each photon needs to carry a flux of g = 1-2eta to get the right + # total flux. + # + # That's all the easy case. The trickier case is when we are sky-background dominated. + # Then we can usually get away with fewer shot photons than the above. In particular, + # if the noise from the photon shooting is much less than the sky noise, then we can + # use fewer shot photons and essentially have each photon have a flux > 1. This is ok + # as long as the additional noise due to this approximation is "much less than" the + # noise we'll be adding to the image for the sky noise. + # + # Let's still have Ntot photons, but now each with a flux of g. And let's look at the + # noise we get in the brightest pixel that has a nominal total flux of Imax. + # + # The number of photons hitting this pixel will be Imax/flux * Ntot. + # The variance of this number is the same thing (Poisson counting). + # So the noise in that pixel is: + # + # N^2 = Imax/flux * Ntot * g^2 + # + # And the signal in that pixel will be: + # + # S = Imax/flux * (N+ - N-) * g which has to equal Imax, so + # g = flux / Ntot(1-2eta) + # N^2 = Imax/Ntot * flux / (1-2eta)^2 + # + # As expected, we see that lowering Ntot will increase the noise in that (and every + # other) pixel. + # The input max_extra_noise parameter is the maximum value of spurious noise we want + # to allow. + # + # So setting N^2 = Imax + nu, we get + # + # Ntot = flux / (1-2eta)^2 / (1 + nu/Imax) + # g = (1 - 2eta) * (1 + nu/Imax) + # + # Returns the total flux placed inside the image bounds by photon shooting. + # + + flux = self.flux + if flux == 0.0: + return 0, 1.0 + + # The _flux_per_photon property is (1-2eta) + # This factor will already be accounted for by the shoot function, so don't include + # that as part of our scaling here. There may be other adjustments though, so g=1 here. + eta_factor = self._flux_per_photon + mod_flux = flux / (eta_factor * eta_factor) + g = 1.0 + + # If requested, let the target flux value vary as a Poisson deviate + if poisson_flux: + # If we have both positive and negative photons, then the mix of these + # already gives us some variation in the flux value from the variance + # of how many are positive and how many are negative. + # The number of negative photons varies as a binomial distribution. + # = eta * Ntot * g + # = (1-eta) * Ntot * g + # = (1-2eta) * Ntot * g = flux + # Var(F-) = eta * (1-eta) * Ntot * g^2 + # F+ = Ntot * g - F- is not an independent variable, so + # Var(F+ - F-) = Var(Ntot*g - 2*F-) + # = 4 * Var(F-) + # = 4 * eta * (1-eta) * Ntot * g^2 + # = 4 * eta * (1-eta) * flux + # We want the variance to be equal to flux, so we need an extra: + # delta Var = (1 - 4*eta + 4*eta^2) * flux + # = (1-2eta)^2 * flux + absflux = abs(flux) + mean = eta_factor * eta_factor * absflux + pd = PoissonDeviate(rng, mean) + pd_val = pd() - mean + absflux + ratio = pd_val / absflux + g *= ratio + mod_flux *= ratio + + if n_photons == 0.0: + n_photons = abs(mod_flux) + if max_extra_noise > 0.0: + gfactor = 1.0 + max_extra_noise / abs(self.max_sb) + n_photons /= gfactor + g *= gfactor + + # Make n_photons an integer. + iN = int(n_photons + 0.5) + + return iN, g + + @_wraps( + _galsim.GSObject.makePhot, + lax_description="The JAX-GalSim version of `makePhot` does not support the deprecated surface_ops argument.", + ) + def makePhot( + self, + n_photons=0, + rng=None, + max_extra_noise=0.0, + poisson_flux=None, + photon_ops=(), + local_wcs=None, + surface_ops=None, + ): + if surface_ops is not None: + from .deprecated import depr + + depr("surface_ops", 2.3, "photon_ops") + photon_ops = surface_ops + + # Make sure the type of n_photons is correct and has a valid value: + if not n_photons >= 0.0: + raise GalSimRangeError("Invalid n_photons < 0.", n_photons, 0.0, None) + + if poisson_flux is None: + # If n_photons is given, poisson_flux = False + poisson_flux = n_photons == 0.0 + + # Check that either n_photons is set to something or flux is set to something + if n_photons == 0.0 and self.flux == 1.0: + galsim_warn( + "Warning: drawImage for object with flux == 1, area == 1, and " + "exptime == 1, but n_photons == 0. This will only shoot a single photon." + ) + + Ntot, g = self._calculate_nphotons( + n_photons, poisson_flux, max_extra_noise, rng + ) + + try: + photons = self.shoot(Ntot, rng) + except (GalSimError, NotImplementedError) as e: + raise GalSimNotImplementedError( + "Unable to draw this GSObject with photon shooting. Perhaps it " + "is a Deconvolve or is a compound including one or more " + "Deconvolve objects.\nOriginal error: %r" % (e) + ) + + if g != 1.0: + photons.scaleFlux(g) + + for op in photon_ops: + op.applyTo(photons, local_wcs, rng) + + return photons + + @_wraps( + _galsim.GSObject.drawPhot, + lax_description="The JAX-GalSim version of `drawPhot` does not support the deprecated surface_ops argument.", + ) + def drawPhot( + self, + image, + gain=1.0, + add_to_image=False, + n_photons=0, + rng=None, + max_extra_noise=0.0, + poisson_flux=None, + sensor=None, + photon_ops=(), + maxN=None, + orig_center=PositionI(0, 0), + local_wcs=None, + ): + # Make sure the type of n_photons is correct and has a valid value: + if not n_photons >= 0.0: + raise GalSimRangeError("Invalid n_photons < 0.", n_photons, 0.0, None) + + if poisson_flux is None: + # If n_photons is given, poisson_flux = False + poisson_flux = n_photons == 0.0 + + # Check that either n_photons is set to something or flux is set to something + if n_photons == 0.0 and self.flux == 1.0: + galsim_warn( + "Warning: drawImage for object with flux == 1, area == 1, and " + "exptime == 1, but n_photons == 0. This will only shoot a single photon." + ) + + # Make sure the image is set up to have unit pixel scale and centered at 0,0. + if image.wcs is None or not image.wcs._isPixelScale: + raise GalSimValueError( + "drawPhot requires an image with a PixelScale wcs", image + ) + + if sensor is None: + sensor = Sensor() + elif not isinstance(sensor, Sensor): + raise TypeError("The sensor provided is not a Sensor instance") + + Ntot, g = self._calculate_nphotons( + n_photons, poisson_flux, max_extra_noise, rng + ) + + if gain != 1.0: + g /= gain + + # total flux falling inside image bounds, this will be returned on exit. + added_flux = 0.0 + + if maxN is None: + maxN = Ntot + + if not add_to_image: + image.setZero() + + # Nleft is the number of photons remaining to shoot. + Nleft = Ntot + photons = None # Just in case Nleft is already 0. + resume = False + while Nleft > 0: + # Shoot at most maxN at a time + thisN = min(maxN, Nleft) + + try: + photons = self.shoot(thisN, rng) + except (GalSimError, NotImplementedError) as e: + raise GalSimNotImplementedError( + "Unable to draw this GSObject with photon shooting. Perhaps it " + "is a Deconvolve or is a compound including one or more " + "Deconvolve objects.\nOriginal error: %r" % (e) + ) + + if g != 1.0 or thisN != Ntot: + photons.scaleFlux(g * thisN / Ntot) + + if image.scale != 1.0: + photons.scaleXY( + 1.0 / image.scale + ) # Convert x,y to image coords if necessary + + for op in photon_ops: + op.applyTo(photons, local_wcs, rng) + + if image.dtype in (np.float32, np.float64): + added_flux += sensor.accumulate( + photons, image, orig_center, resume=resume + ) + resume = ( + True # Resume from this point if there are any further iterations. + ) + else: + # Need a temporary + from jax_galsim.image import ImageD + + im1 = ImageD(bounds=image.bounds) + added_flux += sensor.accumulate(photons, im1, orig_center) + image += im1 + + Nleft -= thisN + + return added_flux, photons + + @_wraps(_galsim.GSObject.shoot) + def shoot(self, n_photons, rng=None): + photons = pa.PhotonArray(n_photons) + if n_photons == 0: + # It's ok to shoot 0, but downstream can have problems with it, so just stop now. + return photons + if rng is None: + rng = BaseDeviate() + + self._shoot(photons, rng) + return photons + + @_wraps(_galsim.GSObject._shoot) + def _shoot(self, photons, rng): + raise NotImplementedError( + "%s does not implement shoot" % self.__class__.__name__ + ) + + @_wraps(_galsim.GSObject.applyTo) + def applyTo(self, photon_array, local_wcs=None, rng=None): + p1 = pa.PhotonArray(len(photon_array)) + if photon_array.hasAllocatedWavelengths(): + p1._wave = photon_array._wave + if photon_array.hasAllocatedPupil(): + p1._pupil_u = photon_array._pupil_u + p1._pupil_v = photon_array._pupil_v + if photon_array.hasAllocatedTimes(): + p1._time = photon_array._time + obj = local_wcs.toImage(self) if local_wcs is not None else self + obj._shoot(p1, rng) + photon_array.convolve(p1, rng) + def tree_flatten(self): """This function flattens the GSObject into a list of children nodes that will be traced by JAX and auxiliary static data.""" diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index c37fe2c3..750f3d30 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -6,6 +6,7 @@ import galsim as _galsim import jax import jax.numpy as jnp +import jax.random as jrng from galsim.errors import ( GalSimIncompatibleValuesError, GalSimRangeError, @@ -24,6 +25,7 @@ from jax_galsim.image import Image from jax_galsim.interpolant import Quintic from jax_galsim.position import PositionD +from jax_galsim.random import UniformDeviate from jax_galsim.transform import Transformation from jax_galsim.utilities import convert_interpolant, lazy_property from jax_galsim.wcs import BaseWCS, PixelScale @@ -668,33 +670,6 @@ def _pad_image(self): def _kim(self): return self._xim.calculate_fft() - @lazy_property - def _pos_neg_fluxes(self): - # record pos and neg fluxes now too - pflux = jnp.sum(jnp.where(self._pad_image.array > 0, self._pad_image.array, 0)) - nflux = jnp.abs( - jnp.sum(jnp.where(self._pad_image.array < 0, self._pad_image.array, 0)) - ) - pint = self._x_interpolant.positive_flux - nint = self._x_interpolant.negative_flux - pint2d = pint * pint + nint * nint - nint2d = 2 * pint * nint - return [ - pint2d * pflux + nint2d * nflux, - pint2d * nflux + nint2d * pflux, - ] - - @property - def _positive_flux(self): - return self._pos_neg_fluxes[0] - - @property - def _negative_flux(self): - return self._pos_neg_fluxes[1] - - def _flux_per_photon(self): - return self._calculate_flux_per_photon() - @lazy_property def _maxk(self): if self._jax_aux_data["_force_maxk"]: @@ -797,9 +772,6 @@ def _kValue(self, kpos): self._k_interpolant, )[0] - def _shoot(self, photons, rng): - raise NotImplementedError("Photon shooting not implemented.") - def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): jacobian = jnp.eye(2) if jac is None else jac @@ -858,6 +830,61 @@ def _drawKImage(self, image, jac=None): # Return an image return Image(array=im, bounds=image.bounds, wcs=image.wcs, check_bounds=False) + @lazy_property + def _pos_neg_fluxes(self): + # record pos and neg fluxes now too + pflux = jnp.sum(jnp.where(self._pad_image.array > 0, self._pad_image.array, 0)) + nflux = jnp.abs( + jnp.sum(jnp.where(self._pad_image.array < 0, self._pad_image.array, 0)) + ) + pint = self._x_interpolant.positive_flux + nint = self._x_interpolant.negative_flux + pint2d = pint * pint + nint * nint + nint2d = 2 * pint * nint + return [ + pint2d * pflux + nint2d * nflux, + pint2d * nflux + nint2d * pflux, + ] + + @property + def _positive_flux(self): + return self._pos_neg_fluxes[0] + + @property + def _negative_flux(self): + return self._pos_neg_fluxes[1] + + def _flux_per_photon(self): + return self._calculate_flux_per_photon() + + def _shoot(self, photons, rng): + # we first draw the index location from the image + img = self._pad_image + subkey = rng._state.split_one() + inds = jrng.choice( + subkey, + img.array.size, + shape=(photons.size(),), + replace=True, + # we use abs here since some of the pixels could be negative + # and for a noise image this procedure results in a fair + # sampling of the noise + p=jnp.abs(img.array.ravel()) / jnp.sum(jnp.abs(img.array)), + ).astype(int) + yinds, xinds = jnp.unravel_index(inds, img.array.shape) + + xedges = jnp.arange(img.bounds.xmin, img.bounds.xmax + 2) - 0.5 + yedges = jnp.arange(img.bounds.ymin, img.bounds.ymax + 2) - 0.5 + + # now we draw the position within the pixel + ud = UniformDeviate(rng) + photons.x = ud.generate(photons.x) + xedges[xinds] + photons.y = ud.generate(photons.y) + yedges[yinds] + photons.flux = jnp.sign(img.array.ravel())[inds] * self._flux_per_photon() + + # now we convolve with the x interpolant + raise NotImplementedError("InterpolatedImages do not support photon shooting!") + @_wraps(_galsim._InterpolatedImage) def _InterpolatedImage( diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index a13dcd60..0b71cc4d 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -12,6 +12,7 @@ from jax_galsim.core.utils import bisect_for_root, ensure_hashable from jax_galsim.gsobject import GSObject from jax_galsim.position import PositionD +from jax_galsim.random import UniformDeviate @jax.jit @@ -380,3 +381,23 @@ def withFlux(self, flux): flux=flux, gsparams=self.gsparams, ) + + @_wraps(_galsim.Moffat.shoot) + def _shoot(self, photons, rng): + # from the galsim C++ in SBMoffat.cpp + ud = UniformDeviate(rng) + + # First get a point uniformly distributed on unit circle + theta = ud.generate(photons.x) * 2.0 * jnp.pi + rsq = ud.generate( + photons.x + ) # cumulative dist function P( self.size(): + raise GalSimValueError( + "The given rhs does not fit into this array starting at %d" % istart, + rhs, + ) + self._x = self._x.at[istart : istart + rhs.size()].set(rhs.x) + self._y = self._y.at[istart : istart + rhs.size()].set(rhs.y) + self._flux = self._flux.at[istart : istart + rhs.size()].set(rhs.flux) + if rhs.hasAllocatedAngles(): + self.allocateAngles() + self._dxdz = self._dxdz.at[istart : istart + rhs.size()].set(rhs.dxdz) + self._dydz = self._dydz.at[istart : istart + rhs.size()].set(rhs.dydz) + if rhs.hasAllocatedWavelengths(): + self.allocateWavelengths() + self._wave = self._wave.at[istart : istart + rhs.size()].set(rhs.wavelength) + if rhs.hasAllocatedPupil(): + self.allocatePupil() + self._pupil_u = self._pupil_u.at[istart : istart + rhs.size()].set( + rhs.pupil_u + ) + self._pupil_v = self._pupil_v.at[istart : istart + rhs.size()].set( + rhs.pupil_v + ) + if rhs.hasAllocatedTimes(): + self.allocateTimes() + self._time = self._time.at[istart : istart + rhs.size()].set(rhs.time) + + def convolve(self, rhs, rng=None): + """Convolve this `PhotonArray` with another. + + ..note:: + + If both self and rhs have wavelengths, angles, pupil coordinates or times assigned, + then the values from the first array (i.e. self) take precedence. + """ + if rhs.size() != self.size(): + raise GalSimIncompatibleValuesError( + "PhotonArray.convolve with unequal size arrays", self_pa=self, rhs=rhs + ) + + if rhs.isCorrelated() and self.isCorrelated(): + rng = BaseDeviate(rng) + subkey = rng._state.split_one() + sinds = jrng.choice( + subkey, + self.size(), + shape=(self.size(),), + replace=False, + ) + else: + sinds = jnp.arange(self.size()) + + if rhs.hasAllocatedAngles() and not self.hasAllocatedAngles(): + self.dxdz = rhs.dxdz[sinds] + self.dydz = rhs.dydz[sinds] + + if rhs.hasAllocatedWavelengths() and not self.hasAllocatedWavelengths(): + self.wavelength = rhs.wavelength + + if rhs.hasAllocatedPupil() and not self.hasAllocatedPupil(): + self.pupil_u = rhs.pupil_u[sinds] + self.pupil_v = rhs.pupil_v[sinds] + + if rhs.hasAllocatedTimes() and not self.hasAllocatedTimes(): + self.time = rhs.time[sinds] + + if rhs.isCorrelated(): + self.setCorrelated() + + self._x = self._x + rhs._x[sinds] + self._y = self._y + rhs._y[sinds] + self._flux = self._flux * rhs._flux[sinds] * self.size() + + def __repr__(self): + s = "galsim.PhotonArray(%d, x=array(%r), y=array(%r), flux=array(%r)" % ( + int(cast_to_python_float(self.size())), + self.x.tolist(), + self.y.tolist(), + self.flux.tolist(), + ) + if self.hasAllocatedAngles(): + s += ", dxdz=array(%r), dydz=array(%r)" % ( + self.dxdz.tolist(), + self.dydz.tolist(), + ) + if self.hasAllocatedWavelengths(): + s += ", wavelength=array(%r)" % (self.wavelength.tolist()) + if self.hasAllocatedPupil(): + s += ", pupil_u=array(%r), pupil_v=array(%r)" % ( + self.pupil_u.tolist(), + self.pupil_v.tolist(), + ) + if self.hasAllocatedTimes(): + s += ", time=array(%r)" % (self.time.tolist()) + s += ")" + return s + + def __str__(self): + return "galsim.PhotonArray(%d)" % int(cast_to_python_float(self.size())) + + __hash__ = None + + def __eq__(self, other): + return self is other or ( + isinstance(other, PhotonArray) + and jnp.array_equal(self.x, other.x) + and jnp.array_equal(self.y, other.y) + and jnp.array_equal(self.flux, other.flux) + and self.hasAllocatedAngles() == other.hasAllocatedAngles() + and self.hasAllocatedWavelengths() == other.hasAllocatedWavelengths() + and self.hasAllocatedPupil() == other.hasAllocatedPupil() + and self.hasAllocatedTimes() == other.hasAllocatedTimes() + and ( + jnp.array_equal(self.dxdz, other.dxdz) + if self.hasAllocatedAngles() + else True + ) + and ( + jnp.array_equal(self.dydz, other.dydz) + if self.hasAllocatedAngles() + else True + ) + and ( + jnp.array_equal(self.wavelength, other.wavelength) + if self.hasAllocatedWavelengths() + else True + ) + and ( + jnp.array_equal(self.pupil_u, other.pupil_u) + if self.hasAllocatedPupil() + else True + ) + and ( + jnp.array_equal(self.pupil_v, other.pupil_v) + if self.hasAllocatedPupil() + else True + ) + and ( + jnp.array_equal(self.time, other.time) + if self.hasAllocatedTimes() + else True + ) + ) + + def __ne__(self, other): + return not self == other + + def addTo(self, image): + """Add flux of photons to an image by binning into pixels. + + Photons in this `PhotonArray` are binned into the pixels of the input + `Image` and their flux summed into the pixels. The `Image` is assumed to represent + surface brightness, so photons' fluxes are divided by image pixel area. + Photons past the edges of the image are discarded. + + Parameters: + image: The `Image` to which the photons' flux will be added. + + Returns: + the total flux of photons the landed inside the image bounds. + """ + if not image.bounds.isDefined(): + raise GalSimUndefinedBoundsError( + "Attempting to PhotonArray::addTo an Image with undefined Bounds" + ) + xbins = jnp.arange(image.bounds.xmin, image.bounds.xmax + 2) - 0.5 + ybins = jnp.arange(image.bounds.ymin, image.bounds.ymax + 2) - 0.5 + im = jnp.histogram2d( + self._x, self._y, bins=(xbins, ybins), weights=self._flux, density=False + )[0] + image._array += im + return im.sum() + + @classmethod + def makeFromImage(cls, image, max_flux=1.0, rng=None): + """Turn an existing `Image` into a `PhotonArray` that would accumulate into this image. + + The flux in each non-zero pixel will be turned into 1 or more photons with random positions + within the pixel bounds. The ``max_flux`` parameter (which defaults to 1) sets an upper + limit for the absolute value of the flux of any photon. Pixels with abs values > maxFlux + will spawn multiple photons. + + Parameters: + image: The image to turn into a `PhotonArray` + max_flux: The maximum flux value to use for any output photon [default: 1] + rng: A `BaseDeviate` to use for the random number generation [default: None] + + Returns: + a `PhotonArray` + """ + + if max_flux <= 0: + raise GalSimRangeError("max_flux must be positive", max_flux, 0.0) + + n_per = jnp.clip(jnp.ceil(jnp.abs(image.array) / max_flux), 1).astype(int) + flux_per = (image.array / n_per).ravel() + n_per = n_per.ravel() + flux_per = flux_per.ravel() + inds = jnp.arange(image.array.size) + inds = jnp.repeat(inds, n_per) + yinds, xinds = jnp.unravel_index(inds, image.array.shape) + + xedges = jnp.arange(image.bounds.xmin, image.bounds.xmax + 2) - 0.5 + yedges = jnp.arange(image.bounds.ymin, image.bounds.ymax + 2) - 0.5 + + # now we draw the position within the pixel + ud = UniformDeviate(rng) + photons = cls(n_per.sum()) + photons.x = ud.generate(photons.x) + xedges[xinds] + photons.y = ud.generate(photons.y) + yedges[yinds] + photons.flux = flux_per[inds] + + if image.scale is not None: + photons.scaleXY(image.scale) + + return photons + + def write(self, file_name): + """Write a `PhotonArray` to a FITS file. + + The output file will be a FITS binary table with a row for each photon in the `PhotonArray`. + Columns will include 'id' (sequential from 1 to nphotons), 'x', 'y', and 'flux'. + Additionally, the columns 'dxdz', 'dydz', and 'wavelength' will be included if they are + set for this `PhotonArray` object. + + The file can be read back in with the classmethod `PhotonArray.read`:: + + >>> photons.write('photons.fits') + >>> photons2 = galsim.PhotonArray.read('photons.fits') + + Parameters: + file_name: The file name of the output FITS file. + """ + import numpy as np + from jax_galsim import fits + + cols = [] + cols.append(pyfits.Column(name="id", format="J", array=range(self.size()))) + cols.append(pyfits.Column(name="x", format="D", array=np.array(self.x))) + cols.append(pyfits.Column(name="y", format="D", array=np.array(self.y))) + cols.append(pyfits.Column(name="flux", format="D", array=np.array(self.flux))) + + if self.hasAllocatedAngles(): + cols.append( + pyfits.Column(name="dxdz", format="D", array=np.array(self.dxdz)) + ) + cols.append( + pyfits.Column(name="dydz", format="D", array=np.array(self.dydz)) + ) + + if self.hasAllocatedWavelengths(): + cols.append( + pyfits.Column( + name="wavelength", format="D", array=np.array(self.wavelength) + ) + ) + + if self.hasAllocatedPupil(): + cols.append( + pyfits.Column(name="pupil_u", format="D", array=np.array(self.pupil_u)) + ) + cols.append( + pyfits.Column(name="pupil_v", format="D", array=np.array(self.pupil_v)) + ) + + if self.hasAllocatedTimes(): + cols.append( + pyfits.Column(name="time", format="D", array=np.array(self.time)) + ) + + cols = pyfits.ColDefs(cols) + table = pyfits.BinTableHDU.from_columns(cols) + fits.writeFile(file_name, table) + + @classmethod + def read(cls, file_name): + """Create a `PhotonArray`, reading the photon data from a FITS file. + + The file being read in is not arbitrary. It is expected to be a file that was written + out with the `PhotonArray.write` method.:: + + >>> photons.write('photons.fits') + >>> photons2 = galsim.PhotonArray.read('photons.fits') + + Parameters: + file_name: The file name of the input FITS file. + """ + with pyfits.open(file_name) as fits: + data = fits[1].data + N = len(data) + names = data.columns.names + + photons = cls( + N, + x=jnp.array(data["x"]), + y=jnp.array(data["y"]), + flux=jnp.array(data["flux"]), + ) + if "dxdz" in names: + photons.dxdz = jnp.array(data["dxdz"]) + photons.dydz = jnp.array(data["dydz"]) + if "wavelength" in names: + photons.wavelength = jnp.array(data["wavelength"]) + if "pupil_u" in names: + photons.pupil_u = jnp.array(data["pupil_u"]) + photons.pupil_v = jnp.array(data["pupil_v"]) + if "time" in names: + photons.time = jnp.array(data["time"]) + return photons diff --git a/jax_galsim/random.py b/jax_galsim/random.py index 0ee9b2fe..4d84c80f 100644 --- a/jax_galsim/random.py +++ b/jax_galsim/random.py @@ -15,10 +15,43 @@ from jax_galsim.core.utils import ensure_hashable -LAX_FUNCTIONAL_RNG = ( - "The JAX version of the this class is purely function and thus cannot " - "share state with any other version of this class. Also no type checking is done on the inputs." -) +LAX_FUNCTIONAL_RNG = """\ +JAX-GalSim PRNGs have some support linking states, but it may not always and/or may cause issues. + + - Linked states across JIT boundaries or devices will not work. + - Within a single routine linking may work. + - You may encounter errors related to global side effects for some combinations of linked states + and jitted/vmapped routines. +""" + + +@register_pytree_node_class +class _DeviateState: + """This class holds the RNG state for a JAX-GalSim PRNG. + + **This class is not intended to be used directly.** + + Parameters + ---------- + key : jax.random.PRNGKey + The JAX PRNG key made via `jrandom.PRNGKey` or equivalent. + """ + + def __init__(self, key): + self.key = key + + def split_one(self): + self.key, subkey = jrandom.split(self.key) + return subkey + + def tree_flatten(self): + children = (self.key,) + aux_data = {} + return (children, aux_data) + + @classmethod + def tree_unflatten(cls, aux_data, children): + return cls(children[0]) @_wraps( @@ -52,31 +85,38 @@ def seed(self, seed=0): @_wraps(_galsim.BaseDeviate._seed) def _seed(self, seed=0): _initial_seed = seed or secrets.randbelow(2**31) - self._key = jrandom.PRNGKey(_initial_seed) + self._state.key = jrandom.PRNGKey(_initial_seed) @_wraps( _galsim.BaseDeviate.reset, - lax_description=( - "The JAX version of this method does no type checking. Also, the JAX version of this " - "class cannot be linked to another JAX version of this class so ``reset`` is equivalent " - "to ``seed``. If another ``BaseDeviate`` is supplied, that deviate's current state is used." - ), + lax_description=("The JAX version of this method does no type checking."), ) def reset(self, seed=None): - if isinstance(seed, BaseDeviate): - self._reset(seed) + if isinstance(seed, _DeviateState): + self._state = seed + elif isinstance(seed, BaseDeviate): + self._state = seed._state elif isinstance(seed, jax.Array): - self._key = wrap_key_data(seed) + self._state = _DeviateState(wrap_key_data(seed)) elif isinstance(seed, str): - self._key = wrap_key_data(jnp.array(eval(seed), dtype=jnp.uint32)) + self._state = _DeviateState( + wrap_key_data(jnp.array(eval(seed), dtype=jnp.uint32)) + ) elif isinstance(seed, tuple): - self._key = wrap_key_data(jnp.array(seed, dtype=jnp.uint32)) + self._state = _DeviateState( + wrap_key_data(jnp.array(seed, dtype=jnp.uint32)) + ) else: - self._seed(seed=seed) + _initial_seed = seed or secrets.randbelow(2**31) + self._state = _DeviateState(jrandom.PRNGKey(_initial_seed)) - @_wraps(_galsim.BaseDeviate._reset) - def _reset(self, rng): - self._key = rng._key + @property + def _key(self): + return self._state.key + + @_key.setter + def _key(self, key): + self._state.key = key def serialize(self): return repr(ensure_hashable(jrandom.key_data(self._key))) @@ -109,15 +149,16 @@ def clearCache(self): ), ) def discard(self, n, suppress_warnings=False): - self._key = self.__class__._discard(self._key, n) + self._key, subkeys = self.__class__._discard(self._key, n) + return subkeys - @jax.jit + @partial(jax.jit, static_argnums=(1,)) def _discard(key, n): - def __discard(i, key): - key, subkey = jrandom.split(key) - return key + def _f(key, x): + key, sub_key = jrandom.split(key) + return key, sub_key - return jax.lax.fori_loop(0, n, __discard, key) + return jax.lax.scan(_f, key, None, length=n) @_wraps( _galsim.BaseDeviate.raw, @@ -158,7 +199,7 @@ def __call__(self): @_wraps(_galsim.BaseDeviate.duplicate) def duplicate(self): ret = self.__class__.__new__(self.__class__) - ret._key = self._key + ret._state = _DeviateState(self._state.key) ret._params = self._params.copy() return ret @@ -183,7 +224,7 @@ def tree_flatten(self): """This function flattens the PRNG into a list of children nodes that will be traced by JAX and auxiliary static data.""" # Define the children nodes of the PyTree that need tracing - children = (jrandom.key_data(self._key), self._params) + children = (self._state, self._params) # Define auxiliary static data that doesn’t need to be traced aux_data = {} return (children, aux_data) diff --git a/jax_galsim/sensor.py b/jax_galsim/sensor.py new file mode 100644 index 00000000..1cec712e --- /dev/null +++ b/jax_galsim/sensor.py @@ -0,0 +1,40 @@ +import galsim as _galsim +from .position import PositionI +from .errors import GalSimUndefinedBoundsError + +from jax._src.numpy.util import _wraps + + +@_wraps(_galsim.Sensor) +class Sensor: + def __init__(self): + pass + + @_wraps(_galsim.Sensor.accumulate) + def accumulate(self, photons, image, orig_center=None, resume=False): + if not image.bounds.isDefined(): + raise GalSimUndefinedBoundsError( + "Calling accumulate on image with undefined bounds" + ) + return photons.addTo(image) + + @_wraps(_galsim.Sensor.calculate_pixel_areas) + def calculate_pixel_areas(self, image, orig_center=PositionI(0, 0), use_flux=True): + return 1.0 + + def updateRNG(self, rng): + pass + + def __repr__(self): + return "galsim.Sensor()" + + def __eq__(self, other): + return self is other or ( + isinstance(other, Sensor) and repr(self) == repr(other) + ) # Checks that neither is a subclass + + def __ne__(self, other): + return not self.__eq__(other) + + def __hash__(self): + return hash(repr(self)) diff --git a/jax_galsim/sum.py b/jax_galsim/sum.py index eaa9d9d7..1454c0bf 100644 --- a/jax_galsim/sum.py +++ b/jax_galsim/sum.py @@ -7,6 +7,7 @@ from jax_galsim.gsobject import GSObject from jax_galsim.gsparams import GSParams from jax_galsim.position import PositionD +from jax_galsim.random import BinomialDeviate @_wraps( @@ -172,6 +173,59 @@ def _drawKImage(self, image, jac=None): image += obj._drawKImage(image, jac) return image + @property + def _positive_flux(self): + pflux_list = jnp.array([obj.positive_flux for obj in self.obj_list]) + return jnp.sum(pflux_list) + + @property + def _negative_flux(self): + nflux_list = jnp.array([obj.negative_flux for obj in self.obj_list]) + return jnp.sum(nflux_list) + + @property + def _flux_per_photon(self): + return self._calculate_flux_per_photon() + + def _shoot(self, photons, rng): + remainingAbsoluteFlux = self.positive_flux + self.negative_flux + fluxPerPhoton = remainingAbsoluteFlux / len(photons) + + remainingN = len(photons) + istart = ( + 0 # The location in the photons array where we assign the component arrays. + ) + + # Get photons from each summand, using BinomialDeviate to randomize + # the distribution of photons among summands + for i, obj in enumerate(self.obj_list): + thisAbsoluteFlux = obj.positive_flux + obj.negative_flux + + # How many photons to shoot from this summand? + thisN = remainingN # All of what's left, if this is the last summand... + if i < len(self.obj_list) - 1: + # otherwise, allocate a randomized fraction of the remaining photons to summand. + bd = BinomialDeviate( + rng, remainingN, thisAbsoluteFlux / remainingAbsoluteFlux + ) + thisN = int(bd()) + if thisN > 0: + thisPA = obj.shoot(thisN, rng) + # Now rescale the photon fluxes so that they are each nominally fluxPerPhoton + # whereas the shoot() routine would have made them each nominally + # thisAbsoluteFlux/thisN + thisPA.scaleFlux(fluxPerPhoton * thisN / thisAbsoluteFlux) + photons.assignAt(istart, thisPA) + istart += thisN + remainingN -= thisN + remainingAbsoluteFlux -= thisAbsoluteFlux + # assert remainingN == 0 + # assert np.isclose(remainingAbsoluteFlux, 0.0) + + # This process produces correlated photons, so mark the resulting array as such. + if len(self.obj_list) > 1: + photons.setCorrelated() + def tree_flatten(self): """This function flattens the GSObject into a list of children nodes that will be traced by JAX and auxiliary static data.""" diff --git a/jax_galsim/transform.py b/jax_galsim/transform.py index aa34eec5..4c5ee846 100644 --- a/jax_galsim/transform.py +++ b/jax_galsim/transform.py @@ -380,6 +380,13 @@ def _drawKImage(self, image, jac=None): image = image * self._flux_scaling return image + def _shoot(self, photons, rng): + self._original._shoot(photons, rng) + photons.x, photons.y = self._fwd(photons.x, photons.y) + photons.x += self._offset.x + photons.y += self._offset.y + photons.scaleFlux(self._flux_scaling) + def tree_flatten(self): """This function flattens the Transform into a list of children nodes that will be traced by JAX and auxiliary static data.""" diff --git a/tests/galsim_tests_config.yaml b/tests/galsim_tests_config.yaml index 2213ad6e..9d68478c 100644 --- a/tests/galsim_tests_config.yaml +++ b/tests/galsim_tests_config.yaml @@ -25,7 +25,7 @@ enabled_tests: # correspond to features that are not implemented yet # in jax_galsim allowed_failures: - - "Phot shooting not yet implemented in drawImage" + - "module 'jax_galsim' has no attribute 'DeltaFunction'" - "Real-space convolutions are not implemented" - "Photon shooting convolutions are not implemented" - "module 'jax_galsim' has no attribute 'Airy'" @@ -50,7 +50,6 @@ allowed_failures: - "module 'jax_galsim.utilities' has no attribute 'dol_to_lod'" - "module 'jax_galsim.utilities' has no attribute 'nCr'" - "'Image' object has no attribute 'bin'" - - "has no attribute 'shoot'" - "module 'jax_galsim' has no attribute 'integ'" - "module 'jax_galsim' has no attribute 'InterpolatedKImage'" - "module 'jax_galsim' has no attribute 'CorrelatedNoise'" @@ -60,7 +59,6 @@ allowed_failures: - "ValueError not raised by greatCirclePoint" - "TypeError not raised by __mul__" - "ValueError not raised by CelestialCoord" - - "has no attribute 'drawPhot'" - "'Image' object has no attribute 'FindAdaptiveMom'" - "module 'jax_galsim' has no attribute 'BaseCorrelatedNoise'" - " module 'jax_galsim' has no attribute 'fft'" @@ -70,3 +68,10 @@ allowed_failures: - "pad_image not implemented in jax_galsim." - "InterpolatedImages do not support noise padding in jax_galsim." - "module 'jax_galsim' has no attribute 'FittedSIPWCS'" + - "module 'jax_galsim' has no attribute 'Bandpass'" + - "module 'jax_galsim' has no attribute 'Refraction'" + - "module 'jax_galsim' has no attribute 'FRatioAngles'" + - "module 'jax_galsim' has no attribute 'PupilAnnulusSampler'" + - "module 'jax_galsim' has no attribute 'TimeSampler'" + - "object has no attribute 'noise'" + - "module 'jax_galsim' has no attribute 'SED'" diff --git a/tests/jax/galsim/test_draw_jax.py b/tests/jax/galsim/test_draw_jax.py index a9dae4c2..cb93d249 100644 --- a/tests/jax/galsim/test_draw_jax.py +++ b/tests/jax/galsim/test_draw_jax.py @@ -1127,7 +1127,7 @@ def test_shoot(): obj.drawImage(image2, method='phot', poisson_flux=False, add_to_image=False, rng=rng, maxN=100000) image2 += 100 - np.testing.assert_almost_equal(image2.array, image1.array, decimal=12) + np.testing.assert_array_almost_equal(image2.array, image1.array, decimal=12) # Also check that you get the same answer with a smaller maxN. image3 = galsim.ImageF(32,32, init_value=100) @@ -1138,7 +1138,7 @@ def test_shoot(): # Test that shooting with 0.0 flux makes a zero-photons image. image4 = (obj*0).drawImage(method='phot') - np.testing.assert_equal(image4.array, 0) + np.testing.assert_array_equal(image4.array, 0) # Warns if flux is 1 and n_photons not given. psf = galsim.Gaussian(sigma=3) diff --git a/tests/jax/galsim/test_noise_jax.py b/tests/jax/galsim/test_noise_jax.py index 8a1ce29f..c6eb3c97 100644 --- a/tests/jax/galsim/test_noise_jax.py +++ b/tests/jax/galsim/test_noise_jax.py @@ -61,12 +61,10 @@ def test_deviate_noise(): assert noise != noise4 assert noise.rng() == noise2.rng() assert noise == noise2 # Still equal because both chains incremented one place. - # jax does not link RNGs so these are not equal - assert noise != noise3 + assert noise == noise3 noise.rng() assert noise2 != noise3 # This is no longer equal, since only noise.rng is incremented. - # jax does not link RNGs so these are not equal - assert noise != noise3 + assert noise == noise3 assert_raises(TypeError, galsim.DeviateNoise, 53) assert_raises(NotImplementedError, galsim.BaseNoise().getVariance) @@ -247,12 +245,10 @@ def test_gaussian_noise(): assert gn != gn5 assert gn.rng.raw() == gn2.rng.raw() assert gn == gn2 - # jax does not link RNGs - assert gn != gn3 + assert gn == gn3 gn.rng.raw() assert gn != gn2 - # jax does not link RNGs - assert gn != gn3 + assert gn == gn3 @timer @@ -347,12 +343,10 @@ def test_variable_gaussian_noise(): assert vgn != vgn5 assert vgn.rng.raw() == vgn2.rng.raw() assert vgn == vgn2 - # jax does not link RNGs - assert vgn != vgn3 + assert vgn == vgn3 vgn.rng.raw() assert vgn != vgn2 - # jax does not link RNGs - assert vgn != vgn3 + assert vgn == vgn3 assert_raises(TypeError, vgn.applyTo, 23) assert_raises(ValueError, vgn.applyTo, galsim.ImageF(3, 3)) @@ -517,12 +511,10 @@ def test_poisson_noise(): assert pn != pn5 assert pn.rng.raw() == pn2.rng.raw() assert pn == pn2 - # jax does not link RNGs - assert pn != pn3 + assert pn == pn3 pn.rng.raw() assert pn != pn2 - # jax does not link RNGs - assert pn != pn3 + assert pn == pn3 @timer @@ -805,12 +797,10 @@ def test_ccdnoise(): assert ccdnoise != ccdnoise7 assert ccdnoise.rng.raw() == ccdnoise2.rng.raw() assert ccdnoise == ccdnoise2 - # jax does not link RNGs - assert ccdnoise != ccdnoise3 + assert ccdnoise == ccdnoise3 ccdnoise.rng.raw() assert ccdnoise != ccdnoise2 - # jax does not link RNGs - assert ccdnoise != ccdnoise3 + assert ccdnoise == ccdnoise3 @timer diff --git a/tests/jax/galsim/test_photon_array_jax.py b/tests/jax/galsim/test_photon_array_jax.py new file mode 100644 index 00000000..63fe3868 --- /dev/null +++ b/tests/jax/galsim/test_photon_array_jax.py @@ -0,0 +1,1873 @@ +# Copyright (c) 2012-2023 by the GalSim developers team on GitHub +# https://github.com/GalSim-developers +# +# This file is part of GalSim: The modular galaxy image simulation toolkit. +# https://github.com/GalSim-developers/GalSim +# +# GalSim is free software: redistribution and use in source and binary forms, +# with or without modification, are permitted provided that the following +# conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions, and the disclaimer given in the accompanying LICENSE +# file. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions, and the disclaimer given in the documentation +# and/or other materials provided with the distribution. +# + +import unittest +import numpy as np +import os +import sys +import warnings + +# We don't require astroplan. So check if it's installed. +try: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore") + import astroplan + no_astroplan = False +except ImportError: + no_astroplan = True + +import galsim +from galsim_test_helpers import * + +bppath = os.path.join(galsim.meta_data.share_dir, "bandpasses") +sedpath = os.path.join(galsim.meta_data.share_dir, "SEDs") + + +@timer +def test_photon_array(): + """Test the basic methods of PhotonArray class""" + nphotons = 1000 + + # First create from scratch + photon_array = galsim.PhotonArray(nphotons) + assert len(photon_array.x) == nphotons + assert len(photon_array.y) == nphotons + assert len(photon_array.flux) == nphotons + assert not photon_array.hasAllocatedWavelengths() + assert not photon_array.hasAllocatedAngles() + + # Initial values should all be 0 + np.testing.assert_array_equal(photon_array.x, 0.0) + np.testing.assert_array_equal(photon_array.y, 0.0) + np.testing.assert_array_equal(photon_array.flux, 0.0) + + # Check picklability + check_pickle(photon_array) + + # JAX does not support this way of assignement + # # Check assignment via numpy [:] + # photon_array.x[:] = 5 + # photon_array.y[:] = 17 + # photon_array.flux[:] = 23 + # np.testing.assert_array_equal(photon_array.x, 5.) + # np.testing.assert_array_equal(photon_array.y, 17.) + # np.testing.assert_array_equal(photon_array.flux, 23.) + + # Check assignment directly to the attributes + photon_array.x = 25 + photon_array.y = 37 + photon_array.flux = 53 + np.testing.assert_array_equal(photon_array.x, 25.0) + np.testing.assert_array_equal(photon_array.y, 37.0) + np.testing.assert_array_equal(photon_array.flux, 53.0) + + # Now create from shooting a profile + obj = galsim.Exponential(flux=1.7, scale_radius=2.3) + rng = galsim.UniformDeviate(1234) + photon_array = obj.shoot(nphotons, rng) + orig_x = photon_array.x.copy() + orig_y = photon_array.y.copy() + orig_flux = photon_array.flux.copy() + assert len(photon_array.x) == nphotons + assert len(photon_array.y) == nphotons + assert len(photon_array.flux) == nphotons + assert not photon_array.hasAllocatedWavelengths() + assert not photon_array.hasAllocatedAngles() + assert not photon_array.hasAllocatedPupil() + assert not photon_array.hasAllocatedTimes() + + # Check arithmetic ops + photon_array.x *= 5 + photon_array.y += 17 + photon_array.flux /= 23 + np.testing.assert_array_almost_equal(photon_array.x, orig_x * 5.0) + np.testing.assert_array_almost_equal(photon_array.y, orig_y + 17.0) + np.testing.assert_array_almost_equal(photon_array.flux, orig_flux / 23.0) + + # Check picklability again with non-zero values + check_pickle(photon_array) + + # Now assign to the optional arrays + photon_array.dxdz = 0.17 + assert photon_array.hasAllocatedAngles() + assert not photon_array.hasAllocatedWavelengths() + np.testing.assert_array_equal(photon_array.dxdz, 0.17) + np.testing.assert_array_equal(photon_array.dydz, 0.0) + + photon_array.dydz = 0.59 + np.testing.assert_array_equal(photon_array.dxdz, 0.17) + np.testing.assert_array_equal(photon_array.dydz, 0.59) + + # Check shooting negative flux + obj = galsim.Exponential(flux=-1.7, scale_radius=2.3) + rng = galsim.UniformDeviate(1234) + neg_photon_array = obj.shoot(nphotons, rng) + np.testing.assert_array_equal(neg_photon_array.x, orig_x) + np.testing.assert_array_equal(neg_photon_array.y, orig_y) + np.testing.assert_array_equal(neg_photon_array.flux, -orig_flux) + + # Start over to check that assigning to wavelength leaves dxdz, dydz alone. + photon_array = obj.shoot(nphotons, rng) + photon_array.wavelength = 500.0 + assert photon_array.hasAllocatedWavelengths() + assert not photon_array.hasAllocatedAngles() + assert not photon_array.hasAllocatedPupil() + assert not photon_array.hasAllocatedTimes() + np.testing.assert_array_equal(photon_array.wavelength, 500) + + photon_array.dxdz = 0.23 + photon_array.dydz = 0.88 + photon_array.wavelength = 912.0 + assert photon_array.hasAllocatedWavelengths() + assert photon_array.hasAllocatedAngles() + assert not photon_array.hasAllocatedPupil() + assert not photon_array.hasAllocatedTimes() + np.testing.assert_array_equal(photon_array.dxdz, 0.23) + np.testing.assert_array_equal(photon_array.dydz, 0.88) + np.testing.assert_array_equal(photon_array.wavelength, 912) + + # Add pupil coords + photon_array.pupil_u = 6.0 + assert photon_array.hasAllocatedWavelengths() + assert photon_array.hasAllocatedAngles() + assert photon_array.hasAllocatedPupil() + assert not photon_array.hasAllocatedTimes() + np.testing.assert_array_equal(photon_array.dxdz, 0.23) + np.testing.assert_array_equal(photon_array.dydz, 0.88) + np.testing.assert_array_equal(photon_array.wavelength, 912) + np.testing.assert_array_equal(photon_array.pupil_u, 6.0) + np.testing.assert_array_equal(photon_array.pupil_v, 0.0) + + # Add time stamps + photon_array.time = 0.0 + assert photon_array.hasAllocatedWavelengths() + assert photon_array.hasAllocatedAngles() + assert photon_array.hasAllocatedPupil() + assert photon_array.hasAllocatedTimes() + np.testing.assert_array_equal(photon_array.dxdz, 0.23) + np.testing.assert_array_equal(photon_array.dydz, 0.88) + np.testing.assert_array_equal(photon_array.wavelength, 912) + np.testing.assert_array_equal(photon_array.pupil_u, 6.0) + np.testing.assert_array_equal(photon_array.pupil_v, 0.0) + np.testing.assert_array_equal(photon_array.time, 0.0) + + # Check toggling is_corr + assert not photon_array.isCorrelated() + photon_array.setCorrelated() + assert photon_array.isCorrelated() + photon_array.setCorrelated(False) + assert not photon_array.isCorrelated() + photon_array.setCorrelated(True) + assert photon_array.isCorrelated() + + # Check rescaling the total flux + flux = photon_array.flux.sum() + np.testing.assert_almost_equal(photon_array.getTotalFlux(), flux) + photon_array.scaleFlux(17) + np.testing.assert_almost_equal(photon_array.getTotalFlux(), 17 * flux) + photon_array.setTotalFlux(199) + np.testing.assert_almost_equal(photon_array.getTotalFlux(), 199) + photon_array.scaleFlux(-1.7) + np.testing.assert_almost_equal(photon_array.getTotalFlux(), -1.7 * 199) + photon_array.setTotalFlux(-199) + np.testing.assert_almost_equal(photon_array.getTotalFlux(), -199) + + # Check rescaling the positions + x = photon_array.x.copy() + y = photon_array.y.copy() + photon_array.scaleXY(1.9) + np.testing.assert_array_almost_equal(photon_array.x, 1.9 * x) + np.testing.assert_array_almost_equal(photon_array.y, 1.9 * y) + + # Check ways to assign to photons + pa1 = galsim.PhotonArray(50) + pa1.x = photon_array.x[:50] + pa1.y = photon_array.y[:50] + pa1.flux = photon_array.flux[:50] + # for i in range(50): + # pa1.y[i] = photon_array.y[i] + # pa1.flux[0:50] = photon_array.flux[:50] + pa1.dxdz = photon_array.dxdz[:50] + pa1.dydz = photon_array.dydz[:50] + pa1.wavelength = photon_array.wavelength[:50] + pa1.pupil_u = photon_array.pupil_u[:50] + pa1.pupil_v = photon_array.pupil_v[:50] + pa1.time = photon_array.time[:50] + np.testing.assert_array_almost_equal(pa1.x, photon_array.x[:50]) + np.testing.assert_array_almost_equal(pa1.y, photon_array.y[:50]) + np.testing.assert_array_almost_equal(pa1.flux, photon_array.flux[:50]) + np.testing.assert_array_almost_equal(pa1.dxdz, photon_array.dxdz[:50]) + np.testing.assert_array_almost_equal(pa1.dydz, photon_array.dydz[:50]) + np.testing.assert_array_almost_equal(pa1.wavelength, photon_array.wavelength[:50]) + np.testing.assert_array_almost_equal(pa1.pupil_u, photon_array.pupil_u[:50]) + np.testing.assert_array_almost_equal(pa1.pupil_v, photon_array.pupil_v[:50]) + np.testing.assert_array_almost_equal(pa1.time, photon_array.time[:50]) + + # Check assignAt + pa2 = galsim.PhotonArray(100) + pa2.assignAt(0, pa1) + pa2.assignAt(50, pa1) + np.testing.assert_array_almost_equal(pa2.x[:50], pa1.x) + np.testing.assert_array_almost_equal(pa2.y[:50], pa1.y) + np.testing.assert_array_almost_equal(pa2.flux[:50], pa1.flux) + np.testing.assert_array_almost_equal(pa2.dxdz[:50], pa1.dxdz) + np.testing.assert_array_almost_equal(pa2.dydz[:50], pa1.dydz) + np.testing.assert_array_almost_equal(pa2.wavelength[:50], pa1.wavelength) + np.testing.assert_array_almost_equal(pa2.pupil_u[:50], pa1.pupil_u) + np.testing.assert_array_almost_equal(pa2.pupil_v[:50], pa1.pupil_v) + np.testing.assert_array_almost_equal(pa2.time[:50], pa1.time) + np.testing.assert_array_almost_equal(pa2.x[50:], pa1.x) + np.testing.assert_array_almost_equal(pa2.y[50:], pa1.y) + np.testing.assert_array_almost_equal(pa2.flux[50:], pa1.flux) + np.testing.assert_array_almost_equal(pa2.dxdz[50:], pa1.dxdz) + np.testing.assert_array_almost_equal(pa2.dydz[50:], pa1.dydz) + np.testing.assert_array_almost_equal(pa2.wavelength[50:], pa1.wavelength) + np.testing.assert_array_almost_equal(pa2.pupil_u[50:], pa1.pupil_u) + np.testing.assert_array_almost_equal(pa2.pupil_v[50:], pa1.pupil_v) + np.testing.assert_array_almost_equal(pa2.time[50:], pa1.time) + + # Error if it doesn't fit. + assert_raises(ValueError, pa2.assignAt, 90, pa1) + + # Test some trivial usage of makeFromImage + zero = galsim.Image(4, 4, init_value=0) + photons = galsim.PhotonArray.makeFromImage(zero) + print("photons = ", photons) + assert len(photons) == 16 + np.testing.assert_array_equal(photons.flux, 0.0) + + ones = galsim.Image(4, 4, init_value=1) + photons = galsim.PhotonArray.makeFromImage(ones) + print("photons = ", photons) + assert len(photons) == 16 + np.testing.assert_array_almost_equal(photons.flux, 1.0) + + tens = galsim.Image(4, 4, init_value=8) + photons = galsim.PhotonArray.makeFromImage(tens, max_flux=5.0) + print("photons = ", photons) + assert len(photons) == 32 + np.testing.assert_array_almost_equal(photons.flux, 4.0) + + assert_raises(ValueError, galsim.PhotonArray.makeFromImage, zero, max_flux=0.0) + assert_raises(ValueError, galsim.PhotonArray.makeFromImage, zero, max_flux=-2) + + # Check some other errors + undef = galsim.Image() + assert_raises(galsim.GalSimUndefinedBoundsError, pa2.addTo, undef) + + # Check picklability again with non-zero values for everything + check_pickle(photon_array) + + +@timer +def test_convolve(): + nphotons = 1000000 + + obj = galsim.Gaussian(flux=1.7, sigma=2.3) + rng = galsim.UniformDeviate(1234) + pa1 = obj.shoot(nphotons, rng) + rng2 = rng.duplicate() # Save this state. + pa2 = obj.shoot(nphotons, rng) + + # If not correlated then convolve is deterministic + conv_x = pa1.x + pa2.x + conv_y = pa1.y + pa2.y + conv_flux = pa1.flux * pa2.flux * nphotons + + np.testing.assert_allclose(np.sum(pa1.flux), 1.7) + np.testing.assert_allclose(np.sum(pa2.flux), 1.7) + np.testing.assert_allclose(np.sum(conv_flux), 1.7 * 1.7) + + np.testing.assert_allclose(np.sum(pa1.x**2) / nphotons, 2.3**2, rtol=0.01) + np.testing.assert_allclose(np.sum(pa2.x**2) / nphotons, 2.3**2, rtol=0.01) + np.testing.assert_allclose( + np.sum(conv_x**2) / nphotons, 2.0 * 2.3**2, rtol=0.01 + ) + + np.testing.assert_allclose(np.sum(pa1.y**2) / nphotons, 2.3**2, rtol=0.01) + np.testing.assert_allclose(np.sum(pa2.y**2) / nphotons, 2.3**2, rtol=0.01) + np.testing.assert_allclose( + np.sum(conv_y**2) / nphotons, 2.0 * 2.3**2, rtol=0.01 + ) + + pa3 = galsim.PhotonArray(nphotons) + pa3.assignAt(0, pa1) # copy from pa1 + pa3.convolve(pa2) + np.testing.assert_allclose(pa3.x, conv_x) + np.testing.assert_allclose(pa3.y, conv_y) + np.testing.assert_allclose(pa3.flux, conv_flux) + + # If one of them is correlated, it is still deterministic. + pa3.assignAt(0, pa1) + pa3.setCorrelated() + pa3.convolve(pa2) + np.testing.assert_allclose(pa3.x, conv_x) + np.testing.assert_allclose(pa3.y, conv_y) + np.testing.assert_allclose(pa3.flux, conv_flux) + + pa3.assignAt(0, pa1) + pa3.setCorrelated(False) + pa2.setCorrelated() + pa3.convolve(pa2) + np.testing.assert_allclose(pa3.x, conv_x) + np.testing.assert_allclose(pa3.y, conv_y) + np.testing.assert_allclose(pa3.flux, conv_flux) + + # But if both are correlated, then it's not this simple. + pa3.assignAt(0, pa1) + pa3.setCorrelated() + assert pa3.isCorrelated() + assert pa2.isCorrelated() + pa3.convolve(pa2) + with assert_raises(AssertionError): + np.testing.assert_allclose(pa3.x, conv_x) + with assert_raises(AssertionError): + np.testing.assert_allclose(pa3.y, conv_y) + np.testing.assert_allclose(np.sum(pa3.flux), 1.7 * 1.7) + np.testing.assert_allclose(np.sum(pa3.x**2) / nphotons, 2 * 2.3**2, rtol=0.01) + np.testing.assert_allclose(np.sum(pa3.y**2) / nphotons, 2 * 2.3**2, rtol=0.01) + + # Can also effect the convolution by treating the psf as a PhotonOp + pa3.assignAt(0, pa1) + pa3.setCorrelated() + obj.applyTo(pa3, rng=rng2) + np.testing.assert_allclose(pa3.x, conv_x) + np.testing.assert_allclose(pa3.y, conv_y) + np.testing.assert_allclose(pa3.flux, conv_flux) + + # Error to have different lengths + pa4 = galsim.PhotonArray(50, pa1.x[:50], pa1.y[:50], pa1.flux[:50]) + assert_raises(galsim.GalSimError, pa1.convolve, pa4) + + # Check propagation of dxdz, dydz, wavelength, pupil_u, pupil_v + for attr, checkFn in zip( + ["dxdz", "dydz", "wavelength", "pupil_u", "pupil_v", "time"], + [ + "hasAllocatedAngles", + "hasAllocatedAngles", + "hasAllocatedWavelengths", + "hasAllocatedPupil", + "hasAllocatedPupil", + "hasAllocatedTimes", + ], + ): + pa1 = obj.shoot(nphotons, rng) + pa2 = obj.shoot(nphotons, rng) + assert not getattr(pa1, checkFn)() + assert not getattr(pa1, checkFn)() + data = np.linspace(-0.1, 0.1, nphotons) + setattr(pa1, attr, data) + assert getattr(pa1, checkFn)() + assert not getattr(pa2, checkFn)() + pa1.convolve(pa2) + assert getattr(pa1, checkFn)() + assert not getattr(pa2, checkFn)() + np.testing.assert_array_equal(getattr(pa1, attr), data) + pa2.convolve(pa1) + assert getattr(pa1, checkFn)() + assert getattr(pa2, checkFn)() + np.testing.assert_array_equal(getattr(pa2, attr), data) + + # both have data now... + pa1.convolve(pa2) + np.testing.assert_array_equal(getattr(pa1, attr), data) + np.testing.assert_array_equal(getattr(pa2, attr), data) + + # If the second one has different data, the first takes precedence. + setattr(pa2, attr, data * 2) + pa1.convolve(pa2) + np.testing.assert_array_equal(getattr(pa1, attr), data) + np.testing.assert_array_equal(getattr(pa2, attr), 2 * data) + + +@timer +def test_wavelength_sampler(): + nphotons = 1000 + obj = galsim.Exponential(flux=1.7, scale_radius=2.3) + rng = galsim.UniformDeviate(1234) + + photon_array = obj.shoot(nphotons, rng) + + sed = galsim.SED(os.path.join(sedpath, "CWW_E_ext.sed"), "A", "flambda").thin() + bandpass = galsim.Bandpass(os.path.join(bppath, "LSST_r.dat"), "nm").thin() + + sampler = galsim.WavelengthSampler(sed, bandpass) + sampler.applyTo(photon_array, rng=rng) + + # Note: the underlying functionality of the sampleWavelengths function is tested + # in test_sed.py. So here we are really just testing that the wrapper class is + # properly writing to the photon_array.wavelengths array. + + assert photon_array.hasAllocatedWavelengths() + assert not photon_array.hasAllocatedAngles() + + check_pickle(sampler) + + print("mean wavelength = ", np.mean(photon_array.wavelength)) + print("min wavelength = ", np.min(photon_array.wavelength)) + print("max wavelength = ", np.max(photon_array.wavelength)) + + assert np.min(photon_array.wavelength) > bandpass.blue_limit + assert np.max(photon_array.wavelength) < bandpass.red_limit + + # This is a regression test based on the value at commit 134a119 + np.testing.assert_allclose( + np.mean(photon_array.wavelength), 622.755128, rtol=1.0e-4 + ) + + # If we use a flat SED (in photons/nm), then the mean sampled wavelength should very closely + # match the bandpass effective wavelength. + photon_array2 = galsim.PhotonArray(100000) + sed2 = galsim.SED("1", "nm", "fphotons") + sampler2 = galsim.WavelengthSampler(sed2, bandpass) + sampler2.applyTo(photon_array2, rng=rng) + np.testing.assert_allclose( + np.mean(photon_array2.wavelength), + bandpass.effective_wavelength, + rtol=0, + atol=0.2, # 2 Angstrom accuracy is pretty good + err_msg="Mean sampled wavelength not close to effective_wavelength", + ) + + # If the photon array already has wavelengths set, then it proceeds, but gives a warning. + with assert_warns(galsim.GalSimWarning): + sampler2.applyTo(photon_array2, rng=rng) + np.testing.assert_allclose( + np.mean(photon_array2.wavelength), + bandpass.effective_wavelength, + rtol=0, + atol=0.2, + ) + + # Test that using this as a surface op works properly. + + # First do the shooting and clipping manually. + im1 = galsim.Image(64, 64, scale=1) + im1.setCenter(0, 0) + photon_array.flux[photon_array.wavelength < 600] = 0.0 + photon_array.addTo(im1) + + # Make a dummy surface op that clips any photons with lambda < 600 + class Clip600: + def applyTo(self, photon_array, local_wcs=None, rng=None): + photon_array.flux[photon_array.wavelength < 600] = 0.0 + + # Use (a new) sampler and clip600 as photon_ops in drawImage + im2 = galsim.Image(64, 64, scale=1) + im2.setCenter(0, 0) + clip600 = Clip600() + rng2 = galsim.BaseDeviate(1234) + sampler2 = galsim.WavelengthSampler(sed, bandpass) + obj.drawImage( + im2, + method="phot", + n_photons=nphotons, + use_true_center=False, + photon_ops=[sampler2, clip600], + rng=rng2, + save_photons=True, + ) + print("sum = ", im1.array.sum(), im2.array.sum()) + np.testing.assert_array_equal(im1.array, im2.array) + + # Equivalent version just getting photons back + rng2.seed(1234) + photons = obj.makePhot(n_photons=nphotons, photon_ops=[sampler2, clip600], rng=rng2) + print("phot.x = ", photons.x) + print("im2.photons.x = ", im2.photons.x) + assert photons == im2.photons + + # Base class is invalid to try to use. + op = galsim.PhotonOp() + with assert_raises(NotImplementedError): + op.applyTo(photon_array) + + +@timer +def test_photon_angles(): + """Test the photon_array function""" + # Make a photon array + seed = 12345 + ud = galsim.UniformDeviate(seed) + gal = galsim.Sersic(n=4, half_light_radius=1) + photon_array = gal.shoot(100000, ud) + + # Add the directions (N.B. using the same seed as for generating the photon array + # above. The fact that it is the same does not matter here; the testing routine + # only needs to have a definite seed value so the consistency of the results with + # expectations can be evaluated precisely + fratio = 1.2 + obscuration = 0.2 + + # rng can be None, an existing BaseDeviate, or an integer + for rng in [None, ud, 12345]: + assigner = galsim.FRatioAngles(fratio, obscuration) + assigner.applyTo(photon_array, rng=rng) + + check_pickle(assigner) + + dxdz = photon_array.dxdz + dydz = photon_array.dydz + + phi = np.arctan2(dydz, dxdz) + tantheta = np.sqrt(np.square(dxdz) + np.square(dydz)) + sintheta = np.sin(np.arctan(tantheta)) + + # Check that the values are within the ranges expected + # (The test on phi really can't fail, because it is only testing the range of the + # arctan2 function.) + np.testing.assert_array_less( + -phi, np.pi, "Azimuth angles outside possible range" + ) + np.testing.assert_array_less( + phi, np.pi, "Azimuth angles outside possible range" + ) + + fov_angle = np.arctan(0.5 / fratio) + obscuration_angle = obscuration * fov_angle + np.testing.assert_array_less( + -sintheta, + -np.sin(obscuration_angle), + "Inclination angles outside possible range", + ) + np.testing.assert_array_less( + sintheta, np.sin(fov_angle), "Inclination angles outside possible range" + ) + + # Compare these slopes with the expected distributions (uniform in azimuth + # over all azimiths and uniform in sin(inclination) over the range of + # allowed inclinations + # Only test this for the last one, so we make sure we have a deterministic result. + # (The above tests should be reliable even for the default rng.) + phi_histo, phi_bins = np.histogram(phi, bins=100) + sintheta_histo, sintheta_bins = np.histogram(sintheta, bins=100) + phi_ref = float(np.sum(phi_histo)) / phi_histo.size + sintheta_ref = float(np.sum(sintheta_histo)) / sintheta_histo.size + + chisqr_phi = np.sum(np.square(phi_histo - phi_ref) / phi_ref) / phi_histo.size + chisqr_sintheta = ( + np.sum(np.square(sintheta_histo - sintheta_ref) / sintheta_ref) + / sintheta_histo.size + ) + + print("chisqr_phi = ", chisqr_phi) + print("chisqr_sintheta = ", chisqr_sintheta) + assert 0.9 < chisqr_phi < 1.1, "Distribution in azimuth is not nearly uniform" + assert ( + 0.9 < chisqr_sintheta < 1.1 + ), "Distribution in sin(inclination) is not nearly uniform" + + # Try some invalid inputs + assert_raises(ValueError, galsim.FRatioAngles, fratio=-0.3) + assert_raises(ValueError, galsim.FRatioAngles, fratio=1.2, obscuration=-0.3) + assert_raises(ValueError, galsim.FRatioAngles, fratio=1.2, obscuration=1.0) + assert_raises(ValueError, galsim.FRatioAngles, fratio=1.2, obscuration=1.9) + + +@timer +def test_photon_io(): + """Test the ability to read and write photons to a file""" + nphotons = 1000 + + obj = galsim.Exponential(flux=1.7, scale_radius=2.3) + rng = galsim.UniformDeviate(1234) + image = obj.drawImage(method="phot", n_photons=nphotons, save_photons=True, rng=rng) + photons = image.photons + assert photons.size() == len(photons) == nphotons + + with assert_raises(galsim.GalSimIncompatibleValuesError): + obj.drawImage(method="phot", n_photons=nphotons, save_photons=True, maxN=1.0e5) + + file_name = "output/photons1.dat" + photons.write(file_name) + + photons1 = galsim.PhotonArray.read(file_name) + + assert photons1.size() == nphotons + assert not photons1.hasAllocatedWavelengths() + assert not photons1.hasAllocatedAngles() + assert not photons1.hasAllocatedPupil() + assert not photons1.hasAllocatedTimes() + + np.testing.assert_array_equal(photons1.x, photons.x) + np.testing.assert_array_equal(photons1.y, photons.y) + np.testing.assert_array_equal(photons1.flux, photons.flux) + + sed = galsim.SED(os.path.join(sedpath, "CWW_E_ext.sed"), "nm", "flambda").thin() + bandpass = galsim.Bandpass(os.path.join(bppath, "LSST_r.dat"), "nm").thin() + + wave_sampler = galsim.WavelengthSampler(sed, bandpass) + angle_sampler = galsim.FRatioAngles(1.3, 0.3) + + ops = [wave_sampler, angle_sampler] + for op in ops: + op.applyTo(photons, rng=rng) + + # Directly inject some pupil coordinates and time stamps + photons.pupil_u = np.linspace(0, 1, nphotons) + photons.pupil_v = np.linspace(1, 2, nphotons) + photons.time = np.linspace(0, 30, nphotons) + + file_name = "output/photons2.dat" + photons.write(file_name) + + photons2 = galsim.PhotonArray.read(file_name) + + assert photons2.size() == nphotons + assert photons2.hasAllocatedWavelengths() + assert photons2.hasAllocatedAngles() + assert photons2.hasAllocatedPupil() + assert photons2.hasAllocatedTimes() + + np.testing.assert_array_equal(photons2.x, photons.x) + np.testing.assert_array_equal(photons2.y, photons.y) + np.testing.assert_array_equal(photons2.flux, photons.flux) + np.testing.assert_array_equal(photons2.dxdz, photons.dxdz) + np.testing.assert_array_equal(photons2.dydz, photons.dydz) + np.testing.assert_array_equal(photons2.wavelength, photons.wavelength) + np.testing.assert_array_equal(photons.pupil_u, photons.pupil_u) + np.testing.assert_array_equal(photons.pupil_v, photons.pupil_v) + np.testing.assert_array_equal(photons.time, photons.time) + + +@timer +def test_dcr(): + """Test the dcr surface op""" + # This tests that implementing DCR with the surface op is equivalent to using + # ChromaticAtmosphere. + # We use fairly extreme choices for the parameters to make the comparison easier, so + # we can still get good discrimination of any errors with only 10^6 photons. + zenith_angle = 45 * galsim.degrees # Larger angle has larger DCR. + parallactic_angle = 129 * galsim.degrees # Something random, not near 0 or 180 + pixel_scale = ( + 0.03 # Small pixel scale means shifts are many pixels, rather than a fraction. + ) + alpha = -1.2 # The normal alpha is -0.2, so this is exaggerates the effect. + + bandpass = galsim.Bandpass("LSST_r.dat", "nm") + base_wavelength = bandpass.effective_wavelength + base_wavelength += 500 # This exaggerates the effects fairly substantially. + + sed = galsim.SED("CWW_E_ext.sed", wave_type="ang", flux_type="flambda") + + flux = 1.0e6 + fwhm = 0.3 + base_PSF = galsim.Kolmogorov(fwhm=fwhm) + + # Use ChromaticAtmosphere + # Note, somewhat gratuitous check that ImageI works with dtype=int in config below. + im1 = galsim.ImageI(50, 50, scale=pixel_scale) + star = galsim.DeltaFunction() * sed + star = star.withFlux(flux, bandpass=bandpass) + chrom_PSF = galsim.ChromaticAtmosphere( + base_PSF, + base_wavelength=base_wavelength, + zenith_angle=zenith_angle, + parallactic_angle=parallactic_angle, + alpha=alpha, + ) + chrom = galsim.Convolve(star, chrom_PSF) + chrom.drawImage(bandpass, image=im1) + + # Repeat with config + config = { + "psf": { + "type": "ChromaticAtmosphere", + "base_profile": {"type": "Kolmogorov", "fwhm": fwhm}, + "base_wavelength": base_wavelength, + "zenith_angle": zenith_angle, + "parallactic_angle": parallactic_angle, + "alpha": alpha, + }, + "gal": {"type": "DeltaFunction", "flux": flux, "sed": sed}, + "image": { + "xsize": 50, + "ysize": 50, + "pixel_scale": pixel_scale, + "bandpass": bandpass, + "random_seed": 31415, + "dtype": int, + }, + } + im1c = galsim.config.BuildImage(config) + assert im1c == im1 + + # Use PhotonDCR + im2 = galsim.ImageI(50, 50, scale=pixel_scale) + dcr = galsim.PhotonDCR( + base_wavelength=base_wavelength, + zenith_angle=zenith_angle, + parallactic_angle=parallactic_angle, + alpha=alpha, + ) + achrom = base_PSF.withFlux(flux) + # Because we'll be comparing to config version, get the rng the way it will do it. + rng = galsim.BaseDeviate(galsim.BaseDeviate(31415).raw() + 1) + wave_sampler = galsim.WavelengthSampler(sed, bandpass) + photon_ops = [wave_sampler, dcr] + achrom.drawImage(image=im2, method="phot", rng=rng, photon_ops=photon_ops) + + check_pickle(dcr) + + # Repeat with config + config = { + "psf": {"type": "Kolmogorov", "fwhm": fwhm}, + "gal": {"type": "DeltaFunction", "flux": flux}, + "image": { + "xsize": 50, + "ysize": 50, + "pixel_scale": pixel_scale, + "bandpass": bandpass, + "random_seed": 31415, + "dtype": "np.int32", + }, + "stamp": { + "draw_method": "phot", + "photon_ops": [ + {"type": "WavelengthSampler", "sed": sed}, + { + "type": "PhotonDCR", + "base_wavelength": base_wavelength, + "zenith_angle": zenith_angle, + "parallactic_angle": parallactic_angle, + "alpha": alpha, + }, + ], + }, + } + im2c = galsim.config.BuildImage(config) + assert im2c == im2 + + # Should work with fft, but not quite match (because of inexact photon locations). + im3 = galsim.ImageF(50, 50, scale=pixel_scale) + achrom.drawImage(image=im3, method="fft", rng=rng, photon_ops=photon_ops) + printval(im3, im2, show=False) + np.testing.assert_allclose( + im3.array, + im2.array, + atol=0.1 * np.max(im2.array), + err_msg="PhotonDCR on fft image didn't match phot image", + ) + # Moments come out less than 1% different. + res2 = im2.FindAdaptiveMom() + res3 = im3.FindAdaptiveMom() + np.testing.assert_allclose(res3.moments_amp, res2.moments_amp, rtol=1.0e-2) + np.testing.assert_allclose(res3.moments_sigma, res2.moments_sigma, rtol=1.0e-2) + np.testing.assert_allclose( + res3.observed_shape.e1, res2.observed_shape.e1, atol=1.0e-2 + ) + np.testing.assert_allclose( + res3.observed_shape.e2, res2.observed_shape.e2, atol=1.0e-2 + ) + np.testing.assert_allclose( + res3.moments_centroid.x, res2.moments_centroid.x, rtol=1.0e-2 + ) + np.testing.assert_allclose( + res3.moments_centroid.y, res2.moments_centroid.y, rtol=1.0e-2 + ) + + # Repeat with maxN < flux + # Note: Because of the different way this generates the random positions, it's not identical + # to the above run without maxN. Both runs are equally valid realizations of photon + # positions corresponding to the FFT image. But not the same realization. + achrom.drawImage( + image=im3, method="auto", rng=rng, photon_ops=photon_ops, maxN=10**4 + ) + printval(im3, im2, show=False) + np.testing.assert_allclose( + im3.array, + im2.array, + atol=0.2 * np.max(im2.array), + err_msg="PhotonDCR on fft image with maxN didn't match phot image", + ) + res3 = im3.FindAdaptiveMom() + np.testing.assert_allclose(res3.moments_amp, res2.moments_amp, rtol=1.0e-2) + np.testing.assert_allclose(res3.moments_sigma, res2.moments_sigma, rtol=1.0e-2) + np.testing.assert_allclose( + res3.observed_shape.e1, res2.observed_shape.e1, atol=1.0e-2 + ) + np.testing.assert_allclose( + res3.observed_shape.e2, res2.observed_shape.e2, atol=1.0e-2 + ) + np.testing.assert_allclose( + res3.moments_centroid.x, res2.moments_centroid.x, rtol=1.0e-2 + ) + np.testing.assert_allclose( + res3.moments_centroid.y, res2.moments_centroid.y, rtol=1.0e-2 + ) + + # Compare ChromaticAtmosphere image with PhotonDCR image. + printval(im2, im1, show=False) + # tolerace for photon shooting is ~sqrt(flux) = 1.e3 + np.testing.assert_allclose( + im2.array, + im1.array, + atol=1.0e3, + err_msg="PhotonDCR didn't match ChromaticAtmosphere", + ) + + # Use ChromaticAtmosphere in photon_ops + im3 = galsim.ImageI(50, 50, scale=pixel_scale) + photon_ops = [chrom_PSF] + star.drawImage(bandpass, image=im3, method="phot", rng=rng, photon_ops=photon_ops) + printval(im3, im1, show=False) + np.testing.assert_allclose( + im3.array, + im1.array, + atol=1.0e3, + err_msg="ChromaticAtmosphere in photon_ops didn't match", + ) + + # Repeat with thinned bandpass and SED to check that thin still works well. + im3 = galsim.ImageI(50, 50, scale=pixel_scale) + thin = 0.1 # Even higher also works. But this is probably enough. + thin_bandpass = bandpass.thin(thin) + thin_sed = sed.thin(thin) + print("len bp = %d => %d" % (len(bandpass.wave_list), len(thin_bandpass.wave_list))) + print("len sed = %d => %d" % (len(sed.wave_list), len(thin_sed.wave_list))) + wave_sampler = galsim.WavelengthSampler(thin_sed, thin_bandpass) + photon_ops = [wave_sampler, dcr] + achrom.drawImage(image=im3, method="phot", rng=rng, photon_ops=photon_ops) + + printval(im3, im1, show=False) + np.testing.assert_allclose( + im3.array, + im1.array, + atol=1.0e3, + err_msg="thinning factor %f led to 1.e-4 level inaccuracy" % thin, + ) + + # Check scale_unit + im4 = galsim.ImageI(50, 50, scale=pixel_scale / 60) + wave_sampler = galsim.WavelengthSampler(sed, bandpass) + dcr = galsim.PhotonDCR( + base_wavelength=base_wavelength, + zenith_angle=zenith_angle, + parallactic_angle=parallactic_angle, + scale_unit="arcmin", + alpha=alpha, + ) + photon_ops = [wave_sampler, dcr] + rng = galsim.BaseDeviate(galsim.BaseDeviate(31415).raw() + 1) + achrom.dilate(1.0 / 60).drawImage( + image=im4, method="phot", rng=rng, photon_ops=photon_ops + ) + printval(im4, im1, show=False) + np.testing.assert_allclose( + im4.array, + im1.array, + atol=1.0e3, + err_msg="PhotonDCR with scale_unit=arcmin, didn't match", + ) + + galsim.config.RemoveCurrent(config) + del config["stamp"]["photon_ops"][1]["_get"] + config["stamp"]["photon_ops"][1]["scale_unit"] = "arcmin" + config["image"]["pixel_scale"] = pixel_scale / 60 + config["psf"]["fwhm"] = fwhm / 60 + im4c = galsim.config.BuildImage(config) + assert im4c == im4 + + # Check some other valid options + # alpha = 0 means don't do any size scaling. + # obj_coord, HA and latitude are another option for setting the angles + # pressure, temp, and water pressure are settable. + # Also use a non-trivial WCS. + wcs = galsim.FitsWCS("des_data/DECam_00154912_12_header.fits") + image = galsim.Image(50, 50, wcs=wcs) + bandpass = galsim.Bandpass("LSST_r.dat", wave_type="nm").thin(0.1) + base_wavelength = bandpass.effective_wavelength + lsst_lat = galsim.Angle.from_dms("-30:14:23.76") + lsst_long = galsim.Angle.from_dms("-70:44:34.67") + local_sidereal_time = ( + 3.14 * galsim.hours + ) # Not pi. This is the time for this observation. + + im5 = galsim.ImageI(50, 50, wcs=wcs) + obj_coord = wcs.toWorld(im5.true_center) + base_PSF = galsim.Kolmogorov(fwhm=0.9) + achrom = base_PSF.withFlux(flux) + dcr = galsim.PhotonDCR( + base_wavelength=base_wavelength, + obj_coord=obj_coord, + HA=local_sidereal_time - obj_coord.ra, + latitude=lsst_lat, + pressure=72, # default is 69.328 + temperature=290, # default is 293.15 + H2O_pressure=0.9, + ) # default is 1.067 + # alpha=0) # default is 0, so don't need to set it. + wave_sampler = galsim.WavelengthSampler(sed, bandpass) + photon_ops = [wave_sampler, dcr] + rng = galsim.BaseDeviate(galsim.BaseDeviate(31415).raw() + 1) + achrom.drawImage(image=im5, method="phot", rng=rng, photon_ops=photon_ops) + + check_pickle(dcr) + + galsim.config.RemoveCurrent(config) + config["psf"]["fwhm"] = 0.9 + config["image"] = { + "xsize": 50, + "ysize": 50, + "wcs": {"type": "Fits", "file_name": "des_data/DECam_00154912_12_header.fits"}, + "bandpass": bandpass, + "random_seed": 31415, + "dtype": "np.int32", + "world_pos": obj_coord, + } + config["stamp"]["photon_ops"][1] = { + "type": "PhotonDCR", + "base_wavelength": base_wavelength, + "HA": local_sidereal_time - obj_coord.ra, + "latitude": "-30:14:23.76 deg", + "pressure": 72, + "temperature": 290, + "H2O_pressure": 0.9, + } + im5c = galsim.config.BuildImage(config) + assert im5c == im5 + + # Also one using zenith_coord = (lst, lat) + config["stamp"]["photon_ops"][1] = { + "type": "PhotonDCR", + "base_wavelength": base_wavelength, + "zenith_coord": { + "type": "RADec", + "ra": local_sidereal_time, + "dec": lsst_lat, + }, + "pressure": 72, + "temperature": 290, + "H2O_pressure": 0.9, + } + im5d = galsim.config.BuildImage(config) + assert im5d == im5 + + im6 = galsim.ImageI(50, 50, wcs=wcs) + star = galsim.DeltaFunction() * sed + star = star.withFlux(flux, bandpass=bandpass) + chrom_PSF = galsim.ChromaticAtmosphere( + base_PSF, + base_wavelength=bandpass.effective_wavelength, + obj_coord=obj_coord, + HA=local_sidereal_time - obj_coord.ra, + latitude=lsst_lat, + pressure=72, + temperature=290, + H2O_pressure=0.9, + alpha=0, + ) + chrom = galsim.Convolve(star, chrom_PSF) + chrom.drawImage(bandpass, image=im6) + + printval(im5, im6, show=False) + np.testing.assert_allclose( + im5.array, im6.array, atol=1.0e3, err_msg="PhotonDCR with alpha=0 didn't match" + ) + + # Use ChromaticAtmosphere in photon_ops + im7 = galsim.ImageI(50, 50, wcs=wcs) + photon_ops = [chrom_PSF] + star.drawImage(bandpass, image=im7, method="phot", rng=rng, photon_ops=photon_ops) + printval(im7, im6, show=False) + np.testing.assert_allclose( + im7.array, + im6.array, + atol=1.0e3, + err_msg="ChromaticAtmosphere in photon_ops didn't match", + ) + + # ChromaticAtmosphere in photon_ops is almost trivially equal to base_psf and dcr in photon_ops. + im8 = galsim.ImageI(50, 50, wcs=wcs) + photon_ops = [base_PSF, dcr] + star.drawImage(bandpass, image=im8, method="phot", rng=rng, photon_ops=photon_ops) + printval(im8, im6, show=False) + np.testing.assert_allclose( + im8.array, + im6.array, + atol=1.0e3, + err_msg="base_psf + dcr in photon_ops didn't match", + ) + + # Including the wavelength sampler with chromatic drawing is not necessary, but is allowed. + # (Mostly in case someone wants to do something a little different w.r.t. wavelength sampling. + photon_ops = [wave_sampler, base_PSF, dcr] + star.drawImage(bandpass, image=im8, method="phot", rng=rng, photon_ops=photon_ops) + printval(im8, im6, show=False) + np.testing.assert_allclose( + im8.array, + im6.array, + atol=1.0e3, + err_msg="wave_sampler,base_psf,dcr in photon_ops didn't match", + ) + + # Also check invalid parameters + zenith_coord = galsim.CelestialCoord(13.54 * galsim.hours, lsst_lat) + assert_raises( + TypeError, + galsim.PhotonDCR, + zenith_angle=zenith_angle, + parallactic_angle=parallactic_angle, + ) # base_wavelength is required + assert_raises( + TypeError, + galsim.PhotonDCR, + base_wavelength=500, + parallactic_angle=parallactic_angle, + ) # zenith_angle (somehow) is required + assert_raises( + TypeError, + galsim.PhotonDCR, + 500, + zenith_angle=34.4, + parallactic_angle=parallactic_angle, + ) # zenith_angle must be Angle + assert_raises( + TypeError, + galsim.PhotonDCR, + 500, + zenith_angle=zenith_angle, + parallactic_angle=34.5, + ) # parallactic_angle must be Angle + assert_raises( + TypeError, galsim.PhotonDCR, 500, obj_coord=obj_coord, latitude=lsst_lat + ) # Missing HA + assert_raises( + TypeError, + galsim.PhotonDCR, + 500, + obj_coord=obj_coord, + HA=local_sidereal_time - obj_coord.ra, + ) # Missing latitude + assert_raises( + TypeError, galsim.PhotonDCR, 500, obj_coord=obj_coord + ) # Need either zenith_coord, or (HA,lat) + assert_raises( + TypeError, + galsim.PhotonDCR, + 500, + obj_coord=obj_coord, + zenith_coord=zenith_coord, + HA=local_sidereal_time - obj_coord.ra, + ) # Can't have both HA and zenith_coord + assert_raises( + TypeError, + galsim.PhotonDCR, + 500, + obj_coord=obj_coord, + zenith_coord=zenith_coord, + latitude=lsst_lat, + ) # Can't have both lat and zenith_coord + assert_raises( + TypeError, + galsim.PhotonDCR, + 500, + zenith_angle=zenith_angle, + parallactic_angle=parallactic_angle, + H20_pressure=1.0, + ) # invalid (misspelled) + assert_raises( + ValueError, + galsim.PhotonDCR, + 500, + zenith_angle=zenith_angle, + parallactic_angle=parallactic_angle, + scale_unit="inches", + ) # invalid scale_unit + photons = galsim.PhotonArray(2, flux=1) + assert_raises( + galsim.GalSimError, dcr.applyTo, photons + ) # Requires wavelengths to be set + assert_raises( + galsim.GalSimError, chrom_PSF.applyTo, photons + ) # Requires wavelengths to be set + photons = galsim.PhotonArray(2, flux=1, wavelength=500) + assert_raises(TypeError, dcr.applyTo, photons) # Requires local_wcs + + # Invalid to use dcr without some way of setting wavelengths. + assert_raises( + galsim.GalSimError, achrom.drawImage, im2, method="phot", photon_ops=[dcr] + ) + + +@unittest.skipIf(no_astroplan, "Unable to import astroplan") +@timer +def test_dcr_angles(): + """Check the DCR angle calculations by comparing to astroplan's calculations of the same.""" + # Note: test_chromatic.py and test_sed.py both also test aspects of the dcr module, so + # this particular test could belong in either of them too. But I (MJ) put it here, since + # I wrote it in conjunction with the tests of PhotonDCR to try to make sure that code + # is working properly. + import astropy.time + + # Set up an observation date, time, location, coordinate + # These are arbitrary, so ripped from astroplan's docs + # https://media.readthedocs.org/pdf/astroplan/latest/astroplan.pdf + subaru = astroplan.Observer.at_site("subaru") + time = astropy.time.Time("2015-06-16 12:00:00") + + # Stars that are visible from the north in summer time. + names = [ + "Vega", + "Polaris", + "Altair", + "Regulus", + "Spica", + "Algol", + "Fomalhaut", + "Markab", + "Deneb", + "Mizar", + "Dubhe", + "Sirius", + "Rigel", + "Alderamin", + ] + + for name in names: + try: + star = astroplan.FixedTarget.from_name(name) + except Exception as e: + print("Caught exception trying to make star from name ", name) + print(e) + print("Aborting. (Probably some kind of network problem.)") + return + print(star) + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore") + ap_z = subaru.altaz(time, star).zen + ap_q = subaru.parallactic_angle(time, star) + local_sidereal_time = subaru.local_sidereal_time(time) + print("According to astroplan:") + print(" z,q = ", ap_z.deg, ap_q.deg) + + # Repeat with GalSim + coord = galsim.CelestialCoord( + star.ra.deg * galsim.degrees, star.dec.deg * galsim.degrees + ) + lat = subaru.location.lat.deg * galsim.degrees + ha = local_sidereal_time.deg * galsim.degrees - coord.ra + zenith = galsim.CelestialCoord(local_sidereal_time.deg * galsim.degrees, lat) + + # Two ways to calculate it + # 1. From coord, ha, lat + z, q, _ = galsim.dcr.parse_dcr_angles(obj_coord=coord, HA=ha, latitude=lat) + print("According to GalSim:") + print(" z,q = ", z / galsim.degrees, q / galsim.degrees) + + np.testing.assert_almost_equal( + z.rad, + ap_z.rad, + 2, + "zenith angle doesn't agree with astroplan's calculation.", + ) + + # Unfortunately, at least as of version 0.4, astroplan's parallactic angle calculation + # has a bug. It computes it as the arctan of some value, but doesn't use arctan2. + # So whenever |q| > 90 degrees, it gets it wrong by 180 degrees. Therefore, we only + # test that tan(q) is right. We'll check the quadrant below in test_dcr_moments(). + np.testing.assert_almost_equal( + np.tan(q), + np.tan(ap_q), + 2, + "parallactic angle doesn't agree with astroplan's calculation.", + ) + + # 2. From coord, zenith_coord + z, q, _ = galsim.dcr.parse_dcr_angles(obj_coord=coord, zenith_coord=zenith) + print(" z,q = ", z / galsim.degrees, q / galsim.degrees) + + np.testing.assert_almost_equal( + z.rad, + ap_z.rad, + 2, + "zenith angle doesn't agree with astroplan's calculation.", + ) + np.testing.assert_almost_equal( + np.tan(q), + np.tan(ap_q), + 2, + "parallactic angle doesn't agree with astroplan's calculation.", + ) + + +def test_dcr_moments(): + """Check that DCR gets the direction of the moment changes correct for some simple geometries. + i.e. Basically check the sign conventions used in the DCR code. + """ + # First, the basics. + # 1. DCR shifts blue photons closer to zenith, because the index of refraction larger. + # cf. http://lsstdesc.github.io/chroma/ + # 2. Galsim models profiles as seen from Earth with North up (and therefore East left). + # 3. Hour angle is negative when the object is in the east (soon after rising, say), + # zero when crossing the zenith meridian, and then positive to the west. + + # Use g-band, where the effect is more dramatic across the band than in redder bands. + # Also use a reference wavelength significantly to the red, so there should be a net + # overall shift towards zenith as well as a shear along the line to zenith. + bandpass = galsim.Bandpass("LSST_g.dat", "nm").thin(0.1) + base_wavelength = 600 # > red end of g band + + # Uniform across the band is fine for this. + sed = galsim.SED("1", wave_type="nm", flux_type="fphotons") + rng = galsim.BaseDeviate(31415) + wave_sampler = galsim.WavelengthSampler(sed, bandpass) + + star = galsim.Kolmogorov(fwhm=0.3, flux=1.0e6) # 10^6 photons should be enough. + im = galsim.ImageD( + 50, 50, scale=0.05 + ) # Small pixel scale, so shift is many pixels. + ra = 0 * galsim.degrees # Completely irrelevant here. + lat = -20 * galsim.degrees # Also doesn't really matter much. + + # 1. HA < 0, Dec < lat Spot should be shifted up and right. e2 > 0. + dcr = galsim.PhotonDCR( + base_wavelength=base_wavelength, + HA=-2 * galsim.hours, + latitude=lat, + obj_coord=galsim.CelestialCoord(ra, lat - 20 * galsim.degrees), + ) + photon_ops = [wave_sampler, dcr] + star.drawImage(image=im, method="phot", rng=rng, photon_ops=photon_ops) + moments = galsim.utilities.unweighted_moments(im, origin=im.true_center) + print("1. HA < 0, Dec < lat: ", moments) + assert moments["My"] > 0 # up + assert moments["Mx"] > 0 # right + assert moments["Mxy"] > 0 # e2 > 0 + + # 2. HA = 0, Dec < lat Spot should be shifted up. e1 < 0, e2 ~= 0. + dcr = galsim.PhotonDCR( + base_wavelength=base_wavelength, + HA=0 * galsim.hours, + latitude=lat, + obj_coord=galsim.CelestialCoord(ra, lat - 20 * galsim.degrees), + ) + photon_ops = [wave_sampler, dcr] + star.drawImage(image=im, method="phot", rng=rng, photon_ops=photon_ops) + moments = galsim.utilities.unweighted_moments(im, origin=im.true_center) + print("2. HA = 0, Dec < lat: ", moments) + assert moments["My"] > 0 # up + assert abs(moments["Mx"]) < 0.05 # not left or right + assert moments["Mxx"] < moments["Myy"] # e1 < 0 + assert abs(moments["Mxy"]) < 0.1 # e2 ~= 0 + + # 3. HA > 0, Dec < lat Spot should be shifted up and left. e2 < 0. + dcr = galsim.PhotonDCR( + base_wavelength=base_wavelength, + HA=2 * galsim.hours, + latitude=lat, + obj_coord=galsim.CelestialCoord(ra, lat - 20 * galsim.degrees), + ) + photon_ops = [wave_sampler, dcr] + star.drawImage(image=im, method="phot", rng=rng, photon_ops=photon_ops) + moments = galsim.utilities.unweighted_moments(im, origin=im.true_center) + print("3. HA > 0, Dec < lat: ", moments) + assert moments["My"] > 0 # up + assert moments["Mx"] < 0 # left + assert moments["Mxy"] < 0 # e2 < 0 + + # 4. HA < 0, Dec = lat Spot should be shifted right. e1 > 0, e2 ~= 0. + dcr = galsim.PhotonDCR( + base_wavelength=base_wavelength, + HA=-2 * galsim.hours, + latitude=lat, + obj_coord=galsim.CelestialCoord(ra, lat), + ) + photon_ops = [wave_sampler, dcr] + star.drawImage(image=im, method="phot", rng=rng, photon_ops=photon_ops) + moments = galsim.utilities.unweighted_moments(im, origin=im.true_center) + print("4. HA < 0, Dec = lat: ", moments) + assert ( + abs(moments["My"]) < 1.0 + ) # not up or down (Actually slightly down in the south.) + assert moments["Mx"] > 0 # right + assert moments["Mxx"] > moments["Myy"] # e1 > 0 + assert ( + abs(moments["Mxy"]) < 2.0 + ) # e2 ~= 0 (Actually slightly negative because of curvature.) + + # 5. HA = 0, Dec = lat Spot should not be shifted. e1 ~= 0, e2 ~= 0. + dcr = galsim.PhotonDCR( + base_wavelength=base_wavelength, + HA=0 * galsim.hours, + latitude=lat, + obj_coord=galsim.CelestialCoord(ra, lat), + ) + photon_ops = [wave_sampler, dcr] + star.drawImage(image=im, method="phot", rng=rng, photon_ops=photon_ops) + moments = galsim.utilities.unweighted_moments(im, origin=im.true_center) + print("5. HA = 0, Dec = lat: ", moments) + assert abs(moments["My"]) < 0.05 # not up or down + assert abs(moments["Mx"]) < 0.05 # not left or right + assert abs(moments["Mxx"] - moments["Myy"]) < 0.1 # e1 ~= 0 + assert abs(moments["Mxy"]) < 0.1 # e2 ~= 0 + + # 6. HA > 0, Dec = lat Spot should be shifted left. e1 > 0, e2 ~= 0. + dcr = galsim.PhotonDCR( + base_wavelength=base_wavelength, + HA=2 * galsim.hours, + latitude=lat, + obj_coord=galsim.CelestialCoord(ra, lat), + ) + photon_ops = [wave_sampler, dcr] + star.drawImage(image=im, method="phot", rng=rng, photon_ops=photon_ops) + moments = galsim.utilities.unweighted_moments(im, origin=im.true_center) + print("6. HA > 0, Dec = lat: ", moments) + assert ( + abs(moments["My"]) < 1.0 + ) # not up or down (Actually slightly down in the south.) + assert moments["Mx"] < 0 # left + assert moments["Mxx"] > moments["Myy"] # e1 > 0 + assert ( + abs(moments["Mxy"]) < 2.0 + ) # e2 ~= 0 (Actually slgihtly positive because of curvature.) + + # 7. HA < 0, Dec > lat Spot should be shifted down and right. e2 < 0. + dcr = galsim.PhotonDCR( + base_wavelength=base_wavelength, + HA=-2 * galsim.hours, + latitude=lat, + obj_coord=galsim.CelestialCoord(ra, lat + 20 * galsim.degrees), + ) + photon_ops = [wave_sampler, dcr] + star.drawImage(image=im, method="phot", rng=rng, photon_ops=photon_ops) + moments = galsim.utilities.unweighted_moments(im, origin=im.true_center) + print("7. HA < 0, Dec > lat: ", moments) + assert moments["My"] < 0 # down + assert moments["Mx"] > 0 # right + assert moments["Mxy"] < 0 # e2 < 0 + + # 8. HA = 0, Dec > lat Spot should be shifted down. e1 < 0, e2 ~= 0. + dcr = galsim.PhotonDCR( + base_wavelength=base_wavelength, + HA=0 * galsim.hours, + latitude=lat, + obj_coord=galsim.CelestialCoord(ra, lat + 20 * galsim.degrees), + ) + photon_ops = [wave_sampler, dcr] + star.drawImage(image=im, method="phot", rng=rng, photon_ops=photon_ops) + moments = galsim.utilities.unweighted_moments(im, origin=im.true_center) + print("8. HA = 0, Dec > lat: ", moments) + assert moments["My"] < 0 # down + assert abs(moments["Mx"]) < 0.05 # not left or right + assert moments["Mxx"] < moments["Myy"] # e1 < 0 + assert abs(moments["Mxy"]) < 0.1 # e2 ~= 0 + + # 9. HA > 0, Dec > lat Spot should be shifted down and left. e2 > 0. + dcr = galsim.PhotonDCR( + base_wavelength=base_wavelength, + HA=2 * galsim.hours, + latitude=lat, + obj_coord=galsim.CelestialCoord(ra, lat + 20 * galsim.degrees), + ) + photon_ops = [wave_sampler, dcr] + star.drawImage(image=im, method="phot", rng=rng, photon_ops=photon_ops) + moments = galsim.utilities.unweighted_moments(im, origin=im.true_center) + print("9. HA > 0, Dec > lat: ", moments) + assert moments["My"] < 0 # down + assert moments["Mx"] < 0 # left + assert moments["Mxy"] > 0 # e2 > 0 + + +@timer +def test_refract(): + ud = galsim.UniformDeviate(57721) + for _ in range(1000): + photon_array = galsim.PhotonArray(1000, flux=1) + photon_array.allocateAngles() + ud.generate(photon_array.dxdz) + ud.generate(photon_array.dydz) + photon_array.dxdz *= 1.2 # -0.6 to 0.6 + photon_array.dydz *= 1.2 + photon_array.dxdz -= 0.6 + photon_array.dydz -= 0.6 + # copy for testing later + dxdz0 = np.array(photon_array.dxdz) + dydz0 = np.array(photon_array.dydz) + index_ratio = ud() * 4 + 0.25 # 0.25 to 4.25 + refract = galsim.Refraction(index_ratio) + refract.applyTo(photon_array) + + check_pickle(refract) + + # Triangle is length 1 in the z direction and length sqrt(dxdz**2+dydz**2) + # in the 'r' direction. + rsqr0 = dxdz0**2 + dydz0**2 + sintheta0 = np.sqrt(rsqr0) / np.sqrt(1 + rsqr0) + # See if total internal reflection applies + w = sintheta0 < index_ratio + np.testing.assert_array_equal(photon_array.dxdz[~w], np.nan) + np.testing.assert_array_equal(photon_array.dydz[~w], np.nan) + np.testing.assert_array_equal(photon_array.flux, np.where(w, 1.0, 0.0)) + + sintheta0 = sintheta0[w] + dxdz0 = dxdz0[w] + dydz0 = dydz0[w] + dxdz1 = photon_array.dxdz[w] + dydz1 = photon_array.dydz[w] + rsqr1 = dxdz1**2 + dydz1**2 + sintheta1 = np.sqrt(rsqr1) / np.sqrt(1 + rsqr1) + # Check Snell's law + np.testing.assert_allclose(sintheta0, index_ratio * sintheta1) + + # Check azimuthal angle stays constant + phi0 = np.arctan2(dydz0, dxdz0) + phi1 = np.arctan2(dydz1, dxdz1) + np.testing.assert_allclose(phi0, phi1) + + # Check plane of refraction is perpendicular to (0,0,1) + np.testing.assert_allclose( + np.dot( + np.cross( + np.stack([dxdz0, dydz0, -np.ones(len(dxdz0))], axis=1), + np.stack([dxdz1, dydz1, -np.ones(len(dxdz1))], axis=1), + ), + [0, 0, 1], + ), + 0.0, + rtol=0, + atol=1e-13, + ) + + # Try a wavelength dependent index_ratio + index_ratio = lambda w: np.where(w < 1, 1.1, 2.2) + photon_array = galsim.PhotonArray(100) + photon_array.allocateWavelengths() + photon_array.allocateAngles() + ud.generate(photon_array.wavelength) + ud.generate(photon_array.dxdz) + ud.generate(photon_array.dydz) + photon_array.dxdz *= 1.2 # -0.6 to 0.6 + photon_array.dydz *= 1.2 + photon_array.dxdz -= 0.6 + photon_array.dydz -= 0.6 + photon_array.wavelength *= 2 # 0 to 2 + dxdz0 = photon_array.dxdz.copy() + dydz0 = photon_array.dydz.copy() + + refract_func = galsim.Refraction(index_ratio=index_ratio) + refract_func.applyTo(photon_array) + dxdz_func = photon_array.dxdz.copy() + dydz_func = photon_array.dydz.copy() + + photon_array.dxdz = dxdz0.copy() + photon_array.dydz = dydz0.copy() + refract11 = galsim.Refraction(index_ratio=1.1) + refract11.applyTo(photon_array) + dxdz11 = photon_array.dxdz.copy() + dydz11 = photon_array.dydz.copy() + + photon_array.dxdz = dxdz0.copy() + photon_array.dydz = dydz0.copy() + refract22 = galsim.Refraction(index_ratio=2.2) + refract22.applyTo(photon_array) + dxdz22 = photon_array.dxdz.copy() + dydz22 = photon_array.dydz.copy() + + w = photon_array.wavelength < 1 + np.testing.assert_allclose(dxdz_func, np.where(w, dxdz11, dxdz22)) + np.testing.assert_allclose(dydz_func, np.where(w, dydz11, dydz22)) + + +@timer +def test_focus_depth(): + bd = galsim.BaseDeviate(1234) + for _ in range(100): + # Test that FocusDepth is additive + photon_array = galsim.PhotonArray(1000) + photon_array2 = galsim.PhotonArray(1000) + photon_array.x = 0.0 + photon_array.y = 0.0 + photon_array2.x = 0.0 + photon_array2.y = 0.0 + galsim.FRatioAngles(1.234, obscuration=0.606).applyTo(photon_array, rng=bd) + photon_array2.dxdz = photon_array.dxdz + photon_array2.dydz = photon_array.dydz + fd1 = galsim.FocusDepth(1.1) + fd2 = galsim.FocusDepth(2.2) + fd3 = galsim.FocusDepth(3.3) + fd1.applyTo(photon_array) + fd2.applyTo(photon_array) + fd3.applyTo(photon_array2) + + check_pickle(fd1) + + np.testing.assert_allclose(photon_array.x, photon_array2.x, rtol=0, atol=1e-15) + np.testing.assert_allclose(photon_array.y, photon_array2.y, rtol=0, atol=1e-15) + # Assuming focus is at x=y=0, then + # intrafocal (depth < 0) => (x > 0 => dxdz < 0) + # extrafocal (depth > 0) => (x > 0 => dxdz > 0) + # We applied an extrafocal operation above, so check for corresponding + # relation between x, dxdz + np.testing.assert_array_less(0, photon_array.x * photon_array.dxdz) + + # transforming by depth and -depth is null + fd4 = galsim.FocusDepth(-3.3) + fd4.applyTo(photon_array) + np.testing.assert_allclose(photon_array.x, 0.0, rtol=0, atol=1e-15) + np.testing.assert_allclose(photon_array.y, 0.0, rtol=0, atol=1e-15) + + # Check that invalid photon array is trapped + pa = galsim.PhotonArray(10) + fd = galsim.FocusDepth(1.0) + with np.testing.assert_raises(galsim.GalSimError): + fd.applyTo(pa) + + # Check that we can infer depth from photon positions before and after... + for _ in range(100): + photon_array = galsim.PhotonArray(1000) + photon_array2 = galsim.PhotonArray(1000) + ud = galsim.UniformDeviate(bd) + ud.generate(photon_array.x) + ud.generate(photon_array.y) + photon_array.x -= 0.5 + photon_array.y -= 0.5 + galsim.FRatioAngles(1.234, obscuration=0.606).applyTo(photon_array, rng=bd) + photon_array2.x = photon_array.x + photon_array2.y = photon_array.y + photon_array2.dxdz = photon_array.dxdz + photon_array2.dydz = photon_array.dydz + depth = ud() - 0.5 + galsim.FocusDepth(depth).applyTo(photon_array2) + np.testing.assert_allclose( + (photon_array2.x - photon_array.x) / photon_array.dxdz, depth + ) + np.testing.assert_allclose( + (photon_array2.y - photon_array.y) / photon_array.dydz, depth + ) + np.testing.assert_allclose(photon_array.dxdz, photon_array2.dxdz) + np.testing.assert_allclose(photon_array.dydz, photon_array2.dydz) + + +@timer +def test_lsst_y_focus(): + # Check that applying reasonable focus depth (from O'Connor++06) indeed leads to smaller spot + # size for LSST y-band. + rng = galsim.BaseDeviate(9876543210) + bandpass = galsim.Bandpass("LSST_y.dat", wave_type="nm") + sed = galsim.SED("1", wave_type="nm", flux_type="flambda") + obj = galsim.Gaussian(fwhm=1e-5) + oversampling = 32 + photon_ops0 = [ + galsim.WavelengthSampler(sed, bandpass), + galsim.FRatioAngles(1.234, 0.606), + galsim.FocusDepth(0.0), + galsim.Refraction(3.9), + ] + img0 = obj.drawImage( + sensor=galsim.SiliconSensor(), + method="phot", + n_photons=100000, + photon_ops=photon_ops0, + scale=0.2 / oversampling, + nx=32 * oversampling, + ny=32 * oversampling, + rng=rng, + ) + T0 = img0.calculateMomentRadius() + T0 *= 10 * oversampling / 0.2 # arcsec => microns + + # O'Connor finds minimum spot size when the focus depth is ~ -12 microns. Our sensor isn't + # necessarily the same as the one there though; our minimum seems to be around -6 microns. + # That could be due to differences in the design of the sensor though. We just use -6 microns + # here, which is still useful to test the sign of the `depth` parameter and the interaction of + # the 4 different surface operators required to produce this effect, and is roughly consistent + # with O'Connor. + + depth1 = -6.0 # microns, negative means surface is intrafocal + depth1 /= 10 # microns => pixels + photon_ops1 = [ + galsim.WavelengthSampler(sed, bandpass), + galsim.FRatioAngles(1.234, 0.606), + galsim.FocusDepth(depth1), + galsim.Refraction(3.9), + ] + img1 = obj.drawImage( + sensor=galsim.SiliconSensor(), + method="phot", + n_photons=100000, + photon_ops=photon_ops1, + scale=0.2 / oversampling, + nx=32 * oversampling, + ny=32 * oversampling, + rng=rng, + ) + T1 = img1.calculateMomentRadius() + T1 *= 10 * oversampling / 0.2 # arcsec => microns + np.testing.assert_array_less(T1, T0) + + +@timer +def test_fromArrays(): + """Check that fromArrays constructor catches errors and ALWAYS copies.""" + + rng = galsim.BaseDeviate(123) + + x = np.empty(1000) + y = np.empty(1000) + flux = np.empty(1000) + + Nsplit = 444 + + pa_batch = galsim.PhotonArray.fromArrays(x, y, flux) + pa_1 = galsim.PhotonArray.fromArrays(x[:Nsplit], y[:Nsplit], flux[:Nsplit]) + pa_2 = galsim.PhotonArray.fromArrays(x[Nsplit:], y[Nsplit:], flux[Nsplit:]) + + assert pa_batch.x is not x + assert pa_batch.y is not y + assert pa_batch.flux is not flux + np.testing.assert_array_equal(pa_batch.x, x) + np.testing.assert_array_equal(pa_batch.y, y) + np.testing.assert_array_equal(pa_batch.flux, flux) + np.testing.assert_array_equal(pa_1.x, pa_batch.x[:Nsplit]) + np.testing.assert_array_equal(pa_1.y, pa_batch.y[:Nsplit]) + np.testing.assert_array_equal(pa_1.flux, pa_batch.flux[:Nsplit]) + np.testing.assert_array_equal(pa_2.x, pa_batch.x[Nsplit:]) + np.testing.assert_array_equal(pa_2.y, pa_batch.y[Nsplit:]) + np.testing.assert_array_equal(pa_2.flux, pa_batch.flux[Nsplit:]) + + # Do some manipulation and check views are still equivalent + obj1 = galsim.Gaussian(fwhm=0.1) * 64 + obj2 = galsim.Kolmogorov(fwhm=0.2) * 23 + + obj1._shoot(pa_1, rng) + obj2._shoot(pa_2, rng) + + assert pa_batch.x is x + assert pa_batch.y is y + assert pa_batch.flux is flux + np.testing.assert_array_equal(pa_batch.x, x) + np.testing.assert_array_equal(pa_batch.y, y) + np.testing.assert_array_equal(pa_batch.flux, flux) + np.testing.assert_array_equal(pa_1.x, pa_batch.x[:Nsplit]) + np.testing.assert_array_equal(pa_1.y, pa_batch.y[:Nsplit]) + np.testing.assert_array_equal(pa_1.flux, pa_batch.flux[:Nsplit]) + np.testing.assert_array_equal(pa_2.x, pa_batch.x[Nsplit:]) + np.testing.assert_array_equal(pa_2.y, pa_batch.y[Nsplit:]) + np.testing.assert_array_equal(pa_2.flux, pa_batch.flux[Nsplit:]) + + # Add some optional args and apply PhotonOps to the batch this time. + dxdz = np.empty(1000) + dydz = np.empty(1000) + wavelength = np.empty(1000) + pupil_u = np.empty(1000) + pupil_v = np.empty(1000) + time = np.empty(1000) + pa_batch = galsim.PhotonArray.fromArrays( + x, y, flux, dxdz, dydz, wavelength, pupil_u, pupil_v, time + ) + pa_1 = galsim.PhotonArray.fromArrays( + x[:Nsplit], + y[:Nsplit], + flux[:Nsplit], + dxdz[:Nsplit], + dydz[:Nsplit], + wavelength[:Nsplit], + pupil_u[:Nsplit], + pupil_v[:Nsplit], + time[:Nsplit], + ) + pa_2 = galsim.PhotonArray.fromArrays( + x[Nsplit:], + y[Nsplit:], + flux[Nsplit:], + dxdz[Nsplit:], + dydz[Nsplit:], + wavelength[Nsplit:], + pupil_u[Nsplit:], + pupil_v[Nsplit:], + time[Nsplit:], + ) + + sed = galsim.SED("vega.txt", wave_type="nm", flux_type="flambda") + bp = galsim.Bandpass("LSST_r.dat", wave_type="nm") + with assert_warns(galsim.GalSimWarning): + galsim.WavelengthSampler(sed, bp).applyTo(pa_batch, rng=rng) + galsim.FRatioAngles(1.2, 0.61).applyTo(pa_batch, rng=rng) + galsim.TimeSampler(0.0, 30.0).applyTo(pa_batch, rng=rng) + + assert pa_batch.x is x + assert pa_batch.y is y + assert pa_batch.flux is flux + assert pa_batch.dxdz is dxdz + assert pa_batch.dydz is dydz + assert pa_batch.wavelength is wavelength + assert pa_batch.pupil_u is pupil_u + assert pa_batch.pupil_v is pupil_v + assert pa_batch.time is time + np.testing.assert_array_equal(pa_batch.x, x) + np.testing.assert_array_equal(pa_batch.y, y) + np.testing.assert_array_equal(pa_batch.flux, flux) + np.testing.assert_array_equal(pa_batch.dxdz, dxdz) + np.testing.assert_array_equal(pa_batch.dydz, dydz) + np.testing.assert_array_equal(pa_batch.wavelength, wavelength) + np.testing.assert_array_equal(pa_batch.pupil_u, pupil_u) + np.testing.assert_array_equal(pa_batch.pupil_v, pupil_v) + np.testing.assert_array_equal(pa_batch.time, time) + np.testing.assert_array_equal(pa_1.x, pa_batch.x[:Nsplit]) + np.testing.assert_array_equal(pa_1.y, pa_batch.y[:Nsplit]) + np.testing.assert_array_equal(pa_1.flux, pa_batch.flux[:Nsplit]) + np.testing.assert_array_equal(pa_1.dxdz, pa_batch.dxdz[:Nsplit]) + np.testing.assert_array_equal(pa_1.dydz, pa_batch.dydz[:Nsplit]) + np.testing.assert_array_equal(pa_1.wavelength, pa_batch.wavelength[:Nsplit]) + np.testing.assert_array_equal(pa_1.pupil_u, pa_batch.pupil_u[:Nsplit]) + np.testing.assert_array_equal(pa_1.pupil_v, pa_batch.pupil_v[:Nsplit]) + np.testing.assert_array_equal(pa_1.time, pa_batch.time[:Nsplit]) + np.testing.assert_array_equal(pa_2.x, pa_batch.x[Nsplit:]) + np.testing.assert_array_equal(pa_2.y, pa_batch.y[Nsplit:]) + np.testing.assert_array_equal(pa_2.flux, pa_batch.flux[Nsplit:]) + np.testing.assert_array_equal(pa_2.dxdz, pa_batch.dxdz[Nsplit:]) + np.testing.assert_array_equal(pa_2.dydz, pa_batch.dydz[Nsplit:]) + np.testing.assert_array_equal(pa_2.wavelength, pa_batch.wavelength[Nsplit:]) + np.testing.assert_array_equal(pa_2.pupil_u, pa_batch.pupil_u[Nsplit:]) + np.testing.assert_array_equal(pa_2.pupil_v, pa_batch.pupil_v[Nsplit:]) + np.testing.assert_array_equal(pa_2.time, pa_batch.time[Nsplit:]) + + # Check the is_corr flag gets set + assert not pa_batch.isCorrelated() + pa_batch = galsim.PhotonArray.fromArrays( + x, y, flux, dxdz, dydz, wavelength, is_corr=True + ) + assert pa_batch.isCorrelated() + + # Check some invalid inputs are caught + with np.testing.assert_raises(TypeError): + galsim.PhotonArray.fromArrays(list(x), y, flux, dxdz, dydz, wavelength) + with np.testing.assert_raises(TypeError): + galsim.PhotonArray.fromArrays( + np.empty(1000, dtype=int), y, flux, dxdz, dydz, wavelength + ) + with np.testing.assert_raises(ValueError): + galsim.PhotonArray.fromArrays(x[:10], y, flux, dxdz, dydz, wavelength) + with np.testing.assert_raises(ValueError): + galsim.PhotonArray.fromArrays( + np.empty(2000)[::2], y, flux, dxdz, dydz, wavelength + ) + + +def test_pupil_annulus_sampler(): + """Check that we get a uniform distribution from PupilAnnulusSampler""" + seed = 54321 + sampler = galsim.PupilAnnulusSampler(1.0, 0.5) + pa = galsim.PhotonArray(1_000_000) + sampler.applyTo(pa, rng=seed) + r = np.hypot(pa.pupil_u, pa.pupil_v) + assert np.min(r) > 0.5 + assert np.max(r) < 1.0 + h, edges = np.histogram( + r, + bins=10, + range=(0.5, 1.0), + ) + areas = np.pi * (edges[1:] ** 2 - edges[:-1] ** 2) + # each bin should have ~100_000 photons, so +/- 0.3%. Test at 1%. + assert np.std(h / areas) / np.mean(h / areas) < 0.01 + + phi = np.arctan2(pa.pupil_v, pa.pupil_u) + phi[phi < 0] += 2 * np.pi + h, edges = np.histogram(phi, bins=10, range=(0.0, 2 * np.pi)) + assert np.std(h) / np.mean(h) < 0.01 + + check_pickle(sampler) + + +def test_time_sampler(): + """Check TimeSampler build arguments""" + seed = 97531 + sampler = galsim.TimeSampler() + assert sampler.t0 == 0 + assert sampler.exptime == 0 + pa = galsim.PhotonArray(1_000_000) + sampler.applyTo(pa, rng=seed) + np.testing.assert_array_equal(pa.time, 0.0) + check_pickle(sampler) + + sampler = galsim.TimeSampler(t0=1.0) + assert sampler.t0 == 1 + assert sampler.exptime == 0 + sampler.applyTo(pa, rng=seed) + np.testing.assert_array_equal(pa.time, 1.0) + check_pickle(sampler) + + sampler = galsim.TimeSampler(exptime=30.0) + assert sampler.t0 == 0 + assert sampler.exptime == 30 + sampler.applyTo(pa, rng=seed) + np.testing.assert_array_less(pa.time, 30) + np.testing.assert_array_less(-pa.time, 0) + check_pickle(sampler) + + sampler = galsim.TimeSampler(t0=10, exptime=30.0) + assert sampler.t0 == 10 + assert sampler.exptime == 30 + sampler.applyTo(pa, rng=seed) + np.testing.assert_array_less(pa.time, 40) + np.testing.assert_array_less(-pa.time, 10) + check_pickle(sampler) + + +def test_setFromImage_crash(): + """Geri Braunlich ran into a seg fault where the photon array was not allocated to be + sufficiently large for the photons it got from an image. + This test reproduces the error for version 2.4.8 for the purpose of fixing it. + + The bug turned out to be that some pixel values were (slightly) negative from the FFT, + and the total flux was estimated as np.sum(image.array). The negative pixels added + negatively to this sum, so the calculated total flux wasn't quite enough to hold all the + required photons. + + The fix was to use the absolute value of the image for this calculation. + """ + # These are (approximately) the specific values for one case where the code used to crash. + prof = galsim.Gaussian(sigma=0.13).withFlux(3972551) + wcs = galsim.JacobianWCS(-0.170, -0.106, 0.106, -0.170) + image = galsim.Image(1000, 1000, wcs=wcs, dtype=float) + + # Start with a simple draw with no photons + im1 = prof.drawImage(image=image.copy()) + + # Now with photon_ops. + # This had been sufficient to trigger the bug, but now photon_ops=[] is the same as None. + im2 = prof.drawImage(image=image.copy(), photon_ops=[], n_subsample=1) + assert im1 == im2 + + # Repeat with a non-empty, but still trivial, photon_ops. + im3 = prof.drawImage( + image=image.copy(), photon_ops=[galsim.FRatioAngles(1.2)], n_subsample=1 + ) + + # They aren't quite identical because of numerical rounding issues from going through + # a sum of fluxes on individual photons. + # In particular, we want to make sure negative pixels stay negative through this process. + assert im1 != im3 + np.testing.assert_allclose(im1.array, im3.array, rtol=1.0e-11) + w = np.where(im1.array != im3.array) + print("diff in ", len(w[0]), "pixels") + assert ( + len(w[0]) < 100 + ) # I find it to be different in only 39 photons on my machine. + + +if __name__ == "__main__": + testfns = [v for k, v in vars().items() if k[:5] == "test_" and callable(v)] + if no_astroplan: + print("Skipping test_dcr_angles, since astroplan not installed.") + testfns.remove(test_dcr_angles) + for testfn in testfns: + testfn() diff --git a/tests/jax/galsim/test_random_jax.py b/tests/jax/galsim/test_random_jax.py index 879cdad2..30941daa 100644 --- a/tests/jax/galsim/test_random_jax.py +++ b/tests/jax/galsim/test_random_jax.py @@ -166,31 +166,26 @@ def test_uniform(): u.raw(), u2.raw(), err_msg="Uniform deviates generate different raw values" ) - # NOTE: these tests differ from galsim since we cannot connect - # generators rng2 = galsim.BaseDeviate(testseed) rng2.discard(4) - rng.discard(4) # new line relative to galsim np.testing.assert_equal( rng.raw(), rng2.raw(), err_msg="BaseDeviates generate different raw values after discard", ) - # NOTE: these tests will never work in galsim since we cannot - # connect RNGs in JAX - # # Check that two connected uniform deviates work correctly together. - # u2 = galsim.UniformDeviate(testseed) - # u.reset(u2) - # testResult2 = (u(), u2(), u()) - # np.testing.assert_array_equal( - # np.array(testResult), np.array(testResult2), - # err_msg='Wrong uniform random number sequence generated using two uds') - # u.seed(testseed) - # testResult2 = (u2(), u(), u2()) - # np.testing.assert_array_equal( - # np.array(testResult), np.array(testResult2), - # err_msg='Wrong uniform random number sequence generated using two uds after seed') + # Check that two connected uniform deviates work correctly together. + u2 = galsim.UniformDeviate(testseed) + u.reset(u2) + testResult2 = (u(), u2(), u()) + np.testing.assert_array_equal( + np.array(testResult), np.array(testResult2), + err_msg='Wrong uniform random number sequence generated using two uds') + u.seed(testseed) + testResult2 = (u2(), u(), u2()) + np.testing.assert_array_equal( + np.array(testResult), np.array(testResult2), + err_msg='Wrong uniform random number sequence generated using two uds after seed') # Check that seeding with the time works (although we cannot check the output). # We're mostly just checking that this doesn't raise an exception. @@ -387,22 +382,21 @@ def test_gaussian(): np.array(testResult), np.array(testResult2), err_msg='Wrong Gaussian random number sequence generated after reset(ud)') - # NOTE jax doesn't allow connected RNGs - # # Check that two connected Gaussian deviates work correctly together. - # g2 = galsim.GaussianDeviate(testseed, mean=gMean, sigma=gSigma) - # g.reset(g2) - # # Note: GaussianDeviate generates two values at a time, so we have to compare them in pairs. - # testResult2 = (g(), g(), g2()) - # np.testing.assert_array_equal( - # np.array(testResult), np.array(testResult2), - # err_msg='Wrong Gaussian random number sequence generated using two gds') - # g.seed(testseed) - # # For the same reason, after seeding one, we need to manually clear the other's cache: - # g2.clearCache() - # testResult2 = (g2(), g2(), g()) - # np.testing.assert_array_equal( - # np.array(testResult), np.array(testResult2), - # err_msg='Wrong Gaussian random number sequence generated using two gds after seed') + # Check that two connected Gaussian deviates work correctly together. + g2 = galsim.GaussianDeviate(testseed, mean=gMean, sigma=gSigma) + g.reset(g2) + # Note: GaussianDeviate generates two values at a time, so we have to compare them in pairs. + testResult2 = (g(), g(), g2()) + np.testing.assert_array_equal( + np.array(testResult), np.array(testResult2), + err_msg='Wrong Gaussian random number sequence generated using two gds') + g.seed(testseed) + # For the same reason, after seeding one, we need to manually clear the other's cache: + g2.clearCache() + testResult2 = (g2(), g2(), g()) + np.testing.assert_array_equal( + np.array(testResult), np.array(testResult2), + err_msg='Wrong Gaussian random number sequence generated using two gds after seed') # Check that seeding with the time works (although we cannot check the output). # We're mostly just checking that this doesn't raise an exception. @@ -590,19 +584,18 @@ def test_binomial(): np.array(testResult), np.array(testResult2), err_msg='Wrong binomial random number sequence generated after reset(ud)') - # NOTE JAX does not support connected RNGs - # # Check that two connected binomial deviates work correctly together. - # b2 = galsim.BinomialDeviate(testseed, N=bN, p=bp) - # b.reset(b2) - # testResult2 = (b(), b2(), b()) - # np.testing.assert_array_equal( - # np.array(testResult), np.array(testResult2), - # err_msg='Wrong binomial random number sequence generated using two bds') - # b.seed(testseed) - # testResult2 = (b2(), b(), b2()) - # np.testing.assert_array_equal( - # np.array(testResult), np.array(testResult2), - # err_msg='Wrong binomial random number sequence generated using two bds after seed') + # Check that two connected binomial deviates work correctly together. + b2 = galsim.BinomialDeviate(testseed, N=bN, p=bp) + b.reset(b2) + testResult2 = (b(), b2(), b()) + np.testing.assert_array_equal( + np.array(testResult), np.array(testResult2), + err_msg='Wrong binomial random number sequence generated using two bds') + b.seed(testseed) + testResult2 = (b2(), b(), b2()) + np.testing.assert_array_equal( + np.array(testResult), np.array(testResult2), + err_msg='Wrong binomial random number sequence generated using two bds after seed') # Check that seeding with the time works (although we cannot check the output). # We're mostly just checking that this doesn't raise an exception. @@ -778,19 +771,18 @@ def test_poisson(): np.array(testResult), np.array(testResult2), err_msg='Wrong poisson random number sequence generated after reset(ud)') - # NOTE: jax-galsim doesn't have connected RNGs - # # Check that two connected poisson deviates work correctly together. - # p2 = galsim.PoissonDeviate(testseed, mean=pMean) - # p.reset(p2) - # testResult2 = (p(), p2(), p()) - # np.testing.assert_array_equal( - # np.array(testResult), np.array(testResult2), - # err_msg='Wrong poisson random number sequence generated using two pds') - # p.seed(testseed) - # testResult2 = (p2(), p(), p2()) - # np.testing.assert_array_equal( - # np.array(testResult), np.array(testResult2), - # err_msg='Wrong poisson random number sequence generated using two pds after seed') + # Check that two connected poisson deviates work correctly together. + p2 = galsim.PoissonDeviate(testseed, mean=pMean) + p.reset(p2) + testResult2 = (p(), p2(), p()) + np.testing.assert_array_equal( + np.array(testResult), np.array(testResult2), + err_msg='Wrong poisson random number sequence generated using two pds') + p.seed(testseed) + testResult2 = (p2(), p(), p2()) + np.testing.assert_array_equal( + np.array(testResult), np.array(testResult2), + err_msg='Wrong poisson random number sequence generated using two pds after seed') # Check that seeding with the time works (although we cannot check the output). # We're mostly just checking that this doesn't raise an exception. @@ -941,56 +933,53 @@ def test_poisson_highmean(): print('after %d vals, next one is %s, %s' % (nvals, v1, v2)) assert v1 == v2 - # NOTE: JAX cannot connect RNGs - # # Check that two connected poisson deviates work correctly together. - # p2 = galsim.PoissonDeviate(testseed, mean=mean) - # p.reset(p2) - # testResult2 = (p(), p(), p2()) - # np.testing.assert_array_equal( - # testResult2, testResult, - # err_msg='Wrong poisson random number sequence generated using two pds') - # p.seed(testseed) - # p2.clearCache() - # testResult2 = (p2(), p2(), p()) - # np.testing.assert_array_equal( - # testResult2, testResult, - # err_msg='Wrong poisson random number sequence generated using two pds after seed') - - # FIXME no noise methods for images in JAX yet + # Check that two connected poisson deviates work correctly together. + p2 = galsim.PoissonDeviate(testseed, mean=mean) + p.reset(p2) + testResult2 = (p(), p(), p2()) + np.testing.assert_array_equal( + testResult2, testResult, + err_msg='Wrong poisson random number sequence generated using two pds') + p.seed(testseed) + p2.clearCache() + testResult2 = (p2(), p2(), p()) + np.testing.assert_array_equal( + testResult2, testResult, + err_msg='Wrong poisson random number sequence generated using two pds after seed') + # Test filling an image - # p.seed(testseed) - # testimage = galsim.ImageD(np.zeros((3, 1))) - # testimage.addNoise(galsim.DeviateNoise(p)) - # np.testing.assert_array_equal( - # testimage.array.flatten(), testResult, - # err_msg='Wrong poisson random number sequence generated when applied to image.') - - # # The PoissonNoise version also subtracts off the mean value - # rng = galsim.BaseDeviate(testseed) - # pn = galsim.PoissonNoise(rng, sky_level=mean) - # testimage.fill(0) - # testimage.addNoise(pn) - # np.testing.assert_array_equal( - # testimage.array.flatten(), np.array(testResult) - mean, - # err_msg='Wrong poisson random number sequence generated using PoissonNoise') - - # # Check PoissonNoise variance: - # np.testing.assert_allclose( - # pn.getVariance(), mean, rtol=1.e-8, - # err_msg="PoissonNoise getVariance returns wrong variance") - # np.testing.assert_allclose( - # pn.sky_level, mean, rtol=1.e-8, - # err_msg="PoissonNoise sky_level returns wrong value") - - # # Check that the noise model really does produce this variance. - # big_im = galsim.Image(2048, 2048, dtype=float) - # big_im.addNoise(pn) - # var = np.var(big_im.array) - # print('variance = ', var) - # print('getVar = ', pn.getVariance()) - # np.testing.assert_allclose( - # var, pn.getVariance(), rtol=rtol_var, - # err_msg='Realized variance for PoissonNoise did not match getVariance()') + p.seed(testseed) + testimage = galsim.ImageD(np.zeros((3, 1))) + testimage.addNoise(galsim.DeviateNoise(p)) + np.testing.assert_array_equal( + testimage.array.flatten(), testResult, + err_msg='Wrong poisson random number sequence generated when applied to image.') + + rng = galsim.BaseDeviate(testseed) + pn = galsim.PoissonNoise(rng, sky_level=mean) + testimage.fill(0) + testimage.addNoise(pn) + np.testing.assert_array_equal( + testimage.array.flatten(), np.array(testResult) - mean, + err_msg='Wrong poisson random number sequence generated using PoissonNoise') + + # Check PoissonNoise variance: + np.testing.assert_allclose( + pn.getVariance(), mean, rtol=1.e-8, + err_msg="PoissonNoise getVariance returns wrong variance") + np.testing.assert_allclose( + pn.sky_level, mean, rtol=1.e-8, + err_msg="PoissonNoise sky_level returns wrong value") + + # Check that the noise model really does produce this variance. + big_im = galsim.Image(2048, 2048, dtype=float) + big_im.addNoise(pn) + var = np.var(big_im.array) + print('variance = ', var) + print('getVar = ', pn.getVariance()) + np.testing.assert_allclose( + var, pn.getVariance(), rtol=rtol_var, + err_msg='Realized variance for PoissonNoise did not match getVariance()') @timer From bd703e2805cc5df6d62a8c0fef4c291793531b6f Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Mon, 13 Nov 2023 07:14:00 -0600 Subject: [PATCH 02/85] Update jax_galsim/random.py --- jax_galsim/random.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_galsim/random.py b/jax_galsim/random.py index 4d84c80f..6434a895 100644 --- a/jax_galsim/random.py +++ b/jax_galsim/random.py @@ -16,7 +16,7 @@ from jax_galsim.core.utils import ensure_hashable LAX_FUNCTIONAL_RNG = """\ -JAX-GalSim PRNGs have some support linking states, but it may not always and/or may cause issues. +JAX-GalSim PRNGs have some support for linking states, but it may not always work and/or may cause issues. - Linked states across JIT boundaries or devices will not work. - Within a single routine linking may work. From a2d0cfb6b82a847492c40918c122339e63cdb42d Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Mon, 13 Nov 2023 07:15:33 -0600 Subject: [PATCH 03/85] Update jax_galsim/random.py --- jax_galsim/random.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/jax_galsim/random.py b/jax_galsim/random.py index 6434a895..19ea7256 100644 --- a/jax_galsim/random.py +++ b/jax_galsim/random.py @@ -149,8 +149,7 @@ def clearCache(self): ), ) def discard(self, n, suppress_warnings=False): - self._key, subkeys = self.__class__._discard(self._key, n) - return subkeys + self._key = self.__class__._discard(self._key, n) @partial(jax.jit, static_argnums=(1,)) def _discard(key, n): From bcb1eca54ee55536f6f0cb43e902c051cd76bb71 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Mon, 13 Nov 2023 07:16:56 -0600 Subject: [PATCH 04/85] Apply suggestions from code review --- jax_galsim/random.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/jax_galsim/random.py b/jax_galsim/random.py index 19ea7256..d484ccae 100644 --- a/jax_galsim/random.py +++ b/jax_galsim/random.py @@ -151,13 +151,13 @@ def clearCache(self): def discard(self, n, suppress_warnings=False): self._key = self.__class__._discard(self._key, n) - @partial(jax.jit, static_argnums=(1,)) + jax.jit def _discard(key, n): - def _f(key, x): - key, sub_key = jrandom.split(key) - return key, sub_key + def __discard(i, key): + key, subkey = jrandom.split(key) + return key - return jax.lax.scan(_f, key, None, length=n) + return jax.lax.fori_loop(0, n, __discard, key) @_wraps( _galsim.BaseDeviate.raw, From 9c311ffdac705478cd4d76d9732cc7968d92e2ef Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Mon, 13 Nov 2023 07:17:15 -0600 Subject: [PATCH 05/85] Update jax_galsim/random.py --- jax_galsim/random.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_galsim/random.py b/jax_galsim/random.py index d484ccae..f3af980e 100644 --- a/jax_galsim/random.py +++ b/jax_galsim/random.py @@ -151,7 +151,7 @@ def clearCache(self): def discard(self, n, suppress_warnings=False): self._key = self.__class__._discard(self._key, n) - jax.jit + @jax.jit def _discard(key, n): def __discard(i, key): key, subkey = jrandom.split(key) From 2e61945e814955a4355b87873490e9561e8f0ac7 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 13 Nov 2023 07:18:10 -0600 Subject: [PATCH 06/85] STY blacken --- jax_galsim/photon_array.py | 12 +++++++----- jax_galsim/sensor.py | 6 +++--- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/jax_galsim/photon_array.py b/jax_galsim/photon_array.py index cc907199..ed4969cf 100644 --- a/jax_galsim/photon_array.py +++ b/jax_galsim/photon_array.py @@ -1,17 +1,18 @@ import galsim as _galsim import jax.numpy as jnp -from jax._src.numpy.util import _wraps import jax.random as jrng +from jax._src.numpy.util import _wraps from jax.tree_util import register_pytree_node_class from jax_galsim.core.utils import cast_to_python_float from jax_galsim.errors import ( - GalSimValueError, - GalSimUndefinedBoundsError, - GalSimRangeError, GalSimIncompatibleValuesError, + GalSimRangeError, + GalSimUndefinedBoundsError, + GalSimValueError, ) -from jax_galsim.random import UniformDeviate, BaseDeviate +from jax_galsim.random import BaseDeviate, UniformDeviate + from ._pyfits import pyfits @@ -560,6 +561,7 @@ def write(self, file_name): file_name: The file name of the output FITS file. """ import numpy as np + from jax_galsim import fits cols = [] diff --git a/jax_galsim/sensor.py b/jax_galsim/sensor.py index 1cec712e..8aa9f743 100644 --- a/jax_galsim/sensor.py +++ b/jax_galsim/sensor.py @@ -1,9 +1,9 @@ import galsim as _galsim -from .position import PositionI -from .errors import GalSimUndefinedBoundsError - from jax._src.numpy.util import _wraps +from .errors import GalSimUndefinedBoundsError +from .position import PositionI + @_wraps(_galsim.Sensor) class Sensor: From a791f136795fb86d15f9cc43884da148a3561e3d Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 14 Nov 2023 16:31:12 -0600 Subject: [PATCH 07/85] ENH finished photon shooting for interpolated images --- jax_galsim/__init__.py | 1 + jax_galsim/convolve.py | 11 +- jax_galsim/deltafunction.py | 86 +++++++++++++++ jax_galsim/gsobject.py | 15 ++- jax_galsim/interpolant.py | 123 +++++++++++++++++++++- jax_galsim/interpolatedimage.py | 17 ++- jax_galsim/photon_array.py | 5 +- tests/GalSim | 2 +- tests/galsim_tests_config.yaml | 3 +- tests/jax/test_interpolatedimage_utils.py | 58 ++++++++++ tests/jax/test_photon_array_jax_custom.py | 36 +++++++ 11 files changed, 346 insertions(+), 11 deletions(-) create mode 100644 jax_galsim/deltafunction.py create mode 100644 tests/jax/test_photon_array_jax_custom.py diff --git a/jax_galsim/__init__.py b/jax_galsim/__init__.py index f15cfcc2..94d0afc8 100644 --- a/jax_galsim/__init__.py +++ b/jax_galsim/__init__.py @@ -55,6 +55,7 @@ from .sum import Add, Sum from .transform import Transform, Transformation from .convolve import Convolve, Convolution, Deconvolution, Deconvolve +from .deltafunction import DeltaFunction # WCS from .wcs import ( diff --git a/jax_galsim/convolve.py b/jax_galsim/convolve.py index 3f34d1b1..6902168f 100644 --- a/jax_galsim/convolve.py +++ b/jax_galsim/convolve.py @@ -6,6 +6,7 @@ from jax_galsim.gsobject import GSObject from jax_galsim.gsparams import GSParams +from jax_galsim.photon_array import PhotonArray @_wraps( @@ -308,7 +309,15 @@ def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): raise NotImplementedError("Real-space convolutions are not implemented") def _shoot(self, photons, rng): - raise NotImplementedError("Photon shooting convolutions are not implemented") + self.obj_list[0]._shoot(photons, rng) + # It may be necessary to shuffle when convolving because we do not have a + # gaurantee that the convolvee's photons are uncorrelated, e.g., they might + # both have their negative ones at the end. + # However, this decision is now made by the convolve method. + for obj in self.obj_list[1:]: + p1 = PhotonArray(len(photons)) + obj._shoot(p1, rng) + photons.convolve(p1, rng) def _drawKImage(self, image, jac=None): image = self.obj_list[0]._drawKImage(image, jac) diff --git a/jax_galsim/deltafunction.py b/jax_galsim/deltafunction.py new file mode 100644 index 00000000..63265e28 --- /dev/null +++ b/jax_galsim/deltafunction.py @@ -0,0 +1,86 @@ +import galsim as _galsim +import jax +import jax.numpy as jnp +from jax._src.numpy.util import _wraps +from jax.tree_util import register_pytree_node_class + +from jax_galsim.core.draw import draw_by_kValue, draw_by_xValue +from jax_galsim.core.utils import ensure_hashable +from jax_galsim.gsobject import GSObject + + +@_wraps(_galsim.DeltaFunction) +@register_pytree_node_class +class DeltaFunction(GSObject): + _opt_params = {"flux": float} + + _mock_inf = ( + 1.0e300 # Some arbitrary very large number to use when we need infinity. + ) + + _has_hard_edges = False + _is_axisymmetric = True + _is_analytic_x = False + _is_analytic_k = True + + def __init__(self, flux=1.0, gsparams=None): + super().__init__(flux=flux, gsparams=gsparams) + + def __hash__(self): + return hash(("galsim.DeltaFunction", ensure_hashable(self.flux), self.gsparams)) + + def __repr__(self): + return "galsim.DeltaFunction(flux=%r, gsparams=%r)" % ( + ensure_hashable(self.flux), + self.gsparams, + ) + + def __str__(self): + s = "galsim.DeltaFunction(" + if self.flux != 1.0: + s += "flux=%s" % self.flux + s += ")" + return s + + @property + def _maxk(self): + return DeltaFunction._mock_inf + + @property + def _stepk(self): + return DeltaFunction._mock_inf + + @property + def _max_sb(self): + return DeltaFunction._mock_inf + + def _xValue(self, pos): + return jax.lax.cond( + jnp.array(pos.x == 0.0, dtype=bool) + & jnp.array(pos.y == 0.0, dtype=bool), + lambda *a: DeltaFunction._mock_inf, + lambda *a: 0.0, + ) + + def _kValue(self, kpos): + # this is a wasteful and fancy way to get the shape to broadcast to + # to match the input kpos + return self.flux + kpos.x * (0.0 + 0.0j) + + def _shoot(self, photons, rng): + flux_per_photon = self.flux / photons.size() + photons.x = 0.0 + photons.y = 0.0 + photons.flux = flux_per_photon + + def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): + _jac = jnp.eye(2) if jac is None else jac + return draw_by_xValue(self, image, _jac, jnp.asarray(offset), flux_scaling) + + def _drawKImage(self, image, jac=None): + _jac = jnp.eye(2) if jac is None else jac + return draw_by_kValue(self, image, _jac) + + @_wraps(_galsim.DeltaFunction.withFlux) + def withFlux(self, flux): + return DeltaFunction(flux=flux, gsparams=self.gsparams) diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index fd8285e0..9b3aa795 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -675,7 +675,7 @@ def drawImage( setup_only=False, ): from jax_galsim.box import Pixel - from jax_galsim.convolve import Convolve + from jax_galsim.convolve import Convolution, Convolve from jax_galsim.image import Image from jax_galsim.wcs import PixelScale @@ -687,6 +687,19 @@ def drawImage( "Setting maxN is incompatible with save_photons=True" ) + # Check that the user isn't convolving by a Pixel already. This is almost always an error. + if method == "auto" and isinstance(self, Convolution): + if any([isinstance(obj, Pixel) for obj in self.obj_list]): + galsim_warn( + "You called drawImage with ``method='auto'`` " + "for an object that includes convolution by a Pixel. " + "This is probably an error. Normally, you should let GalSim " + "handle the Pixel convolution for you. If you want to handle the Pixel " + "convolution yourself, you can use method=no_pixel. Or if you really meant " + "for your profile to include the Pixel and also have GalSim convolve by " + "an _additional_ Pixel, you can suppress this warning by using method=fft." + ) + # Figure out what wcs we are going to use. wcs = self._determine_wcs(scale, wcs, image) diff --git a/jax_galsim/interpolant.py b/jax_galsim/interpolant.py index 9ed587e6..03a2ab98 100644 --- a/jax_galsim/interpolant.py +++ b/jax_galsim/interpolant.py @@ -9,11 +9,57 @@ import jax.numpy as jnp from galsim.errors import GalSimValueError from jax._src.numpy.util import _wraps +from jax.tree_util import Partial as jax_partial from jax.tree_util import register_pytree_node_class from jax_galsim.bessel import si from jax_galsim.core.utils import is_equal_with_arrays +from jax_galsim.errors import GalSimError from jax_galsim.gsparams import GSParams +from jax_galsim.random import UniformDeviate + + +@jax.jit +def _rejection_sample(photons, rng, tot_xrange, xval, pos_flux, neg_flux, max_val): + def _cond_fun(args): + _, _, tot, _, curr = args + return curr < tot + + def _body_fun(args): + arr, sign, tot, ud, curr = args + xloc = (ud() - 0.5) * tot_xrange + yv = ud() * max_val + xloc_val = xval(xloc) + arr, sign, curr = jax.lax.cond( + yv <= jnp.abs(xloc_val), + lambda arr, sign, curr, xloc, xloc_val: ( + arr.at[curr].set(xloc), + sign.at[curr].set(jnp.sign(xloc_val)), + curr + 1, + ), + lambda arr, sign, curr, xloc, xloc_val: (arr, sign, curr), + arr, + sign, + curr, + xloc, + xloc_val, + ) + return arr, sign, tot, ud, curr + + ud = UniformDeviate(rng) + photons.x, _sign_x, _, ud, _ = jax.lax.while_loop( + _cond_fun, + _body_fun, + (jnp.zeros_like(photons.x), jnp.zeros_like(photons.x), photons.size(), ud, 0), + ) + photons.y, _sign_y, _, ud, _ = jax.lax.while_loop( + _cond_fun, + _body_fun, + (jnp.zeros_like(photons.y), jnp.zeros_like(photons.y), photons.size(), ud, 0), + ) + flux_per = (pos_flux + neg_flux) ** 2 / photons.size() + photons.flux = _sign_x * _sign_y * flux_per + return photons, rng @_wraps(_galsim.interpolant.Interpolant) @@ -255,6 +301,11 @@ def urange(self): % self.__class__.__name__ ) + def _shoot(self, photons, rng): + raise NotImplementedError( + "%s does not implement shoot" % self.__class__.__name__ + ) + # subclasses should implement __init__, _xval, _uval, # _unit_integrals, _positive_flux, _negative_flux, urange, and xrange @@ -303,6 +354,11 @@ def ixrange(self): """The total integral range of the interpolant. Typically 2 * xrange.""" return 0 + def _shoot(self, photons, rng): + photons.x = 0.0 + photons.y = 0.0 + photons.flux = 1.0 / photons.size() + @_wraps(_galsim.interpolant.Nearest) @register_pytree_node_class @@ -341,6 +397,12 @@ def ixrange(self): """The total integral range of the interpolant. Typically 2 * xrange.""" return 1 + def _shoot(self, photons, rng): + ud = UniformDeviate(rng) + photons.x = ud.generate(photons.x) - 0.5 + photons.y = ud.generate(photons.y) - 0.5 + photons.flux = 1.0 / photons.size() + @_wraps(_galsim.interpolant.SincInterpolant) @register_pytree_node_class @@ -408,6 +470,12 @@ def _comp_fluxes(self): self._positive_flux = jax.lax.stop_gradient(jnp.sum(val[val > 0])).item() * 2.0 self._negative_flux = self._positive_flux - 1.0 + def _shoot(self, photons, rng): + raise GalSimError( + "%s does not implement shoot since the " + "kernel is not compact in real-space." % self.__class__.__name__ + ) + @_wraps(_galsim.interpolant.Linear) @register_pytree_node_class @@ -454,6 +522,12 @@ def ixrange(self): """The total integral range of the interpolant. Typically 2 * xrange.""" return 2 + def _shoot(self, photons, rng): + ud = UniformDeviate(rng) + photons.x = ud.generate(photons.x) + ud.generate(photons.x) - 1.0 + photons.y = ud.generate(photons.y) + ud.generate(photons.y) - 1.0 + photons.flux = 1.0 / photons.size() + @_wraps(_galsim.interpolant.Cubic) @register_pytree_node_class @@ -519,6 +593,21 @@ def ixrange(self): """The total integral range of the interpolant. Typically 2 * xrange.""" return 4 + def _shoot(self, photons, rng): + _photons, _rng = _rejection_sample( + photons, + rng, + self.xrange * 2.0, + jax_partial(self.__class__._xval), + self.positive_flux, + self.negative_flux, + self._xval_noraise(0.0), + ) + photons.x = _photons.x + photons.y = _photons.y + photons.flux = _photons.flux + rng._state = _rng._state + @_wraps(_galsim.interpolant.Quintic) @register_pytree_node_class @@ -614,6 +703,21 @@ def ixrange(self): """The total integral range of the interpolant. Typically 2 * xrange.""" return 6 + def _shoot(self, photons, rng): + _photons, _rng = _rejection_sample( + photons, + rng, + self.xrange * 2.0, + jax_partial(self.__class__._xval), + self.positive_flux, + self.negative_flux, + self._xval_noraise(0.0), + ) + photons.x = _photons.x + photons.y = _photons.y + photons.flux = _photons.flux + rng._state = _rng._state + @_wraps(_galsim.interpolant.Lanczos) @register_pytree_node_class @@ -1398,7 +1502,7 @@ def __str__(self): # this is a pure function and we apply JIT ahead of time since this # one is pretty slow @jax.jit - def _xval(x, n, conserve_dc, _K): + def _xval(n, conserve_dc, _K, x): x = jnp.abs(x) def _low(x, n): @@ -1471,7 +1575,7 @@ def _no_dcval(val, x, n, _K): ) def _xval_noraise(self, x): - return Lanczos._xval(x, self._n, self._conserve_dc, self._K_arr) + return Lanczos._xval(self._n, self._conserve_dc, self._K_arr, x) def _raw_uval(u, n): # this function is used in the init and so was causing a recursion depth error @@ -1590,6 +1694,21 @@ def unit_integrals(self, max_len=None): else: return self._unit_integrals_no_conserve_dc[self._n][:n] + def _shoot(self, photons, rng): + _photons, _rng = _rejection_sample( + photons, + rng, + self.xrange * 2.0, + jax_partial(self.__class__._xval, self._n, self._conserve_dc, self._K_arr), + self.positive_flux, + self.negative_flux, + self._xval_noraise(0.0), + ) + photons.x = _photons.x + photons.y = _photons.y + photons.flux = _photons.flux + rng._state = _rng._state + # we apply JIT here to esnure the class init is fast @jax.jit diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index eeafb782..f5f34ebe 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -24,6 +24,7 @@ from jax_galsim.gsparams import GSParams from jax_galsim.image import Image from jax_galsim.interpolant import Quintic +from jax_galsim.photon_array import PhotonArray from jax_galsim.position import PositionD from jax_galsim.random import UniformDeviate from jax_galsim.transform import Transformation @@ -877,10 +878,22 @@ def _shoot(self, photons, rng): ud = UniformDeviate(rng) photons.x = ud.generate(photons.x) + xedges[xinds] photons.y = ud.generate(photons.y) + yedges[yinds] - photons.flux = jnp.sign(img.array.ravel())[inds] * self._flux_per_photon() + photons.flux = ( + jnp.sign(img.array.ravel())[inds] + * self._flux_per_photon() + * (self.positive_flux + self.negative_flux) + / photons.size() + ) + + # accounnt for offset - we add the offset to get to + # image pixels in xValue so we need to subtract it here + photons.x -= self._offset.x + photons.y -= self._offset.y # now we convolve with the x interpolant - raise NotImplementedError("InterpolatedImages do not support photon shooting!") + x_photons = PhotonArray(photons.size()) + self._x_interpolant._shoot(x_photons, rng) + photons.convolve(x_photons) @_wraps(_galsim._InterpolatedImage) diff --git a/jax_galsim/photon_array.py b/jax_galsim/photon_array.py index ed4969cf..3bd9f1d2 100644 --- a/jax_galsim/photon_array.py +++ b/jax_galsim/photon_array.py @@ -492,10 +492,12 @@ def addTo(self, image): raise GalSimUndefinedBoundsError( "Attempting to PhotonArray::addTo an Image with undefined Bounds" ) + # the numpy histogram function histograms x along the first dimension and y along the + # along the second dimension. We need the opposite so we swap the inputs xbins = jnp.arange(image.bounds.xmin, image.bounds.xmax + 2) - 0.5 ybins = jnp.arange(image.bounds.ymin, image.bounds.ymax + 2) - 0.5 im = jnp.histogram2d( - self._x, self._y, bins=(xbins, ybins), weights=self._flux, density=False + self._y, self._x, bins=(ybins, xbins), weights=self._flux, density=False )[0] image._array += im return im.sum() @@ -524,7 +526,6 @@ def makeFromImage(cls, image, max_flux=1.0, rng=None): n_per = jnp.clip(jnp.ceil(jnp.abs(image.array) / max_flux), 1).astype(int) flux_per = (image.array / n_per).ravel() n_per = n_per.ravel() - flux_per = flux_per.ravel() inds = jnp.arange(image.array.size) inds = jnp.repeat(inds, n_per) yinds, xinds = jnp.unravel_index(inds, image.array.shape) diff --git a/tests/GalSim b/tests/GalSim index 1ed5131a..3e308a21 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 1ed5131a54b4dbee384fee6b82b5e2e478ef0492 +Subproject commit 3e308a2194f8a3d08e811046634d5f115fc54356 diff --git a/tests/galsim_tests_config.yaml b/tests/galsim_tests_config.yaml index 9d68478c..88636ff2 100644 --- a/tests/galsim_tests_config.yaml +++ b/tests/galsim_tests_config.yaml @@ -14,6 +14,7 @@ enabled_tests: - test_wcs.py - test_box.py - test_interpolatedimage.py + - test_deltafunction.py coord: - test_angle.py - test_angleunit.py @@ -25,9 +26,7 @@ enabled_tests: # correspond to features that are not implemented yet # in jax_galsim allowed_failures: - - "module 'jax_galsim' has no attribute 'DeltaFunction'" - "Real-space convolutions are not implemented" - - "Photon shooting convolutions are not implemented" - "module 'jax_galsim' has no attribute 'Airy'" - "module 'jax_galsim' has no attribute 'Kolmogorov'" - "module 'jax_galsim' has no attribute 'Sersic'" diff --git a/tests/jax/test_interpolatedimage_utils.py b/tests/jax/test_interpolatedimage_utils.py index c977eb46..6c0200f0 100644 --- a/tests/jax/test_interpolatedimage_utils.py +++ b/tests/jax/test_interpolatedimage_utils.py @@ -315,3 +315,61 @@ def test_interpolatedimage_utils_jax_galsim_fft_vs_galsim_fft(n): np.testing.assert_allclose(gim.array, im.array) np.testing.assert_allclose(gkim.array, kim.array) np.testing.assert_allclose(gxkim.array, xkim.array) + + +@pytest.mark.parametrize( + "interp", + [ + Nearest(), + Linear(), + Cubic(), + Quintic(), + Lanczos(3, conserve_dc=True), + Lanczos(3, conserve_dc=False), + Lanczos(7), + ], +) +def test_interpolatedimage_interpolant_rejection_sample(interp): + from jax.tree_util import Partial as jax_partial + + from jax_galsim.interpolant import _rejection_sample + from jax_galsim.photon_array import PhotonArray + from jax_galsim.random import BaseDeviate + + rng = BaseDeviate(1234) + + ntot = 1000000 + photons = PhotonArray(ntot) + photons, _ = _rejection_sample( + photons, + rng, + interp.xrange * 2.0, + jax_partial(interp._xval_noraise), + interp.positive_flux, + interp.negative_flux, + interp._xval_noraise(0.0), + ) + + h, bins = jnp.histogram(photons.x, bins=500) + mid = (bins[1:] + bins[:-1]) / 2.0 + dx = bins[1:] - bins[:-1] + yv = ( + jnp.abs(interp._xval_noraise(mid)) + * dx + * ntot + * 1.0 + / (interp.positive_flux + interp.negative_flux) + ) + msk = yv > 100 + fdev = np.abs(h - yv) / np.abs(np.sqrt(yv)) + np.testing.assert_allclose(fdev[msk], 0, rtol=0, atol=5.0, err_msg=f"{interp}") + np.testing.assert_allclose(fdev[~msk], 0, rtol=0, atol=15.0, err_msg=f"{interp}") + + if interp.__class__.__name__ == "Quintic" and False: + import proplot as pplt + + fig, axs = pplt.subplots(figsize=(4, 4)) + axs.hist(photons.x, bins=500, log=False) + axs.plot(mid, yv, color="k") + fig.show() + breakpoint() diff --git a/tests/jax/test_photon_array_jax_custom.py b/tests/jax/test_photon_array_jax_custom.py new file mode 100644 index 00000000..fd9a7846 --- /dev/null +++ b/tests/jax/test_photon_array_jax_custom.py @@ -0,0 +1,36 @@ +import jax_galsim +import numpy as np + + +def test_photon_array_make_from_image_notranspose(): + # this test uses a very assymetric array to ensure there is not a transpose + # error in the code + ref_array = np.array( + [ + [0.01, 0.08, 0.07, 0.02], + [0.13, 0.38, 10.52, 0.06], + [0.09, 0.41, 0.44, 0.09], + [0.04, 0.11, 0.10, 0.01], + [0.04, 0.11, 0.10, 0.01], + ] + ) + image = jax_galsim.Image(ref_array) + + photons = jax_galsim.PhotonArray.makeFromImage(image, max_flux=0.1) + + image2 = jax_galsim.Image(np.zeros_like(ref_array)) + photons.addTo(image2) + + if not np.allclose(image2.array, ref_array) and False: + import proplot as pplt + + fig, axs = pplt.subplots(nrows=1, ncols=3) + axs[0].imshow(ref_array) + axs[1].imshow(image2.array) + axs[2].imshow(image2.array - ref_array) + + import pdb + + pdb.set_trace() + + np.testing.assert_allclose(image2.array, ref_array) From c51c14447db46fdb08040763c6ff8ab61195d903 Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 14 Nov 2023 16:33:43 -0600 Subject: [PATCH 08/85] STY blacken --- jax_galsim/deltafunction.py | 3 +-- tests/jax/test_photon_array_jax_custom.py | 3 ++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/jax_galsim/deltafunction.py b/jax_galsim/deltafunction.py index 63265e28..2401e47c 100644 --- a/jax_galsim/deltafunction.py +++ b/jax_galsim/deltafunction.py @@ -56,8 +56,7 @@ def _max_sb(self): def _xValue(self, pos): return jax.lax.cond( - jnp.array(pos.x == 0.0, dtype=bool) - & jnp.array(pos.y == 0.0, dtype=bool), + jnp.array(pos.x == 0.0, dtype=bool) & jnp.array(pos.y == 0.0, dtype=bool), lambda *a: DeltaFunction._mock_inf, lambda *a: 0.0, ) diff --git a/tests/jax/test_photon_array_jax_custom.py b/tests/jax/test_photon_array_jax_custom.py index fd9a7846..05281f02 100644 --- a/tests/jax/test_photon_array_jax_custom.py +++ b/tests/jax/test_photon_array_jax_custom.py @@ -1,6 +1,7 @@ -import jax_galsim import numpy as np +import jax_galsim + def test_photon_array_make_from_image_notranspose(): # this test uses a very assymetric array to ensure there is not a transpose From dfbf59fc0a659f3f8829f258f668c5b51393efbd Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 15 Nov 2023 13:36:35 -0600 Subject: [PATCH 09/85] tests of jit --- jax_galsim/core/testing.py | 31 +++++++++++++ tests/jax/test_jitting.py | 95 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 126 insertions(+) create mode 100644 jax_galsim/core/testing.py diff --git a/jax_galsim/core/testing.py b/jax_galsim/core/testing.py new file mode 100644 index 00000000..7935487f --- /dev/null +++ b/jax_galsim/core/testing.py @@ -0,0 +1,31 @@ +from contextlib import contextmanager +from time import perf_counter_ns + + +class TimingResult: + def __init__(self): + self.dt = None + + def __str__(self): + if self.dt is None: + return "- ms" + else: + if self.dt > 10000: + return f"{self.dt/1000} s" + else: + return f"{self.dt} ms" + + +@contextmanager +def time_code_block(msg=None, quiet=False): + tr = TimingResult() + t0 = perf_counter_ns() + yield tr + t1 = perf_counter_ns() + tr.dt = (t1 - t0) / 1e6 + if not quiet: + if msg is not None: + msg = msg + " " + else: + msg = "" + print(f"{msg}time: {tr.dt} ms") diff --git a/tests/jax/test_jitting.py b/tests/jax/test_jitting.py index d4a64c87..eb1f696b 100644 --- a/tests/jax/test_jitting.py +++ b/tests/jax/test_jitting.py @@ -1,7 +1,10 @@ +from functools import partial import jax import jax.numpy as jnp +import numpy as np import jax_galsim as galsim +from jax_galsim.core.testing import time_code_block # Defining jitting identity identity = jax.jit(lambda x: x) @@ -168,3 +171,95 @@ def test_eq(self, other): assert test_eq(identity(g), g) assert test_eq(identity(e), e) + + +def test_jitting_draw_fft(): + def _build_and_draw(hlr, fwhm, jit=True): + gal = galsim.Exponential(half_light_radius=hlr, flux=1000.0) + psf = galsim.Gaussian(fwhm=fwhm, flux=1.0) + final = galsim.Convolve( + [gal, psf], + ) + n = final.getGoodImageSize(0.2).item() + n += 1 + nfft = galsim.Image.good_fft_size(4 * n) + if jit: + return _draw_it_jit(final, n, nfft) + else: + final = final.withGSParams( + minimum_fft_size=128, + maximum_fft_size=128, + ) + return final.drawImage( + nx=n, + ny=n, + scale=0.2, + ) + + @partial(jax.jit, static_argnums=(1, 2)) + def _draw_it_jit(obj, n, nfft): + obj = obj.withGSParams( + minimum_fft_size=128, + maximum_fft_size=128, + ) + return obj.drawImage( + nx=n, + ny=n, + scale=0.2, + ) + + with time_code_block("warmup no-jit"): + img = _build_and_draw(0.5, 1.0, jit=False) + with time_code_block("no-jit"): + img = _build_and_draw(0.5, 1.0, jit=False) + + with time_code_block("warmup jit"): + img = _build_and_draw(0.5, 1.0) + with time_code_block("jit"): + img = _build_and_draw(0.5, 1.0) + + np.testing.assert_array_almost_equal(img.array.sum(), 1000.0, 0) + + +def test_jitting_draw_phot(): + def _build_and_draw(hlr, fwhm, jit=True): + gal = galsim.Exponential(half_light_radius=hlr, flux=1000.0) + psf = galsim.Gaussian(fwhm=fwhm, flux=1.0) + final = galsim.Convolve( + [gal, psf], + ) + n = final.getGoodImageSize(0.2).item() + n += 1 + n_photons = final._calculate_nphotons(0, False, 0, None)[0] + if jit: + return _draw_it_jit(final, n, n_photons) + else: + return final.drawImage( + nx=n, + ny=n, + scale=0.2, + method="phot", + n_photons=n_photons, + ) + + @partial(jax.jit, static_argnums=(1, 2)) + def _draw_it_jit(obj, n, nphotons): + return obj.drawImage( + nx=n, + ny=n, + scale=0.2, + n_photons=nphotons, + method="phot", + ) + + with time_code_block("warmup no-jit"): + img = _build_and_draw(0.5, 1.0, jit=False) + with time_code_block("no-jit"): + img = _build_and_draw(0.5, 1.0, jit=False) + + with time_code_block("warmup jit"): + img = _build_and_draw(0.5, 1.0) + with time_code_block("jit"): + img = _build_and_draw(0.5, 1.0) + + np.testing.assert_array_almost_equal(img.array.sum(), 1000.0, 3) From 36ccd9f8d0bdf67cd50678502eb757bdfa5e2a6e Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 15 Nov 2023 13:41:02 -0600 Subject: [PATCH 10/85] STY please the flake8 --- tests/jax/test_jitting.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/jax/test_jitting.py b/tests/jax/test_jitting.py index eb1f696b..dfd36a60 100644 --- a/tests/jax/test_jitting.py +++ b/tests/jax/test_jitting.py @@ -1,4 +1,5 @@ from functools import partial + import jax import jax.numpy as jnp import numpy as np From 289001b551be93ad30c22ba0e9ce61b89834b957 Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 16 Nov 2023 20:56:56 -0600 Subject: [PATCH 11/85] TST jit w/ photon shooting --- jax_galsim/core/draw.py | 324 ++++++++++++++++++++++++++++++- jax_galsim/gsobject.py | 379 ++++++++++++++++++++++--------------- jax_galsim/photon_array.py | 63 +++--- jax_galsim/sensor.py | 11 ++ jax_galsim/transform.py | 6 +- tests/jax/test_jitting.py | 20 +- 6 files changed, 622 insertions(+), 181 deletions(-) diff --git a/jax_galsim/core/draw.py b/jax_galsim/core/draw.py index a8edfe51..00696da7 100644 --- a/jax_galsim/core/draw.py +++ b/jax_galsim/core/draw.py @@ -1,13 +1,17 @@ +import galsim as _galsim import jax import jax.numpy as jnp +from jax._src.numpy.util import _wraps -from jax_galsim import Image, PositionD +from jax_galsim.random import PoissonDeviate def draw_by_xValue( gsobject, image, jacobian=jnp.eye(2), offset=jnp.zeros(2), flux_scaling=1.0 ): """Utility function to draw a real-space GSObject into an Image.""" + from jax_galsim import Image, PositionD + # Applies flux scaling to compensate for pixel scale # See SBProfile.draw() flux_scaling *= image.scale**2 @@ -36,6 +40,8 @@ def draw_by_xValue( def draw_by_kValue(gsobject, image, jacobian=jnp.eye(2)): + from jax_galsim import Image, PositionD + # Create an array of coordinates coords = jnp.stack(image.get_pixel_centers(), axis=-1) coords = coords * image.scale # Scale by the image pixel scale @@ -52,6 +58,8 @@ def draw_by_kValue(gsobject, image, jacobian=jnp.eye(2)): def apply_kImage_phases(offset, image, jacobian=jnp.eye(2)): + from jax_galsim import Image, PositionD + # Create an array of coordinates kcoords = jnp.stack(image.get_pixel_centers(), axis=-1) kcoords = kcoords * image.scale # Scale by the image pixel scale @@ -74,3 +82,317 @@ def phase(kpos): wcs=image.wcs, check_bounds=False, ) + + +def sample_poisson_flux(flux, eta_factor, rng=None): + """Sample the flux according to a Poisson distribution. + + Parameters: + flux: The flux of the GSObject (e.g., ``obj.flux``). + eta_factor: The flux per photon (e.g., ``obj._flux_per_photon``). + rng: If provided, a random number generator to use for photon shooting, + which may be any kind of `BaseDeviate` object. If ``rng`` is None, one + will be automatically created, using the time as a seed. + [default: None] + """ + # If we have both positive and negative photons, then the mix of these + # already gives us some variation in the flux value from the variance + # of how many are positive and how many are negative. + # The number of negative photons varies as a binomial distribution. + # = eta * Ntot * g + # = (1-eta) * Ntot * g + # = (1-2eta) * Ntot * g = flux + # Var(F-) = eta * (1-eta) * Ntot * g^2 + # F+ = Ntot * g - F- is not an independent variable, so + # Var(F+ - F-) = Var(Ntot*g - 2*F-) + # = 4 * Var(F-) + # = 4 * eta * (1-eta) * Ntot * g^2 + # = 4 * eta * (1-eta) * flux + # We want the variance to be equal to flux, so we need an extra: + # delta Var = (1 - 4*eta + 4*eta^2) * flux + # = (1-2eta)^2 * flux + absflux = abs(flux) + mean = eta_factor * eta_factor * absflux + pd = PoissonDeviate(rng, mean) + pd_val = pd() - mean + absflux + return pd_val + + +@_wraps( + _galsim.GSObject._calculate_nphotons, + lax_description="""\ +Calculate the number of photons to shoot for photon shooting. + +This routine is pure Python and is not JAX-compatible. + +Parameters: + flux: The flux of the GSObject (e.g., ``obj.flux``). + eta_factor: The flux per photon (e.g., ``obj._flux_per_photon``). + max_sb: The maximum surface brightness of the object (e.g., ``obj.max_sb``). + rng: If provided, a random number generator to use for photon shooting, + which may be any kind of `BaseDeviate` object. If ``rng`` is None, one + will be automatically created, using the time as a seed. + [default: None] + max_extra_noise: If provided, the allowed extra noise in each pixel when photon + shooting. This is only relevant if ``n_photons=0``, so the number of + photons is being automatically calculated. In that case, if the image + noise is dominated by the sky background, then you can get away with + using fewer shot photons than the full ``n_photons = flux``. + Essentially each shot photon can have a ``flux > 1``, which increases + the noise in each pixel. The ``max_extra_noise`` parameter specifies + how much extra noise per pixel is allowed because of this approximation. + A typical value for this might be ``max_extra_noise = sky_level / 100`` + where ``sky_level`` is the flux per pixel due to the sky. Note that + this uses a "variance" definition of noise, not a "sigma" definition. + [default: 0.] + poisson_flux: Whether to allow total object flux scaling to vary according to + Poisson statistics for ``n_photons`` samples when photon shooting. + [default: True, unless ``n_photons`` is given, in which case the default + is False] + +""", +) +def calculate_n_photons( + flux, + eta_factor, + max_sb, + rng=None, + max_extra_noise=0, + poisson_flux=True, +): + # For profiles that are positive definite, then N = flux. Easy. + # + # However, some profiles shoot some of their photons with negative flux. This means that + # we need a few more photons to get the right S/N = sqrt(flux). Take eta to be the + # fraction of shot photons that have negative flux. + # + # S^2 = (N+ - N-)^2 = (N+ + N- - 2N-)^2 = (Ntot - 2N-)^2 = Ntot^2(1 - 2 eta)^2 + # N^2 = Var(S) = (N+ + N-) = Ntot + # + # So flux = (S/N)^2 = Ntot (1-2eta)^2 + # Ntot = flux / (1-2eta)^2 + # + # However, if each photon has a flux of 1, then S = (1-2eta) Ntot = flux / (1-2eta). + # So in fact, each photon needs to carry a flux of g = 1-2eta to get the right + # total flux. + # + # That's all the easy case. The trickier case is when we are sky-background dominated. + # Then we can usually get away with fewer shot photons than the above. In particular, + # if the noise from the photon shooting is much less than the sky noise, then we can + # use fewer shot photons and essentially have each photon have a flux > 1. This is ok + # as long as the additional noise due to this approximation is "much less than" the + # noise we'll be adding to the image for the sky noise. + # + # Let's still have Ntot photons, but now each with a flux of g. And let's look at the + # noise we get in the brightest pixel that has a nominal total flux of Imax. + # + # The number of photons hitting this pixel will be Imax/flux * Ntot. + # The variance of this number is the same thing (Poisson counting). + # So the noise in that pixel is: + # + # N^2 = Imax/flux * Ntot * g^2 + # + # And the signal in that pixel will be: + # + # S = Imax/flux * (N+ - N-) * g which has to equal Imax, so + # g = flux / Ntot(1-2eta) + # N^2 = Imax/Ntot * flux / (1-2eta)^2 + # + # As expected, we see that lowering Ntot will increase the noise in that (and every + # other) pixel. + # The input max_extra_noise parameter is the maximum value of spurious noise we want + # to allow. + # + # So setting N^2 = Imax + nu, we get + # + # Ntot = flux / (1-2eta)^2 / (1 + nu/Imax) + # g = (1 - 2eta) * (1 + nu/Imax) + # + # Returns the total flux placed inside the image bounds by photon shooting. + # + + if flux == 0.0: + return 0, 1.0 + + # The _flux_per_photon property is (1-2eta) + # This factor will already be accounted for by the shoot function, so don't include + # that as part of our scaling here. There may be other adjustments though, so g=1 here. + mod_flux = flux / (eta_factor * eta_factor) + g = 1.0 + + # If requested, let the target flux value vary as a Poisson deviate + if poisson_flux: + pd_val = sample_poisson_flux(flux, eta_factor, rng=rng) + ratio = pd_val / abs(flux) + g *= ratio + mod_flux *= ratio + + n_photons = abs(mod_flux) + if max_extra_noise > 0.0: + gfactor = 1.0 + max_extra_noise / abs(max_sb) + n_photons /= gfactor + g *= gfactor + + # Make n_photons an integer. + iN = int(n_photons + 0.5) + + return iN, g + + +# the code below is a jax version of calculate_nphotons +# that I am not sure if we need or not. +# saving in a comment for now + +# def _calculate_nphotons(self, n_photons, poisson_flux, max_extra_noise, rng): +# _n_photons, _g, _rng = jax.lax.cond( +# self.flux == 0.0, +# lambda n_photons, poisson_flux, max_extra_noise, rng: (0, 1.0, rng), +# lambda n_photons, poisson_flux, max_extra_noise, rng: self._calculate_nphotons_nonzero( +# n_photons, poisson_flux, max_extra_noise, rng +# ), +# n_photons, +# poisson_flux, +# max_extra_noise, +# rng, +# ) +# if rng is not None: +# rng._state = _rng._state +# return _n_photons, _g + + +# def _adjust_flux_g_poisson(self, poisson_flux, flux, mod_flux, eta_factor, rng, g): +# from jax_galsim.random import PoissonDeviate + +# # If we have both positive and negative photons, then the mix of these +# # already gives us some variation in the flux value from the variance +# # of how many are positive and how many are negative. +# # The number of negative photons varies as a binomial distribution. +# # = eta * Ntot * g +# # = (1-eta) * Ntot * g +# # = (1-2eta) * Ntot * g = flux +# # Var(F-) = eta * (1-eta) * Ntot * g^2 +# # F+ = Ntot * g - F- is not an independent variable, so +# # Var(F+ - F-) = Var(Ntot*g - 2*F-) +# # = 4 * Var(F-) +# # = 4 * eta * (1-eta) * Ntot * g^2 +# # = 4 * eta * (1-eta) * flux +# # We want the variance to be equal to flux, so we need an extra: +# # delta Var = (1 - 4*eta + 4*eta^2) * flux +# # = (1-2eta)^2 * flux +# absflux = abs(flux) +# mean = eta_factor * eta_factor * absflux +# pd = PoissonDeviate(rng, mean) +# pd_val = pd() - mean + absflux +# ratio = pd_val / absflux +# g *= ratio +# mod_flux *= ratio +# return jnp.abs(mod_flux), g, rng + + +# def _scale_extra_noise(self, max_extra_noise, mod_flux, g, max_sb): +# gfactor = 1.0 + max_extra_noise / jnp.abs(max_sb) +# mod_flux /= gfactor +# g *= gfactor +# return mod_flux, g + + +# def _calculate_nphotons_nonzero(self, n_photons, poisson_flux, max_extra_noise, rng): +# # For profiles that are positive definite, then N = flux. Easy. +# # +# # However, some profiles shoot some of their photons with negative flux. This means that +# # we need a few more photons to get the right S/N = sqrt(flux). Take eta to be the +# # fraction of shot photons that have negative flux. +# # +# # S^2 = (N+ - N-)^2 = (N+ + N- - 2N-)^2 = (Ntot - 2N-)^2 = Ntot^2(1 - 2 eta)^2 +# # N^2 = Var(S) = (N+ + N-) = Ntot +# # +# # So flux = (S/N)^2 = Ntot (1-2eta)^2 +# # Ntot = flux / (1-2eta)^2 +# # +# # However, if each photon has a flux of 1, then S = (1-2eta) Ntot = flux / (1-2eta). +# # So in fact, each photon needs to carry a flux of g = 1-2eta to get the right +# # total flux. +# # +# # That's all the easy case. The trickier case is when we are sky-background dominated. +# # Then we can usually get away with fewer shot photons than the above. In particular, +# # if the noise from the photon shooting is much less than the sky noise, then we can +# # use fewer shot photons and essentially have each photon have a flux > 1. This is ok +# # as long as the additional noise due to this approximation is "much less than" the +# # noise we'll be adding to the image for the sky noise. +# # +# # Let's still have Ntot photons, but now each with a flux of g. And let's look at the +# # noise we get in the brightest pixel that has a nominal total flux of Imax. +# # +# # The number of photons hitting this pixel will be Imax/flux * Ntot. +# # The variance of this number is the same thing (Poisson counting). +# # So the noise in that pixel is: +# # +# # N^2 = Imax/flux * Ntot * g^2 +# # +# # And the signal in that pixel will be: +# # +# # S = Imax/flux * (N+ - N-) * g which has to equal Imax, so +# # g = flux / Ntot(1-2eta) +# # N^2 = Imax/Ntot * flux / (1-2eta)^2 +# # +# # As expected, we see that lowering Ntot will increase the noise in that (and every +# # other) pixel. +# # The input max_extra_noise parameter is the maximum value of spurious noise we want +# # to allow. +# # +# # So setting N^2 = Imax + nu, we get +# # +# # Ntot = flux / (1-2eta)^2 / (1 + nu/Imax) +# # g = (1 - 2eta) * (1 + nu/Imax) +# # +# # Returns the total flux placed inside the image bounds by photon shooting. +# # + +# flux = self.flux + +# # The _flux_per_photon property is (1-2eta) +# # This factor will already be accounted for by the shoot function, so don't include +# # that as part of our scaling here. There may be other adjustments though, so g=1 here. +# eta_factor = self._flux_per_photon +# mod_flux = flux / (eta_factor * eta_factor) +# g = 1.0 + +# # If requested, let the target flux value vary as a Poisson deviate +# mod_flux, g, _rng = jax.lax.cond( +# poisson_flux, +# lambda poisson_flux, flux, mod_flux, eta_factor, rng, g: self._adjust_flux_g_poisson( +# poisson_flux, flux, mod_flux, eta_factor, rng, g +# ), +# lambda poisson_flux, flux, mod_flux, eta_factor, rng, g: (mod_flux, g, rng), +# poisson_flux, +# flux, +# mod_flux, +# eta_factor, +# rng, +# g, +# ) +# if rng is not None: +# rng._state = _rng._state + +# mod_flux, g = jax.lax.cond( +# max_extra_noise > 0.0, +# lambda max_extra_noise, mod_flux, g, max_sb: self._scale_extra_noise( +# max_extra_noise, mod_flux, g, max_sb +# ), +# lambda max_extra_noise, mod_flux, g, max_sb: (mod_flux, g), +# max_extra_noise, +# mod_flux, +# g, +# self.max_sb, +# ) + +# # Make n_photons an integer and use input if requested +# n_photons = jax.lax.cond( +# n_photons == 0.0, +# lambda n_photons, mod_flux: jnp.ceil(mod_flux).astype(int), +# lambda n_photons, mod_flux: jnp.ceil(n_photons).astype(int), +# n_photons, +# mod_flux, +# ) + +# return n_photons, g, rng diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index 9b3aa795..6d893e4d 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -1,3 +1,5 @@ +from functools import partial + import galsim as _galsim import jax import jax.numpy as jnp @@ -5,6 +7,7 @@ from jax._src.numpy.util import _wraps import jax_galsim.photon_array as pa +from jax_galsim.core.draw import calculate_n_photons, sample_poisson_flux from jax_galsim.core.utils import is_equal_with_arrays from jax_galsim.errors import ( GalSimError, @@ -15,8 +18,9 @@ galsim_warn, ) from jax_galsim.gsparams import GSParams +from jax_galsim.photon_array import PhotonArray from jax_galsim.position import Position, PositionD, PositionI -from jax_galsim.random import BaseDeviate, PoissonDeviate +from jax_galsim.random import BaseDeviate from jax_galsim.sensor import Sensor from jax_galsim.utilities import parse_pos_args @@ -723,7 +727,7 @@ def drawImage( flux_scale /= local_wcs.pixelArea() # Only do the gain here if not photon shooting, since need the number of photons to # reflect that actual photons, not ADU. - if gain != 1 and method != "phot" and sensor is None: + if method != "phot" and sensor is None and gain != 1: flux_scale /= gain # Determine the offset, and possibly fix the centering for even-sized images @@ -1100,105 +1104,24 @@ def _drawKImage( @_wraps(_galsim.GSObject._calculate_nphotons) def _calculate_nphotons(self, n_photons, poisson_flux, max_extra_noise, rng): - # For profiles that are positive definite, then N = flux. Easy. - # - # However, some profiles shoot some of their photons with negative flux. This means that - # we need a few more photons to get the right S/N = sqrt(flux). Take eta to be the - # fraction of shot photons that have negative flux. - # - # S^2 = (N+ - N-)^2 = (N+ + N- - 2N-)^2 = (Ntot - 2N-)^2 = Ntot^2(1 - 2 eta)^2 - # N^2 = Var(S) = (N+ + N-) = Ntot - # - # So flux = (S/N)^2 = Ntot (1-2eta)^2 - # Ntot = flux / (1-2eta)^2 - # - # However, if each photon has a flux of 1, then S = (1-2eta) Ntot = flux / (1-2eta). - # So in fact, each photon needs to carry a flux of g = 1-2eta to get the right - # total flux. - # - # That's all the easy case. The trickier case is when we are sky-background dominated. - # Then we can usually get away with fewer shot photons than the above. In particular, - # if the noise from the photon shooting is much less than the sky noise, then we can - # use fewer shot photons and essentially have each photon have a flux > 1. This is ok - # as long as the additional noise due to this approximation is "much less than" the - # noise we'll be adding to the image for the sky noise. - # - # Let's still have Ntot photons, but now each with a flux of g. And let's look at the - # noise we get in the brightest pixel that has a nominal total flux of Imax. - # - # The number of photons hitting this pixel will be Imax/flux * Ntot. - # The variance of this number is the same thing (Poisson counting). - # So the noise in that pixel is: - # - # N^2 = Imax/flux * Ntot * g^2 - # - # And the signal in that pixel will be: - # - # S = Imax/flux * (N+ - N-) * g which has to equal Imax, so - # g = flux / Ntot(1-2eta) - # N^2 = Imax/Ntot * flux / (1-2eta)^2 - # - # As expected, we see that lowering Ntot will increase the noise in that (and every - # other) pixel. - # The input max_extra_noise parameter is the maximum value of spurious noise we want - # to allow. - # - # So setting N^2 = Imax + nu, we get - # - # Ntot = flux / (1-2eta)^2 / (1 + nu/Imax) - # g = (1 - 2eta) * (1 + nu/Imax) - # - # Returns the total flux placed inside the image bounds by photon shooting. - # - - flux = self.flux - if flux == 0.0: - return 0, 1.0 - - # The _flux_per_photon property is (1-2eta) - # This factor will already be accounted for by the shoot function, so don't include - # that as part of our scaling here. There may be other adjustments though, so g=1 here. - eta_factor = self._flux_per_photon - mod_flux = flux / (eta_factor * eta_factor) - g = 1.0 - - # If requested, let the target flux value vary as a Poisson deviate - if poisson_flux: - # If we have both positive and negative photons, then the mix of these - # already gives us some variation in the flux value from the variance - # of how many are positive and how many are negative. - # The number of negative photons varies as a binomial distribution. - # = eta * Ntot * g - # = (1-eta) * Ntot * g - # = (1-2eta) * Ntot * g = flux - # Var(F-) = eta * (1-eta) * Ntot * g^2 - # F+ = Ntot * g - F- is not an independent variable, so - # Var(F+ - F-) = Var(Ntot*g - 2*F-) - # = 4 * Var(F-) - # = 4 * eta * (1-eta) * Ntot * g^2 - # = 4 * eta * (1-eta) * flux - # We want the variance to be equal to flux, so we need an extra: - # delta Var = (1 - 4*eta + 4*eta^2) * flux - # = (1-2eta)^2 * flux - absflux = abs(flux) - mean = eta_factor * eta_factor * absflux - pd = PoissonDeviate(rng, mean) - pd_val = pd() - mean + absflux - ratio = pd_val / absflux - g *= ratio - mod_flux *= ratio - if n_photons == 0.0: - n_photons = abs(mod_flux) - if max_extra_noise > 0.0: - gfactor = 1.0 + max_extra_noise / abs(self.max_sb) - n_photons /= gfactor - g *= gfactor - - # Make n_photons an integer. - iN = int(n_photons + 0.5) + Ntot, g = calculate_n_photons( + self.flux, + self._flux_per_photon, + self.max_sb, + rng=rng, + max_extra_noise=max_extra_noise, + poisson_flux=poisson_flux, + ) + else: + Ntot = int(n_photons + 0.5) + if poisson_flux: + pd_val = sample_poisson_flux(self.flux, self._flux_per_photon, rng=rng) + g = pd_val / jnp.abs(self.flux) + else: + g = 1.0 - return iN, g + return Ntot, g @_wraps( _galsim.GSObject.makePhot, @@ -1248,8 +1171,13 @@ def makePhot( "Deconvolve objects.\nOriginal error: %r" % (e) ) - if g != 1.0: - photons.scaleFlux(g) + photons = jax.lax.cond( + g == 1.0, + lambda photons, g: photons, + lambda photons, g: photons.scaleFlux(g), + photons, + g, + ) for op in photon_ops: op.applyTo(photons, local_wcs, rng) @@ -1304,12 +1232,13 @@ def drawPhot( Ntot, g = self._calculate_nphotons( n_photons, poisson_flux, max_extra_noise, rng ) - - if gain != 1.0: - g /= gain - - # total flux falling inside image bounds, this will be returned on exit. - added_flux = 0.0 + g = jax.lax.cond( + gain != 1.0, + lambda g, gain: g / gain, + lambda g, gain: g, + g, + gain, + ) if maxN is None: maxN = Ntot @@ -1317,52 +1246,41 @@ def drawPhot( if not add_to_image: image.setZero() - # Nleft is the number of photons remaining to shoot. - Nleft = Ntot - photons = None # Just in case Nleft is already 0. - resume = False - while Nleft > 0: - # Shoot at most maxN at a time - thisN = min(maxN, Nleft) - - try: - photons = self.shoot(thisN, rng) - except (GalSimError, NotImplementedError) as e: - raise GalSimNotImplementedError( - "Unable to draw this GSObject with photon shooting. Perhaps it " - "is a Deconvolve or is a compound including one or more " - "Deconvolve objects.\nOriginal error: %r" % (e) - ) - - if g != 1.0 or thisN != Ntot: - photons.scaleFlux(g * thisN / Ntot) - - if image.scale != 1.0: - photons.scaleXY( - 1.0 / image.scale - ) # Convert x,y to image coords if necessary - - for op in photon_ops: - op.applyTo(photons, local_wcs, rng) - - if image.dtype in (np.float32, np.float64): - added_flux += sensor.accumulate( - photons, image, orig_center, resume=resume - ) - resume = ( - True # Resume from this point if there are any further iterations. - ) - else: - # Need a temporary - from jax_galsim.image import ImageD - - im1 = ImageD(bounds=image.bounds) - added_flux += sensor.accumulate(photons, im1, orig_center) - image += im1 - - Nleft -= thisN + ( + photons, + _rng, + added_flux, + _Nleft, + _image, + _photon_ops, + _sensor, + ) = _draw_phot_while_loop( + PhotonArray(maxN), + rng, + self, + image, + g, + Ntot, + maxN, + photon_ops, + local_wcs, + sensor, + orig_center, + ) + if rng is not None: + rng._state = _rng._state + else: + rng = _rng + for i in range(len(photon_ops)): + photon_ops[i] = _photon_ops[i] + image._array = _image._array + # TODO: how to update the sensor? + if sensor.__class__ is not Sensor: + raise GalSimNotImplementedError( + "Non-default sensors that carry state are not yet supported in jax-galsim." + ) - return added_flux, photons + return added_flux, photons or None # Just in case Nleft is already 0. @_wraps(_galsim.GSObject.shoot) def shoot(self, n_photons, rng=None): @@ -1409,3 +1327,160 @@ def tree_flatten(self): def tree_unflatten(cls, aux_data, children): """Recreates an instance of the class from flatten representation""" return cls(**(children[0]), **aux_data) + + +@partial(jax.jit, static_argnames=("maxN",)) +def _draw_phot_while_loop( + photons, + rng, + obj, + image, + g, + Ntot, + maxN, + photon_ops, + local_wcs, + sensor, + orig_center, +): + def _cond_fun(args): + ( + photons, + rng, + added_flux, + obj, + Nleft, + resume, + image, + g, + photon_ops, + local_wcs, + sensor, + orig_center, + ) = args + return Nleft > 0 + + def _body_fun(args): + ( + photons, + rng, + added_flux, + obj, + Nleft, + resume, + image, + g, + photon_ops, + local_wcs, + sensor, + orig_center, + ) = args + # Shoot at most maxN at a time + thisN = jnp.minimum(maxN, Nleft) + + try: + photons = obj.shoot(maxN, rng) + except (GalSimError, NotImplementedError) as e: + raise GalSimNotImplementedError( + "Unable to draw this GSObject with photon shooting. Perhaps it " + "is a Deconvolve or is a compound including one or more " + "Deconvolve objects.\nOriginal error: %r" % (e) + ) + photons.flux = jnp.where( + jnp.arange(maxN) < thisN, + photons.flux, + 0.0, + ) + + photons = jax.lax.cond( + # weird way to say gain == 1 and thisN == Ntot + jnp.abs(g - 1.0) + jnp.abs(thisN - Ntot) == 0, + lambda photons, g, thisN, Ntot: photons, + # the factor here is (maxN / thisN) * (thisN / Ntot) = maxN / Ntot + # the first bit is that we drew maxN photons, but only thisN of them are valid + # the second bit is that we only drew thisN photons, but use a total of Ntot photons + lambda photons, g, thisN, Ntot: photons.scaleFlux(g * maxN / Ntot), + photons, + g, + thisN, + Ntot, + ) + + photons = jax.lax.cond( + image.scale != 1.0, + lambda photons, scale: photons.scaleXY( + 1.0 / scale + ), # Convert x,y to image coords if necessary + lambda photons, scale: photons, + photons, + image.scale, + ) + + for op in photon_ops: + op.applyTo(photons, local_wcs, rng) + + if image.dtype in (jnp.float32, jnp.float64): + added_flux += sensor.accumulate(photons, image, orig_center, resume=resume) + resume = True # Resume from this point if there are any further iterations. + else: + # Need a temporary + from jax_galsim.image import ImageD + + im1 = ImageD(bounds=image.bounds) + added_flux += sensor.accumulate(photons, im1, orig_center) + image += im1 + + Nleft -= thisN + + return ( + photons, + rng, + added_flux, + obj, + Nleft, + resume, + image, + g, + photon_ops, + local_wcs, + sensor, + orig_center, + ) + + added_flux = jnp.array(0) + Nleft = jnp.array(Ntot) + resume = jnp.array(False) + rng = BaseDeviate(rng) + ( + photons, + rng, + added_flux, + obj, + Nleft, + resume, + image, + g, + photon_ops, + local_wcs, + sensor, + orig_center, + ) = jax.lax.while_loop( + _cond_fun, + _body_fun, + ( + photons, + rng, + added_flux, + obj, + Nleft, + resume, + image, + g, + photon_ops, + local_wcs, + sensor, + orig_center, + ), + ) + + return photons, rng, added_flux, Nleft, image, photon_ops, sensor diff --git a/jax_galsim/photon_array.py b/jax_galsim/photon_array.py index 3bd9f1d2..f0db40d0 100644 --- a/jax_galsim/photon_array.py +++ b/jax_galsim/photon_array.py @@ -112,18 +112,17 @@ def _fromArrays( time=None, is_corr=False, ): - ret = cls( - x.shape[0], - x=x, - y=y, - flux=flux, - dxdz=dxdz, - dydz=dydz, - wavelength=wavelength, - pupil_u=pupil_u, - pupil_v=pupil_v, - time=time, - ) + ret = cls.__new__(cls) + ret._N = x.shape[0] + ret._x = x + ret._y = y + ret._flux = flux + ret._dxdz = dxdz + ret._dydz = dydz + ret._wave = wavelength + ret._pupil_u = pupil_u + ret._pupil_v = pupil_v + ret._time = time ret._is_corr = is_corr return ret @@ -140,13 +139,25 @@ def tree_flatten(self): "is_corr": self.isCorrelated(), }, ) - aux_data = None + aux_data = (self._N,) return (children, aux_data) @classmethod def tree_unflatten(cls, aux_data, children): """Recreates an instance of the class from flatten representation""" - return cls._fromArrays(*children[0], **children[1]) + ret = cls.__new__(cls) + ret._N = aux_data[0] + ret._x = children[0][0] + ret._y = children[0][1] + ret._flux = children[0][2] + ret._dxdz = children[1]["dxdz"] + ret._dydz = children[1]["dydz"] + ret._wave = children[1]["wavelength"] + ret._pupil_u = children[1]["pupil_u"] + ret._pupil_v = children[1]["pupil_v"] + ret._time = children[1]["time"] + ret._is_corr = children[1]["is_corr"] + return ret def size(self): """Return the size of the photon array. Equivalent to ``len(self)``.""" @@ -308,6 +319,8 @@ def setTotalFlux(self, flux): """ self.scaleFlux(flux / self.getTotalFlux()) + return self + def scaleFlux(self, scale): """Rescale the photon fluxes by the given factor. @@ -316,6 +329,8 @@ def scaleFlux(self, scale): """ self._flux *= scale + return self + def scaleXY(self, scale): """Scale the photon positions (`x` and `y`) by the given factor. @@ -325,6 +340,8 @@ def scaleXY(self, scale): self._x *= scale self._y *= scale + return self + def assignAt(self, istart, rhs): """Assign the contents of another `PhotonArray` to this one starting at istart.""" if istart + rhs.size() > self.size(): @@ -354,6 +371,8 @@ def assignAt(self, istart, rhs): self.allocateTimes() self._time = self._time.at[istart : istart + rhs.size()].set(rhs.time) + return self + def convolve(self, rhs, rng=None): """Convolve this `PhotonArray` with another. @@ -400,6 +419,8 @@ def convolve(self, rhs, rng=None): self._y = self._y + rhs._y[sinds] self._flux = self._flux * rhs._flux[sinds] * self.size() + return self + def __repr__(self): s = "galsim.PhotonArray(%d, x=array(%r), y=array(%r), flux=array(%r)" % ( int(cast_to_python_float(self.size())), @@ -492,15 +513,11 @@ def addTo(self, image): raise GalSimUndefinedBoundsError( "Attempting to PhotonArray::addTo an Image with undefined Bounds" ) - # the numpy histogram function histograms x along the first dimension and y along the - # along the second dimension. We need the opposite so we swap the inputs - xbins = jnp.arange(image.bounds.xmin, image.bounds.xmax + 2) - 0.5 - ybins = jnp.arange(image.bounds.ymin, image.bounds.ymax + 2) - 0.5 - im = jnp.histogram2d( - self._y, self._x, bins=(ybins, xbins), weights=self._flux, density=False - )[0] - image._array += im - return im.sum() + xinds = jnp.floor(self._x - image.bounds.xmin).astype(int) + yinds = jnp.floor(self._y - image.bounds.ymin).astype(int) + image._array = image._array.at[yinds, xinds].add(self._flux) + + return self._flux.sum() @classmethod def makeFromImage(cls, image, max_flux=1.0, rng=None): diff --git a/jax_galsim/sensor.py b/jax_galsim/sensor.py index 8aa9f743..19f3693e 100644 --- a/jax_galsim/sensor.py +++ b/jax_galsim/sensor.py @@ -1,11 +1,13 @@ import galsim as _galsim from jax._src.numpy.util import _wraps +from jax.tree_util import register_pytree_node_class from .errors import GalSimUndefinedBoundsError from .position import PositionI @_wraps(_galsim.Sensor) +@register_pytree_node_class class Sensor: def __init__(self): pass @@ -38,3 +40,12 @@ def __ne__(self, other): def __hash__(self): return hash(repr(self)) + + def tree_flatten(self): + children = tuple() + aux_data = {} + return (children, aux_data) + + @classmethod + def tree_unflatten(cls, aux_data, children): + return cls() diff --git a/jax_galsim/transform.py b/jax_galsim/transform.py index 4c5ee846..7b5ff2a0 100644 --- a/jax_galsim/transform.py +++ b/jax_galsim/transform.py @@ -402,7 +402,11 @@ def tree_flatten(self): @classmethod def tree_unflatten(cls, aux_data, children): """Recreates an instance of the class from flatten representation""" - return cls(children[0], **(children[1]), **aux_data) + obj = cls.__new__(cls) + obj._gsparams = aux_data["gsparams"] + obj._propagate_gsparams = aux_data["propagate_gsparams"] + obj._original, obj._params = children + return obj def _Transform( diff --git a/tests/jax/test_jitting.py b/tests/jax/test_jitting.py index dfd36a60..e38aeb6c 100644 --- a/tests/jax/test_jitting.py +++ b/tests/jax/test_jitting.py @@ -5,6 +5,7 @@ import numpy as np import jax_galsim as galsim +from jax_galsim.core.draw import calculate_n_photons from jax_galsim.core.testing import time_code_block # Defining jitting identity @@ -231,9 +232,15 @@ def _build_and_draw(hlr, fwhm, jit=True): ) n = final.getGoodImageSize(0.2).item() n += 1 - n_photons = final._calculate_nphotons(0, False, 0, None)[0] + n_photons = calculate_n_photons( + final.flux, + final._flux_per_photon, + final.max_sb, + poisson_flux=False, + )[0] + gain = 1.0 if jit: - return _draw_it_jit(final, n, n_photons) + return _draw_it_jit(final, n, n_photons, gain) else: return final.drawImage( nx=n, @@ -241,16 +248,21 @@ def _build_and_draw(hlr, fwhm, jit=True): scale=0.2, method="phot", n_photons=n_photons, + poisson_flux=False, + gain=gain, ) @partial(jax.jit, static_argnums=(1, 2)) - def _draw_it_jit(obj, n, nphotons): + def _draw_it_jit(obj, n, nphotons, gain): return obj.drawImage( nx=n, ny=n, scale=0.2, n_photons=nphotons, method="phot", + poisson_flux=False, + gain=gain, + maxN=101, ) with time_code_block("warmup no-jit"): @@ -263,4 +275,4 @@ def _draw_it_jit(obj, n, nphotons): with time_code_block("jit"): img = _build_and_draw(0.5, 1.0) - np.testing.assert_array_almost_equal(img.array.sum(), 1000.0, 3) + np.testing.assert_allclose(img.array.sum(), 1000.0) From f7ae8aac573fc48e6dfeea3be3474eaeab791f13 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 17 Nov 2023 05:58:45 -0600 Subject: [PATCH 12/85] REF change repr just a bit --- jax_galsim/angle.py | 6 +++--- jax_galsim/core/utils.py | 32 +++++++++++++++++++++++++------- jax_galsim/photon_array.py | 28 +++++++++++++++------------- 3 files changed, 43 insertions(+), 23 deletions(-) diff --git a/jax_galsim/angle.py b/jax_galsim/angle.py index fbbeaf18..450fe588 100644 --- a/jax_galsim/angle.py +++ b/jax_galsim/angle.py @@ -22,7 +22,7 @@ from jax._src.numpy.util import _wraps from jax.tree_util import register_pytree_node_class -from jax_galsim.core.utils import cast_to_float_array_scalar, ensure_hashable +from jax_galsim.core.utils import cast_to_array_scalar, ensure_hashable @_wraps(_galsim.AngleUnit) @@ -34,7 +34,7 @@ def __init__(self, value): """ :param value: The measure of the unit in radians. """ - self._value = cast_to_float_array_scalar(value) + self._value = cast_to_array_scalar(value, dtype=float) @property def value(self): @@ -142,7 +142,7 @@ def __init__(self, theta, unit=None): raise TypeError("Invalid unit %s of type %s" % (unit, type(unit))) else: # Normal case - self._rad = cast_to_float_array_scalar(theta) * unit.value + self._rad = cast_to_array_scalar(theta, dtype=float) * unit.value @property def rad(self): diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index 62db97e7..cee7e24e 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -13,24 +13,42 @@ def compute_major_minor_from_jacobian(jac): return major, minor -def cast_to_float_array_scalar(x): - """Cast the input to a float array scalar. Works on python floats, iterables and jax arrays. +def cast_to_array_scalar(x, dtype=None): + """Cast the input to an array scalar. Works on python scalars, iterables and jax arrays. For iterables it always takes the first element after a call to .ravel()""" + if dtype is None: + if hasattr(x, "dtype"): + dtype = x.dtype + else: + dtype = float + if isinstance(x, jax.Array): - return jnp.atleast_1d(x).astype(float).ravel()[0] + return jnp.atleast_1d(x).astype(dtype).ravel()[0] elif hasattr(x, "astype"): - return x.astype(float).ravel()[0] + return x.astype(dtype).ravel()[0] else: - return jnp.atleast_1d(jnp.array(x, dtype=float)).ravel()[0] + return jnp.atleast_1d(jnp.array(x, dtype=dtype)).ravel()[0] def cast_to_python_float(x): """Cast the input to a python float. Works on python floats and jax arrays. For jax arrays it always takes the first element after a call to .ravel()""" if isinstance(x, jax.Array): - return cast_to_float_array_scalar(x).item() + return cast_to_array_scalar(x, dtype=float).item() + else: + try: + return float(x) + except Exception: + return x + + +def cast_to_python_int(x): + """Cast the input to a python int. Works on python ints and jax arrays. + For jax arrays it always takes the first element after a call to .ravel()""" + if isinstance(x, jax.Array): + return cast_to_array_scalar(x, dtype=int).item() else: - return float(x) + return int(x) def cast_to_float(x): diff --git a/jax_galsim/photon_array.py b/jax_galsim/photon_array.py index f0db40d0..20e27bb4 100644 --- a/jax_galsim/photon_array.py +++ b/jax_galsim/photon_array.py @@ -4,7 +4,7 @@ from jax._src.numpy.util import _wraps from jax.tree_util import register_pytree_node_class -from jax_galsim.core.utils import cast_to_python_float +from jax_galsim.core.utils import cast_to_python_int from jax_galsim.errors import ( GalSimIncompatibleValuesError, GalSimRangeError, @@ -422,31 +422,33 @@ def convolve(self, rhs, rng=None): return self def __repr__(self): - s = "galsim.PhotonArray(%d, x=array(%r), y=array(%r), flux=array(%r)" % ( - int(cast_to_python_float(self.size())), - self.x.tolist(), - self.y.tolist(), - self.flux.tolist(), + import numpy as np + + s = "galsim.PhotonArray(%r, x=array(%r), y=array(%r), flux=array(%r)" % ( + cast_to_python_int(self.size()), + np.array(self.x).tolist(), + np.array(self.y).tolist(), + np.array(self.flux).tolist(), ) if self.hasAllocatedAngles(): s += ", dxdz=array(%r), dydz=array(%r)" % ( - self.dxdz.tolist(), - self.dydz.tolist(), + np.array(self.dxdz).tolist(), + np.array(self.dydz).tolist(), ) if self.hasAllocatedWavelengths(): - s += ", wavelength=array(%r)" % (self.wavelength.tolist()) + s += ", wavelength=array(%r)" % (np.array(self.wavelength).tolist()) if self.hasAllocatedPupil(): s += ", pupil_u=array(%r), pupil_v=array(%r)" % ( - self.pupil_u.tolist(), - self.pupil_v.tolist(), + np.array(self.pupil_u).tolist(), + np.array(self.pupil_v).tolist(), ) if self.hasAllocatedTimes(): - s += ", time=array(%r)" % (self.time.tolist()) + s += ", time=array(%r)" % np.array(self.time).tolist() s += ")" return s def __str__(self): - return "galsim.PhotonArray(%d)" % int(cast_to_python_float(self.size())) + return "galsim.PhotonArray(%r)" % cast_to_python_int(self.size()) __hash__ = None From 42c4ddb8618caae957a4125f3a217e9693308e66 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 17 Nov 2023 06:19:04 -0600 Subject: [PATCH 13/85] BUG wrong pixel location --- jax_galsim/photon_array.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax_galsim/photon_array.py b/jax_galsim/photon_array.py index 20e27bb4..4ebd5534 100644 --- a/jax_galsim/photon_array.py +++ b/jax_galsim/photon_array.py @@ -515,8 +515,8 @@ def addTo(self, image): raise GalSimUndefinedBoundsError( "Attempting to PhotonArray::addTo an Image with undefined Bounds" ) - xinds = jnp.floor(self._x - image.bounds.xmin).astype(int) - yinds = jnp.floor(self._y - image.bounds.ymin).astype(int) + xinds = jnp.floor(self._x - image.bounds.xmin + 0.5).astype(int) + yinds = jnp.floor(self._y - image.bounds.ymin + 0.5).astype(int) image._array = image._array.at[yinds, xinds].add(self._flux) return self._flux.sum() From fe73df2d22725f7cae82fbcb0533a7d679479154 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 17 Nov 2023 13:41:55 -0600 Subject: [PATCH 14/85] REF more jit in photon ops --- jax_galsim/interpolant.py | 8 +- jax_galsim/photon_array.py | 317 +++++++++++++++++++++++-------------- jax_galsim/sum.py | 94 ++++++----- tests/GalSim | 2 +- 4 files changed, 265 insertions(+), 156 deletions(-) diff --git a/jax_galsim/interpolant.py b/jax_galsim/interpolant.py index 03a2ab98..32551672 100644 --- a/jax_galsim/interpolant.py +++ b/jax_galsim/interpolant.py @@ -4,6 +4,8 @@ interpolants themselves (e.g., the coefficients that define the kernel shapes, the integrals of the kernels, etc.) are constants. """ +import math + import galsim as _galsim import jax import jax.numpy as jnp @@ -287,7 +289,7 @@ def xrange(self): @property def ixrange(self): """The total integral range of the interpolant. Typically 2 * xrange.""" - return 2 * int(jnp.ceil(self.xrange)) + return 2 * int(math.ceil(self.xrange)) @property def krange(self): @@ -467,7 +469,9 @@ def _comp_fluxes(self): narr = jnp.arange(n) val = (si(jnp.pi * (narr + 1)) - si(jnp.pi * (narr))) / jnp.pi - self._positive_flux = jax.lax.stop_gradient(jnp.sum(val[val > 0])).item() * 2.0 + self._positive_flux = ( + jax.lax.stop_gradient(jnp.sum(jnp.where(val > 0, val, 0.0))) * 2.0 + ) self._negative_flux = self._positive_flux - 1.0 def _shoot(self, photons, rng): diff --git a/jax_galsim/photon_array.py b/jax_galsim/photon_array.py index 4ebd5534..4cfd3219 100644 --- a/jax_galsim/photon_array.py +++ b/jax_galsim/photon_array.py @@ -1,4 +1,5 @@ import galsim as _galsim +import jax import jax.numpy as jnp import jax.random as jrng from jax._src.numpy.util import _wraps @@ -48,13 +49,13 @@ def __init__( self._x = jnp.zeros(self._N, dtype=float) self._y = jnp.zeros(self._N, dtype=float) self._flux = jnp.zeros(self._N, dtype=float) - self._dxdz = None - self._dydz = None - self._wave = None - self._pupil_u = None - self._pupil_v = None - self._time = None - self._is_corr = False + self._dxdz = jnp.full(self._N, jnp.nan, dtype=float) + self._dydz = jnp.full(self._N, jnp.nan, dtype=float) + self._wave = jnp.full(self._N, jnp.nan, dtype=float) + self._pupil_u = jnp.full(self._N, jnp.nan, dtype=float) + self._pupil_v = jnp.full(self._N, jnp.nan, dtype=float) + self._time = jnp.full(self._N, jnp.nan, dtype=float) + self._is_corr = jnp.array(False) if x is not None: self.x = x @@ -114,16 +115,34 @@ def _fromArrays( ): ret = cls.__new__(cls) ret._N = x.shape[0] - ret._x = x - ret._y = y - ret._flux = flux - ret._dxdz = dxdz - ret._dydz = dydz - ret._wave = wavelength - ret._pupil_u = pupil_u - ret._pupil_v = pupil_v - ret._time = time - ret._is_corr = is_corr + ret._x = x.copy() + ret._y = y.copy() + ret._flux = flux.copy() + ret._dxdz = ( + dxdz.copy() if dxdz is not None else jnp.full(ret._N, jnp.nan, dtype=float) + ) + ret._dydz = ( + dydz.copy() if dydz is not None else jnp.full(ret._N, jnp.nan, dtype=float) + ) + ret._wave = ( + wavelength.copy() + if wavelength is not None + else jnp.full(ret._N, jnp.nan, dtype=float) + ) + ret._pupil_u = ( + pupil_u.copy() + if pupil_u is not None + else jnp.full(ret._N, jnp.nan, dtype=float) + ) + ret._pupil_v = ( + pupil_v.copy() + if pupil_v is not None + else jnp.full(ret._N, jnp.nan, dtype=float) + ) + ret._time = ( + time.copy() if time is not None else jnp.full(ret._N, jnp.nan, dtype=float) + ) + ret._is_corr = jnp.array(is_corr) return ret def tree_flatten(self): @@ -204,8 +223,13 @@ def dxdz(self): @dxdz.setter def dxdz(self, value): - self.allocateAngles() self._dxdz = self._dxdz.at[:].set(value) + self._dydz = jax.lax.cond( + jnp.any(jnp.isfinite(self._dxdz)) & jnp.all(~jnp.isfinite(self._dydz)), + lambda dydz: jnp.zeros_like(dydz), + lambda dydz: dydz, + self._dydz, + ) @property def dydz(self): @@ -214,8 +238,13 @@ def dydz(self): @dydz.setter def dydz(self, value): - self.allocateAngles() self._dydz = self._dydz.at[:].set(value) + self._dxdz = jax.lax.cond( + jnp.any(jnp.isfinite(self._dydz)) & jnp.all(~jnp.isfinite(self._dxdz)), + lambda dxdz: jnp.zeros_like(dxdz), + lambda dxdz: dxdz, + self._dxdz, + ) @property def wavelength(self): @@ -224,7 +253,6 @@ def wavelength(self): @wavelength.setter def wavelength(self, value): - self.allocateWavelengths() self._wave = self._wave.at[:].set(value) @property @@ -234,8 +262,14 @@ def pupil_u(self): @pupil_u.setter def pupil_u(self, value): - self.allocatePupil() self._pupil_u = self._pupil_u.at[:].set(value) + self._pupil_v = jax.lax.cond( + jnp.any(jnp.isfinite(self._pupil_u)) + & jnp.all(~jnp.isfinite(self._pupil_v)), + lambda pupil_v: jnp.zeros_like(pupil_v), + lambda pupil_v: pupil_v, + self._pupil_v, + ) @property def pupil_v(self): @@ -244,8 +278,14 @@ def pupil_v(self): @pupil_v.setter def pupil_v(self, value): - self.allocatePupil() self._pupil_v = self._pupil_v.at[:].set(value) + self._pupil_u = jax.lax.cond( + jnp.any(jnp.isfinite(self._pupil_v)) + & jnp.all(~jnp.isfinite(self._pupil_u)), + lambda pupil_u: jnp.zeros_like(pupil_u), + lambda pupil_u: pupil_u, + self._pupil_u, + ) @property def time(self): @@ -254,50 +294,43 @@ def time(self): @time.setter def time(self, value): - self.allocateTimes() self._time = self._time.at[:].set(value) def hasAllocatedAngles(self): """Returns whether the arrays for the incidence angles `dxdz` and `dydz` have been allocated. """ - return self._dxdz is not None and self._dydz is not None + return jnp.any(jnp.isfinite(self.dxdz) | jnp.isfinite(self.dydz)) def allocateAngles(self): """Allocate memory for the incidence angles, `dxdz` and `dydz`.""" - if not self.hasAllocatedAngles(): - self._dxdz = jnp.zeros(self._N, dtype=float) - self._dydz = jnp.zeros(self._N, dtype=float) + pass def hasAllocatedWavelengths(self): """Returns whether the `wavelength` array has been allocated.""" - return self._wave is not None + return jnp.any(jnp.isfinite(self.wavelength)) def allocateWavelengths(self): """Allocate the memory for the `wavelength` array.""" - if not self.hasAllocatedWavelengths(): - self._wave = jnp.zeros(self._N, dtype=float) + pass def hasAllocatedPupil(self): """Returns whether the arrays for the pupil coordinates `pupil_u` and `pupil_v` have been allocated. """ - return self._pupil_u is not None and self._pupil_v is not None + return jnp.any(jnp.isfinite(self.pupil_u) | jnp.isfinite(self.pupil_v)) def allocatePupil(self): """Allocate the memory for the pupil coordinates, `pupil_u` and `pupil_v`.""" - if not self.hasAllocatedPupil(): - self._pupil_u = jnp.zeros(self._N, dtype=float) - self._pupil_v = jnp.zeros(self._N, dtype=float) + pass def hasAllocatedTimes(self): """Returns whether the array for the time stamps `time` has been allocated.""" - return self._time is not None + return jnp.any(jnp.isfinite(self.time)) def allocateTimes(self): """Allocate the memory for the time stamps, `time`.""" - if not self.hasAllocatedTimes(): - self._time = jnp.zeros(self._N, dtype=float) + return True def isCorrelated(self): """Returns whether the photons are correlated""" @@ -305,7 +338,7 @@ def isCorrelated(self): def setCorrelated(self, is_corr=True): """Set whether the photons are correlated""" - self._is_corr = is_corr + self._is_corr = jnp.array(is_corr, dtype=bool) def getTotalFlux(self): """Return the total flux of all the photons.""" @@ -352,27 +385,49 @@ def assignAt(self, istart, rhs): self._x = self._x.at[istart : istart + rhs.size()].set(rhs.x) self._y = self._y.at[istart : istart + rhs.size()].set(rhs.y) self._flux = self._flux.at[istart : istart + rhs.size()].set(rhs.flux) - if rhs.hasAllocatedAngles(): - self.allocateAngles() - self._dxdz = self._dxdz.at[istart : istart + rhs.size()].set(rhs.dxdz) - self._dydz = self._dydz.at[istart : istart + rhs.size()].set(rhs.dydz) - if rhs.hasAllocatedWavelengths(): - self.allocateWavelengths() - self._wave = self._wave.at[istart : istart + rhs.size()].set(rhs.wavelength) - if rhs.hasAllocatedPupil(): - self.allocatePupil() - self._pupil_u = self._pupil_u.at[istart : istart + rhs.size()].set( - rhs.pupil_u - ) - self._pupil_v = self._pupil_v.at[istart : istart + rhs.size()].set( - rhs.pupil_v - ) - if rhs.hasAllocatedTimes(): - self.allocateTimes() - self._time = self._time.at[istart : istart + rhs.size()].set(rhs.time) + self._dxdz = self._dxdz.at[istart : istart + rhs.size()].set(rhs.dxdz) + self._dydz = self._dydz.at[istart : istart + rhs.size()].set(rhs.dydz) + self._wave = self._wave.at[istart : istart + rhs.size()].set(rhs.wavelength) + self._pupil_u = self._pupil_u.at[istart : istart + rhs.size()].set(rhs.pupil_u) + self._pupil_v = self._pupil_v.at[istart : istart + rhs.size()].set(rhs.pupil_v) + self._time = self._time.at[istart : istart + rhs.size()].set(rhs.time) + + return self + + def _assign_from_categorical_index(self, cat_inds, cat_ind_to_assign, rhs): + """Assign the contents of another `PhotonArray` to this one at locations + where cat_ind == cat_ind_to_assign. + """ + msk = cat_ind_to_assign == cat_inds + self._x = jnp.where(msk, rhs._x, self._x) + self._y = jnp.where(msk, rhs._y, self._y) + self._flux = jnp.where(msk, rhs._flux, self._flux) + + self._dxdz = jnp.where(msk, rhs._dxdz, self._dxdz) + self._dydz = jnp.where(msk, rhs._dydz, self._dydz) + self._wave = jnp.where(msk, rhs._wave, self._wave) + self._pupil_u = jnp.where(msk, rhs._pupil_u, self._pupil_u) + self._pupil_v = jnp.where(msk, rhs._pupil_v, self._pupil_v) + self._time = jnp.where(msk, rhs._time, self._time) return self + @classmethod + def _stack_photon_arrays_to_dict_of_matrices(cls, photon_arrays): + ret = { + "x": jnp.stack([pa.x for pa in photon_arrays]), + "y": jnp.stack([pa.y for pa in photon_arrays]), + "flux": jnp.stack([pa.flux for pa in photon_arrays]), + "is_corr": jnp.stack([pa.isCorrelated() for pa in photon_arrays]), + "dxdz": jnp.stack([pa.dxdz for pa in photon_arrays]), + "dydz": jnp.stack([pa.dydz for pa in photon_arrays]), + "wavelength": jnp.stack([pa.wavelength for pa in photon_arrays]), + "pupil_u": jnp.stack([pa.pupil_u for pa in photon_arrays]), + "pupil_v": jnp.stack([pa.pupil_v for pa in photon_arrays]), + "time": jnp.stack([pa.time for pa in photon_arrays]), + } + return ret + def convolve(self, rhs, rng=None): """Convolve this `PhotonArray` with another. @@ -386,38 +441,82 @@ def convolve(self, rhs, rng=None): "PhotonArray.convolve with unequal size arrays", self_pa=self, rhs=rhs ) - if rhs.isCorrelated() and self.isCorrelated(): - rng = BaseDeviate(rng) - subkey = rng._state.split_one() - sinds = jrng.choice( - subkey, - self.size(), - shape=(self.size(),), - replace=False, - ) - else: - sinds = jnp.arange(self.size()) + rng = BaseDeviate(rng) + rsinds = jrng.choice( + rng._state.split_one(), + self.size(), + shape=(self.size(),), + replace=False, + ) + nrsinds = jnp.arange(self.size()) + + sinds = jax.lax.cond( + jnp.array(self.isCorrelated()) & jnp.array(rhs.isCorrelated()), + lambda nrsinds, rsinds: rsinds, + lambda nrsinds, rsinds: nrsinds, + nrsinds, + rsinds, + ) - if rhs.hasAllocatedAngles() and not self.hasAllocatedAngles(): - self.dxdz = rhs.dxdz[sinds] - self.dydz = rhs.dydz[sinds] + self.dxdz, self.dydz = jax.lax.cond( + rhs.hasAllocatedAngles() & (~self.hasAllocatedAngles()), + lambda self_dxdz, rhs_dxdz, self_dydz, rhs_dydz, sinds: ( + rhs_dxdz.at[sinds].get(), + rhs_dydz.at[sinds].get(), + ), + lambda self_dxdz, rhs_dxdz, self_dydz, rhs_dydz, sinds: ( + self_dxdz, + self_dydz, + ), + self.dxdz, + rhs.dxdz, + self.dydz, + rhs.dydz, + sinds, + ) - if rhs.hasAllocatedWavelengths() and not self.hasAllocatedWavelengths(): - self.wavelength = rhs.wavelength + self.wavelength = jax.lax.cond( + rhs.hasAllocatedWavelengths() & (~self.hasAllocatedWavelengths()), + lambda self_wave, rhs_wave, sinds: rhs_wave.at[sinds].get(), + lambda self_wave, rhs_wave, sinds: self_wave, + self.wavelength, + rhs.wavelength, + sinds, + ) - if rhs.hasAllocatedPupil() and not self.hasAllocatedPupil(): - self.pupil_u = rhs.pupil_u[sinds] - self.pupil_v = rhs.pupil_v[sinds] + self.pupil_u, self.pupil_v = jax.lax.cond( + rhs.hasAllocatedPupil() & (~self.hasAllocatedPupil()), + lambda self_pupil_u, rhs_pupil_u, self_pupil_v, rhs_pupil_v, sinds: ( + rhs_pupil_u.at[sinds].get(), + rhs_pupil_v.at[sinds].get(), + ), + lambda self_pupil_u, rhs_pupil_u, self_pupil_v, rhs_pupil_v, sinds: ( + self_pupil_u, + self_pupil_v, + ), + self.pupil_u, + rhs.pupil_u, + self.pupil_v, + rhs.pupil_v, + sinds, + ) - if rhs.hasAllocatedTimes() and not self.hasAllocatedTimes(): - self.time = rhs.time[sinds] + self.time = jax.lax.cond( + rhs.hasAllocatedTimes() & (~self.hasAllocatedTimes()), + lambda self_time, rhs_time, sinds: rhs_time.at[sinds].get(), + lambda self_time, rhs_time, sinds: self_time, + self.time, + rhs.time, + sinds, + ) - if rhs.isCorrelated(): - self.setCorrelated() + self.setCorrelated( + jnp.array(self.isCorrelated()) | jnp.array(rhs.isCorrelated()) + ) - self._x = self._x + rhs._x[sinds] - self._y = self._y + rhs._y[sinds] - self._flux = self._flux * rhs._flux[sinds] * self.size() + self._x = self._x + rhs._x.at[sinds].get() + self._y = self._y + rhs._y.at[sinds].get() + self._flux = self._flux * rhs._flux.at[sinds].get() * self.size() return self @@ -458,40 +557,12 @@ def __eq__(self, other): and jnp.array_equal(self.x, other.x) and jnp.array_equal(self.y, other.y) and jnp.array_equal(self.flux, other.flux) - and self.hasAllocatedAngles() == other.hasAllocatedAngles() - and self.hasAllocatedWavelengths() == other.hasAllocatedWavelengths() - and self.hasAllocatedPupil() == other.hasAllocatedPupil() - and self.hasAllocatedTimes() == other.hasAllocatedTimes() - and ( - jnp.array_equal(self.dxdz, other.dxdz) - if self.hasAllocatedAngles() - else True - ) - and ( - jnp.array_equal(self.dydz, other.dydz) - if self.hasAllocatedAngles() - else True - ) - and ( - jnp.array_equal(self.wavelength, other.wavelength) - if self.hasAllocatedWavelengths() - else True - ) - and ( - jnp.array_equal(self.pupil_u, other.pupil_u) - if self.hasAllocatedPupil() - else True - ) - and ( - jnp.array_equal(self.pupil_v, other.pupil_v) - if self.hasAllocatedPupil() - else True - ) - and ( - jnp.array_equal(self.time, other.time) - if self.hasAllocatedTimes() - else True - ) + and jnp.array_equal(self.dxdz, other.dxdz, equal_nan=True) + and jnp.array_equal(self.dydz, other.dydz, equal_nan=True) + and jnp.array_equal(self.wavelength, other.wavelength, equal_nan=True) + and jnp.array_equal(self.pupil_u, other.pupil_u, equal_nan=True) + and jnp.array_equal(self.pupil_v, other.pupil_v, equal_nan=True) + and jnp.array_equal(self.time, other.time, equal_nan=True) ) def __ne__(self, other): @@ -517,7 +588,19 @@ def addTo(self, image): ) xinds = jnp.floor(self._x - image.bounds.xmin + 0.5).astype(int) yinds = jnp.floor(self._y - image.bounds.ymin + 0.5).astype(int) - image._array = image._array.at[yinds, xinds].add(self._flux) + # the jax documentation says that they drop out of bounds indices, + # but the galsim unit tests reveal that withoout the check below, + # the indices are not dropped. + # I think maybe it is only indices beyond the end of the array that are + # dropped and negative indices wrap around + good = ( + (xinds >= 0) + & (xinds < image.array.shape[1]) + & (yinds >= 0) + & (yinds < image.array.shape[0]) + ) + flux = jnp.where(good, self._flux, 0.0) + image._array = image._array.at[yinds, xinds].add(flux) return self._flux.sum() diff --git a/jax_galsim/sum.py b/jax_galsim/sum.py index 1454c0bf..d74676c5 100644 --- a/jax_galsim/sum.py +++ b/jax_galsim/sum.py @@ -1,4 +1,5 @@ import galsim as _galsim +import jax import jax.numpy as jnp import numpy as np from jax._src.numpy.util import _wraps @@ -6,8 +7,9 @@ from jax_galsim.gsobject import GSObject from jax_galsim.gsparams import GSParams +from jax_galsim.photon_array import PhotonArray from jax_galsim.position import PositionD -from jax_galsim.random import BinomialDeviate +from jax_galsim.random import BaseDeviate @_wraps( @@ -188,43 +190,63 @@ def _flux_per_photon(self): return self._calculate_flux_per_photon() def _shoot(self, photons, rng): - remainingAbsoluteFlux = self.positive_flux + self.negative_flux - fluxPerPhoton = remainingAbsoluteFlux / len(photons) - - remainingN = len(photons) - istart = ( - 0 # The location in the photons array where we assign the component arrays. + tot_flux = self.positive_flux + self.negative_flux + fluxes = jnp.array( + [obj.positive_flux + obj.negative_flux for obj in self.obj_list] + ) + rng = BaseDeviate(rng) + key = rng._state.split_one() + cat_inds = jax.random.choice( + key, + len(self.obj_list), + shape=(len(photons),), + replace=True, + p=fluxes / tot_flux, ) + photon_arrays = [obj.shoot(photons.size(), rng=rng) for obj in self.obj_list] + + def _body_fun(i, args): + photons, rng, pa_dict, fluxes, cat_inds, tot_flux = args + thisAbsoluteFlux = jax.lax.dynamic_index_in_dim(fluxes, i, keepdims=False) + pa = PhotonArray._fromArrays( + x=jax.lax.dynamic_index_in_dim(pa_dict["x"], i, keepdims=False), + y=jax.lax.dynamic_index_in_dim(pa_dict["y"], i, keepdims=False), + flux=( + jax.lax.dynamic_index_in_dim(pa_dict["flux"], i, keepdims=False) + * (tot_flux / thisAbsoluteFlux) + ), + dxdz=jax.lax.dynamic_index_in_dim(pa_dict["dxdz"], i, keepdims=False), + dydz=jax.lax.dynamic_index_in_dim(pa_dict["dydz"], i, keepdims=False), + wavelength=jax.lax.dynamic_index_in_dim( + pa_dict["wavelength"], i, keepdims=False + ), + pupil_u=jax.lax.dynamic_index_in_dim( + pa_dict["pupil_u"], i, keepdims=False + ), + pupil_v=jax.lax.dynamic_index_in_dim( + pa_dict["pupil_v"], i, keepdims=False + ), + time=jax.lax.dynamic_index_in_dim(pa_dict["time"], i, keepdims=False), + is_corr=jnp.any(pa_dict["is_corr"], axis=0), + ) + photons._assign_from_categorical_index(cat_inds, i, pa) + return photons, rng, pa_dict, fluxes, cat_inds, tot_flux - # Get photons from each summand, using BinomialDeviate to randomize - # the distribution of photons among summands - for i, obj in enumerate(self.obj_list): - thisAbsoluteFlux = obj.positive_flux + obj.negative_flux - - # How many photons to shoot from this summand? - thisN = remainingN # All of what's left, if this is the last summand... - if i < len(self.obj_list) - 1: - # otherwise, allocate a randomized fraction of the remaining photons to summand. - bd = BinomialDeviate( - rng, remainingN, thisAbsoluteFlux / remainingAbsoluteFlux - ) - thisN = int(bd()) - if thisN > 0: - thisPA = obj.shoot(thisN, rng) - # Now rescale the photon fluxes so that they are each nominally fluxPerPhoton - # whereas the shoot() routine would have made them each nominally - # thisAbsoluteFlux/thisN - thisPA.scaleFlux(fluxPerPhoton * thisN / thisAbsoluteFlux) - photons.assignAt(istart, thisPA) - istart += thisN - remainingN -= thisN - remainingAbsoluteFlux -= thisAbsoluteFlux - # assert remainingN == 0 - # assert np.isclose(remainingAbsoluteFlux, 0.0) - - # This process produces correlated photons, so mark the resulting array as such. - if len(self.obj_list) > 1: - photons.setCorrelated() + _photons, _rng = jax.lax.fori_loop( + 0, + len(self.obj_list), + _body_fun, + ( + photons, + BaseDeviate(rng), + PhotonArray._stack_photon_arrays_to_dict_of_matrices(photon_arrays), + fluxes, + cat_inds, + tot_flux, + ), + )[0:2] + rng._state = _rng._state + photons.assignAt(0, _photons) def tree_flatten(self): """This function flattens the GSObject into a list of children diff --git a/tests/GalSim b/tests/GalSim index 3e308a21..710cca28 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 3e308a2194f8a3d08e811046634d5f115fc54356 +Subproject commit 710cca286c5fcd229d1c309aaf6e5c61ec81f9dc From 869443603d41576cc6b1e7bc481c0fdaefa40191 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 17 Nov 2023 13:44:23 -0600 Subject: [PATCH 15/85] TST add sum to jit test --- tests/jax/test_jitting.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/jax/test_jitting.py b/tests/jax/test_jitting.py index e38aeb6c..bb410b4c 100644 --- a/tests/jax/test_jitting.py +++ b/tests/jax/test_jitting.py @@ -225,7 +225,10 @@ def _draw_it_jit(obj, n, nfft): def test_jitting_draw_phot(): def _build_and_draw(hlr, fwhm, jit=True): - gal = galsim.Exponential(half_light_radius=hlr, flux=1000.0) + gal = ( + galsim.Exponential(half_light_radius=hlr, flux=1000.0) + + galsim.Exponential(half_light_radius=hlr * 2.0, flux=100.0) + ) psf = galsim.Gaussian(fwhm=fwhm, flux=1.0) final = galsim.Convolve( [gal, psf], @@ -275,4 +278,4 @@ def _draw_it_jit(obj, n, nphotons, gain): with time_code_block("jit"): img = _build_and_draw(0.5, 1.0) - np.testing.assert_allclose(img.array.sum(), 1000.0) + np.testing.assert_allclose(img.array.sum(), 1100.0) From 026a03d3faff32f637044f1bfdd5f5624aab75d1 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 17 Nov 2023 14:18:44 -0600 Subject: [PATCH 16/85] STY blacken --- tests/jax/test_jitting.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/jax/test_jitting.py b/tests/jax/test_jitting.py index bb410b4c..77f17c38 100644 --- a/tests/jax/test_jitting.py +++ b/tests/jax/test_jitting.py @@ -225,10 +225,9 @@ def _draw_it_jit(obj, n, nfft): def test_jitting_draw_phot(): def _build_and_draw(hlr, fwhm, jit=True): - gal = ( - galsim.Exponential(half_light_radius=hlr, flux=1000.0) - + galsim.Exponential(half_light_radius=hlr * 2.0, flux=100.0) - ) + gal = galsim.Exponential( + half_light_radius=hlr, flux=1000.0 + ) + galsim.Exponential(half_light_radius=hlr * 2.0, flux=100.0) psf = galsim.Gaussian(fwhm=fwhm, flux=1.0) final = galsim.Convolve( [gal, psf], From d9a9d4c6ac2bb4d6f2baa016337fbfb6e9347dbe Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Fri, 17 Nov 2023 14:22:12 -0600 Subject: [PATCH 17/85] Update jax_galsim/core/utils.py --- jax_galsim/core/utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index cee7e24e..fc34b818 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -48,7 +48,10 @@ def cast_to_python_int(x): if isinstance(x, jax.Array): return cast_to_array_scalar(x, dtype=int).item() else: - return int(x) + try: + return int(x) + except Exception: + return x def cast_to_float(x): From 7cc93841584e15a66dc8a56f03600b4333392a5c Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Fri, 17 Nov 2023 14:23:24 -0600 Subject: [PATCH 18/85] Update jax_galsim/core/utils.py --- jax_galsim/core/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index fc34b818..f69afaf4 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -50,7 +50,7 @@ def cast_to_python_int(x): else: try: return int(x) - except Exception: + except TypeError: return x From 110442f2a1521e703f6216ac345d3de8976d0bbf Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Fri, 17 Nov 2023 14:23:52 -0600 Subject: [PATCH 19/85] Update jax_galsim/core/utils.py --- jax_galsim/core/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index f69afaf4..c499929c 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -38,7 +38,7 @@ def cast_to_python_float(x): else: try: return float(x) - except Exception: + except TypeError: return x From 0cb4de9298fce36ba3e0f364d02c55dc977adc9f Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Fri, 17 Nov 2023 14:30:29 -0600 Subject: [PATCH 20/85] Update jax_galsim/noise.py --- jax_galsim/noise.py | 1 - 1 file changed, 1 deletion(-) diff --git a/jax_galsim/noise.py b/jax_galsim/noise.py index 96280f48..5d517d19 100644 --- a/jax_galsim/noise.py +++ b/jax_galsim/noise.py @@ -510,7 +510,6 @@ def copy(self, rng=None): # kind of deviate, but just reset it to follow the given rng. dev = self.rng.duplicate() dev.reset(rng) - print(repr(dev), repr(self.rng)) return DeviateNoise(dev) def __repr__(self): From f2ea9b290078aa087d9385c30a6e31ba56efe4c3 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 17 Nov 2023 19:33:23 -0600 Subject: [PATCH 21/85] REF simpler --- jax_galsim/photon_array.py | 16 --------- jax_galsim/random.py | 2 +- jax_galsim/sum.py | 74 ++++++++++++++++---------------------- tests/jax/test_jitting.py | 5 +++ 4 files changed, 36 insertions(+), 61 deletions(-) diff --git a/jax_galsim/photon_array.py b/jax_galsim/photon_array.py index 4cfd3219..81c75876 100644 --- a/jax_galsim/photon_array.py +++ b/jax_galsim/photon_array.py @@ -412,22 +412,6 @@ def _assign_from_categorical_index(self, cat_inds, cat_ind_to_assign, rhs): return self - @classmethod - def _stack_photon_arrays_to_dict_of_matrices(cls, photon_arrays): - ret = { - "x": jnp.stack([pa.x for pa in photon_arrays]), - "y": jnp.stack([pa.y for pa in photon_arrays]), - "flux": jnp.stack([pa.flux for pa in photon_arrays]), - "is_corr": jnp.stack([pa.isCorrelated() for pa in photon_arrays]), - "dxdz": jnp.stack([pa.dxdz for pa in photon_arrays]), - "dydz": jnp.stack([pa.dydz for pa in photon_arrays]), - "wavelength": jnp.stack([pa.wavelength for pa in photon_arrays]), - "pupil_u": jnp.stack([pa.pupil_u for pa in photon_arrays]), - "pupil_v": jnp.stack([pa.pupil_v for pa in photon_arrays]), - "time": jnp.stack([pa.time for pa in photon_arrays]), - } - return ret - def convolve(self, rhs, rng=None): """Convolve this `PhotonArray` with another. diff --git a/jax_galsim/random.py b/jax_galsim/random.py index f3af980e..1b4ce1f4 100644 --- a/jax_galsim/random.py +++ b/jax_galsim/random.py @@ -96,7 +96,7 @@ def reset(self, seed=None): self._state = seed elif isinstance(seed, BaseDeviate): self._state = seed._state - elif isinstance(seed, jax.Array): + elif isinstance(seed, jax.Array) and seed.shape == (2,): self._state = _DeviateState(wrap_key_data(seed)) elif isinstance(seed, str): self._state = _DeviateState( diff --git a/jax_galsim/sum.py b/jax_galsim/sum.py index d74676c5..7879264a 100644 --- a/jax_galsim/sum.py +++ b/jax_galsim/sum.py @@ -7,7 +7,6 @@ from jax_galsim.gsobject import GSObject from jax_galsim.gsparams import GSParams -from jax_galsim.photon_array import PhotonArray from jax_galsim.position import PositionD from jax_galsim.random import BaseDeviate @@ -203,50 +202,37 @@ def _shoot(self, photons, rng): replace=True, p=fluxes / tot_flux, ) - photon_arrays = [obj.shoot(photons.size(), rng=rng) for obj in self.obj_list] - - def _body_fun(i, args): - photons, rng, pa_dict, fluxes, cat_inds, tot_flux = args - thisAbsoluteFlux = jax.lax.dynamic_index_in_dim(fluxes, i, keepdims=False) - pa = PhotonArray._fromArrays( - x=jax.lax.dynamic_index_in_dim(pa_dict["x"], i, keepdims=False), - y=jax.lax.dynamic_index_in_dim(pa_dict["y"], i, keepdims=False), - flux=( - jax.lax.dynamic_index_in_dim(pa_dict["flux"], i, keepdims=False) - * (tot_flux / thisAbsoluteFlux) - ), - dxdz=jax.lax.dynamic_index_in_dim(pa_dict["dxdz"], i, keepdims=False), - dydz=jax.lax.dynamic_index_in_dim(pa_dict["dydz"], i, keepdims=False), - wavelength=jax.lax.dynamic_index_in_dim( - pa_dict["wavelength"], i, keepdims=False - ), - pupil_u=jax.lax.dynamic_index_in_dim( - pa_dict["pupil_u"], i, keepdims=False - ), - pupil_v=jax.lax.dynamic_index_in_dim( - pa_dict["pupil_v"], i, keepdims=False - ), - time=jax.lax.dynamic_index_in_dim(pa_dict["time"], i, keepdims=False), - is_corr=jnp.any(pa_dict["is_corr"], axis=0), - ) + for i, obj in enumerate(self.obj_list): + pa = obj.shoot(photons.size(), rng=rng) + # now we rescale the fluxes of the photons + # the photons start with + # + # flux_per_photon = (obj.positive_flux + obj.negative_flux) / photons.size() + # + # but they should have had a flux per photon of + # + # flux_per_photon = (self.positive_flux + self.negative_flux) / thisN + # = fluxes[i] / thisN + # + # where thisN = jnp.sum(cat_inds == i). We drew photons.size() photons instead + # of thisN, above. so we scale their fluxes by a factor of + # + # _scale_fac = photons.size() / thisN + # + # next we want them to have a total flux of + # + # tot_flux_per_photon = (self.positive_flux + self.negative_flux) / photons.size() + # + # so we scale by a factor of + # + # _scale_fac = tot_flux_per_photon / flux_per_photon_thisN + # + # so we get a total factor of + # + # _scale_fac = tot_flux / fluxes[i] + _scale_fac = tot_flux / fluxes[i] + pa.scaleFlux(_scale_fac) photons._assign_from_categorical_index(cat_inds, i, pa) - return photons, rng, pa_dict, fluxes, cat_inds, tot_flux - - _photons, _rng = jax.lax.fori_loop( - 0, - len(self.obj_list), - _body_fun, - ( - photons, - BaseDeviate(rng), - PhotonArray._stack_photon_arrays_to_dict_of_matrices(photon_arrays), - fluxes, - cat_inds, - tot_flux, - ), - )[0:2] - rng._state = _rng._state - photons.assignAt(0, _photons) def tree_flatten(self): """This function flattens the GSObject into a list of children diff --git a/tests/jax/test_jitting.py b/tests/jax/test_jitting.py index 77f17c38..9b99b0a5 100644 --- a/tests/jax/test_jitting.py +++ b/tests/jax/test_jitting.py @@ -269,11 +269,16 @@ def _draw_it_jit(obj, n, nphotons, gain): with time_code_block("warmup no-jit"): img = _build_and_draw(0.5, 1.0, jit=False) + np.testing.assert_allclose(img.array.sum(), 1100.0) + with time_code_block("no-jit"): img = _build_and_draw(0.5, 1.0, jit=False) + np.testing.assert_allclose(img.array.sum(), 1100.0) with time_code_block("warmup jit"): img = _build_and_draw(0.5, 1.0) + np.testing.assert_allclose(img.array.sum(), 1100.0) + with time_code_block("jit"): img = _build_and_draw(0.5, 1.0) From 1416586dc571bda53897c7233e8522eec8c830c9 Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 18 Nov 2023 06:46:46 -0600 Subject: [PATCH 22/85] DOC added comments --- jax_galsim/sum.py | 39 +++++++++++++++++++++------------------ 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/jax_galsim/sum.py b/jax_galsim/sum.py index 7879264a..a5f9133b 100644 --- a/jax_galsim/sum.py +++ b/jax_galsim/sum.py @@ -193,6 +193,22 @@ def _shoot(self, photons, rng): fluxes = jnp.array( [obj.positive_flux + obj.negative_flux for obj in self.obj_list] ) + # for a sum of objects, we use a slightly different approach than galsim + # galsim uses a binomial distribution to compute the number of photons per object + # we take an equivalent but different approach in order to use fixed size arrays + # of photons. it means we draw more photons but the code is JIT compilable and a bit simpler + # + # this all works as follows: + # + # - for each photon, we draw from a categorical distribution with probabilities + # proportional to the total absolute fluxes of the objects. + # - we then shoot the photons from each object and rescale the fluxes (see comment below) + # - finally, we get the photons that correspond to this object in the cetegorical distribution + # and assign them to the photons object there is a special private method on the + # PhotonArray that does this assignment + # + # one nice thing about this is that the photons come out pre-shuffled and so we don't have + # to mark them as correlated. rng = BaseDeviate(rng) key = rng._state.split_one() cat_inds = jax.random.choice( @@ -205,30 +221,17 @@ def _shoot(self, photons, rng): for i, obj in enumerate(self.obj_list): pa = obj.shoot(photons.size(), rng=rng) # now we rescale the fluxes of the photons - # the photons start with + # in galsim, photons end up with a flux that is # - # flux_per_photon = (obj.positive_flux + obj.negative_flux) / photons.size() + # fluxes[i] / thisN * tot_flux / photons.size() * thisN / fluxes[i] + # = tot_flux / photons.size() # - # but they should have had a flux per photon of + # our photons start with a flux of # - # flux_per_photon = (self.positive_flux + self.negative_flux) / thisN - # = fluxes[i] / thisN - # - # where thisN = jnp.sum(cat_inds == i). We drew photons.size() photons instead - # of thisN, above. so we scale their fluxes by a factor of - # - # _scale_fac = photons.size() / thisN - # - # next we want them to have a total flux of - # - # tot_flux_per_photon = (self.positive_flux + self.negative_flux) / photons.size() + # flux[i] / photons.size() # # so we scale by a factor of # - # _scale_fac = tot_flux_per_photon / flux_per_photon_thisN - # - # so we get a total factor of - # # _scale_fac = tot_flux / fluxes[i] _scale_fac = tot_flux / fluxes[i] pa.scaleFlux(_scale_fac) From ca3c02cba80c8cf26f7e24ded7f63b1dff3ffe80 Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 18 Nov 2023 07:02:56 -0600 Subject: [PATCH 23/85] TST added test for raising --- tests/jax/test_photon_array_jax_custom.py | 37 ------------- tests/jax/test_photon_shooting_jax.py | 65 +++++++++++++++++++++++ 2 files changed, 65 insertions(+), 37 deletions(-) delete mode 100644 tests/jax/test_photon_array_jax_custom.py create mode 100644 tests/jax/test_photon_shooting_jax.py diff --git a/tests/jax/test_photon_array_jax_custom.py b/tests/jax/test_photon_array_jax_custom.py deleted file mode 100644 index 05281f02..00000000 --- a/tests/jax/test_photon_array_jax_custom.py +++ /dev/null @@ -1,37 +0,0 @@ -import numpy as np - -import jax_galsim - - -def test_photon_array_make_from_image_notranspose(): - # this test uses a very assymetric array to ensure there is not a transpose - # error in the code - ref_array = np.array( - [ - [0.01, 0.08, 0.07, 0.02], - [0.13, 0.38, 10.52, 0.06], - [0.09, 0.41, 0.44, 0.09], - [0.04, 0.11, 0.10, 0.01], - [0.04, 0.11, 0.10, 0.01], - ] - ) - image = jax_galsim.Image(ref_array) - - photons = jax_galsim.PhotonArray.makeFromImage(image, max_flux=0.1) - - image2 = jax_galsim.Image(np.zeros_like(ref_array)) - photons.addTo(image2) - - if not np.allclose(image2.array, ref_array) and False: - import proplot as pplt - - fig, axs = pplt.subplots(nrows=1, ncols=3) - axs[0].imshow(ref_array) - axs[1].imshow(image2.array) - axs[2].imshow(image2.array - ref_array) - - import pdb - - pdb.set_trace() - - np.testing.assert_allclose(image2.array, ref_array) diff --git a/tests/jax/test_photon_shooting_jax.py b/tests/jax/test_photon_shooting_jax.py new file mode 100644 index 00000000..a7640e1e --- /dev/null +++ b/tests/jax/test_photon_shooting_jax.py @@ -0,0 +1,65 @@ +import numpy as np +import pytest +from jax.tree_util import register_pytree_node_class + +import jax_galsim + + +def test_photon_shooting_jax_make_from_image_notranspose(): + # this test uses a very assymetric array to ensure there is not a transpose + # error in the code + ref_array = np.array( + [ + [0.01, 0.08, 0.07, 0.02], + [0.13, 0.38, 10.52, 0.06], + [0.09, 0.41, 0.44, 0.09], + [0.04, 0.11, 0.10, 0.01], + [0.04, 0.11, 0.10, 0.01], + ] + ) + image = jax_galsim.Image(ref_array) + + photons = jax_galsim.PhotonArray.makeFromImage(image, max_flux=0.1) + + image2 = jax_galsim.Image(np.zeros_like(ref_array)) + photons.addTo(image2) + + if not np.allclose(image2.array, ref_array) and False: + import proplot as pplt + + fig, axs = pplt.subplots(nrows=1, ncols=3) + axs[0].imshow(ref_array) + axs[1].imshow(image2.array) + axs[2].imshow(image2.array - ref_array) + + import pdb + + pdb.set_trace() + + np.testing.assert_allclose(image2.array, ref_array) + + +@register_pytree_node_class +class TestExponential(jax_galsim.Exponential): + def _shoot(self, *args, **kwargs): + raise NotImplementedError("this is a test") + + def tree_flatten(self): + """This function flattens the GSObject into a list of children + nodes that will be traced by JAX and auxiliary static data.""" + # Define the children nodes of the PyTree that need tracing + children = (self.params,) + # Define auxiliary static data that doesn’t need to be traced + aux_data = {"gsparams": self.gsparams} + return (children, aux_data) + + @classmethod + def tree_unflatten(cls, aux_data, children): + """Recreates an instance of the class from flatten representation""" + return cls(**(children[0]), **aux_data) + + +def test_photon_shooting_jax_raises(): + obj = TestExponential(half_light_radius=1.0, flux=1.0) + with pytest.raises(jax_galsim.errors.GalSimNotImplementedError): + obj.drawImage(nx=33, ny=33, scale=0.2, method="phot", n_photons=1000) From 8a7a8f34789fff8d25d769d6bf7bc277779fdf8d Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 18 Nov 2023 07:05:11 -0600 Subject: [PATCH 24/85] TST add test of shooting in jit --- tests/jax/test_photon_shooting_jax.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/jax/test_photon_shooting_jax.py b/tests/jax/test_photon_shooting_jax.py index a7640e1e..cf6f1e7b 100644 --- a/tests/jax/test_photon_shooting_jax.py +++ b/tests/jax/test_photon_shooting_jax.py @@ -1,3 +1,4 @@ +import jax import numpy as np import pytest from jax.tree_util import register_pytree_node_class @@ -40,7 +41,7 @@ def test_photon_shooting_jax_make_from_image_notranspose(): @register_pytree_node_class -class TestExponential(jax_galsim.Exponential): +class NoShootingExponential(jax_galsim.Exponential): def _shoot(self, *args, **kwargs): raise NotImplementedError("this is a test") @@ -60,6 +61,14 @@ def tree_unflatten(cls, aux_data, children): def test_photon_shooting_jax_raises(): - obj = TestExponential(half_light_radius=1.0, flux=1.0) + obj = NoShootingExponential(half_light_radius=1.0, flux=1.0) with pytest.raises(jax_galsim.errors.GalSimNotImplementedError): obj.drawImage(nx=33, ny=33, scale=0.2, method="phot", n_photons=1000) + + @jax.jit + def _jitted(): + obj = NoShootingExponential(half_light_radius=1.0, flux=1.0) + return obj.drawImage(nx=33, ny=33, scale=0.2, method="phot", n_photons=1000) + + with pytest.raises(jax_galsim.errors.GalSimNotImplementedError): + _jitted() From 37047c35c368af9f92bb2374474c5616c6e1cab1 Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 18 Nov 2023 07:08:38 -0600 Subject: [PATCH 25/85] TST simpler test code --- tests/jax/test_jitting.py | 4 ++-- tests/jax/test_photon_shooting_jax.py | 11 ++--------- 2 files changed, 4 insertions(+), 11 deletions(-) diff --git a/tests/jax/test_jitting.py b/tests/jax/test_jitting.py index 9b99b0a5..64b26010 100644 --- a/tests/jax/test_jitting.py +++ b/tests/jax/test_jitting.py @@ -226,8 +226,8 @@ def _draw_it_jit(obj, n, nfft): def test_jitting_draw_phot(): def _build_and_draw(hlr, fwhm, jit=True): gal = galsim.Exponential( - half_light_radius=hlr, flux=1000.0 - ) + galsim.Exponential(half_light_radius=hlr * 2.0, flux=100.0) + half_light_radius=hlr, flux=1099.0 + ) + galsim.Exponential(half_light_radius=hlr * 2.0, flux=1.0) psf = galsim.Gaussian(fwhm=fwhm, flux=1.0) final = galsim.Convolve( [gal, psf], diff --git a/tests/jax/test_photon_shooting_jax.py b/tests/jax/test_photon_shooting_jax.py index cf6f1e7b..a6c33d55 100644 --- a/tests/jax/test_photon_shooting_jax.py +++ b/tests/jax/test_photon_shooting_jax.py @@ -46,18 +46,11 @@ def _shoot(self, *args, **kwargs): raise NotImplementedError("this is a test") def tree_flatten(self): - """This function flattens the GSObject into a list of children - nodes that will be traced by JAX and auxiliary static data.""" - # Define the children nodes of the PyTree that need tracing - children = (self.params,) - # Define auxiliary static data that doesn’t need to be traced - aux_data = {"gsparams": self.gsparams} - return (children, aux_data) + return super().tree_flatten() @classmethod def tree_unflatten(cls, aux_data, children): - """Recreates an instance of the class from flatten representation""" - return cls(**(children[0]), **aux_data) + return super().tree_unflatten(aux_data, children) def test_photon_shooting_jax_raises(): From 65995b2b2548b1e8a90985d221407d0d1a144c1e Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 18 Nov 2023 07:12:10 -0600 Subject: [PATCH 26/85] BUG make sure tests do not raise even if they fail --- tests/jax/test_photon_shooting_jax.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/jax/test_photon_shooting_jax.py b/tests/jax/test_photon_shooting_jax.py index a6c33d55..5b9e69b4 100644 --- a/tests/jax/test_photon_shooting_jax.py +++ b/tests/jax/test_photon_shooting_jax.py @@ -25,17 +25,18 @@ def test_photon_shooting_jax_make_from_image_notranspose(): image2 = jax_galsim.Image(np.zeros_like(ref_array)) photons.addTo(image2) - if not np.allclose(image2.array, ref_array) and False: - import proplot as pplt + # code for testing + # if not np.allclose(image2.array, ref_array) and False: + # import proplot as pplt - fig, axs = pplt.subplots(nrows=1, ncols=3) - axs[0].imshow(ref_array) - axs[1].imshow(image2.array) - axs[2].imshow(image2.array - ref_array) + # fig, axs = pplt.subplots(nrows=1, ncols=3) + # axs[0].imshow(ref_array) + # axs[1].imshow(image2.array) + # axs[2].imshow(image2.array - ref_array) - import pdb + # import pdb - pdb.set_trace() + # pdb.set_trace() np.testing.assert_allclose(image2.array, ref_array) From 28a6174110b8e811826cf45d54c089a5101fe1cc Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 18 Nov 2023 07:29:33 -0600 Subject: [PATCH 27/85] DOC added comments --- jax_galsim/exponential.py | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/jax_galsim/exponential.py b/jax_galsim/exponential.py index 79f1dc44..6743550b 100644 --- a/jax_galsim/exponential.py +++ b/jax_galsim/exponential.py @@ -150,7 +150,33 @@ def withFlux(self, flux): @lazy_property def _shoot_cdf(self): - # store this for later + # Comments on the math here: + # + # We are looking to draw from a distribution that is r * exp(-r). + # This distribution is the radial PDF of an Exponential profile. + # The fact of r comes from the area element r * dr. + # + # We can compute the CDF of this distribution analytically, but we cannot + # invert the CDF in closed form. Thus we invert it numerically using a table. + # + # One final detail is that we want the inversion to be accurate and are using + # linear interpolation. Thus we use a change of variables r = -ln(1 - u) + # to make the CDF more linear. + # + # Putting this all together, we get + # + # r * exp(-r) dr = -ln(1-u) (1-u) dr/du du + # = -ln(1-u) (1-u) * 1 / (1-u) + # = -ln(1-u) + # + # The new range of integration is u = 0 to u = 1. Thus the CDF is + # + # CDF = -int_0^u ln(1-u') du' + # = u - (u - 1) ln(1 - u) + # + # The final detail is that galsim defines a shoot accuracy and draws photons + # between r = 0 and rmax = -log(shoot_accuracy). Thus we normalize the CDF to + # its value at umax = 1 - exp(-rmax) and then finally invert the CDF numerically. _rmax = -jnp.log(self.gsparams.shoot_accuracy) _umax = 1.0 - jnp.exp(-_rmax) _u_cdf = jnp.linspace(0, _umax, 10000) @@ -166,7 +192,10 @@ def _shoot(self, photons, rng): photons.x ) # this does not fill arrays like in galsim so is safe _u_cdf, _cdf = self._shoot_cdf + # this interpolation inverts the CDF u = jnp.interp(u, _cdf, _u_cdf) + # this converts from u (see above) to r and scales by the actual size of + # the object r0. r = -jnp.log(1.0 - u) * self._r0 ang = ( From 05558ae1f6471b5bb452624c262e28b65565ab27 Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 18 Nov 2023 07:41:50 -0600 Subject: [PATCH 28/85] DOC added comments --- jax_galsim/interpolant.py | 47 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/jax_galsim/interpolant.py b/jax_galsim/interpolant.py index 32551672..9cd661b9 100644 --- a/jax_galsim/interpolant.py +++ b/jax_galsim/interpolant.py @@ -23,22 +23,65 @@ @jax.jit def _rejection_sample(photons, rng, tot_xrange, xval, pos_flux, neg_flux, max_val): + """Use rejection sampling to generate photons from a given 1D interpolant function. + + We sample both x and y values from the interpolant function. + + Parameters + ---------- + photons : PhotonArray + The photon array to shoot into. + rng : BaseDeviate + The random number generator to use for drawing photons. + tot_xrange : float + The total range of the interpolant function from the most negative + point to the most positive point. The interpolant is assumed to be + symmetric about zero. + xval : callable + The interpolant function. Will only be called with positive values. + pos_flux : float + The total integral under all positive regions of the interpolant function. + neg_flux : float + The absolute value of the total integral under all negative regions of + the interpolant function. + max_val : float + The maximum value of the interpolant function. Usually this is xval(0.0) and + is 1.0. + """ + def _cond_fun(args): + # we stop drawing when we have tot photons + # curr records how many we have currently _, _, tot, _, curr = args return curr < tot def _body_fun(args): arr, sign, tot, ud, curr = args + # arr is the array we are filling with photon positions + # sign is the array of signs of the interpolant function at the photon positions + # tot is the total number of photons to draw + # ud is the random number generator for uniform deviates from 0 to 1 + # curr is the current number of photons drawn + + # we first draw a random x location centered at zero with a + # total range of tot_xrange xloc = (ud() - 0.5) * tot_xrange + + # next we draw a random y value between 0 and max_val yv = ud() * max_val xloc_val = xval(xloc) + + # this cond operator keeps the photon if the y value we drew is + # below the interpolant function at the x location we drew arr, sign, curr = jax.lax.cond( yv <= jnp.abs(xloc_val), + # if we keep it, assign the location, assign the sign, and increment curr lambda arr, sign, curr, xloc, xloc_val: ( arr.at[curr].set(xloc), sign.at[curr].set(jnp.sign(xloc_val)), curr + 1, ), + # otherwise we pass lambda arr, sign, curr, xloc, xloc_val: (arr, sign, curr), arr, sign, @@ -49,6 +92,8 @@ def _body_fun(args): return arr, sign, tot, ud, curr ud = UniformDeviate(rng) + + # we first make the x and y positions photons.x, _sign_x, _, ud, _ = jax.lax.while_loop( _cond_fun, _body_fun, @@ -59,6 +104,8 @@ def _body_fun(args): _body_fun, (jnp.zeros_like(photons.y), jnp.zeros_like(photons.y), photons.size(), ud, 0), ) + # this magic formula comes from looking closely at the galsim code in Interpolant.cpp + # and how things get adjusted down the line OneDimensionalDeviate.cpp flux_per = (pos_flux + neg_flux) ** 2 / photons.size() photons.flux = _sign_x * _sign_y * flux_per return photons, rng From ba61a08ec98f4a003bdcff4275ea7298bccece7e Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 18 Nov 2023 08:52:29 -0600 Subject: [PATCH 29/85] TST make test fail --- tests/jax/test_jitting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/jax/test_jitting.py b/tests/jax/test_jitting.py index 64b26010..9b99b0a5 100644 --- a/tests/jax/test_jitting.py +++ b/tests/jax/test_jitting.py @@ -226,8 +226,8 @@ def _draw_it_jit(obj, n, nfft): def test_jitting_draw_phot(): def _build_and_draw(hlr, fwhm, jit=True): gal = galsim.Exponential( - half_light_radius=hlr, flux=1099.0 - ) + galsim.Exponential(half_light_radius=hlr * 2.0, flux=1.0) + half_light_radius=hlr, flux=1000.0 + ) + galsim.Exponential(half_light_radius=hlr * 2.0, flux=100.0) psf = galsim.Gaussian(fwhm=fwhm, flux=1.0) final = galsim.Convolve( [gal, psf], From 8d4bc551ca9c4163ca8943a94c103fe313f81c6f Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 18 Nov 2023 12:09:29 -0600 Subject: [PATCH 30/85] TST added tests of offsets --- jax_galsim/interpolatedimage.py | 5 +- tests/jax/test_photon_shooting_jax.py | 102 ++++++++++++++++++++++++++ 2 files changed, 106 insertions(+), 1 deletion(-) diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index f5f34ebe..e4efdbbd 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -886,7 +886,10 @@ def _shoot(self, photons, rng): ) # accounnt for offset - we add the offset to get to - # image pixels in xValue so we need to subtract it here + # image pixels in xValue + # here we generate photons from the image and thus + # so we need to subtract it to get back to get to x as + # it would be input in xVal photons.x -= self._offset.x photons.y -= self._offset.y diff --git a/tests/jax/test_photon_shooting_jax.py b/tests/jax/test_photon_shooting_jax.py index 5b9e69b4..9346a149 100644 --- a/tests/jax/test_photon_shooting_jax.py +++ b/tests/jax/test_photon_shooting_jax.py @@ -1,9 +1,11 @@ import jax +import jax.numpy as jnp import numpy as np import pytest from jax.tree_util import register_pytree_node_class import jax_galsim +from jax_galsim.core.testing import time_code_block def test_photon_shooting_jax_make_from_image_notranspose(): @@ -66,3 +68,103 @@ def _jitted(): with pytest.raises(jax_galsim.errors.GalSimNotImplementedError): _jitted() + + +@pytest.mark.parametrize( + "offset", + [ + (0, 0), + (5, 5), + (0, 5), + (0, -5), + (5, 0), + (-5, 0), + (-5, -5), + (-5, 5), + (5, -5), + ], +) +def test_photon_shooting_jax_offset(offset): + gal = jax_galsim.Exponential( + half_light_radius=0.5, flux=1000.0 + ) + jax_galsim.Exponential(half_light_radius=1.0, flux=100.0) + psf = jax_galsim.Gaussian(fwhm=0.9, flux=1.0) + obj = jax_galsim.Convolve([gal, psf]) + img = obj.drawImage(scale=0.2) + iobj = jax_galsim.InterpolatedImage(img, scale=0.2, offset=offset) + + img_fft = iobj.drawImage(nx=33, ny=33, scale=0.2, dtype=np.float64) + + # these magic equations come from the do_shoot routine in the + # GalSim file galsim/tests/galsim_test_helpers.py + rtol = 1e-2 + flux_tot = iobj.flux + flux_max = iobj.max_sb * 0.2**2 + atol = flux_max * rtol * 3 + nphot = int((flux_tot / flux_max / rtol**2).item()) + + with time_code_block(): + img_phot = iobj.drawImage( + nx=33, + ny=33, + scale=0.2, + method="phot", + dtype=np.float64, + n_photons=nphot, + maxN=10000, + rng=jax_galsim.BaseDeviate(1234), + ) + + print( + "fft|phot argmax:", + jnp.argmax(img_fft.array), + jnp.argmax(img_phot.array), + ) + + print( + "fft|phot max:", + jnp.max(img_fft.array), + jnp.max(img_phot.array), + ) + + print( + "fft|phot sum:", + jnp.sum(img_fft.array), + jnp.sum(img_phot.array), + ) + + print( + "fft moments:", + " ".join( + "%s:% 15.7e" % (k, v) + for k, v in jax_galsim.utilities.unweighted_moments(img_fft).items() + ), + ) + print( + "phot moments:", + " ".join( + "%s:% 15.7e" % (k, v) + for k, v in jax_galsim.utilities.unweighted_moments(img_phot).items() + ), + ) + + # code for testing + if not np.allclose(img_fft.array, img_phot.array, rtol=rtol, atol=atol): + import proplot as pplt + + fig, axs = pplt.subplots(nrows=1, ncols=3) + axs[0].imshow(img_fft.array, origin="lower") + axs[1].imshow(img_phot.array, origin="lower") + axs[2].imshow(img_fft.array - img_phot.array, origin="lower") + fig.show() + + import pdb + + pdb.set_trace() + + np.testing.assert_almost_equal( + jnp.argmax(img_fft.array), + jnp.argmax(img_phot.array), + ) + + np.testing.assert_allclose(img_fft.array, img_phot.array, rtol=rtol, atol=atol) From 802c1c11dc75225453c249565b9e9f4d7b2e3894 Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 18 Nov 2023 13:14:04 -0600 Subject: [PATCH 31/85] PERF extra jitting --- jax_galsim/photon_array.py | 41 ++++++++++++++++++++++++-------------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/jax_galsim/photon_array.py b/jax_galsim/photon_array.py index 81c75876..004f543d 100644 --- a/jax_galsim/photon_array.py +++ b/jax_galsim/photon_array.py @@ -570,23 +570,18 @@ def addTo(self, image): raise GalSimUndefinedBoundsError( "Attempting to PhotonArray::addTo an Image with undefined Bounds" ) - xinds = jnp.floor(self._x - image.bounds.xmin + 0.5).astype(int) - yinds = jnp.floor(self._y - image.bounds.ymin + 0.5).astype(int) - # the jax documentation says that they drop out of bounds indices, - # but the galsim unit tests reveal that withoout the check below, - # the indices are not dropped. - # I think maybe it is only indices beyond the end of the array that are - # dropped and negative indices wrap around - good = ( - (xinds >= 0) - & (xinds < image.array.shape[1]) - & (yinds >= 0) - & (yinds < image.array.shape[0]) + + _arr, _flux_sum = _add_photons_to_image( + self._x, + self._y, + self._flux, + image.bounds.xmin, + image.bounds.ymin, + image._array, ) - flux = jnp.where(good, self._flux, 0.0) - image._array = image._array.at[yinds, xinds].add(flux) + image._array = _arr - return self._flux.sum() + return _flux_sum @classmethod def makeFromImage(cls, image, max_flux=1.0, rng=None): @@ -724,3 +719,19 @@ def read(cls, file_name): if "time" in names: photons.time = jnp.array(data["time"]) return photons + + +@jax.jit +def _add_photons_to_image(x, y, flux, xmin, ymin, arr): + xinds = jnp.floor(x - xmin + 0.5).astype(int) + yinds = jnp.floor(y - ymin + 0.5).astype(int) + # the jax documentation says that they drop out of bounds indices, + # but the galsim unit tests reveal that withoout the check below, + # the indices are not dropped. + # I think maybe it is only indices beyond the end of the array that are + # dropped and negative indices wrap around + good = (xinds >= 0) & (xinds < arr.shape[1]) & (yinds >= 0) & (yinds < arr.shape[0]) + _flux = jnp.where(good, flux, 0.0) + _arr = arr.at[yinds, xinds].add(_flux.astype(arr.dtype)) + + return _arr, _flux.sum() From 454f692b9256129ebab87071eac29d0c87c1e1e6 Mon Sep 17 00:00:00 2001 From: beckermr Date: Sun, 19 Nov 2023 07:38:43 -0600 Subject: [PATCH 32/85] WIP enable fixed size photon arrays --- jax_galsim/gsobject.py | 13 +-- jax_galsim/photon_array.py | 161 +++++++++++++++++++++++++++++-------- tests/jax/test_jitting.py | 65 +++++++++++++++ 3 files changed, 195 insertions(+), 44 deletions(-) diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index 6d893e4d..051277ca 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -1386,20 +1386,15 @@ def _body_fun(args): "is a Deconvolve or is a compound including one or more " "Deconvolve objects.\nOriginal error: %r" % (e) ) - photons.flux = jnp.where( - jnp.arange(maxN) < thisN, - photons.flux, - 0.0, - ) + # we drew maxN, but only keep thisN of them + photons._num_keep = thisN photons = jax.lax.cond( # weird way to say gain == 1 and thisN == Ntot jnp.abs(g - 1.0) + jnp.abs(thisN - Ntot) == 0, lambda photons, g, thisN, Ntot: photons, - # the factor here is (maxN / thisN) * (thisN / Ntot) = maxN / Ntot - # the first bit is that we drew maxN photons, but only thisN of them are valid - # the second bit is that we only drew thisN photons, but use a total of Ntot photons - lambda photons, g, thisN, Ntot: photons.scaleFlux(g * maxN / Ntot), + # the factor here is thisN / Ntot since we drew thisN photons, but use a total of Ntot photons + lambda photons, g, thisN, Ntot: photons.scaleFlux(g * thisN / Ntot), photons, g, thisN, diff --git a/jax_galsim/photon_array.py b/jax_galsim/photon_array.py index 004f543d..b7b8722c 100644 --- a/jax_galsim/photon_array.py +++ b/jax_galsim/photon_array.py @@ -1,3 +1,5 @@ +from contextlib import contextmanager + import galsim as _galsim import jax import jax.numpy as jnp @@ -16,6 +18,20 @@ from ._pyfits import pyfits +_JAX_GALSIM_PHOTON_ARRAY_SIZE = None + + +@contextmanager +def fixed_photon_array_size(size): + """Context manager to temporarily set a fixed size for photon arrays.""" + global _JAX_GALSIM_PHOTON_ARRAY_SIZE + old_size = _JAX_GALSIM_PHOTON_ARRAY_SIZE + _JAX_GALSIM_PHOTON_ARRAY_SIZE = size + try: + yield + finally: + _JAX_GALSIM_PHOTON_ARRAY_SIZE = old_size + @_wraps( _galsim.PhotonArray, @@ -41,20 +57,34 @@ def __init__( pupil_u=None, pupil_v=None, time=None, + _nokeep=None, ): - self._N = N + # self._N = N + self._Ntot = _JAX_GALSIM_PHOTON_ARRAY_SIZE or N + if ( + _JAX_GALSIM_PHOTON_ARRAY_SIZE is not None + and N > _JAX_GALSIM_PHOTON_ARRAY_SIZE + ): + raise GalSimValueError( + f"The given photon array size {N} is larger than " + f"the allowed total size {_JAX_GALSIM_PHOTON_ARRAY_SIZE}." + ) + if _nokeep is not None: + self._nokeep = _nokeep + else: + self._nokeep = jnp.arange(self._Ntot) >= N # Only x, y, flux are built by default, since these are always required. # The others we leave as None unless/until they are needed. - self._x = jnp.zeros(self._N, dtype=float) - self._y = jnp.zeros(self._N, dtype=float) - self._flux = jnp.zeros(self._N, dtype=float) - self._dxdz = jnp.full(self._N, jnp.nan, dtype=float) - self._dydz = jnp.full(self._N, jnp.nan, dtype=float) - self._wave = jnp.full(self._N, jnp.nan, dtype=float) - self._pupil_u = jnp.full(self._N, jnp.nan, dtype=float) - self._pupil_v = jnp.full(self._N, jnp.nan, dtype=float) - self._time = jnp.full(self._N, jnp.nan, dtype=float) + self._x = jnp.zeros(self._Ntot, dtype=float) + self._y = jnp.zeros(self._Ntot, dtype=float) + self._flux = jnp.zeros(self._Ntot, dtype=float) + self._dxdz = jnp.full(self._Ntot, jnp.nan, dtype=float) + self._dydz = jnp.full(self._Ntot, jnp.nan, dtype=float) + self._wave = jnp.full(self._Ntot, jnp.nan, dtype=float) + self._pupil_u = jnp.full(self._Ntot, jnp.nan, dtype=float) + self._pupil_v = jnp.full(self._Ntot, jnp.nan, dtype=float) + self._time = jnp.full(self._Ntot, jnp.nan, dtype=float) self._is_corr = jnp.array(False) if x is not None: @@ -113,59 +143,78 @@ def _fromArrays( time=None, is_corr=False, ): + if ( + _JAX_GALSIM_PHOTON_ARRAY_SIZE is not None + and x.shape[0] != _JAX_GALSIM_PHOTON_ARRAY_SIZE + ): + raise GalSimValueError( + "The given arrays do not match the expected total size", + x.shape[0], + _JAX_GALSIM_PHOTON_ARRAY_SIZE, + ) + ret = cls.__new__(cls) - ret._N = x.shape[0] + # ret._N = x.shape[0] + ret._Ntot = _JAX_GALSIM_PHOTON_ARRAY_SIZE or x.shape[0] ret._x = x.copy() ret._y = y.copy() ret._flux = flux.copy() + ret._nokeep = jnp.arange(ret._Ntot) >= x.shape[0] ret._dxdz = ( - dxdz.copy() if dxdz is not None else jnp.full(ret._N, jnp.nan, dtype=float) + dxdz.copy() + if dxdz is not None + else jnp.full(ret._Ntot, jnp.nan, dtype=float) ) ret._dydz = ( - dydz.copy() if dydz is not None else jnp.full(ret._N, jnp.nan, dtype=float) + dydz.copy() + if dydz is not None + else jnp.full(ret._Ntot, jnp.nan, dtype=float) ) ret._wave = ( wavelength.copy() if wavelength is not None - else jnp.full(ret._N, jnp.nan, dtype=float) + else jnp.full(ret._Ntot, jnp.nan, dtype=float) ) ret._pupil_u = ( pupil_u.copy() if pupil_u is not None - else jnp.full(ret._N, jnp.nan, dtype=float) + else jnp.full(ret._Ntot, jnp.nan, dtype=float) ) ret._pupil_v = ( pupil_v.copy() if pupil_v is not None - else jnp.full(ret._N, jnp.nan, dtype=float) + else jnp.full(ret._Ntot, jnp.nan, dtype=float) ) ret._time = ( - time.copy() if time is not None else jnp.full(ret._N, jnp.nan, dtype=float) + time.copy() + if time is not None + else jnp.full(ret._Ntot, jnp.nan, dtype=float) ) ret._is_corr = jnp.array(is_corr) return ret def tree_flatten(self): children = ( - (self.x, self.y, self.flux), + (self._x, self._y, self._flux, self._nokeep), { - "dxdz": self.dxdz, - "dydz": self.dydz, - "wavelength": self.wavelength, - "pupil_u": self.pupil_u, - "pupil_v": self.pupil_v, - "time": self.time, - "is_corr": self.isCorrelated(), + "dxdz": self._dxdz, + "dydz": self._dydz, + "wavelength": self._wave, + "pupil_u": self._pupil_u, + "pupil_v": self._pupil_v, + "time": self._time, + "is_corr": self._is_corr, }, ) - aux_data = (self._N,) + aux_data = (self._Ntot,) return (children, aux_data) @classmethod def tree_unflatten(cls, aux_data, children): """Recreates an instance of the class from flatten representation""" ret = cls.__new__(cls) - ret._N = aux_data[0] + ret._Ntot = aux_data[0] + ret._nokeep = children[0][3] ret._x = children[0][0] ret._y = children[0][1] ret._flux = children[0][2] @@ -180,10 +229,20 @@ def tree_unflatten(cls, aux_data, children): def size(self): """Return the size of the photon array. Equivalent to ``len(self)``.""" - return self._N + return self._Ntot def __len__(self): - return self._N + return self._Ntot + + @property + def _num_keep(self): + """The number of actual photons in the array.""" + return jnp.sum(~self._nokeep).astype(int) + + @_num_keep.setter + def _num_keep(self, num_keep): + """Set the number of actual photons in the array.""" + self._nokeep = jnp.arange(self._Ntot) >= num_keep @property def x(self): @@ -210,7 +269,13 @@ def y(self, value): @property def flux(self): """The flux of the photons.""" - return self._flux + return jax.lax.cond( + self._Ntot == self._num_keep, + lambda flux, ratio: flux, + lambda flux, ratio: flux * ratio, + jnp.where(self._nokeep, 0.0, self._flux), + self._Ntot / self._num_keep, + ) @flux.setter def flux(self, value): @@ -375,6 +440,22 @@ def scaleXY(self, scale): return self + def _sort_by_nokeep(self): + # now sort things to keep to the left + sinds = jnp.argsort(self._nokeep) + self._x = self._x.at[sinds].get() + self._y = self._y.at[sinds].get() + self._flux = self._flux.at[sinds].get() + self._nokeep = self._nokeep.at[sinds].get() + self._dxdz = self._dxdz.at[sinds].get() + self._dydz = self._dydz.at[sinds].get() + self._wave = self._wave.at[sinds].get() + self._pupil_u = self._pupil_u.at[sinds].get() + self._pupil_v = self._pupil_v.at[sinds].get() + self._time = self._time.at[sinds].get() + + return self + def assignAt(self, istart, rhs): """Assign the contents of another `PhotonArray` to this one starting at istart.""" if istart + rhs.size() > self.size(): @@ -385,6 +466,7 @@ def assignAt(self, istart, rhs): self._x = self._x.at[istart : istart + rhs.size()].set(rhs.x) self._y = self._y.at[istart : istart + rhs.size()].set(rhs.y) self._flux = self._flux.at[istart : istart + rhs.size()].set(rhs.flux) + self._nokeep = self._nokeep.at[istart : istart + rhs.size()].set(rhs._nokeep) self._dxdz = self._dxdz.at[istart : istart + rhs.size()].set(rhs.dxdz) self._dydz = self._dydz.at[istart : istart + rhs.size()].set(rhs.dydz) self._wave = self._wave.at[istart : istart + rhs.size()].set(rhs.wavelength) @@ -392,7 +474,7 @@ def assignAt(self, istart, rhs): self._pupil_v = self._pupil_v.at[istart : istart + rhs.size()].set(rhs.pupil_v) self._time = self._time.at[istart : istart + rhs.size()].set(rhs.time) - return self + return self._sort_by_nokeep() def _assign_from_categorical_index(self, cat_inds, cat_ind_to_assign, rhs): """Assign the contents of another `PhotonArray` to this one at locations @@ -402,6 +484,7 @@ def _assign_from_categorical_index(self, cat_inds, cat_ind_to_assign, rhs): self._x = jnp.where(msk, rhs._x, self._x) self._y = jnp.where(msk, rhs._y, self._y) self._flux = jnp.where(msk, rhs._flux, self._flux) + self._nokeep = jnp.where(msk, rhs._nokeep, self._nokeep) self._dxdz = jnp.where(msk, rhs._dxdz, self._dxdz) self._dydz = jnp.where(msk, rhs._dydz, self._dydz) @@ -410,7 +493,7 @@ def _assign_from_categorical_index(self, cat_inds, cat_ind_to_assign, rhs): self._pupil_v = jnp.where(msk, rhs._pupil_v, self._pupil_v) self._time = jnp.where(msk, rhs._time, self._time) - return self + return self._sort_by_nokeep() def convolve(self, rhs, rng=None): """Convolve this `PhotonArray` with another. @@ -428,7 +511,7 @@ def convolve(self, rhs, rng=None): rng = BaseDeviate(rng) rsinds = jrng.choice( rng._state.split_one(), - self.size(), + self._Ntot, shape=(self.size(),), replace=False, ) @@ -436,7 +519,9 @@ def convolve(self, rhs, rng=None): sinds = jax.lax.cond( jnp.array(self.isCorrelated()) & jnp.array(rhs.isCorrelated()), - lambda nrsinds, rsinds: rsinds, + lambda nrsinds, rsinds: rsinds.at[ + jnp.argsort(rhs._nokeep.at[rsinds].get()) + ].get(), lambda nrsinds, rsinds: nrsinds, nrsinds, rsinds, @@ -527,6 +612,7 @@ def __repr__(self): ) if self.hasAllocatedTimes(): s += ", time=array(%r)" % np.array(self.time).tolist() + s += ", _nokeep=array(%r)" % np.array(self._nokeep).tolist() s += ")" return s @@ -541,6 +627,7 @@ def __eq__(self, other): and jnp.array_equal(self.x, other.x) and jnp.array_equal(self.y, other.y) and jnp.array_equal(self.flux, other.flux) + and jnp.array_equal(self._nokeep, other._nokeep) and jnp.array_equal(self.dxdz, other.dxdz, equal_nan=True) and jnp.array_equal(self.dydz, other.dydz, equal_nan=True) and jnp.array_equal(self.wavelength, other.wavelength, equal_nan=True) @@ -574,7 +661,7 @@ def addTo(self, image): _arr, _flux_sum = _add_photons_to_image( self._x, self._y, - self._flux, + jnp.where(self._nokeep, 0.0, self._flux) * self._Ntot / self._num_keep, image.bounds.xmin, image.bounds.ymin, image._array, @@ -651,6 +738,9 @@ def write(self, file_name): cols.append(pyfits.Column(name="x", format="D", array=np.array(self.x))) cols.append(pyfits.Column(name="y", format="D", array=np.array(self.y))) cols.append(pyfits.Column(name="flux", format="D", array=np.array(self.flux))) + cols.append( + pyfits.Column(name="_nokeep", format="L", array=np.array(self._nokeep)) + ) if self.hasAllocatedAngles(): cols.append( @@ -708,6 +798,7 @@ def read(cls, file_name): y=jnp.array(data["y"]), flux=jnp.array(data["flux"]), ) + photons._nokeep = jnp.array(data["_nokeep"]) if "dxdz" in names: photons.dxdz = jnp.array(data["dxdz"]) photons.dydz = jnp.array(data["dydz"]) diff --git a/tests/jax/test_jitting.py b/tests/jax/test_jitting.py index 9b99b0a5..0b20f9eb 100644 --- a/tests/jax/test_jitting.py +++ b/tests/jax/test_jitting.py @@ -7,6 +7,7 @@ import jax_galsim as galsim from jax_galsim.core.draw import calculate_n_photons from jax_galsim.core.testing import time_code_block +from jax_galsim.photon_array import fixed_photon_array_size # Defining jitting identity identity = jax.jit(lambda x: x) @@ -283,3 +284,67 @@ def _draw_it_jit(obj, n, nphotons, gain): img = _build_and_draw(0.5, 1.0) np.testing.assert_allclose(img.array.sum(), 1100.0) + + +def test_jitting_draw_phot_fixed(): + def _build_and_draw(hlr, fwhm, jit=True): + gal = galsim.Exponential( + half_light_radius=hlr, flux=1000.0 + ) + galsim.Exponential(half_light_radius=hlr * 2.0, flux=100.0) + psf = galsim.Gaussian(fwhm=fwhm, flux=1.0) + final = galsim.Convolve( + [gal, psf], + ) + n = final.getGoodImageSize(0.2).item() + n += 1 + n_photons = calculate_n_photons( + final.flux, + final._flux_per_photon, + final.max_sb, + poisson_flux=False, + )[0] + gain = 1.0 + if jit: + return _draw_it_jit(final, n, n_photons, gain) + else: + with fixed_photon_array_size(2048): + return final.drawImage( + nx=n, + ny=n, + scale=0.2, + method="phot", + n_photons=n_photons, + poisson_flux=False, + gain=gain, + ) + + @partial(jax.jit, static_argnums=(1, 2)) + def _draw_it_jit(obj, n, nphotons, gain): + with fixed_photon_array_size(2048): + return obj.drawImage( + nx=n, + ny=n, + scale=0.2, + n_photons=nphotons, + method="phot", + poisson_flux=False, + gain=gain, + maxN=101, + ) + + with time_code_block("warmup no-jit"): + img = _build_and_draw(0.5, 1.0, jit=False) + np.testing.assert_allclose(img.array.sum(), 1100.0) + + with time_code_block("no-jit"): + img = _build_and_draw(0.5, 1.0, jit=False) + np.testing.assert_allclose(img.array.sum(), 1100.0) + + with time_code_block("warmup jit"): + img = _build_and_draw(0.5, 1.0) + np.testing.assert_allclose(img.array.sum(), 1100.0) + + with time_code_block("jit"): + img = _build_and_draw(0.5, 1.0) + + np.testing.assert_allclose(img.array.sum(), 1100.0) From 3fdd163885afb4c529d76d6c60c7268da158ef0e Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 20 Nov 2023 16:42:05 -0600 Subject: [PATCH 33/85] REF a lot of changes to enable vmap, clean up APIs, etc. --- jax_galsim/core/draw.py | 412 ++++++++++------------ jax_galsim/gsobject.py | 335 +++++++++++------- jax_galsim/interpolant.py | 163 ++------- jax_galsim/photon_array.py | 17 +- jax_galsim/random.py | 31 +- jax_galsim/utilities.py | 53 ++- tests/GalSim | 2 +- tests/jax/galsim/test_draw_jax.py | 58 +-- tests/jax/test_interpolatedimage_utils.py | 19 +- tests/jax/test_jitting.py | 2 +- tests/jax/test_photon_shooting_jax.py | 68 ++++ 11 files changed, 587 insertions(+), 573 deletions(-) diff --git a/jax_galsim/core/draw.py b/jax_galsim/core/draw.py index 00696da7..8b632f7d 100644 --- a/jax_galsim/core/draw.py +++ b/jax_galsim/core/draw.py @@ -1,7 +1,8 @@ -import galsim as _galsim +from collections import namedtuple + import jax import jax.numpy as jnp -from jax._src.numpy.util import _wraps +import numpy as np from jax_galsim.random import PoissonDeviate @@ -84,17 +85,147 @@ def phase(kpos): ) -def sample_poisson_flux(flux, eta_factor, rng=None): - """Sample the flux according to a Poisson distribution. +NPhotonsData = namedtuple( + "NPhotonsData", + [ + "n_photons", + "flux", + "flux_per_photon", + "max_sb", + "rng", + "poisson_flux", + "max_extra_noise", + ], +) + + +def calculate_n_photons( + flux, + eta_factor, + max_sb, + rng=None, + max_extra_noise=0, + poisson_flux=True, +): + """ + Calculate the number of photons to shoot for photon shooting. + + This routine is pure Python and is not JAX-compatible. Parameters: flux: The flux of the GSObject (e.g., ``obj.flux``). eta_factor: The flux per photon (e.g., ``obj._flux_per_photon``). + max_sb: The maximum surface brightness of the object (e.g., ``obj.max_sb``). rng: If provided, a random number generator to use for photon shooting, which may be any kind of `BaseDeviate` object. If ``rng`` is None, one will be automatically created, using the time as a seed. [default: None] + max_extra_noise: If provided, the allowed extra noise in each pixel when photon + shooting. This is only relevant if ``n_photons=0``, so the number of + photons is being automatically calculated. In that case, if the image + noise is dominated by the sky background, then you can get away with + using fewer shot photons than the full ``n_photons = flux``. + Essentially each shot photon can have a ``flux > 1``, which increases + the noise in each pixel. The ``max_extra_noise`` parameter specifies + how much extra noise per pixel is allowed because of this approximation. + A typical value for this might be ``max_extra_noise = sky_level / 100`` + where ``sky_level`` is the flux per pixel due to the sky. Note that + this uses a "variance" definition of noise, not a "sigma" definition. + [default: 0.] + poisson_flux: Whether to allow total object flux scaling to vary according to + Poisson statistics for ``n_photons`` samples when photon shooting. + [default: True, unless ``n_photons`` is given, in which case the default + is False] + + Returns: + n_photons: The number of photons. + g: The gain to use when shooting the photons. """ + n_photons, g, _ = _calculate_n_photons( + flux, + eta_factor, + max_sb, + rng, + max_extra_noise, + poisson_flux, + ) + return np.atleast_1d(n_photons).ravel()[0], np.atleast_1d(g).ravel()[0] + + +@jax.jit +def get_n_photons(n_photons_data): + _n_photons, g, _rng = jax.lax.cond( + n_photons_data.n_photons == 0.0, + _sample_zero, + _sample_nonzero, + n_photons_data, + ) + if n_photons_data.rng is not None: + n_photons_data.rng._state = _rng._state + return _n_photons, g, n_photons_data.rng + + +def _sample_nonzero(n_photons_data): + g, _rng = jax.lax.cond( + n_photons_data.poisson_flux, + lambda n_photons_data: _sample_poisson_flux( + n_photons_data.flux, n_photons_data.flux_per_photon, n_photons_data.rng + ), + lambda n_photons_data: (1.0, n_photons_data.rng), + n_photons_data, + ) + if n_photons_data.rng is not None: + n_photons_data.rng._state = _rng._state + vals = jnp.int_(n_photons_data.n_photons + 0.5), g, n_photons_data.rng + return vals + + +@jax.jit +def _sample_zero(n_photons_data): + Ntot, g, _rng = _calculate_n_photons( + n_photons_data.flux, + n_photons_data.flux_per_photon, + n_photons_data.max_sb, + rng=n_photons_data.rng, + max_extra_noise=n_photons_data.max_extra_noise, + poisson_flux=n_photons_data.poisson_flux, + ) + return Ntot, g, _rng + + +@jax.jit +def _calculate_n_photons( + flux, + eta_factor, + max_sb, + rng, + max_extra_noise, + poisson_flux, +): + _n_photons, _g, _rng = jax.lax.cond( + flux == 0.0, + lambda flux, eta_factor, max_sb, poisson_flux, max_extra_noise, rng: ( + 0, + 1.0, + rng, + ), + lambda flux, eta_factor, max_sb, poisson_flux, max_extra_noise, rng: _calculate_n_photons_flux_nonzero( + flux, eta_factor, max_sb, poisson_flux, max_extra_noise, rng + ), + flux, + eta_factor, + max_sb, + poisson_flux, + max_extra_noise, + rng, + ) + if rng is not None: + rng._state = _rng._state + return _n_photons, _g, rng + + +@jax.jit +def _sample_poisson_flux(flux, eta_factor, rng): # If we have both positive and negative photons, then the mix of these # already gives us some variation in the flux value from the variance # of how many are positive and how many are negative. @@ -111,54 +242,29 @@ def sample_poisson_flux(flux, eta_factor, rng=None): # We want the variance to be equal to flux, so we need an extra: # delta Var = (1 - 4*eta + 4*eta^2) * flux # = (1-2eta)^2 * flux - absflux = abs(flux) + absflux = jnp.abs(flux) mean = eta_factor * eta_factor * absflux pd = PoissonDeviate(rng, mean) pd_val = pd() - mean + absflux - return pd_val - - -@_wraps( - _galsim.GSObject._calculate_nphotons, - lax_description="""\ -Calculate the number of photons to shoot for photon shooting. - -This routine is pure Python and is not JAX-compatible. - -Parameters: - flux: The flux of the GSObject (e.g., ``obj.flux``). - eta_factor: The flux per photon (e.g., ``obj._flux_per_photon``). - max_sb: The maximum surface brightness of the object (e.g., ``obj.max_sb``). - rng: If provided, a random number generator to use for photon shooting, - which may be any kind of `BaseDeviate` object. If ``rng`` is None, one - will be automatically created, using the time as a seed. - [default: None] - max_extra_noise: If provided, the allowed extra noise in each pixel when photon - shooting. This is only relevant if ``n_photons=0``, so the number of - photons is being automatically calculated. In that case, if the image - noise is dominated by the sky background, then you can get away with - using fewer shot photons than the full ``n_photons = flux``. - Essentially each shot photon can have a ``flux > 1``, which increases - the noise in each pixel. The ``max_extra_noise`` parameter specifies - how much extra noise per pixel is allowed because of this approximation. - A typical value for this might be ``max_extra_noise = sky_level / 100`` - where ``sky_level`` is the flux per pixel due to the sky. Note that - this uses a "variance" definition of noise, not a "sigma" definition. - [default: 0.] - poisson_flux: Whether to allow total object flux scaling to vary according to - Poisson statistics for ``n_photons`` samples when photon shooting. - [default: True, unless ``n_photons`` is given, in which case the default - is False] - -""", -) -def calculate_n_photons( - flux, - eta_factor, - max_sb, - rng=None, - max_extra_noise=0, - poisson_flux=True, + return pd_val / absflux, rng + + +def _adjust_flux_g_poisson(poisson_flux, flux, mod_flux, eta_factor, rng, g): + ratio, rng = _sample_poisson_flux(flux, eta_factor, rng) + g *= ratio + mod_flux *= ratio + return jnp.abs(mod_flux), g, rng + + +def _scale_extra_noise(max_extra_noise, mod_flux, g, max_sb): + gfactor = 1.0 + max_extra_noise / jnp.abs(max_sb) + mod_flux /= gfactor + g *= gfactor + return mod_flux, g + + +def _calculate_n_photons_flux_nonzero( + flux, flux_per_photon, max_sb, poisson_flux, max_extra_noise, rng ): # For profiles that are positive definite, then N = flux. Easy. # @@ -211,188 +317,40 @@ def calculate_n_photons( # Returns the total flux placed inside the image bounds by photon shooting. # - if flux == 0.0: - return 0, 1.0 - # The _flux_per_photon property is (1-2eta) # This factor will already be accounted for by the shoot function, so don't include # that as part of our scaling here. There may be other adjustments though, so g=1 here. + eta_factor = flux_per_photon mod_flux = flux / (eta_factor * eta_factor) g = 1.0 # If requested, let the target flux value vary as a Poisson deviate - if poisson_flux: - pd_val = sample_poisson_flux(flux, eta_factor, rng=rng) - ratio = pd_val / abs(flux) - g *= ratio - mod_flux *= ratio - - n_photons = abs(mod_flux) - if max_extra_noise > 0.0: - gfactor = 1.0 + max_extra_noise / abs(max_sb) - n_photons /= gfactor - g *= gfactor - - # Make n_photons an integer. - iN = int(n_photons + 0.5) - - return iN, g - - -# the code below is a jax version of calculate_nphotons -# that I am not sure if we need or not. -# saving in a comment for now - -# def _calculate_nphotons(self, n_photons, poisson_flux, max_extra_noise, rng): -# _n_photons, _g, _rng = jax.lax.cond( -# self.flux == 0.0, -# lambda n_photons, poisson_flux, max_extra_noise, rng: (0, 1.0, rng), -# lambda n_photons, poisson_flux, max_extra_noise, rng: self._calculate_nphotons_nonzero( -# n_photons, poisson_flux, max_extra_noise, rng -# ), -# n_photons, -# poisson_flux, -# max_extra_noise, -# rng, -# ) -# if rng is not None: -# rng._state = _rng._state -# return _n_photons, _g - - -# def _adjust_flux_g_poisson(self, poisson_flux, flux, mod_flux, eta_factor, rng, g): -# from jax_galsim.random import PoissonDeviate - -# # If we have both positive and negative photons, then the mix of these -# # already gives us some variation in the flux value from the variance -# # of how many are positive and how many are negative. -# # The number of negative photons varies as a binomial distribution. -# # = eta * Ntot * g -# # = (1-eta) * Ntot * g -# # = (1-2eta) * Ntot * g = flux -# # Var(F-) = eta * (1-eta) * Ntot * g^2 -# # F+ = Ntot * g - F- is not an independent variable, so -# # Var(F+ - F-) = Var(Ntot*g - 2*F-) -# # = 4 * Var(F-) -# # = 4 * eta * (1-eta) * Ntot * g^2 -# # = 4 * eta * (1-eta) * flux -# # We want the variance to be equal to flux, so we need an extra: -# # delta Var = (1 - 4*eta + 4*eta^2) * flux -# # = (1-2eta)^2 * flux -# absflux = abs(flux) -# mean = eta_factor * eta_factor * absflux -# pd = PoissonDeviate(rng, mean) -# pd_val = pd() - mean + absflux -# ratio = pd_val / absflux -# g *= ratio -# mod_flux *= ratio -# return jnp.abs(mod_flux), g, rng - - -# def _scale_extra_noise(self, max_extra_noise, mod_flux, g, max_sb): -# gfactor = 1.0 + max_extra_noise / jnp.abs(max_sb) -# mod_flux /= gfactor -# g *= gfactor -# return mod_flux, g - - -# def _calculate_nphotons_nonzero(self, n_photons, poisson_flux, max_extra_noise, rng): -# # For profiles that are positive definite, then N = flux. Easy. -# # -# # However, some profiles shoot some of their photons with negative flux. This means that -# # we need a few more photons to get the right S/N = sqrt(flux). Take eta to be the -# # fraction of shot photons that have negative flux. -# # -# # S^2 = (N+ - N-)^2 = (N+ + N- - 2N-)^2 = (Ntot - 2N-)^2 = Ntot^2(1 - 2 eta)^2 -# # N^2 = Var(S) = (N+ + N-) = Ntot -# # -# # So flux = (S/N)^2 = Ntot (1-2eta)^2 -# # Ntot = flux / (1-2eta)^2 -# # -# # However, if each photon has a flux of 1, then S = (1-2eta) Ntot = flux / (1-2eta). -# # So in fact, each photon needs to carry a flux of g = 1-2eta to get the right -# # total flux. -# # -# # That's all the easy case. The trickier case is when we are sky-background dominated. -# # Then we can usually get away with fewer shot photons than the above. In particular, -# # if the noise from the photon shooting is much less than the sky noise, then we can -# # use fewer shot photons and essentially have each photon have a flux > 1. This is ok -# # as long as the additional noise due to this approximation is "much less than" the -# # noise we'll be adding to the image for the sky noise. -# # -# # Let's still have Ntot photons, but now each with a flux of g. And let's look at the -# # noise we get in the brightest pixel that has a nominal total flux of Imax. -# # -# # The number of photons hitting this pixel will be Imax/flux * Ntot. -# # The variance of this number is the same thing (Poisson counting). -# # So the noise in that pixel is: -# # -# # N^2 = Imax/flux * Ntot * g^2 -# # -# # And the signal in that pixel will be: -# # -# # S = Imax/flux * (N+ - N-) * g which has to equal Imax, so -# # g = flux / Ntot(1-2eta) -# # N^2 = Imax/Ntot * flux / (1-2eta)^2 -# # -# # As expected, we see that lowering Ntot will increase the noise in that (and every -# # other) pixel. -# # The input max_extra_noise parameter is the maximum value of spurious noise we want -# # to allow. -# # -# # So setting N^2 = Imax + nu, we get -# # -# # Ntot = flux / (1-2eta)^2 / (1 + nu/Imax) -# # g = (1 - 2eta) * (1 + nu/Imax) -# # -# # Returns the total flux placed inside the image bounds by photon shooting. -# # - -# flux = self.flux - -# # The _flux_per_photon property is (1-2eta) -# # This factor will already be accounted for by the shoot function, so don't include -# # that as part of our scaling here. There may be other adjustments though, so g=1 here. -# eta_factor = self._flux_per_photon -# mod_flux = flux / (eta_factor * eta_factor) -# g = 1.0 - -# # If requested, let the target flux value vary as a Poisson deviate -# mod_flux, g, _rng = jax.lax.cond( -# poisson_flux, -# lambda poisson_flux, flux, mod_flux, eta_factor, rng, g: self._adjust_flux_g_poisson( -# poisson_flux, flux, mod_flux, eta_factor, rng, g -# ), -# lambda poisson_flux, flux, mod_flux, eta_factor, rng, g: (mod_flux, g, rng), -# poisson_flux, -# flux, -# mod_flux, -# eta_factor, -# rng, -# g, -# ) -# if rng is not None: -# rng._state = _rng._state - -# mod_flux, g = jax.lax.cond( -# max_extra_noise > 0.0, -# lambda max_extra_noise, mod_flux, g, max_sb: self._scale_extra_noise( -# max_extra_noise, mod_flux, g, max_sb -# ), -# lambda max_extra_noise, mod_flux, g, max_sb: (mod_flux, g), -# max_extra_noise, -# mod_flux, -# g, -# self.max_sb, -# ) - -# # Make n_photons an integer and use input if requested -# n_photons = jax.lax.cond( -# n_photons == 0.0, -# lambda n_photons, mod_flux: jnp.ceil(mod_flux).astype(int), -# lambda n_photons, mod_flux: jnp.ceil(n_photons).astype(int), -# n_photons, -# mod_flux, -# ) - -# return n_photons, g, rng + mod_flux, g, _rng = jax.lax.cond( + poisson_flux, + lambda poisson_flux, flux, mod_flux, eta_factor, rng, g: _adjust_flux_g_poisson( + poisson_flux, flux, mod_flux, eta_factor, rng, g + ), + lambda poisson_flux, flux, mod_flux, eta_factor, rng, g: (mod_flux, g, rng), + poisson_flux, + flux, + mod_flux, + eta_factor, + rng, + g, + ) + if rng is not None: + rng._state = _rng._state + + mod_flux, g = jax.lax.cond( + max_extra_noise > 0.0, + lambda max_extra_noise, mod_flux, g, max_sb: _scale_extra_noise( + max_extra_noise, mod_flux, g, max_sb + ), + lambda max_extra_noise, mod_flux, g, max_sb: (mod_flux, g), + max_extra_noise, + mod_flux, + g, + max_sb, + ) + + return jnp.ceil(mod_flux).astype(int), g, rng diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index 051277ca..6d32fd6f 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -7,13 +7,12 @@ from jax._src.numpy.util import _wraps import jax_galsim.photon_array as pa -from jax_galsim.core.draw import calculate_n_photons, sample_poisson_flux +from jax_galsim.core.draw import NPhotonsData, get_n_photons from jax_galsim.core.utils import is_equal_with_arrays from jax_galsim.errors import ( GalSimError, GalSimIncompatibleValuesError, GalSimNotImplementedError, - GalSimRangeError, GalSimValueError, galsim_warn, ) @@ -666,7 +665,7 @@ def drawImage( center=None, use_true_center=True, offset=None, - n_photons=0.0, + n_photons=None, rng=None, max_extra_noise=0.0, poisson_flux=None, @@ -1104,32 +1103,35 @@ def _drawKImage( @_wraps(_galsim.GSObject._calculate_nphotons) def _calculate_nphotons(self, n_photons, poisson_flux, max_extra_noise, rng): - if n_photons == 0.0: - Ntot, g = calculate_n_photons( - self.flux, - self._flux_per_photon, - self.max_sb, - rng=rng, - max_extra_noise=max_extra_noise, - poisson_flux=poisson_flux, - ) - else: - Ntot = int(n_photons + 0.5) - if poisson_flux: - pd_val = sample_poisson_flux(self.flux, self._flux_per_photon, rng=rng) - g = pd_val / jnp.abs(self.flux) - else: - g = 1.0 - - return Ntot, g + npd = NPhotonsData( + n_photons=n_photons, + poisson_flux=poisson_flux, + max_extra_noise=max_extra_noise, + rng=rng, + flux=self.flux, + flux_per_photon=self._flux_per_photon, + max_sb=self.max_sb, + ) + n_photons, g, _rng = get_n_photons(npd) + if rng is not None: + rng._state = _rng._state + return n_photons, g @_wraps( _galsim.GSObject.makePhot, - lax_description="The JAX-GalSim version of `makePhot` does not support the deprecated surface_ops argument.", + lax_description="""\ +The JAX-GalSim version of `makePhot` + + - does not support the deprecated surface_ops argument + - does little to no error checking on the inputs + - uses a default of ``n_photons=None`` instead of ``n_photons=0`` + to indicate that the number of photons should be determined + from the flux and gain +""", ) def makePhot( self, - n_photons=0, + n_photons=None, rng=None, max_extra_noise=0.0, poisson_flux=None, @@ -1143,24 +1145,17 @@ def makePhot( depr("surface_ops", 2.3, "photon_ops") photon_ops = surface_ops - # Make sure the type of n_photons is correct and has a valid value: - if not n_photons >= 0.0: - raise GalSimRangeError("Invalid n_photons < 0.", n_photons, 0.0, None) - if poisson_flux is None: # If n_photons is given, poisson_flux = False - poisson_flux = n_photons == 0.0 + poisson_flux = n_photons is None - # Check that either n_photons is set to something or flux is set to something - if n_photons == 0.0 and self.flux == 1.0: - galsim_warn( - "Warning: drawImage for object with flux == 1, area == 1, and " - "exptime == 1, but n_photons == 0. This will only shoot a single photon." + if n_photons is not None: + Ntot = int(n_photons + 0.5) + _, g = self._calculate_nphotons( + n_photons, poisson_flux, max_extra_noise, rng ) - - Ntot, g = self._calculate_nphotons( - n_photons, poisson_flux, max_extra_noise, rng - ) + else: + Ntot, g = self._calculate_nphotons(0.0, poisson_flux, max_extra_noise, rng) try: photons = self.shoot(Ntot, rng) @@ -1186,14 +1181,24 @@ def makePhot( @_wraps( _galsim.GSObject.drawPhot, - lax_description="The JAX-GalSim version of `drawPhot` does not support the deprecated surface_ops argument.", + lax_description="""\ +The JAX-GalSim version of `drawPhot` + + - does not support the deprecated surface_ops argument + - does little to no error checking on the inputs + - uses a default of ``n_photons=None`` instead of ``n_photons=0`` + to indicate that the number of photons should be determined + from the flux and gain + - the maxN option requires the use of fixed photon array sizes or a fixed + number of photons +""", ) def drawPhot( self, image, gain=1.0, add_to_image=False, - n_photons=0, + n_photons=None, rng=None, max_extra_noise=0.0, poisson_flux=None, @@ -1203,20 +1208,9 @@ def drawPhot( orig_center=PositionI(0, 0), local_wcs=None, ): - # Make sure the type of n_photons is correct and has a valid value: - if not n_photons >= 0.0: - raise GalSimRangeError("Invalid n_photons < 0.", n_photons, 0.0, None) - + # If n_photons is given and poisson_flux is None, poisson_flux = False if poisson_flux is None: - # If n_photons is given, poisson_flux = False - poisson_flux = n_photons == 0.0 - - # Check that either n_photons is set to something or flux is set to something - if n_photons == 0.0 and self.flux == 1.0: - galsim_warn( - "Warning: drawImage for object with flux == 1, area == 1, and " - "exptime == 1, but n_photons == 0. This will only shoot a single photon." - ) + poisson_flux = n_photons is None # Make sure the image is set up to have unit pixel scale and centered at 0,0. if image.wcs is None or not image.wcs._isPixelScale: @@ -1229,9 +1223,14 @@ def drawPhot( elif not isinstance(sensor, Sensor): raise TypeError("The sensor provided is not a Sensor instance") - Ntot, g = self._calculate_nphotons( - n_photons, poisson_flux, max_extra_noise, rng - ) + if n_photons is not None: + Ntot = int(n_photons + 0.5) + _, g = self._calculate_nphotons( + n_photons, poisson_flux, max_extra_noise, rng + ) + else: + Ntot, g = self._calculate_nphotons(0.0, poisson_flux, max_extra_noise, rng) + g = jax.lax.cond( gain != 1.0, lambda g, gain: g / gain, @@ -1240,40 +1239,64 @@ def drawPhot( gain, ) - if maxN is None: - maxN = Ntot - if not add_to_image: image.setZero() - ( - photons, - _rng, - added_flux, - _Nleft, - _image, - _photon_ops, - _sensor, - ) = _draw_phot_while_loop( - PhotonArray(maxN), - rng, - self, - image, - g, - Ntot, - maxN, - photon_ops, - local_wcs, - sensor, - orig_center, - ) + if maxN is None: + ( + added_flux, + _image, + _sensor, + _photon_ops, + _rng, + _, + photons, + ) = _draw_phot_while_loop_shoot( + Ntot, + Ntot, + Ntot, + self, + rng, + g, + image, + photon_ops, + sensor, + orig_center, + local_wcs, + False, + 0.0, + ) + else: + ( + photons, + _rng, + added_flux, + _Nleft, + _image, + _photon_ops, + _sensor, + ) = _draw_phot_while_loop( + PhotonArray(maxN), + rng, + self, + image, + g, + Ntot, + maxN, + photon_ops, + local_wcs, + sensor, + orig_center, + ) if rng is not None: rng._state = _rng._state else: rng = _rng for i in range(len(photon_ops)): photon_ops[i] = _photon_ops[i] + image._array = _image._array + # TODO: how to update the sensor? if sensor.__class__ is not Sensor: raise GalSimNotImplementedError( @@ -1285,13 +1308,13 @@ def drawPhot( @_wraps(_galsim.GSObject.shoot) def shoot(self, n_photons, rng=None): photons = pa.PhotonArray(n_photons) - if n_photons == 0: - # It's ok to shoot 0, but downstream can have problems with it, so just stop now. - return photons - if rng is None: - rng = BaseDeviate() - self._shoot(photons, rng) + if photons._x.shape[0] > 0: + _rng = BaseDeviate(rng) + self._shoot(photons, _rng) + if rng is not None: + rng._state = _rng._state + return photons @_wraps(_galsim.GSObject._shoot) @@ -1329,6 +1352,71 @@ def tree_unflatten(cls, aux_data, children): return cls(**(children[0]), **aux_data) +def _draw_phot_while_loop_shoot( + maxN, + thisN, + Ntot, + obj, + rng, + g, + image, + photon_ops, + sensor, + orig_center, + local_wcs, + resume, + added_flux, +): + try: + photons = obj.shoot(maxN, rng) + except (GalSimError, NotImplementedError) as e: + raise GalSimNotImplementedError( + "Unable to draw this GSObject with photon shooting. Perhaps it " + "is a Deconvolve or is a compound including one or more " + "Deconvolve objects.\nOriginal error: %r" % (e) + ) + # we drew maxN, but only keep thisN of them + photons._num_keep = thisN + + photons = jax.lax.cond( + # weird way to say gain == 1 and thisN == Ntot + jnp.abs(g - 1.0) + jnp.abs(thisN - Ntot) == 0, + lambda photons, g, thisN, Ntot: photons, + # the factor here is thisN / Ntot since we drew thisN photons, but use a total of Ntot photons + lambda photons, g, thisN, Ntot: photons.scaleFlux(g * thisN / Ntot), + photons, + g, + thisN, + Ntot, + ) + + photons = jax.lax.cond( + image.scale != 1.0, + lambda photons, scale: photons.scaleXY( + 1.0 / scale + ), # Convert x,y to image coords if necessary + lambda photons, scale: photons, + photons, + image.scale, + ) + + for op in photon_ops: + op.applyTo(photons, local_wcs, rng) + + if image.dtype in (jnp.float32, jnp.float64): + added_flux += sensor.accumulate(photons, image, orig_center, resume=resume) + resume = True # Resume from this point if there are any further iterations. + else: + # Need a temporary + from jax_galsim.image import ImageD + + im1 = ImageD(bounds=image.bounds) + added_flux += sensor.accumulate(photons, im1, orig_center) + image += im1 + + return added_flux, image, sensor, photon_ops, rng, resume, photons + + @partial(jax.jit, static_argnames=("maxN",)) def _draw_phot_while_loop( photons, @@ -1378,67 +1466,44 @@ def _body_fun(args): # Shoot at most maxN at a time thisN = jnp.minimum(maxN, Nleft) - try: - photons = obj.shoot(maxN, rng) - except (GalSimError, NotImplementedError) as e: - raise GalSimNotImplementedError( - "Unable to draw this GSObject with photon shooting. Perhaps it " - "is a Deconvolve or is a compound including one or more " - "Deconvolve objects.\nOriginal error: %r" % (e) - ) - # we drew maxN, but only keep thisN of them - photons._num_keep = thisN - - photons = jax.lax.cond( - # weird way to say gain == 1 and thisN == Ntot - jnp.abs(g - 1.0) + jnp.abs(thisN - Ntot) == 0, - lambda photons, g, thisN, Ntot: photons, - # the factor here is thisN / Ntot since we drew thisN photons, but use a total of Ntot photons - lambda photons, g, thisN, Ntot: photons.scaleFlux(g * thisN / Ntot), - photons, - g, + ( + _added_flux, + _image, + _sensor, + _photon_ops, + _rng, + _resume, + _photons, + ) = _draw_phot_while_loop_shoot( + maxN, thisN, Ntot, + obj, + rng, + g, + image, + photon_ops, + sensor, + orig_center, + local_wcs, + resume, + added_flux, ) - photons = jax.lax.cond( - image.scale != 1.0, - lambda photons, scale: photons.scaleXY( - 1.0 / scale - ), # Convert x,y to image coords if necessary - lambda photons, scale: photons, - photons, - image.scale, - ) - - for op in photon_ops: - op.applyTo(photons, local_wcs, rng) - - if image.dtype in (jnp.float32, jnp.float64): - added_flux += sensor.accumulate(photons, image, orig_center, resume=resume) - resume = True # Resume from this point if there are any further iterations. - else: - # Need a temporary - from jax_galsim.image import ImageD - - im1 = ImageD(bounds=image.bounds) - added_flux += sensor.accumulate(photons, im1, orig_center) - image += im1 - Nleft -= thisN return ( - photons, - rng, - added_flux, + _photons, + _rng, + _added_flux, obj, Nleft, - resume, - image, + _resume, + _image, g, - photon_ops, + _photon_ops, local_wcs, - sensor, + _sensor, orig_center, ) diff --git a/jax_galsim/interpolant.py b/jax_galsim/interpolant.py index 9cd661b9..4f398401 100644 --- a/jax_galsim/interpolant.py +++ b/jax_galsim/interpolant.py @@ -11,7 +11,6 @@ import jax.numpy as jnp from galsim.errors import GalSimValueError from jax._src.numpy.util import _wraps -from jax.tree_util import Partial as jax_partial from jax.tree_util import register_pytree_node_class from jax_galsim.bessel import si @@ -19,96 +18,7 @@ from jax_galsim.errors import GalSimError from jax_galsim.gsparams import GSParams from jax_galsim.random import UniformDeviate - - -@jax.jit -def _rejection_sample(photons, rng, tot_xrange, xval, pos_flux, neg_flux, max_val): - """Use rejection sampling to generate photons from a given 1D interpolant function. - - We sample both x and y values from the interpolant function. - - Parameters - ---------- - photons : PhotonArray - The photon array to shoot into. - rng : BaseDeviate - The random number generator to use for drawing photons. - tot_xrange : float - The total range of the interpolant function from the most negative - point to the most positive point. The interpolant is assumed to be - symmetric about zero. - xval : callable - The interpolant function. Will only be called with positive values. - pos_flux : float - The total integral under all positive regions of the interpolant function. - neg_flux : float - The absolute value of the total integral under all negative regions of - the interpolant function. - max_val : float - The maximum value of the interpolant function. Usually this is xval(0.0) and - is 1.0. - """ - - def _cond_fun(args): - # we stop drawing when we have tot photons - # curr records how many we have currently - _, _, tot, _, curr = args - return curr < tot - - def _body_fun(args): - arr, sign, tot, ud, curr = args - # arr is the array we are filling with photon positions - # sign is the array of signs of the interpolant function at the photon positions - # tot is the total number of photons to draw - # ud is the random number generator for uniform deviates from 0 to 1 - # curr is the current number of photons drawn - - # we first draw a random x location centered at zero with a - # total range of tot_xrange - xloc = (ud() - 0.5) * tot_xrange - - # next we draw a random y value between 0 and max_val - yv = ud() * max_val - xloc_val = xval(xloc) - - # this cond operator keeps the photon if the y value we drew is - # below the interpolant function at the x location we drew - arr, sign, curr = jax.lax.cond( - yv <= jnp.abs(xloc_val), - # if we keep it, assign the location, assign the sign, and increment curr - lambda arr, sign, curr, xloc, xloc_val: ( - arr.at[curr].set(xloc), - sign.at[curr].set(jnp.sign(xloc_val)), - curr + 1, - ), - # otherwise we pass - lambda arr, sign, curr, xloc, xloc_val: (arr, sign, curr), - arr, - sign, - curr, - xloc, - xloc_val, - ) - return arr, sign, tot, ud, curr - - ud = UniformDeviate(rng) - - # we first make the x and y positions - photons.x, _sign_x, _, ud, _ = jax.lax.while_loop( - _cond_fun, - _body_fun, - (jnp.zeros_like(photons.x), jnp.zeros_like(photons.x), photons.size(), ud, 0), - ) - photons.y, _sign_y, _, ud, _ = jax.lax.while_loop( - _cond_fun, - _body_fun, - (jnp.zeros_like(photons.y), jnp.zeros_like(photons.y), photons.size(), ud, 0), - ) - # this magic formula comes from looking closely at the galsim code in Interpolant.cpp - # and how things get adjusted down the line OneDimensionalDeviate.cpp - flux_per = (pos_flux + neg_flux) ** 2 / photons.size() - photons.flux = _sign_x * _sign_y * flux_per - return photons, rng +from jax_galsim.utilities import lazy_property @_wraps(_galsim.interpolant.Interpolant) @@ -350,9 +260,31 @@ def urange(self): % self.__class__.__name__ ) + @lazy_property + def _shoot_cdf(self): + x = jnp.linspace(-self.xrange, self.xrange, 10000) + px = jnp.abs(self._xval_noraise(jnp.abs(x))) + cdfx = jnp.cumsum(px) + cdfx /= cdfx[-1] + return x, cdfx + def _shoot(self, photons, rng): - raise NotImplementedError( - "%s does not implement shoot" % self.__class__.__name__ + x, cdfx = self._shoot_cdf + ud = UniformDeviate(rng) + ux = ud.generate(photons.x) + uy = ud.generate(photons.y) + photons.x = jnp.interp(ux, cdfx, x) + photons.y = jnp.interp(uy, cdfx, x) + if photons.size() > 0: + flux_per_photon = ( + self.positive_flux + self.negative_flux + ) ** 2 / photons.size() + else: + flux_per_photon = 0.0 + photons.flux = ( + flux_per_photon + * jnp.sign(self._xval_noraise(photons.x)) + * jnp.sign(self._xval_noraise(photons.y)) ) # subclasses should implement __init__, _xval, _uval, @@ -644,21 +576,6 @@ def ixrange(self): """The total integral range of the interpolant. Typically 2 * xrange.""" return 4 - def _shoot(self, photons, rng): - _photons, _rng = _rejection_sample( - photons, - rng, - self.xrange * 2.0, - jax_partial(self.__class__._xval), - self.positive_flux, - self.negative_flux, - self._xval_noraise(0.0), - ) - photons.x = _photons.x - photons.y = _photons.y - photons.flux = _photons.flux - rng._state = _rng._state - @_wraps(_galsim.interpolant.Quintic) @register_pytree_node_class @@ -754,21 +671,6 @@ def ixrange(self): """The total integral range of the interpolant. Typically 2 * xrange.""" return 6 - def _shoot(self, photons, rng): - _photons, _rng = _rejection_sample( - photons, - rng, - self.xrange * 2.0, - jax_partial(self.__class__._xval), - self.positive_flux, - self.negative_flux, - self._xval_noraise(0.0), - ) - photons.x = _photons.x - photons.y = _photons.y - photons.flux = _photons.flux - rng._state = _rng._state - @_wraps(_galsim.interpolant.Lanczos) @register_pytree_node_class @@ -1745,21 +1647,6 @@ def unit_integrals(self, max_len=None): else: return self._unit_integrals_no_conserve_dc[self._n][:n] - def _shoot(self, photons, rng): - _photons, _rng = _rejection_sample( - photons, - rng, - self.xrange * 2.0, - jax_partial(self.__class__._xval, self._n, self._conserve_dc, self._K_arr), - self.positive_flux, - self.negative_flux, - self._xval_noraise(0.0), - ) - photons.x = _photons.x - photons.y = _photons.y - photons.flux = _photons.flux - rng._state = _rng._state - # we apply JIT here to esnure the class init is fast @jax.jit diff --git a/jax_galsim/photon_array.py b/jax_galsim/photon_array.py index b7b8722c..da72dfbb 100644 --- a/jax_galsim/photon_array.py +++ b/jax_galsim/photon_array.py @@ -61,14 +61,15 @@ def __init__( ): # self._N = N self._Ntot = _JAX_GALSIM_PHOTON_ARRAY_SIZE or N - if ( - _JAX_GALSIM_PHOTON_ARRAY_SIZE is not None - and N > _JAX_GALSIM_PHOTON_ARRAY_SIZE - ): - raise GalSimValueError( - f"The given photon array size {N} is larger than " - f"the allowed total size {_JAX_GALSIM_PHOTON_ARRAY_SIZE}." - ) + # if ( + # _JAX_GALSIM_PHOTON_ARRAY_SIZE is not None + # and isinstance(N, int) + # and N > _JAX_GALSIM_PHOTON_ARRAY_SIZE + # ): + # raise GalSimValueError( + # f"The given photon array size {N} is larger than " + # f"the allowed total size {_JAX_GALSIM_PHOTON_ARRAY_SIZE}." + # ) if _nokeep is not None: self._nokeep = _nokeep else: diff --git a/jax_galsim/random.py b/jax_galsim/random.py index 1b4ce1f4..d45a9169 100644 --- a/jax_galsim/random.py +++ b/jax_galsim/random.py @@ -22,6 +22,19 @@ - Within a single routine linking may work. - You may encounter errors related to global side effects for some combinations of linked states and jitted/vmapped routines. + +Seeding the JAX-GalSim PRNG can be done in a few ways: + + - pass seed=None (This is equivalent to passing seed=0) + - pass an integer seed (This method will throw errors if the integer is traced by JAX.) + - pass another JAX-GalSim PRNG + - pass a JAX PRNG key made via `jax.random.key`. + +**JAX PRNG keys made via `jax.random.PRNGKey` are not supported.** + +When using JAX-GalSim PRNGs and JIT, you should always return the PRNG from the function +and then set the state on input PRNG via `prng.reset(new_prng)`. This will ensure that the +PRNG state is propagated correctly outside the JITed code. """ @@ -33,8 +46,8 @@ class _DeviateState: Parameters ---------- - key : jax.random.PRNGKey - The JAX PRNG key made via `jrandom.PRNGKey` or equivalent. + key : key data with dtype `jax.dtypes.prng_key` + The JAX PRNG key made via `jrandom.key` """ def __init__(self, key): @@ -79,13 +92,13 @@ def generates_in_pairs(self): _galsim.BaseDeviate.seed, lax_description="The JAX version of this method does no type checking.", ) - def seed(self, seed=0): + def seed(self, seed=None): self._seed(seed=seed) @_wraps(_galsim.BaseDeviate._seed) - def _seed(self, seed=0): + def _seed(self, seed=None): _initial_seed = seed or secrets.randbelow(2**31) - self._state.key = jrandom.PRNGKey(_initial_seed) + self._state.key = jrandom.key(_initial_seed) @_wraps( _galsim.BaseDeviate.reset, @@ -96,8 +109,10 @@ def reset(self, seed=None): self._state = seed elif isinstance(seed, BaseDeviate): self._state = seed._state - elif isinstance(seed, jax.Array) and seed.shape == (2,): - self._state = _DeviateState(wrap_key_data(seed)) + elif hasattr(seed, "dtype") and jax.dtypes.issubdtype( + seed.dtype, jax.dtypes.prng_key + ): + self._state = _DeviateState(seed) elif isinstance(seed, str): self._state = _DeviateState( wrap_key_data(jnp.array(eval(seed), dtype=jnp.uint32)) @@ -108,7 +123,7 @@ def reset(self, seed=None): ) else: _initial_seed = seed or secrets.randbelow(2**31) - self._state = _DeviateState(jrandom.PRNGKey(_initial_seed)) + self._state = _DeviateState(jrandom.key(_initial_seed)) @property def _key(self): diff --git a/jax_galsim/utilities.py b/jax_galsim/utilities.py index 6f2e549a..180fd7ee 100644 --- a/jax_galsim/utilities.py +++ b/jax_galsim/utilities.py @@ -4,6 +4,7 @@ import jax import jax.numpy as jnp from jax._src.numpy.util import _wraps +from jax.tree_util import tree_flatten from jax_galsim.errors import GalSimIncompatibleValuesError, GalSimValueError from jax_galsim.position import PositionD, PositionI @@ -11,27 +12,51 @@ printoptions = _galsim.utilities.printoptions +def has_tracers(x): + """Return True if the data is equal, False otherwise. Handles jax.Array types.""" + for item in tree_flatten(x)[0]: + if isinstance(item, jax.core.Tracer): + return True + return False + + @_wraps( _galsim.utilities.lazy_property, lax_description=( "The LAX version of this decorator uses an `_workspace` attribute " "attached to the object so that the cache can easily be discarded " - "for certain operations." + "for certain operations. It also will not cache jax.core.Tracer objects " + "in order to avoid side-effects in jit/grad/vmap transformations." ), ) -def lazy_property(func): - attname = func.__name__ + "_cached" - - @property - @functools.wraps(func) - def _func(self): - if not hasattr(self, "_workspace"): - self._workspace = {} - if attname not in self._workspace: - self._workspace[attname] = func(self) - return self._workspace[attname] - - return _func +def lazy_property(func_=None, cache_jax_tracers=False): + # see https://stackoverflow.com/a/57268935 + def _decorator(func): + attname = func.__name__ + "_cached" + + @property + @functools.wraps(func) + def wrapper(self): + if not hasattr(self, "_workspace"): + self._workspace = {} + if attname not in self._workspace: + val = func(self) + if cache_jax_tracers or (not has_tracers(val)): + self._workspace[attname] = val + else: + val = self._workspace[attname] + return val + + return wrapper + + if callable(func_): + return _decorator(func_) + elif func_ is None: + return _decorator + else: + raise RuntimeWarning( + "Positional arguments are not supported for the lazy_property decorator" + ) @_wraps(_galsim.utilities.parse_pos_args) diff --git a/tests/GalSim b/tests/GalSim index 710cca28..9e8d6565 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 710cca286c5fcd229d1c309aaf6e5c61ec81f9dc +Subproject commit 9e8d6565e88260586911339d1b3d8f32a7a8e1ba diff --git a/tests/jax/galsim/test_draw_jax.py b/tests/jax/galsim/test_draw_jax.py index cb93d249..fe4656e7 100644 --- a/tests/jax/galsim/test_draw_jax.py +++ b/tests/jax/galsim/test_draw_jax.py @@ -1115,19 +1115,22 @@ def test_shoot(): # in exact arithmetic. We had an assert there which blew up in a not very nice way. obj = galsim.Gaussian(sigma=0.2398318) + 0.1*galsim.Gaussian(sigma=0.47966352) obj = obj.withFlux(100001) - image1 = galsim.ImageF(32,32, init_value=100) + # JAX-Galsim adjusts the images to double here + image1 = galsim.ImageD(32,32, init_value=100) rng = galsim.BaseDeviate(1234) obj.drawImage(image1, method='phot', poisson_flux=False, add_to_image=True, rng=rng, maxN=100000) # The test here is really just that it doesn't crash. # But let's do something to check correctness. - image2 = galsim.ImageF(32,32) + # JAX-Galsim adjusts the images to double here + image2 = galsim.ImageD(32,32) rng = galsim.BaseDeviate(1234) obj.drawImage(image2, method='phot', poisson_flux=False, add_to_image=False, rng=rng, maxN=100000) image2 += 100 - np.testing.assert_array_almost_equal(image2.array, image1.array, decimal=12) + # with double, we get the same result to 10 decimal places + np.testing.assert_array_almost_equal(image2.array, image1.array, decimal=10) # Also check that you get the same answer with a smaller maxN. image3 = galsim.ImageF(32,32, init_value=100) @@ -1141,13 +1144,15 @@ def test_shoot(): np.testing.assert_array_equal(image4.array, 0) # Warns if flux is 1 and n_photons not given. + # JAX-GalSim doesn't warn in this case psf = galsim.Gaussian(sigma=3) - with assert_warns(galsim.GalSimWarning): - psf.drawImage(method='phot') - with assert_warns(galsim.GalSimWarning): - psf.drawPhot(image4) - with assert_warns(galsim.GalSimWarning): - psf.makePhot() + # with assert_warns(galsim.GalSimWarning): + # psf.drawImage(method='phot') + # with assert_warns(galsim.GalSimWarning): + # psf.drawPhot(image4) + # with assert_warns(galsim.GalSimWarning): + # psf.makePhot() + # With n_photons=1, it's fine. psf.drawImage(method='phot', n_photons=1) psf.drawPhot(image4, n_photons=1) @@ -1204,23 +1209,24 @@ def test_drawImage_area_exptime(): msg = "obj.drawImage(method='phot') unexpectedly produced equal images with different rng" assert not np.allclose(im5.array, im4.array), msg - # Shooting with flux=1 raises a warning. - obj1 = obj.withFlux(1) - with assert_warns(galsim.GalSimWarning): - obj1.drawImage(method='phot') - # But not if we explicitly tell it to shoot 1 photon - with assert_raises(AssertionError): - assert_warns(galsim.GalSimWarning, obj1.drawImage, method='phot', n_photons=1) - # Likewise for makePhot - with assert_warns(galsim.GalSimWarning): - obj1.makePhot() - with assert_raises(AssertionError): - assert_warns(galsim.GalSimWarning, obj1.makePhot, n_photons=1) - # And drawPhot - with assert_warns(galsim.GalSimWarning): - obj1.drawPhot(im1) - with assert_raises(AssertionError): - assert_warns(galsim.GalSimWarning, obj1.drawPhot, im1, n_photons=1) + # JAX-GalSim doesn't raise for these things + # # Shooting with flux=1 raises a warning. + # obj1 = obj.withFlux(1) + # with assert_warns(galsim.GalSimWarning): + # obj1.drawImage(method='phot') + # # But not if we explicitly tell it to shoot 1 photon + # with assert_raises(AssertionError): + # assert_warns(galsim.GalSimWarning, obj1.drawImage, method='phot', n_photons=1) + # # Likewise for makePhot + # with assert_warns(galsim.GalSimWarning): + # obj1.makePhot() + # with assert_raises(AssertionError): + # assert_warns(galsim.GalSimWarning, obj1.makePhot, n_photons=1) + # # And drawPhot + # with assert_warns(galsim.GalSimWarning): + # obj1.drawPhot(im1) + # with assert_raises(AssertionError): + # assert_warns(galsim.GalSimWarning, obj1.drawPhot, im1, n_photons=1) @timer diff --git a/tests/jax/test_interpolatedimage_utils.py b/tests/jax/test_interpolatedimage_utils.py index 6c0200f0..4ecaf22f 100644 --- a/tests/jax/test_interpolatedimage_utils.py +++ b/tests/jax/test_interpolatedimage_utils.py @@ -329,10 +329,7 @@ def test_interpolatedimage_utils_jax_galsim_fft_vs_galsim_fft(n): Lanczos(7), ], ) -def test_interpolatedimage_interpolant_rejection_sample(interp): - from jax.tree_util import Partial as jax_partial - - from jax_galsim.interpolant import _rejection_sample +def test_interpolatedimage_interpolant_sample(interp): from jax_galsim.photon_array import PhotonArray from jax_galsim.random import BaseDeviate @@ -340,15 +337,7 @@ def test_interpolatedimage_interpolant_rejection_sample(interp): ntot = 1000000 photons = PhotonArray(ntot) - photons, _ = _rejection_sample( - photons, - rng, - interp.xrange * 2.0, - jax_partial(interp._xval_noraise), - interp.positive_flux, - interp.negative_flux, - interp._xval_noraise(0.0), - ) + interp._shoot(photons, rng) h, bins = jnp.histogram(photons.x, bins=500) mid = (bins[1:] + bins[:-1]) / 2.0 @@ -368,8 +357,8 @@ def test_interpolatedimage_interpolant_rejection_sample(interp): if interp.__class__.__name__ == "Quintic" and False: import proplot as pplt - fig, axs = pplt.subplots(figsize=(4, 4)) - axs.hist(photons.x, bins=500, log=False) + fig, axs = pplt.subplots(figsize=(6, 6)) + axs.hist(photons.x, bins=500, log=True) axs.plot(mid, yv, color="k") fig.show() breakpoint() diff --git a/tests/jax/test_jitting.py b/tests/jax/test_jitting.py index 0b20f9eb..6a427f94 100644 --- a/tests/jax/test_jitting.py +++ b/tests/jax/test_jitting.py @@ -240,7 +240,7 @@ def _build_and_draw(hlr, fwhm, jit=True): final._flux_per_photon, final.max_sb, poisson_flux=False, - )[0] + )[0].item() gain = 1.0 if jit: return _draw_it_jit(final, n, n_photons, gain) diff --git a/tests/jax/test_photon_shooting_jax.py b/tests/jax/test_photon_shooting_jax.py index 9346a149..bb2dd3b2 100644 --- a/tests/jax/test_photon_shooting_jax.py +++ b/tests/jax/test_photon_shooting_jax.py @@ -1,3 +1,4 @@ +import galsim as _galsim import jax import jax.numpy as jnp import numpy as np @@ -6,6 +7,7 @@ import jax_galsim from jax_galsim.core.testing import time_code_block +from jax_galsim.photon_array import fixed_photon_array_size def test_photon_shooting_jax_make_from_image_notranspose(): @@ -168,3 +170,69 @@ def test_photon_shooting_jax_offset(offset): ) np.testing.assert_allclose(img_fft.array, img_phot.array, rtol=rtol, atol=atol) + + +def test_photon_shooting_jax_vmapping(): + n_stamps = 100 + rng = np.random.RandomState(1234) + shifts = jnp.array(rng.uniform(-1, 1, size=(n_stamps, 2))) + hlrs = jnp.array(rng.uniform(0.1, 1.0, size=(n_stamps,))) + fwhms = jnp.array(rng.uniform(0.9, 1.0, size=(n_stamps,))) + fluxes = jnp.array(rng.uniform(100, 1000, size=(n_stamps,))) + rng = jax_galsim.BaseDeviate(1234) + seeds = [] + for i in range(n_stamps): + seeds.append(jax.random.key(i + 1)) + max_n_phot = 2048 + seeds = jnp.array(seeds) + + @jax.jit + def _draw(hlr, fwhm, shift, flux, seed): + obj = jax_galsim.Convolve( + [ + jax_galsim.Exponential(half_light_radius=hlr, flux=flux).shift(*shift), + jax_galsim.Gaussian(fwhm=fwhm, flux=1.0), + ] + ) + with fixed_photon_array_size(max_n_phot): + return obj.drawImage( + nx=33, + ny=33, + scale=0.2, + method="phot", + rng=jax_galsim.BaseDeviate(seed), + ) + + with time_code_block("one warmup"): + img = _draw(hlrs[0], fwhms[0], shifts[0], fluxes[0], seeds[0]) + with time_code_block("one"): + img = _draw(hlrs[0], fwhms[0], shifts[0], fluxes[0], seeds[0]) + print(img.array.shape, img.bounds, img.array.sum(), fluxes[0]) + + _vmap_draw = jax.jit(jax.vmap(_draw, in_axes=(0, 0, 0, 0, 0))) + with time_code_block("vmap warmup"): + imgs = _vmap_draw(hlrs, fwhms, shifts, fluxes, seeds) + with time_code_block("vmap"): + imgs = _vmap_draw(hlrs, fwhms, shifts, fluxes, seeds) + print(imgs.array.shape) + + np.testing.assert_allclose(img.array.sum(), imgs.array[0].sum()) + + def _draw_galsim(hlr, fwhm, shift, flux, seed): + obj = _galsim.Convolve( + [ + _galsim.Exponential(half_light_radius=hlr, flux=flux).shift(*shift), + _galsim.Gaussian(fwhm=fwhm, flux=1.0), + ] + ) + return obj.drawImage( + nx=33, + ny=33, + scale=0.2, + method="phot", + rng=_galsim.BaseDeviate(seed), + ) + + with time_code_block("galsim"): + for i in range(n_stamps): + _draw_galsim(hlrs[i], fwhms[i], shifts[i], fluxes[i], i + 1) From 9be573f3f77c12fe941105455d3f010be04752ed Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 20 Nov 2023 17:04:50 -0600 Subject: [PATCH 34/85] ENH use higher order integration method --- jax_galsim/interpolant.py | 7 ++++++- tests/jax/test_interpolatedimage_utils.py | 3 ++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/jax_galsim/interpolant.py b/jax_galsim/interpolant.py index 4f398401..fb220232 100644 --- a/jax_galsim/interpolant.py +++ b/jax_galsim/interpolant.py @@ -264,7 +264,12 @@ def urange(self): def _shoot_cdf(self): x = jnp.linspace(-self.xrange, self.xrange, 10000) px = jnp.abs(self._xval_noraise(jnp.abs(x))) - cdfx = jnp.cumsum(px) + dx = x[1] - x[0] + # cumulative trapezoidal rule + # see scipy.integrate.cumulative_trapezoidal + cdfx = jnp.concatenate( + [jnp.array([0]), jnp.cumsum((px[1:] + px[:-1]) * 0.5 * dx)] + ) cdfx /= cdfx[-1] return x, cdfx diff --git a/tests/jax/test_interpolatedimage_utils.py b/tests/jax/test_interpolatedimage_utils.py index 4ecaf22f..c76278ef 100644 --- a/tests/jax/test_interpolatedimage_utils.py +++ b/tests/jax/test_interpolatedimage_utils.py @@ -354,11 +354,12 @@ def test_interpolatedimage_interpolant_sample(interp): np.testing.assert_allclose(fdev[msk], 0, rtol=0, atol=5.0, err_msg=f"{interp}") np.testing.assert_allclose(fdev[~msk], 0, rtol=0, atol=15.0, err_msg=f"{interp}") - if interp.__class__.__name__ == "Quintic" and False: + if interp.__class__.__name__ in ["Quintic", "Lanczos"] and False: import proplot as pplt fig, axs = pplt.subplots(figsize=(6, 6)) axs.hist(photons.x, bins=500, log=True) axs.plot(mid, yv, color="k") + axs.format(title=interp.__class__.__name__) fig.show() breakpoint() From 8f3331e079fcb3c12bfb6f77561758109af4c860 Mon Sep 17 00:00:00 2001 From: beckermr Date: Sun, 26 Nov 2023 07:33:06 -0500 Subject: [PATCH 35/85] TST enable rest of galsim test suit --- jax_galsim/__init__.py | 1 + jax_galsim/angle.py | 8 +- jax_galsim/bounds.py | 370 +- jax_galsim/core/draw.py | 6 +- jax_galsim/core/utils.py | 89 +- jax_galsim/gsobject.py | 74 +- jax_galsim/image.py | 170 +- jax_galsim/interpolatedimage.py | 9 +- jax_galsim/photon_array.py | 19 +- jax_galsim/position.py | 8 +- jax_galsim/sensor.py | 11 +- jax_galsim/utilities.py | 10 +- jax_galsim/wcs.py | 8 +- tests/Coord | 2 +- tests/GalSim | 2 +- tests/conftest.py | 40 +- tests/galsim_tests_config.yaml | 76 +- tests/jax/galsim/test_draw_jax.py | 1720 ------ tests/jax/galsim/test_image_jax.py | 4850 ----------------- tests/jax/galsim/test_noise_jax.py | 843 --- tests/jax/galsim/test_photon_array_jax.py | 1873 ------- tests/jax/galsim/test_random_jax.py | 2002 ------- tests/jax/galsim/test_shear_jax.py | 377 -- tests/jax/galsim/test_shear_position_jax.py | 220 - tests/jax/galsim/test_wcs_jax.py | 4090 -------------- tests/jax/test_api.py | 112 + tests/jax/test_image_wrapping.py | 191 +- .../jax/{galsim => }/test_interpolant_jax.py | 0 .../{test_metacal.py => test_metacal_jax.py} | 0 tests/jax/test_ref_impl.py | 72 + tests/jax/test_temporary_tests.py | 188 - 31 files changed, 1058 insertions(+), 16383 deletions(-) delete mode 100644 tests/jax/galsim/test_draw_jax.py delete mode 100644 tests/jax/galsim/test_image_jax.py delete mode 100644 tests/jax/galsim/test_noise_jax.py delete mode 100644 tests/jax/galsim/test_photon_array_jax.py delete mode 100644 tests/jax/galsim/test_random_jax.py delete mode 100644 tests/jax/galsim/test_shear_jax.py delete mode 100644 tests/jax/galsim/test_shear_position_jax.py delete mode 100644 tests/jax/galsim/test_wcs_jax.py rename tests/jax/{galsim => }/test_interpolant_jax.py (100%) rename tests/jax/{test_metacal.py => test_metacal_jax.py} (100%) create mode 100644 tests/jax/test_ref_impl.py delete mode 100644 tests/jax/test_temporary_tests.py diff --git a/jax_galsim/__init__.py b/jax_galsim/__init__.py index 94d0afc8..55613102 100644 --- a/jax_galsim/__init__.py +++ b/jax_galsim/__init__.py @@ -44,6 +44,7 @@ ImageS, ImageUI, ImageUS, + _Image, ) # GSObject diff --git a/jax_galsim/angle.py b/jax_galsim/angle.py index 450fe588..3126a6da 100644 --- a/jax_galsim/angle.py +++ b/jax_galsim/angle.py @@ -22,7 +22,7 @@ from jax._src.numpy.util import _wraps from jax.tree_util import register_pytree_node_class -from jax_galsim.core.utils import cast_to_array_scalar, ensure_hashable +from jax_galsim.core.utils import cast_to_float, ensure_hashable @_wraps(_galsim.AngleUnit) @@ -34,7 +34,9 @@ def __init__(self, value): """ :param value: The measure of the unit in radians. """ - self._value = cast_to_array_scalar(value, dtype=float) + if isinstance(value, AngleUnit): + raise TypeError("Cannot construct AngleUnit from another AngleUnit") + self._value = 1.0 * cast_to_float(value) # this will cause an exception if things are not numeric @property def value(self): @@ -142,7 +144,7 @@ def __init__(self, theta, unit=None): raise TypeError("Invalid unit %s of type %s" % (unit, type(unit))) else: # Normal case - self._rad = cast_to_array_scalar(theta, dtype=float) * unit.value + self._rad = cast_to_float(theta) * unit.value @property def rad(self): diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index a61f483d..2c0658b4 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -1,24 +1,37 @@ import galsim as _galsim import jax import jax.numpy as jnp +import numpy as np from jax._src.numpy.util import _wraps from jax.tree_util import register_pytree_node_class -from jax_galsim.core.utils import cast_to_float, cast_to_int, ensure_hashable +from jax_galsim.core.utils import ( + cast_to_float, + cast_to_int, + ensure_hashable, + has_tracers, +) from jax_galsim.position import Position, PositionD, PositionI +from jax_galsim.errors import GalSimUndefinedBoundsError # The reason for avoid these tests is that they are not easy to do for jitted code. @_wraps( _galsim.Bounds, - lax_description=( - "The JAX implementation will not test whether the bounds are valid." - "This is defined as always true." - "It will also not test whether BoundsI is indeed initialized with integers." - ), + lax_description="""\ +"The JAX implementation of galsim.Bounds + + - will not always test for properly defined bounds, especially in jitted code + - will not test whether BoundsI is indeed initialized with integers during vmap/jit/grad transforms +""", ) @register_pytree_node_class -class Bounds(_galsim.Bounds): +class Bounds(object): + def __init__(self): + raise NotImplementedError( + "Cannot instantiate the base class. " "Use either BoundsD or BoundsI." + ) + def _parse_args(self, *args, **kwargs): if len(kwargs) == 0: if len(args) == 4: @@ -26,7 +39,7 @@ def _parse_args(self, *args, **kwargs): self.xmin, self.xmax, self.ymin, self.ymax = args elif len(args) == 0: self._isdefined = False - self.xmin = self.xmax = self.ymin = self.ymax = 0 + self.xmin = self.xmax = self.ymin = self.ymax = jnp.nan elif len(args) == 1: if isinstance(args[0], Bounds): self._isdefined = True @@ -81,47 +94,107 @@ def _parse_args(self, *args, **kwargs): if kwargs: raise TypeError("Got unexpected keyword arguments %s" % kwargs.keys()) - @property - def true_center(self): - """The central position of the `Bounds` as a `PositionD`. + @_wraps(_galsim.Bounds.isDefined) + def isDefined(self, _static=False): + if _static: + return self._isdefined and np.all(self.xmin <= self.xmax) and np.all( + self.ymin <= self.ymax + ) + else: + return ( + jnp.isfinite(self.xmin) + & jnp.isfinite(self.xmax) + & jnp.isfinite(self.ymin) + & jnp.isfinite(self.ymax) + & (self.xmin <= self.xmax) + & (self.ymin <= self.ymax) + ) + + def area(self): + """Return the area of the enclosed region. + + The area is a bit different for integer-type `BoundsI` and float-type `BoundsD` instances. + For floating point types, it is simply ``(xmax-xmin)*(ymax-ymin)``. However, for integer + types, we add 1 to each size to correctly count the number of pixels being described by the + bounding box. + """ + return self._area() + + def withBorder(self, dx, dy=None): + """Return a new `Bounds` object that expands the current bounds by the specified width. - This is always (xmax + xmin)/2., (ymax + ymin)/2., even for integer `BoundsI`, where - this may not necessarily be an integer `PositionI`. + If two arguments are given, then these are separate dx and dy borders. """ - if not self.isDefined(): - raise _galsim.GalSimUndefinedBoundsError( + self._check_scalar(dx, "dx") + if dy is None: + dy = dx + else: + self._check_scalar(dy, "dy") + return self.__class__( + self.xmin - dx, self.xmax + dx, self.ymin - dy, self.ymax + dy + ) + + @property + def origin(self): + "The lower left position of the `Bounds`." + return self._pos_class(self.xmin, self.ymin) + + @property + @_wraps( + _galsim.Bounds.center, + lax_description="The JAX implementation of galsim.Bounds.center does not raise for undefined bounds.", + ) + def center(self): + if not self.isDefined(_static=True): + raise GalSimUndefinedBoundsError( + "center is invalid for an undefined Bounds" + ) + return self._center + + @property + @_wraps( + _galsim.Bounds.true_center, + lax_description="The JAX implementation of galsim.Bounds.true_center does not raise for undefined bounds.", + ) + def true_center(self): + if not self.isDefined(_static=True): + raise GalSimUndefinedBoundsError( "true_center is invalid for an undefined Bounds" ) return PositionD((self.xmax + self.xmin) / 2.0, (self.ymax + self.ymin) / 2.0) @_wraps(_galsim.Bounds.includes) - def includes(self, *args): + def includes(self, *args, _static=False): if len(args) == 1: if isinstance(args[0], Bounds): b = args[0] return ( - self.isDefined() - and b.isDefined() - and self.xmin <= b.xmin - and self.xmax >= b.xmax - and self.ymin <= b.ymin - and self.ymax >= b.ymax + self.isDefined(_static=_static) + & b.isDefined(_static=_static) + & (self.xmin <= b.xmin) + & (self.xmax >= b.xmax) + & (self.ymin <= b.ymin) + & (self.ymax >= b.ymax) ) elif isinstance(args[0], Position): p = args[0] return ( - self.isDefined() - and self.xmin <= p.x <= self.xmax - and self.ymin <= p.y <= self.ymax + self.isDefined(_static=_static) + & (self.xmin <= p.x) + & (p.x <= self.xmax) + & (self.ymin <= p.y) + & (p.y <= self.ymax) ) else: raise TypeError("Invalid argument %s" % args[0]) elif len(args) == 2: x, y = args return ( - self.isDefined() - and self.xmin <= float(x) <= self.xmax - and self.ymin <= float(y) <= self.ymax + self.isDefined(_static=_static) + & (self.xmin <= x) + & (x <= self.xmax) + & (self.ymin <= y) + & (y <= self.ymax) ) elif len(args) == 0: raise TypeError("include takes at least 1 argument (0 given)") @@ -139,42 +212,108 @@ def expand(self, factor_x, factor_y=None): dy = jnp.ceil(dy) return self.withBorder(dx, dy) + def getXMin(self): + "Get the value of xmin." + return self.xmin + + def getXMax(self): + "Get the value of xmax." + return self.xmax + + def getYMin(self): + "Get the value of ymin." + return self.ymin + + def getYMax(self): + "Get the value of ymax." + return self.ymax + + def shift(self, delta): + """Shift the `Bounds` instance by a supplied `Position`. + + Examples: + + The shift method takes either a `PositionI` or `PositionD` instance, which must match + the type of the `Bounds` instance:: + + >>> bounds = BoundsI(1,32,1,32) + >>> bounds = bounds.shift(galsim.PositionI(3, 2)) + >>> bounds = BoundsD(0, 37.4, 0, 49.9) + >>> bounds = bounds.shift(galsim.PositionD(3.9, 2.1)) + """ + if not isinstance(delta, self._pos_class): + raise TypeError("delta must be a %s instance" % self._pos_class) + return self.__class__( + self.xmin + delta.x, + self.xmax + delta.x, + self.ymin + delta.y, + self.ymax + delta.y, + ) + def __and__(self, other): if not isinstance(other, self.__class__): raise TypeError("other must be a %s instance" % self.__class__.__name__) - if not self.isDefined() or not other.isDefined(): - return self.__class__() - else: - xmin = jnp.maximum(self.xmin, other.xmin) - xmax = jnp.minimum(self.xmax, other.xmax) - ymin = jnp.maximum(self.ymin, other.ymin) - ymax = jnp.minimum(self.ymax, other.ymax) - if xmin > xmax or ymin > ymax: - return self.__class__() - else: - return self.__class__(xmin, xmax, ymin, ymax) + # NaNs always propagate, so if either is undefined, the result is undefined + return self.__class__( + jnp.maximum(self.xmin, other.xmin), + jnp.minimum(self.xmax, other.xmax), + jnp.maximum(self.ymin, other.ymin), + jnp.minimum(self.ymax, other.ymax), + ) def __add__(self, other): if isinstance(other, self.__class__): - if not other.isDefined(): - return self - elif self.isDefined(): - xmin = jnp.minimum(self.xmin, other.xmin) - xmax = jnp.maximum(self.xmax, other.xmax) - ymin = jnp.minimum(self.ymin, other.ymin) - ymax = jnp.maximum(self.ymax, other.ymax) - return self.__class__(xmin, xmax, ymin, ymax) - else: - return other + # galsim logic is + # if not other.isDefined(): + # return self + # elif self.isDefined(): + # xmin = jnp.minimum(self.xmin, other.xmin) + # xmax = jnp.maximum(self.xmax, other.xmax) + # ymin = jnp.minimum(self.ymin, other.ymin) + # ymax = jnp.maximum(self.ymax, other.ymax) + # return self.__class__(xmin, xmax, ymin, ymax) + # else: + # return other + return self.__class__( + jax.lax.cond( + ~jnp.any(other.isDefined()), + lambda: BoundsD(self), + lambda: BoundsD( + jax.lax.cond( + jnp.any(self.isDefined()), + lambda: BoundsD( + jnp.minimum(self.xmin, other.xmin), + jnp.maximum(self.xmax, other.xmax), + jnp.minimum(self.ymin, other.ymin), + jnp.maximum(self.ymax, other.ymax), + ), + lambda: BoundsD(other), + ) + ), + ) + ) elif isinstance(other, self._pos_class): - if self.isDefined(): - xmin = jnp.minimum(self.xmin, other.x) - xmax = jnp.maximum(self.xmax, other.x) - ymin = jnp.minimum(self.ymin, other.y) - ymax = jnp.maximum(self.ymax, other.y) - return self.__class__(xmin, xmax, ymin, ymax) - else: - return self.__class__(other) + # the galsim logic is + # if self.isDefined(): + # xmin = jnp.minimum(self.xmin, other.x) + # xmax = jnp.maximum(self.xmax, other.x) + # ymin = jnp.minimum(self.ymin, other.y) + # ymax = jnp.maximum(self.ymax, other.y) + # return self.__class__(xmin, xmax, ymin, ymax) + # else: + # return self.__class__(other) + return self.__class__( + jax.lax.cond( + jnp.any(self.isDefined()), + lambda: BoundsD( + jnp.minimum(self.xmin, other.x), + jnp.maximum(self.xmax, other.x), + jnp.minimum(self.ymin, other.y), + jnp.maximum(self.ymax, other.y), + ), + lambda: BoundsD(other), + ) + ) else: raise TypeError( "other must be either a %s or a %s" @@ -182,7 +321,7 @@ def __add__(self, other): ) def __repr__(self): - if self.isDefined(): + if self.isDefined(_static=True): return "galsim.%s(xmin=%r, xmax=%r, ymin=%r, ymax=%r)" % ( self.__class__.__name__, ensure_hashable(self.xmin), @@ -194,7 +333,7 @@ def __repr__(self): return "galsim.%s()" % (self.__class__.__name__) def __str__(self): - if self.isDefined(): + if self.isDefined(_static=True): return "galsim.%s(%s,%s,%s,%s)" % ( self.__class__.__name__, ensure_hashable(self.xmin), @@ -216,14 +355,34 @@ def __hash__(self): ) ) + def _getinitargs(self): + if self.isDefined(_static=True): + return (self.xmin, self.xmax, self.ymin, self.ymax) + else: + return () + + def __eq__(self, other): + return self is other or ( + isinstance(other, self.__class__) + and ( + ( + np.array_equal(self.xmin, other.xmin, equal_nan=True) + and np.array_equal(self.xmax, other.xmax, equal_nan=True) + and np.array_equal(self.ymin, other.ymin, equal_nan=True) + and np.array_equal(self.ymax, other.ymax, equal_nan=True) + ) + or ((not self.isDefined(_static=True)) and (not other.isDefined(_static=True))) + ) + ) + + def __ne__(self, other): + return not self.__eq__(other) + def tree_flatten(self): """This function flattens the Bounds into a list of children nodes that will be traced by JAX and auxiliary static data.""" # Define the children nodes of the PyTree that need tracing - if self.isDefined(): - children = (self.xmin, self.xmax, self.ymin, self.ymax) - else: - children = tuple() + children = (self.xmin, self.xmax, self.ymin, self.ymax) # Define auxiliary static data that doesn’t need to be traced aux_data = None return (children, aux_data) @@ -245,15 +404,21 @@ def from_galsim(cls, galsim_bounds): "galsim_bounds must be either a %s or a %s" % (_galsim.BoundsD.__name__, _galsim.BoundsI.__name__) ) - return _cls( - galsim_bounds.xmin, - galsim_bounds.xmax, - galsim_bounds.ymin, - galsim_bounds.ymax, - ) + if not galsim_bounds.isDefined(): + return _cls() + else: + return _cls( + galsim_bounds.xmin, + galsim_bounds.xmax, + galsim_bounds.ymin, + galsim_bounds.ymax, + ) -@_wraps(_galsim.BoundsD) +@_wraps( + _galsim.BoundsD, + lax_description="The JAX implementation of galsim.BoundsD does not always check for float values.", +) @register_pytree_node_class class BoundsD(Bounds): _pos_class = PositionD @@ -280,20 +445,49 @@ def _check_scalar(self, x, name): raise TypeError("%s must be a float value" % name) def _area(self): - return (self.xmax - self.xmin) * (self.ymax - self.ymin) + return jax.lax.cond( + jnp.any(self.isDefined()), + lambda xmin, xmax, ymin, ymax: (xmax - xmin) * (ymax - ymin), + lambda xmin, xmax, ymin, ymax: jnp.zeros_like(xmin), + self.xmin, + self.xmax, + self.ymin, + self.ymax, + ) @property def _center(self): return PositionD((self.xmax + self.xmin) / 2.0, (self.ymax + self.ymin) / 2.0) -@_wraps(_galsim.BoundsI) +@_wraps( + _galsim.BoundsI, + lax_description="The JAX implementation of galsim.BoundsI does not always check for integer values.", +) @register_pytree_node_class class BoundsI(Bounds): _pos_class = PositionI def __init__(self, *args, **kwargs): self._parse_args(*args, **kwargs) + + # best-effort error checking + raise_notint = False + try: + bnds = (self.xmin, self.xmax, self.ymin, self.ymax) + if not has_tracers(bnds) and np.all(np.isfinite(bnds)) & np.any( + (self.xmin != np.floor(self.xmin)) + | (self.xmax != np.floor(self.xmax)) + | (self.ymin != np.floor(self.ymin)) + | (self.ymax != np.floor(self.ymax)) + ): + raise_notint = True + except Exception: + pass + + if raise_notint: + raise TypeError("BoundsI must be initialized with integer values") + self.xmin = cast_to_int(self.xmin) self.xmax = cast_to_int(self.xmax) self.ymin = cast_to_int(self.ymin) @@ -313,19 +507,35 @@ def _check_scalar(self, x, name): pass raise TypeError("%s must be an integer value" % name) - def numpyShape(self): + def numpyShape(self, _static=False): "A simple utility function to get the numpy shape that corresponds to this `Bounds` object." - if self.isDefined(): - return self.ymax - self.ymin + 1, self.xmax - self.xmin + 1 + if _static: + if self.isDefined(_static=True): + return (self.ymax - self.ymin + 1, self.xmax - self.xmin + 1) + else: + return (0, 0) else: - return 0, 0 + return jax.lax.cond( + jnp.any(self.isDefined()), + lambda xmin, xmax, ymin, ymax: (ymax - ymin + 1, xmax - xmin + 1), + lambda xmin, xmax, ymin, ymax: (jnp.zeros_like(xmin), jnp.zeros_like(xmin)), + self.xmin, + self.xmax, + self.ymin, + self.ymax, + ) def _area(self): # Remember the + 1 this time to include the pixels on both edges of the bounds. - if not self.isDefined(): - return 0 - else: - return (self.xmax - self.xmin + 1) * (self.ymax - self.ymin + 1) + return jax.lax.cond( + jnp.any(self.isDefined()), + lambda xmin, xmax, ymin, ymax: (xmax - xmin + 1) * (ymax - ymin + 1), + lambda xmin, xmax, ymin, ymax: jnp.zeros_like(xmin), + self.xmin, + self.xmax, + self.ymin, + self.ymax, + ) @property def _center(self): diff --git a/jax_galsim/core/draw.py b/jax_galsim/core/draw.py index 8b632f7d..f54fa08c 100644 --- a/jax_galsim/core/draw.py +++ b/jax_galsim/core/draw.py @@ -37,7 +37,7 @@ def draw_by_xValue( im = (im * flux_scaling).astype(image.dtype) # Return an image - return Image(array=im, bounds=image.bounds, wcs=image.wcs, check_bounds=False) + return Image(array=im, bounds=image.bounds, wcs=image.wcs, _check_bounds=False) def draw_by_kValue(gsobject, image, jacobian=jnp.eye(2)): @@ -55,7 +55,7 @@ def draw_by_kValue(gsobject, image, jacobian=jnp.eye(2)): im = (im).astype(image.dtype) # Return an image - return Image(array=im, bounds=image.bounds, wcs=image.wcs, check_bounds=False) + return Image(array=im, bounds=image.bounds, wcs=image.wcs, _check_bounds=False) def apply_kImage_phases(offset, image, jacobian=jnp.eye(2)): @@ -81,7 +81,7 @@ def phase(kpos): array=image.array * im_phase, bounds=image.bounds, wcs=image.wcs, - check_bounds=False, + _check_bounds=False, ) diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index c499929c..db0d7fe2 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -2,6 +2,21 @@ import jax import jax.numpy as jnp +import numpy as np +from jax.tree_util import tree_flatten + + +def has_tracers(x): + """Return True if the input item is a JAX tracer, False otherwise.""" + for item in tree_flatten(x)[0]: + if ( + isinstance(item, jax.core.Tracer) + or type(item) is object + # or isinstance(item, jax.core.ShapedArray) + # or isinstance(item, str) + ): + return True + return False @jax.jit @@ -13,7 +28,7 @@ def compute_major_minor_from_jacobian(jac): return major, minor -def cast_to_array_scalar(x, dtype=None): +def _cast_to_array_scalar(x, dtype=None): """Cast the input to an array scalar. Works on python scalars, iterables and jax arrays. For iterables it always takes the first element after a call to .ravel()""" if dtype is None: @@ -34,53 +49,60 @@ def cast_to_python_float(x): """Cast the input to a python float. Works on python floats and jax arrays. For jax arrays it always takes the first element after a call to .ravel()""" if isinstance(x, jax.Array): - return cast_to_array_scalar(x, dtype=float).item() + return _cast_to_array_scalar(x, dtype=float).item() else: try: return float(x) except TypeError: return x + except ValueError as e: + # we let NaNs through + if " NaN " in str(e): + return x + else: + raise e def cast_to_python_int(x): """Cast the input to a python int. Works on python ints and jax arrays. For jax arrays it always takes the first element after a call to .ravel()""" if isinstance(x, jax.Array): - return cast_to_array_scalar(x, dtype=int).item() + return _cast_to_array_scalar(x, dtype=int).item() else: try: return int(x) except TypeError: return x + except ValueError as e: + # we let NaNs through + if " NaN " in str(e): + return x + else: + raise e def cast_to_float(x): """Cast the input to a float. Works on python floats and jax arrays.""" - if isinstance(x, jax.Array): - return x.astype(float) - elif hasattr(x, "astype"): - return x.astype(float) - else: + try: + return float(x) + except Exception: try: - return float(x) - except TypeError as e: - # needed so that tests of jax_galsim.angle pass - if "AngleUnit" in str(e): - raise e - + return jnp.asarray(x, dtype=float) + except Exception: return x def cast_to_int(x): """Cast the input to an int. Works on python floats/ints and jax arrays.""" - if isinstance(x, jax.Array): - return x.astype(int) - elif hasattr(x, "astype"): - return x.astype(int) - else: + try: + return int(x) + except Exception: try: - return int(x) - except TypeError: + if not jnp.any(jnp.isnan(x)): + return jnp.asarray(x, dtype=int) + else: + return x + except Exception: return x @@ -132,26 +154,41 @@ def is_equal_with_arrays(x, y): return x == y +def _convert_to_numpy_nan(x): + """Convert input to numpy.nan if it is a NaN, otherwise return it unchanged + so that we get consistent hashing.""" + try: + if np.isnan(x): + return np.nan + else: + return x + except Exception: + return x + + def _recurse_list_to_tuple(x): if isinstance(x, list): return tuple(_recurse_list_to_tuple(v) for v in x) else: - return x + return _convert_to_numpy_nan(x) def ensure_hashable(v): """Ensure that the input is hashable. If it is a jax array, - convert it to a possibly nested tuple or python scalar.""" + convert it to a possibly nested tuple or python scalar. + + All NaNs are converted to numpy.nan to get consistent hashing. + """ if isinstance(v, jax.Array): try: if len(v.shape) > 0: return _recurse_list_to_tuple(v.tolist()) else: - return v.item() + return _convert_to_numpy_nan(v.item()) except Exception: - return v + return _convert_to_numpy_nan(v) else: - return v + return _convert_to_numpy_nan(v) @partial(jax.jit, static_argnames=("niter",)) diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index 6d32fd6f..6dafc034 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -445,7 +445,7 @@ def _setup_image( ) # Resize the given image if necessary - if not image.bounds.isDefined(): + if not image.bounds.isDefined(_static=True): # Can't add to image if need to resize if add_to_image: raise _galsim.GalSimIncompatibleValuesError( @@ -478,7 +478,7 @@ def _setup_image( ny=ny, bounds=bounds, ) - if not bounds.isDefined(): + if not bounds.isDefined(_static=True): raise _galsim.GalSimValueError( "Cannot use undefined bounds", bounds ) @@ -515,7 +515,7 @@ def _local_wcs(self, wcs, image, offset, center, use_true_center, new_bounds): bounds = new_bounds else: bounds = image.bounds - if not bounds.isDefined(): + if not bounds.isDefined(_static=True): raise _galsim.GalSimIncompatibleValuesError( "Cannot provide non-local wcs with automatically sized image", wcs=wcs, @@ -556,7 +556,7 @@ def _parse_center(self, center): def _get_new_bounds(self, image, nx, ny, bounds, center): from jax_galsim.bounds import BoundsI - if image is not None and image.bounds.isDefined(): + if image is not None and image.bounds.isDefined(_static=True): return image.bounds elif nx is not None and ny is not None: b = BoundsI(1, nx, 1, ny) @@ -568,7 +568,7 @@ def _get_new_bounds(self, image, nx, ny, bounds, center): ) ) return b - elif bounds is not None and bounds.isDefined(): + elif bounds is not None and bounds.isDefined(_static=True): return bounds else: return BoundsI() @@ -576,7 +576,7 @@ def _get_new_bounds(self, image, nx, ny, bounds, center): def _adjust_offset(self, new_bounds, offset, center, use_true_center): # Note: this assumes self is in terms of image coordinates. if center is not None: - if new_bounds.isDefined(): + if new_bounds.isDefined(_static=True): offset += center - new_bounds.center else: # Then will be created as even sized image. @@ -590,7 +590,7 @@ def _adjust_offset(self, new_bounds, offset, center, use_true_center): # Also, remember that numpy's shape is ordered as [y,x] dx = offset.x dy = offset.y - shape = new_bounds.numpyShape() + shape = new_bounds.numpyShape(_static=True) dx -= 0.5 * ((shape[1] + 1) % 2) dy -= 0.5 * ((shape[0] + 1) % 2) @@ -690,6 +690,13 @@ def drawImage( "Setting maxN is incompatible with save_photons=True" ) + if method not in ("auto", "fft", "real_space", "phot", "no_pixel", "sb"): + raise GalSimValueError( + "Invalid method name", + method, + ("auto", "fft", "real_space", "phot", "no_pixel", "sb"), + ) + # Check that the user isn't convolving by a Pixel already. This is almost always an error. if method == "auto" and isinstance(self, Convolution): if any([isinstance(obj, Pixel) for obj in self.obj_list]): @@ -703,6 +710,45 @@ def drawImage( "an _additional_ Pixel, you can suppress this warning by using method=fft." ) + if method != "phot": + if n_photons is not None: + raise GalSimIncompatibleValuesError( + "n_photons is only relevant for method='phot'", + method=method, + sensor=sensor, + n_photons=n_photons, + ) + if poisson_flux is not None: + raise GalSimIncompatibleValuesError( + "poisson_flux is only relevant for method='phot'", + method=method, + sensor=sensor, + poisson_flux=poisson_flux, + ) + + if method != "phot" and sensor is None: + if rng is not None: + raise GalSimIncompatibleValuesError( + "rng is only relevant for method='phot' or when using a sensor", + method=method, + sensor=sensor, + rng=rng, + ) + if maxN is not None: + raise GalSimIncompatibleValuesError( + "maxN is only relevant for method='phot' or when using a sensor", + method=method, + sensor=sensor, + maxN=maxN, + ) + if save_photons: + raise GalSimIncompatibleValuesError( + "save_photons is only valid for method='phot' or when using a sensor", + method=method, + sensor=sensor, + save_photons=save_photons, + ) + # Figure out what wcs we are going to use. wcs = self._determine_wcs(scale, wcs, image) @@ -726,7 +772,7 @@ def drawImage( flux_scale /= local_wcs.pixelArea() # Only do the gain here if not photon shooting, since need the number of photons to # reflect that actual photons, not ADU. - if method != "phot" and sensor is None and gain != 1: + if method != "phot" and sensor is None: flux_scale /= gain # Determine the offset, and possibly fix the centering for even-sized images @@ -785,9 +831,9 @@ def drawImage( local_wcs, ) else: - if sensor is not None: + if sensor is not None or photon_ops: raise NotImplementedError( - "Sensor not yet implemented in drawImage for method != 'phot'." + "Sensor/photon_ops not yet implemented in drawImage for method != 'phot'." ) if prof.is_analytic_x: @@ -872,7 +918,7 @@ def drawFFT_makeKImage(self, image): jnp.array( [ jnp.max(jnp.abs(jnp.array(image.bounds._getinitargs()))) * 2, - jnp.max(jnp.array(image.bounds.numpyShape())), + jnp.max(jnp.array(image.bounds.numpyShape(_static=True))), ] ) ) @@ -938,7 +984,7 @@ def drawFFT_finish(self, image, kimage, wrap_size, add_to_image): ) kimg_shift = jnp.fft.ifftshift(kimage_wrap.array, axes=(-2,)) real_image_arr = jnp.fft.fftshift( - jnp.fft.irfft2(kimg_shift, breal.numpyShape()) + jnp.fft.irfft2(kimg_shift, breal.numpyShape(_static=True)) ) real_image = Image( bounds=breal, array=real_image_arr, dtype=image.dtype, wcs=image.wcs @@ -1023,7 +1069,7 @@ def drawKImage( dk = self.stepk else: dk = scale - if image is not None and image.bounds.isDefined(): + if image is not None and image.bounds.isDefined(_static=True): dx = np.pi / (max(image.array.shape) // 2 * dk) elif scale is None or scale <= 0: dx = self.nyquist_scale @@ -1035,7 +1081,7 @@ def drawKImage( # If the profile needs to be constructed from scratch, the _setup_image function will # do that, but only if the profile is in image coordinates for the real space image. # So make that profile. - if image is None or not image.bounds.isDefined(): + if image is None or not image.bounds.isDefined(_static=True): real_prof = PixelScale(dx).profileToImage(self) dtype = np.complex128 if image is None else image.dtype image = real_prof._setup_image( diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 026e7778..a7a265d2 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -9,6 +9,7 @@ from jax_galsim.position import PositionI from jax_galsim.utilities import parse_pos_args from jax_galsim.wcs import BaseWCS, PixelScale +from jax_galsim.errors import GalSimImmutableError @_wraps( @@ -54,6 +55,9 @@ class Image(object): valid_dtypes = _valid_dtypes def __init__(self, *args, **kwargs): + # this one is pecific to jax-galsim and is used to disable bounds checking + _check_bounds = kwargs.pop("_check_bounds", True) + # Parse the args, kwargs ncol = None nrow = None @@ -70,15 +74,21 @@ def __init__(self, *args, **kwargs): elif len(args) == 1: if isinstance(args[0], np.ndarray): array = jnp.array(args[0]) - array, xmin, ymin = self._get_xmin_ymin(array, kwargs) + array, xmin, ymin = self._get_xmin_ymin( + array, kwargs, check_bounds=_check_bounds + ) elif isinstance(args[0], jnp.ndarray): array = args[0] - array, xmin, ymin = self._get_xmin_ymin(array, kwargs) + array, xmin, ymin = self._get_xmin_ymin( + array, kwargs, check_bounds=_check_bounds + ) elif isinstance(args[0], BoundsI): bounds = args[0] elif isinstance(args[0], (list, tuple)): array = jnp.array(args[0]) - array, xmin, ymin = self._get_xmin_ymin(array, kwargs) + array, xmin, ymin = self._get_xmin_ymin( + array, kwargs, check_bounds=_check_bounds + ) elif isinstance(args[0], Image): image = args[0] else: @@ -88,9 +98,8 @@ def __init__(self, *args, **kwargs): else: if "array" in kwargs: array = kwargs.pop("array") - check_bounds = kwargs.pop("check_bounds", True) array, xmin, ymin = self._get_xmin_ymin( - array, kwargs, check_bounds=check_bounds + array, kwargs, check_bounds=_check_bounds ) elif "bounds" in kwargs: bounds = kwargs.pop("bounds") @@ -107,6 +116,7 @@ def __init__(self, *args, **kwargs): init_value = kwargs.pop("init_value", None) scale = kwargs.pop("scale", None) wcs = kwargs.pop("wcs", None) + self._is_const = kwargs.pop("make_const", False) # Check that we got them all if kwargs: @@ -118,11 +128,6 @@ def __init__(self, *args, **kwargs): # remove it since we used it kwargs.pop("copy", None) - if "make_const" in kwargs.keys(): - raise TypeError( - "'make_const' is not a valid keyword argument for the JAX-GalSim version of the Image constructor" - ) - if kwargs: raise TypeError( "Image constructor got unexpected keyword arguments: %s", kwargs @@ -183,7 +188,9 @@ def __init__(self, *args, **kwargs): elif bounds is not None: if not isinstance(bounds, BoundsI): raise TypeError("bounds must be a galsim.BoundsI instance") - self._array = self._make_empty(bounds.numpyShape(), dtype=self._dtype) + self._array = self._make_empty( + bounds.numpyShape(_static=True), dtype=self._dtype + ) self._bounds = bounds if init_value: self._array = self._array + init_value @@ -218,7 +225,6 @@ def __init__(self, *args, **kwargs): self._dtype = dtype self._array = image.array.astype(self._dtype) else: - # TODO: remove this possiblity of creating an empty image. self._array = jnp.zeros(shape=(1, 1), dtype=self._dtype) self._bounds = BoundsI() if init_value is not None: @@ -255,7 +261,9 @@ def _get_xmin_ymin(array, kwargs, check_bounds=True): b = kwargs.pop("bounds") if not isinstance(b, BoundsI): raise TypeError("bounds must be a galsim.BoundsI instance") - if check_bounds: + # we use the static bounds check here since we cannot raise errors in jitted + # code anyways + if check_bounds and b.isDefined(_static=True): # We need to disable this when jitting if b.xmax - b.xmin + 1 != array.shape[1]: raise _galsim.GalSimIncompatibleValuesError( @@ -269,7 +277,11 @@ def _get_xmin_ymin(array, kwargs, check_bounds=True): array=array, bounds=b, ) - if b.isDefined(): + + # this statement in JAX would change array sizes and so is not supported + # so instead we check the static property set on construction for whether + # the bounds are defined + if b.isDefined(_static=True): xmin = b.xmin ymin = b.ymin else: @@ -290,9 +302,11 @@ def _get_xmin_ymin(array, kwargs, check_bounds=True): def __repr__(self): s = "galsim.Image(bounds=%r" % self.bounds - if self.bounds.isDefined(): + if self.bounds.isDefined(_static=True): s += ", array=\n%r" % np.array(self.array) s += ", wcs=%r" % self.wcs + if self.isconst: + s += ", make_const=True" s += ")" return s @@ -334,7 +348,7 @@ def array(self): @property def isconst(self): """Whether the `Image` is constant. I.e. modifying its values is an error.""" - return True + return self._is_const @property def iscomplex(self): @@ -468,7 +482,11 @@ def get_pixel_centers(self): def _make_empty(self, shape, dtype): """Helper function to make an empty numpy array of the given shape.""" - return jnp.zeros(shape=shape, dtype=dtype) + if np.prod(shape) == 0: + # galsim forces degenrate images to have at least 1 pixel + return jnp.zeros(shape=(1, 1), dtype=dtype) + else: + return jnp.zeros(shape=shape, dtype=dtype) def resize(self, bounds, wcs=None): """Resize the image to have a new bounds (must be a `BoundsI` instance) @@ -483,9 +501,13 @@ def resize(self, bounds, wcs=None): wcs: If provided, also update the wcs to the given value. [default: None, which means keep the existing wcs] """ + if self.isconst: + raise GalSimImmutableError("Cannot modify an immutable Image", self) if not isinstance(bounds, BoundsI): raise TypeError("bounds must be a galsim.BoundsI instance") - self._array = self._make_empty(shape=bounds.numpyShape(), dtype=self.dtype) + self._array = self._make_empty( + shape=bounds.numpyShape(_static=True), dtype=self.dtype + ) self._bounds = bounds if wcs is not None: self.wcs = wcs @@ -497,11 +519,11 @@ def subImage(self, bounds): """ if not isinstance(bounds, BoundsI): raise TypeError("bounds must be a galsim.BoundsI instance") - if not self.bounds.isDefined(): + if not self.bounds.isDefined(_static=True): raise _galsim.GalSimUndefinedBoundsError( "Attempt to access subImage of undefined image" ) - if not self.bounds.includes(bounds): + if not self.bounds.includes(bounds, _static=True): raise _galsim.GalSimBoundsError( "Attempt to access subImage not (fully) in image", bounds, self.bounds ) @@ -521,19 +543,21 @@ def setSubImage(self, bounds, rhs): This is equivalent to self[bounds] = rhs """ + if self.isconst: + raise GalSimImmutableError("Cannot modify an immutable Image", self) if not isinstance(bounds, BoundsI): raise TypeError("bounds must be a galsim.BoundsI instance") - if not self.bounds.isDefined(): + if not self.bounds.isDefined(_static=True): raise _galsim.GalSimUndefinedBoundsError( "Attempt to access values of an undefined image" ) - if not self.bounds.includes(bounds): + if not self.bounds.includes(bounds, _static=True): raise _galsim.GalSimBoundsError( "Attempt to access subImage not (fully) in image", bounds, self.bounds ) if not isinstance(rhs, Image): raise TypeError("Trying to copyFrom a non-image") - if bounds.numpyShape() != rhs.bounds.numpyShape(): + if bounds.numpyShape(_static=True) != rhs.bounds.numpyShape(_static=True): raise _galsim.GalSimIncompatibleValuesError( "Trying to copy images that are not the same shape", self_image=self, @@ -685,7 +709,7 @@ def _wrap(self, bounds, hermx, hermy): lax_description="JAX-GalSim does not support forward FFTs of complex dtypes.", ) def calculate_fft(self): - if not self.bounds.isDefined(): + if not self.bounds.isDefined(_static=True): raise _galsim.GalSimUndefinedBoundsError( "calculate_fft requires that the image have defined bounds." ) @@ -737,7 +761,7 @@ def calculate_fft(self): @_wraps(_galsim.Image.calculate_inverse_fft) def calculate_inverse_fft(self): - if not self.bounds.isDefined(): + if not self.bounds.isDefined(_static=True): raise _galsim.GalSimUndefinedBoundsError( "calculate_fft requires that the image have defined bounds." ) @@ -749,7 +773,7 @@ def calculate_inverse_fft(self): raise _galsim.GalSimError( "calculate_inverse_fft requires that the image has a PixelScale wcs." ) - if not self.bounds.includes(0, 0): + if not self.bounds.includes(0, 0, _static=True): raise _galsim.GalSimBoundsError( "calculate_inverse_fft requires that the image includes (0,0)", PositionI(0, 0), @@ -816,9 +840,11 @@ def good_fft_size(cls, input_size): def copyFrom(self, rhs): """Copy the contents of another image""" + if self.isconst: + raise GalSimImmutableError("Cannot modify an immutable Image", self) if not isinstance(rhs, Image): raise TypeError("Trying to copyFrom a non-image") - if self.bounds.numpyShape() != rhs.bounds.numpyShape(): + if self.bounds.numpyShape(_static=True) != rhs.bounds.numpyShape(_static=True): raise _galsim.GalSimIncompatibleValuesError( "Trying to copy images that are not the same shape", self_image=self, @@ -828,7 +854,12 @@ def copyFrom(self, rhs): @_wraps( _galsim.Image.view, - lax_description="Contrary to GalSim, this will create a copy of the orginal image.", + lax_description="""\ +Contrary to GalSim, view + + - will create a copy of the orginal image + - will not check for undefined bounds +""", ) def view( self, @@ -861,7 +892,9 @@ def view( dtype = dtype if dtype else self.dtype # If currently empty, just return a new empty image. - if not self.bounds.isDefined(): + # we use the static bounds check set at construction + # since the dynamic one in JAX would change array shape + if not self.bounds.isDefined(_static=True): return Image(wcs=wcs, dtype=dtype) # Recast the array type if necessary @@ -935,11 +968,11 @@ def __call__(self, *args, **kwargs): @_wraps(_galsim.Image.getValue) def getValue(self, x, y): - if not self.bounds.isDefined(): + if not self.bounds.isDefined(_static=True): raise _galsim.GalSimUndefinedBoundsError( "Attempt to access values of an undefined image" ) - if not self.bounds.includes(x, y): + if not self.bounds.includes(x, y, _static=True): raise _galsim.GalSimBoundsError( "Attempt to access position not in bounds of image.", PositionI(x, y), @@ -955,14 +988,16 @@ def _getValue(self, x, y): @_wraps(_galsim.Image.setValue) def setValue(self, *args, **kwargs): - if not self.bounds.isDefined(): + if self.isconst: + raise GalSimImmutableError("Cannot modify an immutable Image", self) + if not self.bounds.isDefined(_static=True): raise _galsim.GalSimUndefinedBoundsError( "Attempt to set value of an undefined image" ) pos, value = parse_pos_args( args, kwargs, "x", "y", integer=True, others=["value"] ) - if not self.bounds.includes(pos): + if not self.bounds.includes(pos, _static=True): raise _galsim.GalSimBoundsError( "Attempt to set position not in bounds of image", pos, self.bounds ) @@ -981,14 +1016,16 @@ def _setValue(self, x, y, value): @_wraps(_galsim.Image.addValue) def addValue(self, *args, **kwargs): - if not self.bounds.isDefined(): + if self.isconst: + raise GalSimImmutableError("Cannot modify an immutable Image", self) + if not self.bounds.isDefined(_static=True): raise _galsim.GalSimUndefinedBoundsError( "Attempt to set value of an undefined image" ) pos, value = parse_pos_args( args, kwargs, "x", "y", integer=True, others=["value"] ) - if not self.bounds.includes(pos): + if not self.bounds.includes(pos, _static=True): raise _galsim.GalSimBoundsError( "Attempt to set position not in bounds of image", pos, self.bounds ) @@ -1011,7 +1048,9 @@ def fill(self, value): Parameter: value: The value to set all the pixels to. """ - if not self.bounds.isDefined(): + if self.isconst: + raise GalSimImmutableError("Cannot modify an immutable Image", self) + if not self.bounds.isDefined(_static=True): raise _galsim.GalSimUndefinedBoundsError( "Attempt to set values of an undefined image" ) @@ -1023,6 +1062,8 @@ def _fill(self, value): def setZero(self): """Set all pixel values to zero.""" + if self.isconst: + raise GalSimImmutableError("Cannot modify an immutable Image", self) self._fill(0) def invertSelf(self): @@ -1031,7 +1072,9 @@ def invertSelf(self): Note: any pixels whose value is 0 originally are ignored. They remain equal to 0 on the output, rather than turning into inf. """ - if not self.bounds.isDefined(): + if self.isconst: + raise GalSimImmutableError("Cannot modify an immutable Image", self) + if not self.bounds.isDefined(_static=True): raise _galsim.GalSimUndefinedBoundsError( "Attempt to set values of an undefined image" ) @@ -1052,6 +1095,8 @@ def replaceNegative(self, replace_value=0): Parameters: replace_value: The value with which to replace any negative pixels. [default: 0] """ + if self.isconst: + raise GalSimImmutableError("Cannot modify an immutable Image", self) self._array = self.array.at[self.array < 0].set(replace_value) def __eq__(self, other): @@ -1068,7 +1113,8 @@ def __eq__(self, other): and self.bounds == other.bounds and self.wcs == other.wcs and ( - not self.bounds.isDefined() or jnp.array_equal(self.array, other.array) + not self.bounds.isDefined(_static=True) + or jnp.array_equal(self.array, other.array) ) and self.isconst == other.isconst ) @@ -1076,11 +1122,43 @@ def __eq__(self, other): def __ne__(self, other): return not self.__eq__(other) + @_wraps(_galsim.Image.transpose) + def transpose(self): + bT = BoundsI(self.ymin, self.ymax, self.xmin, self.xmax) + return _Image(self.array.T, bT, None) + + @_wraps(_galsim.Image.flip_lr) + def flip_lr(self): + return _Image(self.array.at[:, ::-1].get(), self._bounds, None) + + @_wraps(_galsim.Image.flip_ud) + def flip_ud(self): + return _Image(self.array.at[::-1, :].get(), self._bounds, None) + + @_wraps(_galsim.Image.rot_cw) + def rot_cw(self): + bT = BoundsI(self.ymin, self.ymax, self.xmin, self.xmax) + return _Image(self.array.T.at[::-1, :].get(), bT, None) + + @_wraps(_galsim.Image.rot_ccw) + def rot_ccw(self): + bT = BoundsI(self.ymin, self.ymax, self.xmin, self.xmax) + return _Image(self.array.T.at[:, ::-1].get(), bT, None) + + @_wraps(_galsim.Image.rot_180) + def rot_180(self): + """Return a version of the image rotated 180 degrees. + + Note: The returned image will have an undefined wcs. + If you care about the wcs, you will need to set it yourself. + """ + return _Image(self.array.at[::-1, ::-1].get(), self._bounds, None) + def tree_flatten(self): """Flatten the image into a list of values.""" # Define the children nodes of the PyTree that need tracing children = (self.array, self.wcs) - aux_data = {"dtype": self.dtype, "bounds": self.bounds} + aux_data = {"dtype": self.dtype, "bounds": self.bounds, "isconst": self.isconst} return (children, aux_data) @classmethod @@ -1091,6 +1169,7 @@ def tree_unflatten(cls, aux_data, children): obj.wcs = children[1] obj._bounds = aux_data["bounds"] obj._dtype = aux_data["dtype"] + obj._is_const = aux_data["isconst"] return obj @classmethod @@ -1106,6 +1185,19 @@ def from_galsim(cls, galsim_image): return im +@_wraps(_galsim._Image) +def _Image(array, bounds, wcs): + ret = Image.__new__(Image) + ret.wcs = wcs + ret._dtype = array.dtype.type + if ret._dtype in Image._alias_dtypes: + ret._dtype = Image._alias_dtypes[ret._dtype] + array = array.astype(ret._dtype) + ret._array = array + ret._bounds = bounds + return ret + + # These are essentially aliases for the regular Image with the correct dtype def ImageUS(*args, **kwargs): """Alias for galsim.Image(..., dtype=numpy.uint16)""" diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index e4efdbbd..3be49772 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -61,6 +61,7 @@ def __dir__(cls): - the pad_image options - depixelize - most of the type checks and dtype casts done by galsim + - the image bounds are defined """ ), ) @@ -426,7 +427,7 @@ def __init__( ) # it must have well-defined bounds, otherwise seg fault in SBInterpolatedImage constructor - if not image.bounds.isDefined(): + if not image.bounds.isDefined(_static=True): raise GalSimUndefinedBoundsError( "Supplied image does not have bounds defined." ) @@ -594,7 +595,7 @@ def _image(self): # Store the image as an attribute and make sure we don't change the original image # in anything we do here. (e.g. set scale, etc.) if self._jax_aux_data["depixelize"]: - # FIXME: no depixelize in jax_galsim + # TODO: no depixelize in jax_galsim # self._image = image.view(dtype=np.float64).depixelize(self._x_interpolant) raise NotImplementedError( "InterpolatedImages do not support 'depixelize' in jax_galsim." @@ -801,7 +802,7 @@ def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): im = (im * flux_scaling).astype(image.dtype) # Return an image - return Image(array=im, bounds=image.bounds, wcs=image.wcs, check_bounds=False) + return Image(array=im, bounds=image.bounds, wcs=image.wcs, _check_bounds=False) def _drawKImage(self, image, jac=None): jacobian = jnp.eye(2) if jac is None else jac @@ -826,7 +827,7 @@ def _drawKImage(self, image, jac=None): im = (im).astype(image.dtype) # Return an image - return Image(array=im, bounds=image.bounds, wcs=image.wcs, check_bounds=False) + return Image(array=im, bounds=image.bounds, wcs=image.wcs, _check_bounds=False) @lazy_property def _pos_neg_fluxes(self): diff --git a/jax_galsim/photon_array.py b/jax_galsim/photon_array.py index da72dfbb..4024f337 100644 --- a/jax_galsim/photon_array.py +++ b/jax_galsim/photon_array.py @@ -640,21 +640,12 @@ def __eq__(self, other): def __ne__(self, other): return not self == other + @_wraps( + _galsim.PhotonArray.addTo, + lax_description="The JAX equivalent of galsim.PhotonArray.addTo may not raise for undefined bounds.", + ) def addTo(self, image): - """Add flux of photons to an image by binning into pixels. - - Photons in this `PhotonArray` are binned into the pixels of the input - `Image` and their flux summed into the pixels. The `Image` is assumed to represent - surface brightness, so photons' fluxes are divided by image pixel area. - Photons past the edges of the image are discarded. - - Parameters: - image: The `Image` to which the photons' flux will be added. - - Returns: - the total flux of photons the landed inside the image bounds. - """ - if not image.bounds.isDefined(): + if not image.bounds.isDefined(_static=True): raise GalSimUndefinedBoundsError( "Attempting to PhotonArray::addTo an Image with undefined Bounds" ) diff --git a/jax_galsim/position.py b/jax_galsim/position.py index f89d9c4f..01b264f5 100644 --- a/jax_galsim/position.py +++ b/jax_galsim/position.py @@ -181,8 +181,8 @@ def __init__(self, *args, **kwargs): self._parse_args(*args, **kwargs) # Force conversion to float type in this case - self.x = cast_to_float(self.x) - self.y = cast_to_float(self.y) + self.x = 1.0 * cast_to_float(self.x) + self.y = 1.0 * cast_to_float(self.y) def _check_scalar(self, other, op): try: @@ -206,8 +206,8 @@ def __init__(self, *args, **kwargs): self._parse_args(*args, **kwargs) # inputs must be ints - self.x = cast_to_int(self.x) - self.y = cast_to_int(self.y) + self.x = 1 * cast_to_int(self.x) + self.y = 1 * cast_to_int(self.y) def _check_scalar(self, other, op): try: diff --git a/jax_galsim/sensor.py b/jax_galsim/sensor.py index 19f3693e..47d5a940 100644 --- a/jax_galsim/sensor.py +++ b/jax_galsim/sensor.py @@ -2,8 +2,8 @@ from jax._src.numpy.util import _wraps from jax.tree_util import register_pytree_node_class -from .errors import GalSimUndefinedBoundsError -from .position import PositionI +from jax_galsim.position import PositionI +from jax_galsim.errors import GalSimUndefinedBoundsError @_wraps(_galsim.Sensor) @@ -12,9 +12,12 @@ class Sensor: def __init__(self): pass - @_wraps(_galsim.Sensor.accumulate) + @_wraps( + _galsim.Sensor.accumulate, + lax_description="The JAX equivalent of galsim.Sensor.accumulate does not raise for undefined bounds.", + ) def accumulate(self, photons, image, orig_center=None, resume=False): - if not image.bounds.isDefined(): + if not image.bounds.isDefined(_static=True): raise GalSimUndefinedBoundsError( "Calling accumulate on image with undefined bounds" ) diff --git a/jax_galsim/utilities.py b/jax_galsim/utilities.py index 180fd7ee..ecd962bd 100644 --- a/jax_galsim/utilities.py +++ b/jax_galsim/utilities.py @@ -4,22 +4,14 @@ import jax import jax.numpy as jnp from jax._src.numpy.util import _wraps -from jax.tree_util import tree_flatten from jax_galsim.errors import GalSimIncompatibleValuesError, GalSimValueError from jax_galsim.position import PositionD, PositionI +from jax_galsim.core.utils import has_tracers printoptions = _galsim.utilities.printoptions -def has_tracers(x): - """Return True if the data is equal, False otherwise. Handles jax.Array types.""" - for item in tree_flatten(x)[0]: - if isinstance(item, jax.core.Tracer): - return True - return False - - @_wraps( _galsim.utilities.lazy_property, lax_description=( diff --git a/jax_galsim/wcs.py b/jax_galsim/wcs.py index afcc2079..fc01bfea 100644 --- a/jax_galsim/wcs.py +++ b/jax_galsim/wcs.py @@ -837,6 +837,8 @@ class PixelScale(LocalWCS): _isPixelScale = True def __init__(self, scale): + if isinstance(scale, BaseWCS): + raise TypeError("Cannot initialize PixelScale from a BaseWCS") self._params = {"scale": scale} self._color = None @@ -1630,8 +1632,7 @@ def readFromFitsHeader(header, suppress_warning=True): """ from . import fits - # FIXME: Enable FitsWCS - # from .fitswcs import FitsWCS + from .fitswcs import FitsWCS if not isinstance(header, fits.FitsHeader): header = fits.FitsHeader(header) xmin = header.get("GS_XMIN", 1) @@ -1644,9 +1645,8 @@ def readFromFitsHeader(header, suppress_warning=True): wcs_type = eval("jax_galsim." + wcs_name, gdict) wcs = wcs_type._readHeader(header) else: - raise NotImplementedError("FitsWCS is not implemented for jax_galsim.") # If we aren't told which type to use, this should find something appropriate - # wcs = FitsWCS(header=header, suppress_warning=suppress_warning) + wcs = FitsWCS(header=header, suppress_warning=suppress_warning) if xmin != 1 or ymin != 1: # ds9 always assumes the image has an origin at (1,1), so convert back to actual diff --git a/tests/Coord b/tests/Coord index 1d40ee30..d70a77fa 160000 --- a/tests/Coord +++ b/tests/Coord @@ -1 +1 @@ -Subproject commit 1d40ee30c1a49131f9e93cbf23869bc2f9adedb5 +Subproject commit d70a77fa33eb6d490278fd3f160062cdeb7e7a47 diff --git a/tests/GalSim b/tests/GalSim index 9e8d6565..8a3440d7 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 9e8d6565e88260586911339d1b3d8f32a7a8e1ba +Subproject commit 8a3440d72d739763514a2620e61c6e50668648b9 diff --git a/tests/conftest.py b/tests/conftest.py index 17175c1b..50a3f527 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -45,14 +45,25 @@ def pytest_ignore_collect(collection_path, path, config): These somtimes fail to import and cause pytest to fail. """ if "tests/GalSim/tests" in str(collection_path): - if not any( - [t in str(collection_path) for t in test_config["enabled_tests"]["galsim"]] - ): + if ( + not any( + [ + t in str(collection_path) + for t in test_config["enabled_tests"]["galsim"] + ] + ) + ) and "*" not in test_config["enabled_tests"]["galsim"]: return True + if "tests/Coord/tests" in str(collection_path): - if not any( - [t in str(collection_path) for t in test_config["enabled_tests"]["coord"]] - ): + if ( + not any( + [ + t in str(collection_path) + for t in test_config["enabled_tests"]["coord"] + ] + ) + ) and "*" not in test_config["enabled_tests"]["coord"]: return True @@ -70,9 +81,20 @@ def pytest_collection_modifyitems(config, items): # if this is a galsim test we check if it is requested or not if ( - not any([t in item.nodeid for t in test_config["enabled_tests"]["galsim"]]) - ) and ( - not any([t in item.nodeid for t in test_config["enabled_tests"]["coord"]]) + ( + ( + not any( + [t in item.nodeid for t in test_config["enabled_tests"]["galsim"]] + ) + ) + and "*" not in test_config["enabled_tests"]["galsim"] + ) and ( + ( + not any( + [t in item.nodeid for t in test_config["enabled_tests"]["coord"]] + ) + ) and "*" not in test_config["enabled_tests"]["coord"] + ) ): item.add_marker(skip) diff --git a/tests/galsim_tests_config.yaml b/tests/galsim_tests_config.yaml index 88636ff2..9860c962 100644 --- a/tests/galsim_tests_config.yaml +++ b/tests/galsim_tests_config.yaml @@ -15,6 +15,12 @@ enabled_tests: - test_box.py - test_interpolatedimage.py - test_deltafunction.py + - test_draw.py + - test_random.py + - test_noise.py + - test_image.py + - test_phpton_array.py + - "*" coord: - test_angle.py - test_angleunit.py @@ -48,8 +54,32 @@ allowed_failures: - "module 'jax_galsim.utilities' has no attribute 'rand_with_replacement'" - "module 'jax_galsim.utilities' has no attribute 'dol_to_lod'" - "module 'jax_galsim.utilities' has no attribute 'nCr'" + - "module 'jax_galsim' has no attribute 'LookupTable'" + - "module 'jax_galsim.bessel' has no attribute 'j0'" + - "module 'jax_galsim.bessel' has no attribute 'j1'" + - "module 'jax_galsim.bessel' has no attribute 'jn'" + - "module 'jax_galsim.bessel' has no attribute 'jv'" + - "module 'jax_galsim.bessel' has no attribute 'yn'" + - "module 'jax_galsim.bessel' has no attribute 'yv'" + - "module 'jax_galsim.bessel' has no attribute 'iv'" + - "module 'jax_galsim.bessel' has no attribute 'kn'" + - "module 'jax_galsim.bessel' has no attribute 'kv'" + - "module 'jax_galsim.bessel' has no attribute 'j0_root'" + - "module 'jax_galsim.bessel' has no attribute 'gammainc'" + - "module 'jax_galsim.bessel' has no attribute 'sinc'" + - "module 'jax_galsim.bessel' has no attribute 'ci'" + - "object has no attribute 'calculateHLR'" + - "object has no attribute 'calculateMomentRadius'" + - "object has no attribute 'calculateFWHM'" + - "module 'jax_galsim' has no attribute 'Catalog'" + - "module 'jax_galsim' has no attribute 'Dict'" + - "module 'jax_galsim' has no attribute 'OutputCatalog'" + - "module 'jax_galsim' has no attribute 'cdmodel'" + - "module 'jax_galsim' has no attribute 'ChromaticObject'" + - "module 'jax_galsim' has no attribute 'ChromaticAiry'" + - "module 'jax_galsim' has no attribute 'config'" + - "module 'jax_galsim' has no attribute 'RealGalaxyCatalog'" - "'Image' object has no attribute 'bin'" - - "module 'jax_galsim' has no attribute 'integ'" - "module 'jax_galsim' has no attribute 'InterpolatedKImage'" - "module 'jax_galsim' has no attribute 'CorrelatedNoise'" - "'Image' object has no attribute 'FindAdaptiveMom'" @@ -58,10 +88,8 @@ allowed_failures: - "ValueError not raised by greatCirclePoint" - "TypeError not raised by __mul__" - "ValueError not raised by CelestialCoord" - - "'Image' object has no attribute 'FindAdaptiveMom'" - "module 'jax_galsim' has no attribute 'BaseCorrelatedNoise'" - - " module 'jax_galsim' has no attribute 'fft'" - - "'Image' object has no attribute 'addNoise'" + - "module 'jax_galsim' has no attribute 'fft'" - "Transform does not support callable arguments." - "The JAX galsim.BaseDeviate does not support being used as a numpy PRNG." - "pad_image not implemented in jax_galsim." @@ -74,3 +102,43 @@ allowed_failures: - "module 'jax_galsim' has no attribute 'TimeSampler'" - "object has no attribute 'noise'" - "module 'jax_galsim' has no attribute 'SED'" + - "module 'jax_galsim' has no attribute 'getCOSMOSNoise'" + - "GSParams.__init__() got an unexpected keyword argument 'allowed_flux_variation'" + - "module 'jax_galsim' has no attribute 'Atmosphere'" + - "module 'jax_galsim' has no attribute 'RandomWalk'" + - "module 'jax_galsim' has no attribute 'hsm'" + - "module 'jax_galsim' has no attribute 'des'" + - "'Image' object has no attribute 'applyNonlinearity'" + - "'Image' object has no attribute 'addReciprocityFailure'" + - "'Image' object has no attribute 'quantize'" + - "'Image' object has no attribute 'applyIPC'" + - "'Image' object has no attribute 'applyPersistence'" + - "module 'jax_galsim' has no attribute 'download_cosmos'" + - "module 'jax_galsim' has no attribute 'FourierSqrt'" + - "module 'jax_galsim' has no attribute 'COSMOSCatalog'" + - "module 'jax_galsim' has no attribute 'include_dir'" + - "module 'jax_galsim' has no attribute 'VonKarman'" + - "module 'jax_galsim' has no attribute 'InclinedExponential'" + - "module 'jax_galsim' has no attribute 'RandomKnots'" + - "module 'jax_galsim' has no attribute 'NFWHalo'" + - "module 'jax_galsim' has no attribute 'Cosmology'" + - "module 'jax_galsim' has no attribute 'PowerSpectrum'" + - "module 'jax_galsim' has no attribute 'lensing_ps'" + - "module 'jax_galsim' has no attribute 'main'" + - "module 'jax_galsim' has no attribute 'OpticalPSF'" + - "module 'jax_galsim' has no attribute 'Aperture'" + - "module 'jax_galsim' has no attribute 'AtmosphericScreen'" + - "module 'jax_galsim' has no attribute 'OpticalScreen'" + - "module 'jax_galsim' has no attribute 'ChromaticConvolution'" + - "module 'jax_galsim' has no attribute 'phase_screens'" + - "module 'jax_galsim' has no attribute 'DistDeviate'" + - "module 'jax_galsim' has no attribute 'roman'" + - "module 'jax_galsim' has no attribute 'meta_data'" + - "module 'jax_galsim' has no attribute 'SecondKick'" + - "module 'jax_galsim.utilities' has no attribute 'combine_wave_list'" + - "Sensor/photon_ops not yet implemented in drawImage for method != 'phot'" + - "module 'jax_galsim' has no attribute 'SiliconSensor'" + - "module 'jax_galsim' has no attribute 'set_omp_threads'" + - "module 'jax_galsim' has no attribute 'Spergel'" + - "module 'jax_galsim' has no attribute 'LookupTable2D'" + - "module 'jax_galsim' has no attribute 'zernike'" diff --git a/tests/jax/galsim/test_draw_jax.py b/tests/jax/galsim/test_draw_jax.py deleted file mode 100644 index fe4656e7..00000000 --- a/tests/jax/galsim/test_draw_jax.py +++ /dev/null @@ -1,1720 +0,0 @@ -# Copyright (c) 2012-2023 by the GalSim developers team on GitHub -# https://github.com/GalSim-developers -# -# This file is part of GalSim: The modular galaxy image simulation toolkit. -# https://github.com/GalSim-developers/GalSim -# -# GalSim is free software: redistribution and use in source and binary forms, -# with or without modification, are permitted provided that the following -# conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions, and the disclaimer given in the accompanying LICENSE -# file. -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions, and the disclaimer given in the documentation -# and/or other materials provided with the distribution. -# - -import numpy as np -import os -import sys - -import galsim -from galsim_test_helpers import * - - -# for flux normalization tests -test_flux = 1.8 - -# A helper function used by both test_draw and test_drawk to check that the drawn image -# is a radially symmetric exponential with the right scale. -def CalculateScale(im): - # We just determine the scale radius of the drawn exponential by calculating - # the second moments of the image. - # int r^2 exp(-r/s) 2pir dr = 12 s^4 pi - # int exp(-r/s) 2pir dr = 2 s^2 pi - x, y = np.meshgrid(np.arange(np.shape(im.array)[0]), np.arange(np.shape(im.array)[1])) - if np.iscomplexobj(im.array): - T = complex - else: - T = float - flux = im.array.astype(T).sum() - mx = (x * im.array.astype(T)).sum() / flux - my = (y * im.array.astype(T)).sum() / flux - mxx = (((x-mx)**2) * im.array.astype(T)).sum() / flux - myy = (((y-my)**2) * im.array.astype(T)).sum() / flux - mxy = ((x-mx) * (y-my) * im.array.astype(T)).sum() / flux - s2 = mxx+myy - print(flux,mx,my,mxx,myy,mxy) - np.testing.assert_almost_equal((mxx-myy)/s2, 0, 5, "Found e1 != 0 for Exponential draw") - # NOTE: decreased precision from 5 to 3. Not sure why this is needed. - np.testing.assert_almost_equal(2*mxy/s2, 0, 3, "Found e2 != 0 for Exponential draw") - return np.sqrt(s2/6) * im.scale - - -@timer -def test_drawImage(): - """Test the various optional parameters to the drawImage function. - In particular test the parameters image and dx in various combinations. - """ - # We use a simple Exponential for our object: - obj = galsim.Exponential(flux=test_flux, scale_radius=2) - - # First test drawImage() with method='no_pixel'. It should: - # - create a new image - # - return the new image - # - set the scale to obj.nyquist_scale - # - set the size large enough to contain 99.5% of the flux - im1 = obj.drawImage(method='no_pixel') - nyq_scale = obj.nyquist_scale - np.testing.assert_almost_equal(im1.scale, nyq_scale, 9, - "obj.drawImage() produced image with wrong scale") - np.testing.assert_equal(im1.bounds, galsim.BoundsI(1,56,1,56), - "obj.drawImage() produced image with wrong bounds") - np.testing.assert_almost_equal(CalculateScale(im1), 2, 1, - "Measured wrong scale after obj.drawImage()") - - # The flux is only really expected to come out right if the object has been - # convoled with a pixel: - obj2 = galsim.Convolve([ obj, galsim.Pixel(im1.scale) ]) - im2 = obj2.drawImage(method='no_pixel') - nyq_scale = obj2.nyquist_scale - np.testing.assert_almost_equal(im2.scale, nyq_scale, 9, - "obj2.drawImage() produced image with wrong scale") - np.testing.assert_almost_equal(im2.array.astype(float).sum(), test_flux, 2, - "obj2.drawImage() produced image with wrong flux") - np.testing.assert_equal(im2.bounds, galsim.BoundsI(1,56,1,56), - "obj2.drawImage() produced image with wrong bounds") - np.testing.assert_almost_equal(CalculateScale(im2), 2, 1, - "Measured wrong scale after obj2.drawImage()") - # This should be the same as obj with method='auto' - im2 = obj.drawImage() - np.testing.assert_almost_equal(im2.scale, nyq_scale, 9, - "obj2.drawImage() produced image with wrong scale") - np.testing.assert_almost_equal(im2.array.astype(float).sum(), test_flux, 2, - "obj2.drawImage() produced image with wrong flux") - np.testing.assert_equal(im2.bounds, galsim.BoundsI(1,56,1,56), - "obj2.drawImage() produced image with wrong bounds") - np.testing.assert_almost_equal(CalculateScale(im2), 2, 1, - "Measured wrong scale after obj2.drawImage()") - - # Test if we provide an image argument. It should: - # - write to the existing image - # - also return that image - # - set the scale to obj2.nyquist_scale - # - zero out any existing data - im3 = galsim.ImageD(56,56) - im4 = obj.drawImage(im3) - np.testing.assert_almost_equal(im3.scale, nyq_scale, 9, - "obj.drawImage(im3) produced image with wrong scale") - np.testing.assert_almost_equal(im3.array.sum(), test_flux, 2, - "obj.drawImage(im3) produced image with wrong flux") - np.testing.assert_almost_equal(im3.array.sum(), im2.array.astype(float).sum(), 6, - "obj.drawImage(im3) produced image with different flux than im2") - np.testing.assert_almost_equal(CalculateScale(im3), 2, 1, - "Measured wrong scale after obj.drawImage(im3)") - np.testing.assert_array_equal(im3.array, im4.array, - "im4 = obj.drawImage(im3) produced im4 != im3") - # JAX cannot fill images by reference so we check object identity - assert im3 is im4 - assert im3.array is im4.array - # im3.fill(9.8) - # np.testing.assert_array_equal(im3.array, im4.array, - # "im4 = obj.drawImage(im3) produced im4 is not im3") - im4 = obj.drawImage(im3) - np.testing.assert_almost_equal(im3.array.sum(), im2.array.astype(float).sum(), 6, - "obj.drawImage(im3) doesn't zero out existing data") - - # Test if we provide an image with undefined bounds. It should: - # - resize the provided image - # - also return that image - # - set the scale to obj2.nyquist_scale - im5 = galsim.ImageD() - obj.drawImage(im5) - np.testing.assert_almost_equal(im5.scale, nyq_scale, 9, - "obj.drawImage(im5) produced image with wrong scale") - np.testing.assert_almost_equal(im5.array.sum(), test_flux, 2, - "obj.drawImage(im5) produced image with wrong flux") - np.testing.assert_almost_equal(CalculateScale(im5), 2, 1, - "Measured wrong scale after obj.drawImage(im5)") - np.testing.assert_almost_equal( - im5.array.sum(), im2.array.astype(float).sum(), 6, - "obj.drawImage(im5) produced image with different flux than im2") - np.testing.assert_equal(im5.bounds, galsim.BoundsI(1,56,1,56), - "obj.drawImage(im5) produced image with wrong bounds") - - # Test if we provide a dx to use. It should: - # - create a new image using that dx for the scale - # - return the new image - # - set the size large enough to contain 99.5% of the flux - scale = 0.51 # Just something different from 1 or dx_nyq - im7 = obj.drawImage(scale=scale,method='no_pixel') - np.testing.assert_almost_equal(im7.scale, scale, 9, - "obj.drawImage(dx) produced image with wrong scale") - np.testing.assert_almost_equal(im7.array.astype(float).sum(), test_flux, 2, - "obj.drawImage(dx) produced image with wrong flux") - np.testing.assert_almost_equal(CalculateScale(im7), 2, 1, - "Measured wrong scale after obj.drawImage(dx)") - np.testing.assert_equal(im7.bounds, galsim.BoundsI(1,68,1,68), - "obj.drawImage(dx) produced image with wrong bounds") - - # If also providing center, then same size, but centered near that center. - for center in [(3,3), (210.2, 511.9), (10.55, -23.8), (0.5,0.5)]: - im8 = obj.drawImage(scale=scale, center=center) - np.testing.assert_almost_equal(im8.scale, scale, 9) - # Note: it doesn't have to come out 68,68. If the offset is zero from the integer center, - # it drops down to (66, 66) - if center == (3,3): - np.testing.assert_equal(im8.array.shape, (66, 66)) - else: - np.testing.assert_equal(im8.array.shape, (68, 68)) - np.testing.assert_almost_equal(im8.array.astype(float).sum(), test_flux, 2) - print('center, true = ',center,im8.true_center) - assert abs(center[0] - im8.true_center.x) <= 0.5 - assert abs(center[1] - im8.true_center.y) <= 0.5 - - # Test if we provide an image with a defined scale. It should: - # - write to the existing image - # - use the image's scale - nx = 200 # Some randome size - im9 = galsim.ImageD(nx,nx, scale=scale) - obj.drawImage(im9) - np.testing.assert_almost_equal(im9.scale, scale, 9, - "obj.drawImage(im9) produced image with wrong scale") - np.testing.assert_almost_equal(im9.array.sum(), test_flux, 4, - "obj.drawImage(im9) produced image with wrong flux") - np.testing.assert_almost_equal(CalculateScale(im9), 2, 2, - "Measured wrong scale after obj.drawImage(im9)") - - # Test if we provide an image with a defined scale <= 0. It should: - # - write to the existing image - # - set the scale to obj2.nyquist_scale - im9.scale = -scale - im9.setZero() - obj.drawImage(im9) - np.testing.assert_almost_equal(im9.scale, nyq_scale, 9, - "obj.drawImage(im9) produced image with wrong scale") - np.testing.assert_almost_equal(im9.array.sum(), test_flux, 4, - "obj.drawImage(im9) produced image with wrong flux") - np.testing.assert_almost_equal(CalculateScale(im9), 2, 2, - "Measured wrong scale after obj.drawImage(im9)") - im9.scale = 0 - im9.setZero() - obj.drawImage(im9) - np.testing.assert_almost_equal(im9.scale, nyq_scale, 9, - "obj.drawImage(im9) produced image with wrong scale") - np.testing.assert_almost_equal(im9.array.sum(), test_flux, 4, - "obj.drawImage(im9) produced image with wrong flux") - np.testing.assert_almost_equal(CalculateScale(im9), 2, 2, - "Measured wrong scale after obj.drawImage(im9)") - - - # Test if we provide an image and dx. It should: - # - write to the existing image - # - use the provided dx - # - write the new dx value to the image's scale - im9.scale = 0.73 - im9.setZero() - obj.drawImage(im9, scale=scale) - np.testing.assert_almost_equal(im9.scale, scale, 9, - "obj.drawImage(im9,dx) produced image with wrong scale") - np.testing.assert_almost_equal(im9.array.sum(), test_flux, 4, - "obj.drawImage(im9,dx) produced image with wrong flux") - np.testing.assert_almost_equal(CalculateScale(im9), 2, 2, - "Measured wrong scale after obj.drawImage(im9,dx)") - - # Test if we provide an image and dx <= 0. It should: - # - write to the existing image - # - set the scale to obj2.nyquist_scale - im9.scale = 0.73 - im9.setZero() - obj.drawImage(im9, scale=-scale) - np.testing.assert_almost_equal(im9.scale, nyq_scale, 9, - "obj.drawImage(im9,dx<0) produced image with wrong scale") - np.testing.assert_almost_equal(im9.array.sum(), test_flux, 4, - "obj.drawImage(im9,dx<0) produced image with wrong flux") - np.testing.assert_almost_equal(CalculateScale(im9), 2, 2, - "Measured wrong scale after obj.drawImage(im9,dx<0)") - im9.scale = 0.73 - im9.setZero() - obj.drawImage(im9, scale=0) - np.testing.assert_almost_equal(im9.scale, nyq_scale, 9, - "obj.drawImage(im9,scale=0) produced image with wrong scale") - np.testing.assert_almost_equal(im9.array.sum(), test_flux, 4, - "obj.drawImage(im9,scale=0) produced image with wrong flux") - np.testing.assert_almost_equal(CalculateScale(im9), 2, 2, - "Measured wrong scale after obj.drawImage(im9,scale=0)") - - # Test if we provide nx, ny, and scale. It should: - # - create a new image with the right size - # - set the scale - ny = 100 # Make it non-square - im10 = obj.drawImage(nx=nx, ny=ny, scale=scale) - np.testing.assert_equal(im10.array.shape, (ny, nx), - "obj.drawImage(nx,ny,scale) produced image with wrong size") - np.testing.assert_almost_equal(im10.scale, scale, 9, - "obj.drawImage(nx,ny,scale) produced image with wrong scale") - np.testing.assert_almost_equal(im10.array.sum(), test_flux, 4, - "obj.drawImage(nx,ny,scale) produced image with wrong flux") - mom = galsim.utilities.unweighted_moments(im10) - np.testing.assert_almost_equal( - mom['Mx'], (nx+1.)/2., 4, "obj.drawImage(nx,ny,scale) (even) did not center in x correctly") - np.testing.assert_almost_equal( - mom['My'], (ny+1.)/2., 4, "obj.drawImage(nx,ny,scale) (even) did not center in y correctly") - - # Repeat with odd nx,ny - im10 = obj.drawImage(nx=nx+1, ny=ny+1, scale=scale) - np.testing.assert_equal(im10.array.shape, (ny+1, nx+1), - "obj.drawImage(nx,ny,scale) produced image with wrong size") - np.testing.assert_almost_equal(im10.scale, scale, 9, - "obj.drawImage(nx,ny,scale) produced image with wrong scale") - np.testing.assert_almost_equal(im10.array.sum(), test_flux, 4, - "obj.drawImage(nx,ny,scale) produced image with wrong flux") - mom = galsim.utilities.unweighted_moments(im10) - np.testing.assert_almost_equal( - mom['Mx'], (nx+1.+1.)/2., 4, - "obj.drawImage(nx,ny,scale) (odd) did not center in x correctly") - np.testing.assert_almost_equal( - mom['My'], (ny+1.+1.)/2., 4, - "obj.drawImage(nx,ny,scale) (odd) did not center in y correctly") - - # Test if we provide nx, ny, and no scale. It should: - # - create a new image with the right size - # - set the scale to obj2.nyquist_scale - im10 = obj.drawImage(nx=nx, ny=ny) - np.testing.assert_equal(im10.array.shape, (ny, nx), - "obj.drawImage(nx,ny) produced image with wrong size") - np.testing.assert_almost_equal(im10.scale, nyq_scale, 9, - "obj.drawImage(nx,ny) produced image with wrong scale") - np.testing.assert_almost_equal(im10.array.sum(), test_flux, 4, - "obj.drawImage(nx,ny) produced image with wrong flux") - mom = galsim.utilities.unweighted_moments(im10) - np.testing.assert_almost_equal( - mom['Mx'], (nx+1.)/2., 4, "obj.drawImage(nx,ny) (even) did not center in x correctly") - np.testing.assert_almost_equal( - mom['My'], (ny+1.)/2., 4, "obj.drawImage(nx,ny) (even) did not center in y correctly") - - # Repeat with odd nx,ny - im10 = obj.drawImage(nx=nx+1, ny=ny+1) - np.testing.assert_equal(im10.array.shape, (ny+1, nx+1), - "obj.drawImage(nx,ny) produced image with wrong size") - np.testing.assert_almost_equal(im10.scale, nyq_scale, 9, - "obj.drawImage(nx,ny) produced image with wrong scale") - np.testing.assert_almost_equal(im10.array.sum(), test_flux, 4, - "obj.drawImage(nx,ny) produced image with wrong flux") - mom = galsim.utilities.unweighted_moments(im10) - np.testing.assert_almost_equal( - mom['Mx'], (nx+1.+1.)/2., 4, "obj.drawImage(nx,ny) (odd) did not center in x correctly") - np.testing.assert_almost_equal( - mom['My'], (ny+1.+1.)/2., 4, "obj.drawImage(nx,ny) (odd) did not center in y correctly") - - # Test if we provide bounds and scale. It should: - # - create a new image with the right size - # - set the scale - bounds = galsim.BoundsI(1,nx,1,ny+1) - im10 = obj.drawImage(bounds=bounds, scale=scale) - np.testing.assert_equal(im10.array.shape, (ny+1, nx), - "obj.drawImage(bounds,scale) produced image with wrong size") - np.testing.assert_almost_equal(im10.scale, scale, 9, - "obj.drawImage(bounds,scale) produced image with wrong scale") - np.testing.assert_almost_equal(im10.array.sum(), test_flux, 4, - "obj.drawImage(bounds,scale) produced image with wrong flux") - mom = galsim.utilities.unweighted_moments(im10) - np.testing.assert_almost_equal(mom['Mx'], (nx+1.)/2., 4, - "obj.drawImage(bounds,scale) did not center in x correctly") - np.testing.assert_almost_equal(mom['My'], (ny+1.+1.)/2., 4, - "obj.drawImage(bounds,scale) did not center in y correctly") - - # Test if we provide bounds and no scale. It should: - # - create a new image with the right size - # - set the scale to obj2.nyquist_scale - bounds = galsim.BoundsI(1,nx,1,ny+1) - im10 = obj.drawImage(bounds=bounds) - np.testing.assert_equal(im10.array.shape, (ny+1, nx), - "obj.drawImage(bounds) produced image with wrong size") - np.testing.assert_almost_equal(im10.scale, nyq_scale, 9, - "obj.drawImage(bounds) produced image with wrong scale") - np.testing.assert_almost_equal(im10.array.sum(), test_flux, 4, - "obj.drawImage(bounds) produced image with wrong flux") - mom = galsim.utilities.unweighted_moments(im10) - np.testing.assert_almost_equal(mom['Mx'], (nx+1.)/2., 4, - "obj.drawImage(bounds) did not center in x correctly") - np.testing.assert_almost_equal(mom['My'], (ny+1.+1.)/2., 4, - "obj.drawImage(bounds) did not center in y correctly") - - # Test if we provide nx, ny, scale, and center. It should: - # - create a new image with the right size - # - set the scale - # - set the center to be as close as possible to center - for center in [(3,3), (10.2, 11.9), (10.55, -23.8)]: - im11 = obj.drawImage(nx=nx, ny=ny, scale=scale, center=center) - np.testing.assert_equal(im11.array.shape, (ny, nx)) - np.testing.assert_almost_equal(im11.scale, scale, 9) - np.testing.assert_almost_equal(im11.array.sum(), test_flux, 4) - print('center, true = ',center,im8.true_center) - assert abs(center[0] - im11.true_center.x) <= 0.5 - assert abs(center[1] - im11.true_center.y) <= 0.5 - - # Repeat with odd nx,ny - im11 = obj.drawImage(nx=nx+1, ny=ny+1, scale=scale, center=center) - np.testing.assert_equal(im11.array.shape, (ny+1, nx+1)) - np.testing.assert_almost_equal(im11.scale, scale, 9) - np.testing.assert_almost_equal(im11.array.sum(), test_flux, 4) - assert abs(center[0] - im11.true_center.x) <= 0.5 - assert abs(center[1] - im11.true_center.y) <= 0.5 - - dr = os.path.join( - os.path.dirname(__file__), - "..", - "..", - "GalSim", - "tests", - ) - - # Combinations that raise errors: - assert_raises(TypeError, obj.drawImage, image=im10, bounds=bounds) - assert_raises(TypeError, obj.drawImage, image=im10, dtype=int) - assert_raises(TypeError, obj.drawImage, nx=3, ny=4, image=im10, scale=scale) - assert_raises(TypeError, obj.drawImage, nx=3, ny=4, image=im10) - assert_raises(TypeError, obj.drawImage, nx=3, ny=4, bounds=bounds) - assert_raises(TypeError, obj.drawImage, nx=3, ny=4, add_to_image=True) - assert_raises(TypeError, obj.drawImage, nx=3, ny=4, center=True) - assert_raises(TypeError, obj.drawImage, nx=3, ny=4, center=23) - assert_raises(TypeError, obj.drawImage, bounds=bounds, add_to_image=True) - assert_raises(TypeError, obj.drawImage, image=galsim.Image(), add_to_image=True) - assert_raises(TypeError, obj.drawImage, nx=3) - assert_raises(TypeError, obj.drawImage, ny=3) - assert_raises(TypeError, obj.drawImage, nx=3, ny=3, invalid=True) - assert_raises(TypeError, obj.drawImage, bounds=bounds, scale=scale, wcs=galsim.PixelScale(3)) - assert_raises(TypeError, obj.drawImage, bounds=bounds, wcs=scale) - assert_raises(TypeError, obj.drawImage, image=im10.array) - assert_raises(TypeError, obj.drawImage, wcs=galsim.FitsWCS(dr + '/fits_files/tpv.fits')) - - assert_raises(ValueError, obj.drawImage, bounds=galsim.BoundsI()) - # JAX galsim does not raise for these things - # assert_raises(ValueError, obj.drawImage, image=im10, gain=0.) - # assert_raises(ValueError, obj.drawImage, image=im10, gain=-1.) - # assert_raises(ValueError, obj.drawImage, image=im10, area=0.) - # assert_raises(ValueError, obj.drawImage, image=im10, area=-1.) - # assert_raises(ValueError, obj.drawImage, image=im10, exptime=0.) - # assert_raises(ValueError, obj.drawImage, image=im10, exptime=-1.) - # assert_raises(ValueError, obj.drawImage, image=im10, method='invalid') - - # These options are invalid unless metho=phot - # JAX galsim does not raise for these things - # assert_raises(TypeError, obj.drawImage, image=im10, n_photons=3) - # assert_raises(TypeError, obj.drawImage, rng=galsim.BaseDeviate(234)) - # assert_raises(TypeError, obj.drawImage, max_extra_noise=23) - # assert_raises(TypeError, obj.drawImage, poisson_flux=True) - # assert_raises(TypeError, obj.drawImage, maxN=10000) - # assert_raises(TypeError, obj.drawImage, save_photons=True) - - -@timer -def test_draw_methods(): - """Test the the different method options do the right thing. - """ - # We use a simple Exponential for our object: - obj = galsim.Exponential(flux=test_flux, scale_radius=1.09) - test_scale = 0.28 - pix = galsim.Pixel(scale=test_scale) - obj_pix = galsim.Convolve(obj, pix) - - N = 64 - im1 = galsim.ImageD(N, N, scale=test_scale) - - # auto and fft should be equivalent to drawing obj_pix with no_pixel - im1 = obj.drawImage(image=im1) - im2 = obj_pix.drawImage(image=im1.copy(), method='no_pixel') - print('im1 flux diff = ',abs(im1.array.sum() - test_flux)) - np.testing.assert_almost_equal( - im1.array.sum(), test_flux, 2, - "obj.drawImage() produced image with wrong flux") - print('im2 flux diff = ',abs(im2.array.sum() - test_flux)) - np.testing.assert_almost_equal( - im2.array.sum(), test_flux, 2, - "obj_pix.drawImage(no_pixel) produced image with wrong flux") - print('im1, im2 max diff = ',abs(im1.array - im2.array).max()) - np.testing.assert_array_almost_equal( - im1.array, im2.array, 6, - "obj.drawImage() differs from obj_pix.drawImage(no_pixel)") - im3 = obj.drawImage(image=im1.copy(), method='fft') - print('im1, im3 max diff = ',abs(im1.array - im3.array).max()) - np.testing.assert_array_almost_equal( - im1.array, im3.array, 6, - "obj.drawImage(fft) differs from obj.drawImage") - - # real_space should be similar, but not precisely equal. - im4 = obj.drawImage(image=im1.copy(), method='real_space') - print('im1, im4 max diff = ',abs(im1.array - im4.array).max()) - np.testing.assert_array_almost_equal( - im1.array, im4.array, 4, - "obj.drawImage(real_space) differs from obj.drawImage") - - # sb should match xValue for pixel centers. And be scale**2 factor different from no_pixel. - im5 = obj.drawImage(image=im1.copy(), method='sb', use_true_center=False) - im5.setCenter(0,0) - print('im5(0,0) = ',im5(0,0)) - print('obj.xValue(0,0) = ',obj.xValue(0.,0.)) - np.testing.assert_almost_equal( - im5(0,0), obj.xValue(0.,0.), 6, - "obj.drawImage(sb) values do not match surface brightness given by xValue") - np.testing.assert_almost_equal( - im5(3,2), obj.xValue(3*test_scale, 2*test_scale), 6, - "obj.drawImage(sb) values do not match surface brightness given by xValue") - im5 = obj.drawImage(image=im5, method='sb') - print('im5(0,0) = ',im5(0,0)) - print('obj.xValue(dx/2,dx/2) = ',obj.xValue(test_scale/2., test_scale/2.)) - np.testing.assert_almost_equal( - im5(0,0), obj.xValue(0.5*test_scale, 0.5*test_scale), 6, - "obj.drawImage(sb) values do not match surface brightness given by xValue") - np.testing.assert_almost_equal( - im5(3,2), obj.xValue(3.5*test_scale, 2.5*test_scale), 6, - "obj.drawImage(sb) values do not match surface brightness given by xValue") - im6 = obj.drawImage(image=im1.copy(), method='no_pixel') - print('im6, im5*scale**2 max diff = ',abs(im6.array - im5.array*test_scale**2).max()) - np.testing.assert_array_almost_equal( - im5.array * test_scale**2, im6.array, 6, - "obj.drawImage(sb) * scale**2 differs from obj.drawImage(no_pixel)") - - # Drawing a truncated object, auto should be identical to real_space - obj = galsim.Sersic(flux=test_flux, n=3.7, half_light_radius=2, trunc=4) - obj_pix = galsim.Convolve(obj, pix) - - # auto and real_space should be equivalent to drawing obj_pix with no_pixel - im1 = obj.drawImage(image=im1) - im2 = obj_pix.drawImage(image=im1.copy(), method='no_pixel') - print('im1 flux diff = ',abs(im1.array.sum() - test_flux)) - np.testing.assert_almost_equal( - im1.array.sum(), test_flux, 2, - "obj.drawImage() produced image with wrong flux") - print('im2 flux diff = ',abs(im2.array.sum() - test_flux)) - np.testing.assert_almost_equal( - im2.array.sum(), test_flux, 2, - "obj_pix.drawImage(no_pixel) produced image with wrong flux") - print('im1, im2 max diff = ',abs(im1.array - im2.array).max()) - np.testing.assert_array_almost_equal( - im1.array, im2.array, 6, - "obj.drawImage() differs from obj_pix.drawImage(no_pixel)") - im4 = obj.drawImage(image=im1.copy(), method='real_space') - print('im1, im4 max diff = ',abs(im1.array - im4.array).max()) - np.testing.assert_array_almost_equal( - im1.array, im4.array, 6, - "obj.drawImage(real_space) differs from obj.drawImage") - - # fft should be similar, but not precisely equal. - with assert_warns(galsim.GalSimWarning): - # This emits a warning about convolving two things with hard edges. - im3 = obj.drawImage(image=im1.copy(), method='fft') - print('im1, im3 max diff = ',abs(im1.array - im3.array).max()) - np.testing.assert_array_almost_equal( - im1.array, im3.array, 3, # Should be close, but not exact. - "obj.drawImage(fft) differs from obj.drawImage") - - # sb should match xValue for pixel centers. And be scale**2 factor different from no_pixel. - im5 = obj.drawImage(image=im1.copy(), method='sb') - im5.setCenter(0,0) - print('im5(0,0) = ',im5(0,0)) - print('obj.xValue(dx/2,dx/2) = ',obj.xValue(test_scale/2., test_scale/2.)) - np.testing.assert_almost_equal( - im5(0,0), obj.xValue(0.5*test_scale, 0.5*test_scale), 6, - "obj.drawImage(sb) values do not match surface brightness given by xValue") - np.testing.assert_almost_equal( - im5(3,2), obj.xValue(3.5*test_scale, 2.5*test_scale), 6, - "obj.drawImage(sb) values do not match surface brightness given by xValue") - im6 = obj.drawImage(image=im1.copy(), method='no_pixel') - print('im6, im5*scale**2 max diff = ',abs(im6.array - im5.array*test_scale**2).max()) - np.testing.assert_array_almost_equal( - im5.array * test_scale**2, im6.array, 6, - "obj.drawImage(sb) * scale**2 differs from obj.drawImage(no_pixel)") - - -@timer -def test_drawKImage(): - """Test the various optional parameters to the drawKImage function. - In particular test the parameters image, and scale in various combinations. - """ - # We use a Moffat profile with beta = 1.5, since its real-space profile is - # flux / (2 pi rD^2) * (1 + (r/rD)^2)^3/2 - # and the 2-d Fourier transform of that is - # flux * exp(-rD k) - # So this should draw in Fourier space the same image as the Exponential drawn in - # test_drawImage(). - obj = galsim.Moffat(flux=test_flux, beta=1.5, scale_radius=0.5) - obj = obj.withGSParams(maxk_threshold=1.e-4) - - # First test drawKImage() with no kwargs. It should: - # - create new images - # - return the new images - # - set the scale to 2pi/(N*obj.nyquist_scale) - im1 = obj.drawKImage() - N = 1174 - np.testing.assert_equal(im1.bounds, galsim.BoundsI(-N/2,N/2,-N/2,N/2), - "obj.drawKImage() produced image with wrong bounds") - stepk = obj.stepk - np.testing.assert_almost_equal(im1.scale, stepk, 9, - "obj.drawKImage() produced image with wrong scale") - np.testing.assert_almost_equal(CalculateScale(im1), 2, 1, - "Measured wrong scale after obj.drawKImage()") - - # The flux in Fourier space is just the value at k=0 - np.testing.assert_equal(im1.bounds.center, galsim.PositionI(0,0)) - np.testing.assert_almost_equal(im1(0,0), test_flux, 2, - "obj.drawKImage() produced image with wrong flux") - # Imaginary component should all be 0. - np.testing.assert_almost_equal(im1.imag.array.sum(), 0., 3, - "obj.drawKImage() produced non-zero imaginary image") - - # Test if we provide an image argument. It should: - # - write to the existing image - # - also return that image - # - set the scale to obj.stepk - # - zero out any existing data - im3 = galsim.ImageCD(1149,1149) - im4 = obj.drawKImage(im3) - np.testing.assert_almost_equal(im3.scale, stepk, 9, - "obj.drawKImage(im3) produced image with wrong scale") - np.testing.assert_almost_equal(im3(0,0), test_flux, 2, - "obj.drawKImage(im3) produced real image with wrong flux") - np.testing.assert_almost_equal(im3.imag.array.sum(), 0., 3, - "obj.drawKImage(im3) produced non-zero imaginary image") - np.testing.assert_almost_equal(CalculateScale(im3), 2, 1, - "Measured wrong scale after obj.drawKImage(im3)") - np.testing.assert_array_equal(im3.array, im4.array, - "im4 = obj.drawKImage(im3) produced im4 != im3") - # JAX cannot fill images by reference so we check object identity - assert im3 is im4 - assert im3.array is im4.array - # im3.fill(9.8) - # np.testing.assert_array_equal(im3.array, im4.array, - # "im4 = obj.drawKImage(im3) produced im4 is not im3") - - # Test if we provide an image with undefined bounds. It should: - # - resize the provided image - # - also return that image - # - set the scale to obj.stepk - im5 = galsim.ImageCD() - obj.drawKImage(im5) - np.testing.assert_almost_equal(im5.scale, stepk, 9, - "obj.drawKImage(im5) produced image with wrong scale") - np.testing.assert_almost_equal(im5(0,0), test_flux, 2, - "obj.drawKImage(im5) produced image with wrong flux") - np.testing.assert_almost_equal(im5.imag.array.sum(), 0., 3, - "obj.drawKImage(im5) produced non-zero imaginary image") - np.testing.assert_almost_equal(CalculateScale(im5), 2, 1, - "Measured wrong scale after obj.drawKImage(im5)") - np.testing.assert_equal(im5.bounds, galsim.BoundsI(-N/2,N/2,-N/2,N/2), - "obj.drawKImage(im5) produced image with wrong bounds") - - # Test if we provide a scale to use. It should: - # - create a new image using that scale for the scale - # - return the new image - # - set the size large enough to contain 99.5% of the flux - scale = 0.51 # Just something different from 1 or stepk - im7 = obj.drawKImage(scale=scale) - np.testing.assert_almost_equal(im7.scale, scale, 9, - "obj.drawKImage(dx) produced image with wrong scale") - np.testing.assert_almost_equal(im7(0,0), test_flux, 2, - "obj.drawKImage(dx) produced image with wrong flux") - np.testing.assert_almost_equal(im7.imag.array.astype(float).sum(), 0., 2, - "obj.drawKImage(dx) produced non-zero imaginary image") - np.testing.assert_almost_equal(CalculateScale(im7), 2, 1, - "Measured wrong scale after obj.drawKImage(dx)") - # This image is smaller because not using nyquist scale for stepk - np.testing.assert_equal(im7.bounds, galsim.BoundsI(-37,37,-37,37), - "obj.drawKImage(dx) produced image with wrong bounds") - - # Test if we provide an image with a defined scale. It should: - # - write to the existing image - # - use the image's scale - nx = 401 - im9 = galsim.ImageCD(nx,nx, scale=scale) - obj.drawKImage(im9) - np.testing.assert_almost_equal(im9.scale, scale, 9, - "obj.drawKImage(im9) produced image with wrong scale") - np.testing.assert_almost_equal(im9(0,0), test_flux, 4, - "obj.drawKImage(im9) produced image with wrong flux") - np.testing.assert_almost_equal(im9.imag.array.sum(), 0., 5, - "obj.drawKImage(im9) produced non-zero imaginary image") - np.testing.assert_almost_equal(CalculateScale(im9), 2, 1, - "Measured wrong scale after obj.drawKImage(im9)") - - # Test if we provide an image with a defined scale <= 0. It should: - # - write to the existing image - # - set the scale to obj.stepk - im3.scale = -scale - im3.setZero() - obj.drawKImage(im3) - np.testing.assert_almost_equal(im3.scale, stepk, 9, - "obj.drawKImage(im3) produced image with wrong scale") - np.testing.assert_almost_equal(im3(0,0), test_flux, 4, - "obj.drawKImage(im3) produced image with wrong flux") - np.testing.assert_almost_equal(im3.imag.array.sum(), 0., 5, - "obj.drawKImage(im3) produced non-zero imaginary image") - np.testing.assert_almost_equal(CalculateScale(im3), 2, 1, - "Measured wrong scale after obj.drawKImage(im3)") - im3.scale = 0 - im3.setZero() - obj.drawKImage(im3) - np.testing.assert_almost_equal(im3.scale, stepk, 9, - "obj.drawKImage(im3) produced image with wrong scale") - np.testing.assert_almost_equal(im3(0,0), test_flux, 4, - "obj.drawKImage(im3) produced image with wrong flux") - np.testing.assert_almost_equal(im3.imag.array.sum(), 0., 5, - "obj.drawKImage(im3) produced non-zero imaginary image") - np.testing.assert_almost_equal(CalculateScale(im3), 2, 1, - "Measured wrong scale after obj.drawKImage(im3)") - - # Test if we provide an image and dx. It should: - # - write to the existing image - # - use the provided dx - # - write the new dx value to the image's scale - im9.scale = scale + 0.3 # Just something other than scale - im9.setZero() - obj.drawKImage(im9, scale=scale) - np.testing.assert_almost_equal( - im9.scale, scale, 9, - "obj.drawKImage(im9,scale) produced image with wrong scale") - np.testing.assert_almost_equal( - im9(0,0), test_flux, 4, - "obj.drawKImage(im9,scale) produced image with wrong flux") - np.testing.assert_almost_equal( - im9.imag.array.sum(), 0., 5, - "obj.drawKImage(im9,scale) produced non-zero imaginary image") - np.testing.assert_almost_equal(CalculateScale(im9), 2, 1, - "Measured wrong scale after obj.drawKImage(im9,scale)") - - # Test if we provide an image and scale <= 0. It should: - # - write to the existing image - # - set the scale to obj.stepk - im3.scale = scale + 0.3 - im3.setZero() - obj.drawKImage(im3, scale=-scale) - np.testing.assert_almost_equal( - im3.scale, stepk, 9, - "obj.drawKImage(im3,scale<0) produced image with wrong scale") - np.testing.assert_almost_equal( - im3(0,0), test_flux, 4, - "obj.drawKImage(im3,scale<0) produced image with wrong flux") - np.testing.assert_almost_equal( - im3.imag.array.sum(), 0., 5, - "obj.drawKImage(im3,scale<0) produced non-zero imaginary image") - np.testing.assert_almost_equal(CalculateScale(im3), 2, 1, - "Measured wrong scale after obj.drawKImage(im3,scale<0)") - im3.scale = scale + 0.3 - im3.setZero() - obj.drawKImage(im3, scale=0) - np.testing.assert_almost_equal( - im3.scale, stepk, 9, - "obj.drawKImage(im3,scale=0) produced image with wrong scale") - np.testing.assert_almost_equal( - im3(0,0), test_flux, 4, - "obj.drawKImage(im3,scale=0) produced image with wrong flux") - np.testing.assert_almost_equal( - im3.imag.array.sum(), 0., 5, - "obj.drawKImage(im3,scale=0) produced non-zero imaginary image") - np.testing.assert_almost_equal(CalculateScale(im3), 2, 1, - "Measured wrong scale after obj.drawKImage(im3,scale=0)") - - # Test if we provide nx, ny, and scale. It should: - # - create a new image with the right size - # - set the scale - nx = 200 # Some randome non-square size - ny = 100 - im4 = obj.drawKImage(nx=nx, ny=ny, scale=scale) - np.testing.assert_almost_equal( - im4.scale, scale, 9, - "obj.drawKImage(nx,ny,scale) produced image with wrong scale") - np.testing.assert_equal( - im4.array.shape, (ny, nx), - "obj.drawKImage(nx,ny,scale) produced image with wrong shape") - - # Test if we provide nx, ny, and no scale. It should: - # - create a new image with the right size - # - set the scale to obj.stepk - im4 = obj.drawKImage(nx=nx, ny=ny) - np.testing.assert_almost_equal( - im4.scale, stepk, 9, - "obj.drawKImage(nx,ny) produced image with wrong scale") - np.testing.assert_equal( - im4.array.shape, (ny, nx), - "obj.drawKImage(nx,ny) produced image with wrong shape") - - # Test if we provide bounds and no scale. It should: - # - create a new image with the right size - # - set the scale to obj.stepk - bounds = galsim.BoundsI(1,nx,1,ny) - im4 = obj.drawKImage(bounds=bounds) - np.testing.assert_almost_equal( - im4.scale, stepk, 9, - "obj.drawKImage(bounds) produced image with wrong scale") - np.testing.assert_equal( - im4.array.shape, (ny, nx), - "obj.drawKImage(bounds) produced image with wrong shape") - - # Test if we provide bounds and scale. It should: - # - create a new image with the right size - # - set the scale - bounds = galsim.BoundsI(1,nx,1,ny) - im4 = obj.drawKImage(bounds=bounds, scale=scale) - np.testing.assert_almost_equal( - im4.scale, scale, 9, - "obj.drawKImage(bounds,scale) produced image with wrong scale") - np.testing.assert_equal( - im4.array.shape, (ny, nx), - "obj.drawKImage(bounds,scale) produced image with wrong shape") - - # Test recenter = False option - bounds6 = galsim.BoundsI(0, nx//3, 0, ny//4) - im6 = obj.drawKImage(bounds=bounds6, scale=scale, recenter=False) - np.testing.assert_equal( - im6.bounds, bounds6, - "obj.drawKImage(bounds,scale,recenter=False) produced image with wrong bounds") - np.testing.assert_almost_equal( - im6.scale, scale, 9, - "obj.drawKImage(bounds,scale,recenter=False) produced image with wrong scale") - np.testing.assert_equal( - im6.array.shape, (ny//4+1, nx//3+1), - "obj.drawKImage(bounds,scale,recenter=False) produced image with wrong shape") - np.testing.assert_array_almost_equal( - im6.array, im4[bounds6].array, 9, - "obj.drawKImage(recenter=False) produced different values than recenter=True") - - # Test recenter = False option - im6.setZero() - obj.drawKImage(im6, recenter=False) - np.testing.assert_almost_equal( - im6.scale, scale, 9, - "obj.drawKImage(image,recenter=False) produced image with wrong scale") - np.testing.assert_array_almost_equal( - im6.array, im4[bounds6].array, 9, - "obj.drawKImage(image,recenter=False) produced different values than recenter=True") - - # Can add to image if recenter is False - im6.setZero() - obj.drawKImage(im6, recenter=False, add_to_image=True) - np.testing.assert_almost_equal( - im6.scale, scale, 9, - "obj.drawKImage(image,add_to_image=True) produced image with wrong scale") - np.testing.assert_array_almost_equal( - im6.array, im4[bounds6].array, 9, - "obj.drawKImage(image,add_to_image=True) produced different values than recenter=True") - - # .. or if image is centered. - im7 = im4.copy() - im7.setZero() - im7.setCenter(0,0) - obj.drawKImage(im7, add_to_image=True) - np.testing.assert_almost_equal( - im7.scale, scale, 9, - "obj.drawKImage(image,add_to_image=True) produced image with wrong scale") - np.testing.assert_array_almost_equal( - im7.array, im4.array, 9, - "obj.drawKImage(image,add_to_image=True) produced different values than recenter=True") - - # .. but otherwise not. - with assert_raises(galsim.GalSimIncompatibleValuesError): - obj.drawKImage(image=im6, add_to_image=True) - - # Other error combinations: - assert_raises(TypeError, obj.drawKImage, image=im6, bounds=bounds) - assert_raises(TypeError, obj.drawKImage, image=im6, dtype=int) - assert_raises(TypeError, obj.drawKImage, nx=3, ny=4, image=im6, scale=scale) - assert_raises(TypeError, obj.drawKImage, nx=3, ny=4, image=im6) - assert_raises(TypeError, obj.drawKImage, nx=3, ny=4, add_to_image=True) - assert_raises(TypeError, obj.drawKImage, nx=3, ny=4, bounds=bounds) - assert_raises(TypeError, obj.drawKImage, bounds=bounds, add_to_image=True) - assert_raises(TypeError, obj.drawKImage, image=galsim.Image(dtype=complex), add_to_image=True) - assert_raises(TypeError, obj.drawKImage, nx=3) - assert_raises(TypeError, obj.drawKImage, ny=3) - assert_raises(TypeError, obj.drawKImage, nx=3, ny=3, invalid=True) - assert_raises(TypeError, obj.drawKImage, bounds=bounds, wcs=galsim.PixelScale(3)) - assert_raises(TypeError, obj.drawKImage, image=im6.array) - assert_raises(ValueError, obj.drawKImage, image=galsim.ImageF(3,4)) - assert_raises(ValueError, obj.drawKImage, bounds=galsim.BoundsI()) - - -@timer -def test_drawKImage_Gaussian(): - """Test the drawKImage function using known symmetries of the Gaussian Hankel transform. - - See http://en.wikipedia.org/wiki/Hankel_transform. - """ - test_flux = 2.3 # Choose a non-unity flux - test_sigma = 17. # ...likewise for sigma - test_imsize = 45 # Dimensions of comparison image, doesn't need to be large - - # Define a Gaussian GSObject - gal = galsim.Gaussian(sigma=test_sigma, flux=test_flux) - # Then define a related object which is in fact the opposite number in the Hankel transform pair - # For the Gaussian this is straightforward in our definition of the Fourier transform notation, - # and has sigma -> 1/sigma and flux -> flux * 2 pi / sigma**2 - gal_hankel = galsim.Gaussian(sigma=1./test_sigma, flux=test_flux*2.*np.pi/test_sigma**2) - - # Do a basic flux test: the total flux of the gal should equal gal_Hankel(k=(0, 0)) - np.testing.assert_almost_equal( - gal.flux, gal_hankel.xValue(galsim.PositionD(0., 0.)), decimal=12, - err_msg="Test object flux does not equal k=(0, 0) mode of its Hankel transform conjugate.") - - image_test = galsim.ImageD(test_imsize, test_imsize) - kimage_test = galsim.ImageCD(test_imsize, test_imsize) - - # Then compare these two objects at a couple of different scale (reasonably matched for size) - for scale_test in (0.03 / test_sigma, 0.4 / test_sigma): - gal.drawKImage(image=kimage_test, scale=scale_test) - gal_hankel.drawImage(image_test, scale=scale_test, use_true_center=False, method='sb') - np.testing.assert_array_almost_equal( - kimage_test.real.array, image_test.array, decimal=12, - err_msg="Test object drawKImage() and drawImage() from Hankel conjugate do not match " - "for grid spacing scale = "+str(scale_test)) - np.testing.assert_array_almost_equal( - kimage_test.imag.array, 0., decimal=12, - err_msg="Non-zero imaginary part for drawKImage from test object that is purely " - "centred on the origin.") - - -@timer -def test_drawKImage_Exponential_Moffat(): - """Test the drawKImage function using known symmetries of the Exponential Hankel transform - (which is a Moffat with beta=1.5). - - See http://mathworld.wolfram.com/HankelTransform.html. - """ - test_flux = 4.1 # Choose a non-unity flux - test_scale_radius = 13. # ...likewise for scale_radius - test_imsize = 45 # Dimensions of comparison image, doesn't need to be large - - # Define an Exponential GSObject - gal = galsim.Exponential(scale_radius=test_scale_radius, flux=test_flux) - # Then define a related object which is in fact the opposite number in the Hankel transform pair - # For the Exponential we need a Moffat, with scale_radius=1/scale_radius. The total flux under - # this Moffat with unit amplitude at r=0 is is pi * scale_radius**(-2) / (beta - 1) - # = 2. * pi * scale_radius**(-2) in this case, so it works analagously to the Gaussian above. - gal_hankel = galsim.Moffat(beta=1.5, scale_radius=1. / test_scale_radius, - flux=test_flux * 2. * np.pi / test_scale_radius**2) - - # Do a basic flux test: the total flux of the gal should equal gal_Hankel(k=(0, 0)) - np.testing.assert_almost_equal( - gal.flux, gal_hankel.xValue(galsim.PositionD(0., 0.)), decimal=12, - err_msg="Test object flux does not equal k=(0, 0) mode of its Hankel transform conjugate.") - - image_test = galsim.ImageD(test_imsize, test_imsize) - kimage_test = galsim.ImageCD(test_imsize, test_imsize) - - # Then compare these two objects at a couple of different scale (reasonably matched for size) - for scale_test in (0.15 / test_scale_radius, 0.6 / test_scale_radius): - gal.drawKImage(image=kimage_test, scale=scale_test) - gal_hankel.drawImage(image_test, scale=scale_test, use_true_center=False, method='sb') - np.testing.assert_array_almost_equal( - kimage_test.real.array, image_test.array, decimal=12, - err_msg="Test object drawKImageImage() and drawImage() from Hankel conjugate do not "+ - "match for grid spacing scale = "+str(scale_test)) - np.testing.assert_array_almost_equal( - kimage_test.imag.array, 0., decimal=12, - err_msg="Non-zero imaginary part for drawKImage from test object that is purely "+ - "centred on the origin.") - - -@timer -def test_offset(): - """Test the offset parameter to the drawImage function. - """ - scale = 0.23 - - # Use some more exact GSParams. We'll be comparing FFT images to real-space convolved values, - # so we don't want to suffer from our overall accuracy being only about 10^-3. - # Update: It turns out the only one I needed to reduce to obtain the accuracy I wanted - # below is maxk_threshold. Perhaps this is a sign that we ought to lower it in general? - params = galsim.GSParams(maxk_threshold=1.e-4) - - # We use a simple Exponential for our object: - gal = galsim.Exponential(flux=test_flux, scale_radius=0.5, gsparams=params) - pix = galsim.Pixel(scale, gsparams=params) - obj = galsim.Convolve([gal,pix], gsparams=params) - - # The shapes of the images we will build - # Make sure all combinations of odd/even are represented. - shape_list = [ (256,256), (256,243), (249,260), (255,241), (270,260) ] - - # Some reasonable (x,y) values at which to test the xValues (near the center) - xy_list = [ (128,128), (123,131), (126,124) ] - - # The offsets to test - offset_list = [ (1,-3), (0.3,-0.1), (-2.3,-1.2) ] - - # Make the images somewhat large so the moments are measured accurately. - for nx,ny in shape_list: - - # First check that the image agrees with our calculation of the center - cenx = (nx+1.)/2. - ceny = (ny+1.)/2. - im = galsim.ImageD(nx,ny, scale=scale) - true_center = im.bounds.true_center - np.testing.assert_almost_equal( - cenx, true_center.x, 6, - "im.bounds.true_center.x is wrong for (nx,ny) = %d,%d"%(nx,ny)) - np.testing.assert_almost_equal( - ceny, true_center.y, 6, - "im.bounds.true_center.y is wrong for (nx,ny) = %d,%d"%(nx,ny)) - - # Check that the default draw command puts the centroid in the center of the image. - obj.drawImage(im, method='sb') - mom = galsim.utilities.unweighted_moments(im) - np.testing.assert_almost_equal( - mom['Mx'], cenx, 5, - "obj.drawImage(im) not centered correctly for (nx,ny) = %d,%d"%(nx,ny)) - np.testing.assert_almost_equal( - mom['My'], ceny, 5, - "obj.drawImage(im) not centered correctly for (nx,ny) = %d,%d"%(nx,ny)) - - # Can also use center to explicitly say we want to use the true_center. - im3 = obj.drawImage(im.copy(), method='sb', center=im.true_center) - np.testing.assert_array_almost_equal(im3.array, im.array) - - # Test that a few pixel values match xValue. - # Note: we don't expect the FFT drawn image to match the xValues precisely, since the - # latter use real-space convolution, so they should just match to our overall accuracy - # requirement, which is something like 1.e-3 or so. But an image of just the galaxy - # should use real-space drawing, so should be pretty much exact. - im2 = galsim.ImageD(nx,ny, scale=scale) - gal.drawImage(im2, method='sb') - for x,y in xy_list: - u = (x-cenx) * scale - v = (y-ceny) * scale - np.testing.assert_almost_equal( - im(x,y), obj.xValue(galsim.PositionD(u,v)), 2, - "im(%d,%d) does not match xValue(%f,%f)"%(x,y,u,v)) - np.testing.assert_almost_equal( - im2(x,y), gal.xValue(galsim.PositionD(u,v)), 6, - "im2(%d,%d) does not match xValue(%f,%f)"%(x,y,u,v)) - - # Check that offset moves the centroid by the right amount. - for offx, offy in offset_list: - # For integer offsets, we expect the centroids to come out pretty much exact. - # (Only edge effects of the image should produce any error, and those are very small.) - # However, for non-integer effects, we don't actually expect the centroids to be - # right, even with perfect image rendering. To see why, imagine using a delta function - # for the galaxy. The centroid changes discretely, not continuously as the offset - # varies. The effect isn't as severe of course for our Exponential, but the effect - # is still there in part. Hence, only use 2 decimal places for non-integer offsets. - if offx == int(offx) and offy == int(offy): - decimal = 4 - else: - decimal = 2 - - offset = galsim.PositionD(offx,offy) - obj.drawImage(im, method='sb', offset=offset) - mom = galsim.utilities.unweighted_moments(im) - np.testing.assert_almost_equal( - mom['Mx'], cenx+offx, decimal, - "obj.drawImage(im,offset) not centered correctly for (nx,ny) = %d,%d"%(nx,ny)) - np.testing.assert_almost_equal( - mom['My'], ceny+offy, decimal, - "obj.drawImage(im,offset) not centered correctly for (nx,ny) = %d,%d"%(nx,ny)) - # Test that a few pixel values match xValue - gal.drawImage(im2, method='sb', offset=offset) - for x,y in xy_list: - u = (x-cenx-offx) * scale - v = (y-ceny-offy) * scale - np.testing.assert_almost_equal( - im(x,y), obj.xValue(galsim.PositionD(u,v)), 2, - "im(%d,%d) does not match xValue(%f,%f)"%(x,y,u,v)) - np.testing.assert_almost_equal( - im2(x,y), gal.xValue(galsim.PositionD(u,v)), 6, - "im2(%d,%d) does not match xValue(%f,%f)"%(x,y,u,v)) - - # Check that shift also moves the centroid by the right amount. - shifted_obj = obj.shift(offset * scale) - shifted_obj.drawImage(im, method='sb') - mom = galsim.utilities.unweighted_moments(im) - np.testing.assert_almost_equal( - mom['Mx'], cenx+offx, decimal, - "shifted_obj.drawImage(im) not centered correctly for (nx,ny) = %d,%d"%(nx,ny)) - np.testing.assert_almost_equal( - mom['My'], ceny+offy, decimal, - "shifted_obj.drawImage(im) not centered correctly for (nx,ny) = %d,%d"%(nx,ny)) - # Test that a few pixel values match xValue - shifted_gal = gal.shift(offset * scale) - shifted_gal.drawImage(im2, method='sb') - for x,y in xy_list: - u = (x-cenx) * scale - v = (y-ceny) * scale - np.testing.assert_almost_equal( - im(x,y), shifted_obj.xValue(galsim.PositionD(u,v)), 2, - "im(%d,%d) does not match shifted xValue(%f,%f)"%(x,y,x-cenx,y-ceny)) - np.testing.assert_almost_equal( - im2(x,y), shifted_gal.xValue(galsim.PositionD(u,v)), 6, - "im2(%d,%d) does not match shifted xValue(%f,%f)"%(x,y,x-cenx,y-ceny)) - u = (x-cenx-offx) * scale - v = (y-ceny-offy) * scale - np.testing.assert_almost_equal( - im(x,y), obj.xValue(galsim.PositionD(u,v)), 2, - "im(%d,%d) does not match xValue(%f,%f)"%(x,y,u,v)) - np.testing.assert_almost_equal( - im2(x,y), gal.xValue(galsim.PositionD(u,v)), 6, - "im2(%d,%d) does not match xValue(%f,%f)"%(x,y,u,v)) - - # Test that the center parameter can be used to do the same thing. - center = galsim.PositionD(cenx + offx, ceny + offy) - im3 = obj.drawImage(im.copy(), method='sb', center=center) - np.testing.assert_almost_equal(im3.array, im.array) - assert im3.bounds == im.bounds - assert im3.wcs == im.wcs - - # Can also use both offset and center - im3 = obj.drawImage(im.copy(), method='sb', - center=(cenx-1, ceny+1), offset=(offx+1, offy-1)) - np.testing.assert_almost_equal(im3.array, im.array) - assert im3.bounds == im.bounds - assert im3.wcs == im.wcs - - # Check the image's definition of the nominal center - nom_cenx = (nx+2)//2 - nom_ceny = (ny+2)//2 - nominal_center = im.bounds.center - np.testing.assert_almost_equal( - nom_cenx, nominal_center.x, 6, - "im.bounds.center.x is wrong for (nx,ny) = %d,%d"%(nx,ny)) - np.testing.assert_almost_equal( - nom_ceny, nominal_center.y, 6, - "im.bounds.center.y is wrong for (nx,ny) = %d,%d"%(nx,ny)) - - # Check that use_true_center = false is consistent with an offset by 0 or 0.5 pixels. - obj.drawImage(im, method='sb', use_true_center=False) - mom = galsim.utilities.unweighted_moments(im) - np.testing.assert_almost_equal( - mom['Mx'], nom_cenx, 4, - "obj.drawImage(im, use_true_center=False) not centered correctly for (nx,ny) = "+ - "%d,%d"%(nx,ny)) - np.testing.assert_almost_equal( - mom['My'], nom_ceny, 4, - "obj.drawImage(im, use_true_center=False) not centered correctly for (nx,ny) = "+ - "%d,%d"%(nx,ny)) - cen_offset = galsim.PositionD(nom_cenx - cenx, nom_ceny - ceny) - obj.drawImage(im2, method='sb', offset=cen_offset) - np.testing.assert_array_almost_equal( - im.array, im2.array, 6, - "obj.drawImage(im, offset=%f,%f) different from use_true_center=False") - - # Can also use center to explicitly say to use the integer center - im3 = obj.drawImage(im.copy(), method='sb', center=im.center) - np.testing.assert_almost_equal(im3.array, im.array) - -def test_shoot(): - """Test drawImage(..., method='phot') - - Most tests of the photon shooting method are done using the `do_shoot` function calls - in various places. Here we test other aspects of photon shooting that are not fully - covered by these other tests. - """ - # This test comes from a bug report by Jim Chiang on issue #866. There was a rounding - # problem when the number of photons to shoot came out to 100,000 + 1. It did the first - # 100,000 and then was left with 1, but rounding errors (since it is a double, not an int) - # was 1 - epsilon, and it ended up in a place where it shouldn't have been able to get to - # in exact arithmetic. We had an assert there which blew up in a not very nice way. - obj = galsim.Gaussian(sigma=0.2398318) + 0.1*galsim.Gaussian(sigma=0.47966352) - obj = obj.withFlux(100001) - # JAX-Galsim adjusts the images to double here - image1 = galsim.ImageD(32,32, init_value=100) - rng = galsim.BaseDeviate(1234) - obj.drawImage(image1, method='phot', poisson_flux=False, add_to_image=True, rng=rng, - maxN=100000) - - # The test here is really just that it doesn't crash. - # But let's do something to check correctness. - # JAX-Galsim adjusts the images to double here - image2 = galsim.ImageD(32,32) - rng = galsim.BaseDeviate(1234) - obj.drawImage(image2, method='phot', poisson_flux=False, add_to_image=False, rng=rng, - maxN=100000) - image2 += 100 - # with double, we get the same result to 10 decimal places - np.testing.assert_array_almost_equal(image2.array, image1.array, decimal=10) - - # Also check that you get the same answer with a smaller maxN. - image3 = galsim.ImageF(32,32, init_value=100) - rng = galsim.BaseDeviate(1234) - obj.drawImage(image3, method='phot', poisson_flux=False, add_to_image=True, rng=rng, maxN=1000) - # It's not exactly the same, since the rngs are realized in a different order. - np.testing.assert_allclose(image3.array, image1.array, rtol=0.25) - - # Test that shooting with 0.0 flux makes a zero-photons image. - image4 = (obj*0).drawImage(method='phot') - np.testing.assert_array_equal(image4.array, 0) - - # Warns if flux is 1 and n_photons not given. - # JAX-GalSim doesn't warn in this case - psf = galsim.Gaussian(sigma=3) - # with assert_warns(galsim.GalSimWarning): - # psf.drawImage(method='phot') - # with assert_warns(galsim.GalSimWarning): - # psf.drawPhot(image4) - # with assert_warns(galsim.GalSimWarning): - # psf.makePhot() - - # With n_photons=1, it's fine. - psf.drawImage(method='phot', n_photons=1) - psf.drawPhot(image4, n_photons=1) - psf.makePhot(n_photons=1) - - # Check negative flux shooting with poisson_flux=True - # The do_shoot test in galsim_test_helpers checks negative flux with a fixed number of photons. - # But we also want to check that the automatic number of photons is reaonable when the flux - # is negative. - obj = obj.withFlux(-1.e5) - image3 = galsim.ImageF(64,64) - obj.drawImage(image3, method='phot', poisson_flux=True, rng=rng) - print('image3.sum = ',image3.array.sum()) - # Only accurate to about sqrt(1.e5) from Poisson realization - np.testing.assert_allclose(image3.array.sum(), obj.flux, rtol=0.01) - - -@timer -def test_drawImage_area_exptime(): - """Test that area and exptime kwargs to drawImage() appropriately scale image.""" - exptime = 2 - area = 1.4 - - # We will be photon shooting, so use largish flux. - obj = galsim.Exponential(flux=1776., scale_radius=2) - - im1 = obj.drawImage(nx=24, ny=24, scale=0.3) - im2 = obj.drawImage(image=im1.copy(), exptime=exptime, area=area) - np.testing.assert_array_almost_equal(im1.array, im2.array/exptime/area, 5, - "obj.drawImage() did not respect area and exptime kwargs.") - - # Now check with drawShoot(). Scaling the gain should just scale the image proportionally. - # Scaling the area or exptime should actually produce a non-proportional image, though, since a - # different number of photons will be shot. - - rng = galsim.BaseDeviate(1234) - im1 = obj.drawImage(nx=24, ny=24, scale=0.3, method='phot', rng=rng.duplicate()) - im2 = obj.drawImage(image=im1.copy(), method='phot', rng=rng.duplicate()) - np.testing.assert_array_almost_equal(im1.array, im2.array, 5, - "obj.drawImage(method='phot', rng=rng.duplicate()) did not produce image " - "deterministically.") - im3 = obj.drawImage(image=im1.copy(), method='phot', rng=rng.duplicate(), gain=2) - np.testing.assert_array_almost_equal(im1.array, im3.array*2, 5, - "obj.drawImage(method='phot', rng=rng.duplicate(), gain=2) did not produce image " - "deterministically.") - - im4 = obj.drawImage(image=im1.copy(), method='phot', rng=rng.duplicate(), - area=area, exptime=exptime) - msg = ("obj.drawImage(method='phot') unexpectedly produced proportional images with different " - "area and exptime keywords.") - assert not np.allclose(im1.array, im4.array/area/exptime), msg - - im5 = obj.drawImage(image=im1.copy(), method='phot', area=area, exptime=exptime) - msg = "obj.drawImage(method='phot') unexpectedly produced equal images with different rng" - assert not np.allclose(im5.array, im4.array), msg - - # JAX-GalSim doesn't raise for these things - # # Shooting with flux=1 raises a warning. - # obj1 = obj.withFlux(1) - # with assert_warns(galsim.GalSimWarning): - # obj1.drawImage(method='phot') - # # But not if we explicitly tell it to shoot 1 photon - # with assert_raises(AssertionError): - # assert_warns(galsim.GalSimWarning, obj1.drawImage, method='phot', n_photons=1) - # # Likewise for makePhot - # with assert_warns(galsim.GalSimWarning): - # obj1.makePhot() - # with assert_raises(AssertionError): - # assert_warns(galsim.GalSimWarning, obj1.makePhot, n_photons=1) - # # And drawPhot - # with assert_warns(galsim.GalSimWarning): - # obj1.drawPhot(im1) - # with assert_raises(AssertionError): - # assert_warns(galsim.GalSimWarning, obj1.drawPhot, im1, n_photons=1) - - -@timer -def test_fft(): - """Test the routines for calculating the fft of an image. - """ - - # Start with a really simple test of the round trip fft and then inverse_fft. - # And run it for all input types to make sure they all work. - types = [np.int16, np.int32, np.float32, np.float64, int, float] - for dt in types: - xim = galsim.Image([ [0,2,4,2], - [2,4,6,4], - [4,6,8,4], - [2,4,6,6] ], - xmin=-2, ymin=-2, dtype=dt, scale=0.1) - kim = xim.calculate_fft() - xim2 = kim.calculate_inverse_fft() - np.testing.assert_array_almost_equal(xim.array, xim2.array) - - # Now the other way, starting with a (real) k-space image. - kim = galsim.Image([ [4,2,0], - [6,4,2], - [8,6,4], - [6,4,2] ], - xmin=0, ymin=-2, dtype=dt, scale=0.1) - xim = kim.calculate_inverse_fft() - kim2 = xim.calculate_fft() - np.testing.assert_array_almost_equal(kim.array, kim2.array) - - # Test starting with a larger image that gets wrapped. - kim3 = galsim.Image([ [0,1,2,1,0], - [1,4,6,4,1], - [2,6,8,6,2], - [1,4,6,4,1], - [0,1,2,1,0] ], - xmin=-2, ymin=-2, dtype=dt, scale=0.1) - xim = kim3.calculate_inverse_fft() - kim2 = xim.calculate_fft() - np.testing.assert_array_almost_equal(kim.array, kim2.array) - - # Test padding X Image with zeros - xim = galsim.Image([ [0,0,0,0], - [2,4,6,0], - [4,6,8,0], - [0,0,0,0] ], - xmin=-2, ymin=-2, dtype=dt, scale=0.1) - xim2 = galsim.Image([ [2,4,6], - [4,6,8] ], - xmin=-2, ymin=-1, dtype=dt, scale=0.1) - kim = xim.calculate_fft() - kim2 = xim2.calculate_fft() - np.testing.assert_array_almost_equal(kim.array, kim2.array) - - # Test padding K Image with zeros - kim = galsim.Image([ [4,2,0], - [6,4,0], - [8,6,0], - [6,4,0] ], - xmin=0, ymin=-2, dtype=dt, scale=0.1) - kim2 = galsim.Image([ [6,4], - [8,6], - [6,4], - [4,2] ], - xmin=0, ymin=-1, dtype=dt, scale=0.1) - xim = kim.calculate_inverse_fft() - xim2 = kim2.calculate_inverse_fft() - np.testing.assert_array_almost_equal(xim.array, xim2.array) - - # Now use drawKImage (as above in test_drawKImage) to get a more realistic k-space image - obj = galsim.Moffat(flux=test_flux, beta=1.5, scale_radius=0.5) - obj = obj.withGSParams(maxk_threshold=1.e-4) - im1 = obj.drawKImage() - N = 1174 # NB. It is useful to have this come out not a multiple of 4, since some of the - # calculation needs to be different when N/2 is odd. - np.testing.assert_equal(im1.bounds, galsim.BoundsI(-N/2,N/2,-N/2,N/2), - "obj.drawKImage() produced image with wrong bounds") - nyq_scale = obj.nyquist_scale - - # If we inverse_fft the above automatic image, it should match the automatic real image - # for method = 'sb' and use_true_center=False. - im1_real = im1.calculate_inverse_fft() - # Convolve by a delta function to force FFT drawing. - obj2 = galsim.Convolve(obj, galsim.Gaussian(sigma=1.e-10)) - im1_alt_real = obj2.drawImage(method='sb', use_true_center=False) - im1_alt_real.setCenter(0,0) # This isn't done automatically. - np.testing.assert_equal( - im1_real.bounds, im1_alt_real.bounds, - "inverse_fft did not produce the same bounds as obj2.drawImage(method='sb')") - # The scale and array are only approximately equal, because drawImage rounds the size up to - # an even number and uses Nyquist scale for dx. - np.testing.assert_almost_equal( - im1_real.scale, im1_alt_real.scale, 3, - "inverse_fft produce a different scale than obj2.drawImage(method='sb')") - np.testing.assert_array_almost_equal( - im1_real.array, im1_alt_real.array, 3, - "inverse_fft produce a different array than obj2.drawImage(method='sb')") - - # If we give both a good size to use and match up the scales, then they should produce the - # same thing. - N = galsim.Image.good_fft_size(N) - assert N == 1536 == 3 * 2**9 - kscale = 2.*np.pi / (N * nyq_scale) - im2 = obj.drawKImage(nx=N+1, ny=N+1, scale=kscale) - im2_real = im2.calculate_inverse_fft() - im2_alt_real = obj2.drawImage(nx=N, ny=N, method='sb', use_true_center=False, dtype=float) - im2_alt_real.setCenter(0,0) - np.testing.assert_equal( - im2_real.bounds, im2_alt_real.bounds, - "inverse_fft did not produce the same bounds as obj2.drawImage(nx,ny,method='sb')") - np.testing.assert_almost_equal( - im2_real.scale, im2_alt_real.scale, 9, - "inverse_fft produce a different scale than obj2.drawImage(nx,ny,method='sb')") - np.testing.assert_array_almost_equal( - im2_real.array, im2_alt_real.array, 9, - "inverse_fft produce a different array than obj2.drawImage(nx,ny,method='sb')") - - # wcs must be a PixelScale - xim.wcs = galsim.JacobianWCS(1.1,0.1,0.1,1) - with assert_raises(galsim.GalSimError): - xim.calculate_fft() - with assert_raises(galsim.GalSimError): - xim.calculate_inverse_fft() - xim.wcs = None - with assert_raises(galsim.GalSimError): - xim.calculate_fft() - with assert_raises(galsim.GalSimError): - xim.calculate_inverse_fft() - - # inverse needs image with 0,0 - xim.scale=1 - xim.setOrigin(1,1) - with assert_raises(galsim.GalSimBoundsError): - xim.calculate_inverse_fft() - - -@timer -def test_np_fft(): - """Test the equivalence between np.fft functions and the galsim versions - """ - input_list = [] - input_list.append( np.array([ [0,1,2,1], - [1,2,3,2], - [2,3,4,3], - [1,2,3,2] ], dtype=int )) - input_list.append( np.array([ [0,1], - [1,2], - [2,3], - [1,2] ], dtype=int )) - noise = galsim.GaussianNoise(sigma=5, rng=galsim.BaseDeviate(1234)) - for N in [2,4,8,10]: - xim = galsim.ImageD(N,N) - xim.addNoise(noise) - input_list.append(xim.array) - - for Nx,Ny in [ (2,4), (4,2), (10,6), (6,10) ]: - xim = galsim.ImageD(Nx,Ny) - xim.addNoise(noise) - input_list.append(xim.array) - - for N in [2,4,8,10]: - xim = galsim.ImageCD(N,N) - xim.real.addNoise(noise) - xim.imag.addNoise(noise) - input_list.append(xim.array) - - for Nx,Ny in [ (2,4), (4,2), (10,6), (6,10) ]: - xim = galsim.ImageCD(Nx,Ny) - xim.real.addNoise(noise) - xim.imag.addNoise(noise) - input_list.append(xim.array) - - for xar in input_list: - Ny,Nx = xar.shape - print('Nx,Ny = ',Nx,Ny) - if Nx + Ny < 10: - print('xar = ',xar) - kar1 = np.fft.fft2(xar) - #print('numpy kar = ',kar1) - kar2 = galsim.fft.fft2(xar) - if Nx + Ny < 10: - print('kar = ',kar2) - np.testing.assert_almost_equal(kar1, kar2, 9, "fft2 not equivalent to np.fft.fft2") - - # Check that kar is Hermitian in the way that we describe in the doc for ifft2 - if not np.iscomplexobj(xar): - for kx in range(Nx//2,Nx): - np.testing.assert_almost_equal(kar2[0,kx], kar2[0,Nx-kx].conjugate()) - for ky in range(1,Ny): - np.testing.assert_almost_equal(kar2[ky,kx], kar2[Ny-ky,Nx-kx].conjugate()) - - # Check shift_in - kar3 = np.fft.fft2(np.fft.fftshift(xar)) - kar4 = galsim.fft.fft2(xar, shift_in=True) - np.testing.assert_almost_equal(kar3, kar4, 9, "fft2(shift_in) failed") - - # Check shift_out - kar5 = np.fft.fftshift(np.fft.fft2(xar)) - kar6 = galsim.fft.fft2(xar, shift_out=True) - np.testing.assert_almost_equal(kar5, kar6, 9, "fft2(shift_out) failed") - - # Check both - kar7 = np.fft.fftshift(np.fft.fft2(np.fft.fftshift(xar))) - kar8 = galsim.fft.fft2(xar, shift_in=True, shift_out=True) - np.testing.assert_almost_equal(kar7, kar8, 9, "fft2(shift_in,shift_out) failed") - - # ifft2 - #print('ifft2') - xar1 = np.fft.ifft2(kar2) - xar2 = galsim.fft.ifft2(kar2) - if Nx + Ny < 10: - print('xar2 = ',xar2) - np.testing.assert_almost_equal(xar1, xar2, 9, "ifft2 not equivalent to np.fft.ifft2") - np.testing.assert_almost_equal(xar2, xar, 9, "ifft2(fft2(a)) != a") - - xar3 = np.fft.ifft2(np.fft.fftshift(kar6)) - xar4 = galsim.fft.ifft2(kar6, shift_in=True) - np.testing.assert_almost_equal(xar3, xar4, 9, "ifft2(shift_in) failed") - np.testing.assert_almost_equal(xar4, xar, 9, "ifft2(fft2(a)) != a with shift_in/out") - - xar5 = np.fft.fftshift(np.fft.ifft2(kar4)) - xar6 = galsim.fft.ifft2(kar4, shift_out=True) - np.testing.assert_almost_equal(xar5, xar6, 9, "ifft2(shift_out) failed") - np.testing.assert_almost_equal(xar6, xar, 9, "ifft2(fft2(a)) != a with shift_out/in") - - xar7 = np.fft.fftshift(np.fft.ifft2(np.fft.fftshift(kar8))) - xar8 = galsim.fft.ifft2(kar8, shift_in=True, shift_out=True) - np.testing.assert_almost_equal(xar7, xar8, 9, "ifft2(shift_in,shift_out) failed") - np.testing.assert_almost_equal(xar8, xar, 9, "ifft2(fft2(a)) != a with all shifts") - - if np.iscomplexobj(xar): continue - - # rfft2 - #print('rfft2') - rkar1 = np.fft.rfft2(xar) - rkar2 = galsim.fft.rfft2(xar) - np.testing.assert_almost_equal(rkar1, rkar2, 9, "rfft2 not equivalent to np.fft.rfft2") - - rkar3 = np.fft.rfft2(np.fft.fftshift(xar)) - rkar4 = galsim.fft.rfft2(xar, shift_in=True) - np.testing.assert_almost_equal(rkar3, rkar4, 9, "rfft2(shift_in) failed") - - rkar5 = np.fft.fftshift(np.fft.rfft2(xar),axes=(0,)) - rkar6 = galsim.fft.rfft2(xar, shift_out=True) - np.testing.assert_almost_equal(rkar5, rkar6, 9, "rfft2(shift_out) failed") - - rkar7 = np.fft.fftshift(np.fft.rfft2(np.fft.fftshift(xar)),axes=(0,)) - rkar8 = galsim.fft.rfft2(xar, shift_in=True, shift_out=True) - np.testing.assert_almost_equal(rkar7, rkar8, 9, "rfft2(shift_in,shift_out) failed") - - # irfft2 - #print('irfft2') - xar1 = np.fft.irfft2(rkar1) - xar2 = galsim.fft.irfft2(rkar1) - np.testing.assert_almost_equal(xar1, xar2, 9, "irfft2 not equivalent to np.fft.irfft2") - np.testing.assert_almost_equal(xar2, xar, 9, "irfft2(rfft2(a)) != a") - - xar3 = np.fft.irfft2(np.fft.fftshift(rkar6,axes=(0,))) - xar4 = galsim.fft.irfft2(rkar6, shift_in=True) - np.testing.assert_almost_equal(xar3, xar4, 9, "irfft2(shift_in) failed") - np.testing.assert_almost_equal(xar4, xar, 9, "irfft2(rfft2(a)) != a with shift_in/out") - - xar5 = np.fft.fftshift(np.fft.irfft2(rkar4)) - xar6 = galsim.fft.irfft2(rkar4, shift_out=True) - np.testing.assert_almost_equal(xar5, xar6, 9, "irfft2(shift_out) failed") - np.testing.assert_almost_equal(xar6, xar, 9, "irfft2(rfft2(a)) != a with shift_out/in") - - xar7 = np.fft.fftshift(np.fft.irfft2(np.fft.fftshift(rkar8,axes=(0,)))) - xar8 = galsim.fft.irfft2(rkar8, shift_in=True, shift_out=True) - np.testing.assert_almost_equal(xar7, xar8, 9, "irfft2(shift_in,shift_out) failed") - np.testing.assert_almost_equal(xar8, xar, 9, "irfft2(rfft2(a)) != a with all shifts") - - # ifft can also accept real arrays - xar9 = galsim.fft.fft2(galsim.fft.ifft2(xar)) - np.testing.assert_almost_equal(xar9, xar, 9, "fft2(ifft2(a)) != a") - xar10 = galsim.fft.fft2(galsim.fft.ifft2(xar,shift_in=True),shift_out=True) - np.testing.assert_almost_equal(xar10, xar, 9, "fft2(ifft2(a)) != a with shift_in/out") - xar11 = galsim.fft.fft2(galsim.fft.ifft2(xar,shift_out=True),shift_in=True) - np.testing.assert_almost_equal(xar11, xar, 9, "fft2(ifft2(a)) != a with shift_out/in") - xar12 = galsim.fft.fft2(galsim.fft.ifft2(xar,shift_in=True,shift_out=True), - shift_in=True,shift_out=True) - np.testing.assert_almost_equal(xar12, xar, 9, "fft2(ifft2(a)) != a with all shifts") - - # Check for invalid inputs - # Must be 2-d arrays - xar_1d = input_list[0].ravel() - xar_3d = input_list[0].reshape(2,2,4) - xar_4d = input_list[0].reshape(2,2,2,2) - assert_raises(ValueError, galsim.fft.fft2, xar_1d) - assert_raises(ValueError, galsim.fft.fft2, xar_3d) - assert_raises(ValueError, galsim.fft.fft2, xar_4d) - assert_raises(ValueError, galsim.fft.ifft2, xar_1d) - assert_raises(ValueError, galsim.fft.ifft2, xar_3d) - assert_raises(ValueError, galsim.fft.ifft2, xar_4d) - assert_raises(ValueError, galsim.fft.rfft2, xar_1d) - assert_raises(ValueError, galsim.fft.rfft2, xar_3d) - assert_raises(ValueError, galsim.fft.rfft2, xar_4d) - assert_raises(ValueError, galsim.fft.irfft2, xar_1d) - assert_raises(ValueError, galsim.fft.irfft2, xar_3d) - assert_raises(ValueError, galsim.fft.irfft2, xar_4d) - - # Must have even sizes - xar_oo = input_list[0][:3,:3] - xar_oe = input_list[0][:3,:] - xar_eo = input_list[0][:,:3] - assert_raises(ValueError, galsim.fft.fft2, xar_oo) - assert_raises(ValueError, galsim.fft.fft2, xar_oe) - assert_raises(ValueError, galsim.fft.fft2, xar_eo) - assert_raises(ValueError, galsim.fft.ifft2, xar_oo) - assert_raises(ValueError, galsim.fft.ifft2, xar_oe) - assert_raises(ValueError, galsim.fft.ifft2, xar_eo) - assert_raises(ValueError, galsim.fft.rfft2, xar_oo) - assert_raises(ValueError, galsim.fft.rfft2, xar_oe) - assert_raises(ValueError, galsim.fft.rfft2, xar_eo) - assert_raises(ValueError, galsim.fft.irfft2, xar_oo) - assert_raises(ValueError, galsim.fft.irfft2, xar_oe) - # eo is ok, since the second dimension is actually N/2+1 - -def round_cast(array, dt): - # array.astype(dt) doesn't round to the nearest for integer types. - # This rounds first if dt is integer and then casts. - # NOTE JAX doesn't round to the nearest int when drawing - # if dt(0.5) != 0.5: - # array = np.around(array) - return array.astype(dt) - -@timer -def test_types(): - """Test drawing onto image types other than float32, float64. - """ - - # Methods test drawReal, drawFFT, drawPhot respectively - for method in ['no_pixel', 'fft', 'phot']: - if method == 'phot': - rng = galsim.BaseDeviate(1234) - else: - rng = None - obj = galsim.Exponential(flux=177, scale_radius=2) - ref_im = obj.drawImage(method=method, dtype=float, rng=rng) - - for dt in [ np.float32, np.float64, np.int16, np.int32, np.uint16, np.uint32, - np.complex128, np.complex64 ]: - if method == 'phot': rng.reset(1234) - print('Checking',method,'with dt =', dt) - im = obj.drawImage(method=method, dtype=dt, rng=rng) - np.testing.assert_equal(im.scale, ref_im.scale, - "wrong scale when drawing onto dt=%s"%dt) - np.testing.assert_equal(im.bounds, ref_im.bounds, - "wrong bounds when drawing onto dt=%s"%dt) - np.testing.assert_array_almost_equal(im.array, round_cast(ref_im.array, dt), 6, - "wrong array when drawing onto dt=%s"%dt) - - if method == 'phot': - rng.reset(1234) - obj.drawImage(im, method=method, add_to_image=True, rng=rng) - np.testing.assert_array_almost_equal(im.array, round_cast(ref_im.array, dt) * 2, 6, - "wrong array when adding to image with dt=%s"%dt) - -@timer -def test_direct_scale(): - """Test the explicit functions with scale != 1 - - The default behavior is to change the profile to image coordinates, and draw that onto an - image with scale=1. But the three direct functions allow the image to have a non-unit - pixel scale. (Not more complicated wcs though.) - - This test checks that the results are equivalent between the two calls. - """ - - scale = 0.35 - rng = galsim.BaseDeviate(1234) - obj = galsim.Exponential(flux=177, scale_radius=2) - obj_with_pixel = galsim.Convolve(obj, galsim.Pixel(scale)) - obj_sb = obj / scale**2 - - # Make these odd, so we don't have to deal with the centering offset stuff. - im1 = galsim.ImageD(65, 65, scale=scale) - im2 = galsim.ImageD(65, 65, scale=scale) - im2.setCenter(0,0) - - # One possibe use of the specific functions is to not automatically recenter on 0,0. - # So make sure they work properly if 0,0 is not the center - im3 = galsim.ImageD(32, 32, scale=scale) # origin is (1,1) - im4 = galsim.ImageD(32, 32, scale=scale) - im5 = galsim.ImageD(32, 32, scale=scale) - - obj.drawImage(im1, method='no_pixel') - obj.drawReal(im2) - obj.drawReal(im3) - # Note that cases 4 and 5 have objects that are logically identical (because obj is circularly - # symmetric), but the code follows different paths in the SBProfile.draw function due to the - # different jacobians in each case. - obj.dilate(1.0).drawReal(im4) - obj.rotate(0.3*galsim.radians).drawReal(im5) - print('no_pixel: max diff = ',np.max(np.abs(im1.array - im2.array))) - np.testing.assert_array_almost_equal(im1.array, im2.array, 15, - "drawReal made different image than method='no_pixel'") - np.testing.assert_array_almost_equal(im3.array, im2[im3.bounds].array, 15, - "drawReal made different image when off-center") - np.testing.assert_array_almost_equal(im4.array, im2[im3.bounds].array, 15, - "drawReal made different image when jac is not None") - np.testing.assert_array_almost_equal(im5.array, im2[im3.bounds].array, 15, - "drawReal made different image when jac is not diagonal") - - obj.drawImage(im1, method='sb') - obj_sb.drawReal(im2) - obj_sb.drawReal(im3) - obj_sb.dilate(1.0).drawReal(im4) - obj_sb.rotate(0.3*galsim.radians).drawReal(im5) - print('sb: max diff = ',np.max(np.abs(im1.array - im2.array))) - # JAX - turned this down to 14 here - np.testing.assert_array_almost_equal(im1.array, im2.array, 14, - "drawReal made different image than method='sb'") - np.testing.assert_array_almost_equal(im3.array, im2[im3.bounds].array, 15, - "drawReal made different image when off-center") - np.testing.assert_array_almost_equal(im4.array, im2[im3.bounds].array, 15, - "drawReal made different image when jac is not None") - np.testing.assert_array_almost_equal(im5.array, im2[im3.bounds].array, 14, - "drawReal made different image when jac is not diagonal") - - obj.drawImage(im1, method='fft') - obj_with_pixel.drawFFT(im2) - obj_with_pixel.drawFFT(im3) - obj_with_pixel.dilate(1.0).drawFFT(im4) - obj_with_pixel.rotate(90 * galsim.degrees).drawFFT(im5) - print('fft: max diff = ',np.max(np.abs(im1.array - im2.array))) - np.testing.assert_array_almost_equal(im1.array, im2.array, 15, - "drawFFT made different image than method='fft'") - np.testing.assert_array_almost_equal(im3.array, im2[im3.bounds].array, 15, - "drawFFT made different image when off-center") - np.testing.assert_array_almost_equal(im4.array, im2[im3.bounds].array, 15, - "drawFFT made different image when jac is not None") - np.testing.assert_array_almost_equal(im5.array, im2[im3.bounds].array, 14, - "drawFFT made different image when jac is not diagonal") - - obj.drawImage(im1, method='real_space') - obj_with_pixel.drawReal(im2) - obj_with_pixel.drawReal(im3) - obj_with_pixel.dilate(1.0).drawReal(im4) - obj_with_pixel.rotate(90 * galsim.degrees).drawReal(im5) - print('real_space: max diff = ',np.max(np.abs(im1.array - im2.array))) - # I'm not sure why this one comes out a bit less precisely equal. But 12 digits is still - # plenty accurate enough. - np.testing.assert_almost_equal(im1.array, im2.array, 12, - "drawReal made different image than method='real_space'") - np.testing.assert_almost_equal(im3.array, im2[im3.bounds].array, 14, - "drawReal made different image when off-center") - np.testing.assert_almost_equal(im4.array, im2[im3.bounds].array, 14, - "drawReal made different image when jac is not None") - np.testing.assert_almost_equal(im5.array, im2[im3.bounds].array, 14, - "drawReal made different image when jac is not diagonal") - - obj.drawImage(im1, method='phot', rng=rng.duplicate()) - _, phot1 = obj.drawPhot(im2, rng=rng.duplicate()) - _, phot2 = obj.drawPhot(im3, rng=rng.duplicate()) - phot3 = obj.makePhot(rng=rng.duplicate()) - phot3.scaleXY(1./scale) - phot4 = im3.wcs.toImage(obj).makePhot(rng=rng.duplicate()) - print('phot: max diff = ',np.max(np.abs(im1.array - im2.array))) - np.testing.assert_almost_equal(im1.array, im2.array, 15, - "drawPhot made different image than method='phot'") - np.testing.assert_almost_equal(im3.array, im2[im3.bounds].array, 15, - "drawPhot made different image when off-center") - assert phot2 == phot1, "drawPhot made different photons than method='phot'" - assert phot3 == phot1, "makePhot made different photons than method='phot'" - # phot4 has a different order of operations for the math, so it doesn't come out exact. - np.testing.assert_almost_equal(phot4.x, phot3.x, 15, - "two ways to have makePhot apply scale have different x") - np.testing.assert_almost_equal(phot4.y, phot3.y, 15, - "two ways to have makePhot apply scale have different y") - np.testing.assert_almost_equal(phot4.flux, phot3.flux, 15, - "two ways to have makePhot apply scale have different flux") - - # Check images with invalid wcs raise ValueError - im4 = galsim.ImageD(65, 65) - im5 = galsim.ImageD(65, 65, wcs=galsim.JacobianWCS(0.4,0.1,-0.1,0.5)) - assert_raises(ValueError, obj.drawReal, im4) - assert_raises(ValueError, obj.drawReal, im5) - assert_raises(ValueError, obj.drawFFT, im4) - assert_raises(ValueError, obj.drawFFT, im5) - assert_raises(ValueError, obj.drawPhot, im4) - assert_raises(ValueError, obj.drawPhot, im5) - # Also some other errors from drawPhot - assert_raises(ValueError, obj.drawPhot, im2, n_photons=-20) - assert_raises(TypeError, obj.drawPhot, im2, sensor=5) - assert_raises(ValueError, obj.makePhot, n_photons=-20) - -if __name__ == "__main__": - testfns = [v for k, v in vars().items() if k[:5] == 'test_' and callable(v)] - for testfn in testfns: - testfn() diff --git a/tests/jax/galsim/test_image_jax.py b/tests/jax/galsim/test_image_jax.py deleted file mode 100644 index 1c229c32..00000000 --- a/tests/jax/galsim/test_image_jax.py +++ /dev/null @@ -1,4850 +0,0 @@ -# Copyright (c) 2012-2021 by the GalSim developers team on GitHub -# https://github.com/GalSim-developers -# -# This file is part of GalSim: The modular galaxy image simulation toolkit. -# https://github.com/GalSim-developers/GalSim -# -# GalSim is free software: redistribution and use in source and binary forms, -# with or without modification, are permitted provided that the following -# conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions, and the disclaimer given in the accompanying LICENSE -# file. -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions, and the disclaimer given in the documentation -# and/or other materials provided with the distribution. -# - -"""Unit tests for the Image class. - -These tests use six externally generated (IDL + astrolib FITS writing tools) reference images for -the Image unit tests. These are in tests/data/. - -Each image is 5x7 pixels^2 and if each pixel is labelled (x, y) then each pixel value is 10*x + y. -The array thus has values: - -15 25 35 45 55 65 75 -14 24 34 44 54 64 74 -13 23 33 43 53 63 73 ^ -12 22 32 42 52 62 72 | -11 21 31 41 51 61 71 y - -x -> - -With array directions as indicated. This hopefully will make it easy enough to perform sub-image -checks, etc. - -Images are in US, UI, S, I, F, D, CF, and CD flavors. - -There are also four FITS cubes, and four FITS multi-extension files for testing. Each is 12 -images deep, with the first image being the reference above and each subsequent being the same -incremented by one. - -""" - -from __future__ import print_function - -import os -import sys -from unicodedata import decimal - -sys.path.insert( - 1, os.path.abspath(os.path.join(os.path.dirname(__file__), "../GalSim/tests")) -) -import galsim -import numpy as np -from galsim._pyfits import pyfits -from galsim_test_helpers import * - -# Setup info for tests, not likely to change -ntypes = 8 # Note: Most tests below only run through the first 8 types. -# test_Image_basic tests all 11 types including the aliases. -types = [ - np.int16, - np.int32, - np.uint16, - np.uint32, - np.float32, - np.float64, - np.complex64, - np.complex128, - int, - float, - complex, -] -simple_types = [int, int, int, int, float, float, complex, complex, int, float, complex] -np_types = [ - np.int16, - np.int32, - np.uint16, - np.uint32, - np.float32, - np.float64, - np.complex64, - np.complex128, - np.int32, - np.float64, - np.complex128, -] -tchar = ["S", "I", "US", "UI", "F", "D", "CF", "CD", "I", "D", "CD"] -int_ntypes = ( - 4 # The first four are the integer types for which we need to test &, |, ^. -) - -ncol = 7 -nrow = 5 -test_shape = (ncol, nrow) # shape of image arrays for all tests -ref_array = np.array( - [ - [11, 21, 31, 41, 51, 61, 71], - [12, 22, 32, 42, 52, 62, 72], - [13, 23, 33, 43, 53, 63, 73], - [14, 24, 34, 44, 54, 64, 74], - [15, 25, 35, 45, 55, 65, 75], - ] -).astype(np.int16) -large_array = np.zeros((ref_array.shape[0] * 3, ref_array.shape[1] * 2), dtype=np.int16) -large_array[::3, ::2] = ref_array - -# Depth of FITS datacubes and multi-extension FITS files -if __name__ == "__main__": - nimages = 12 -else: - # There really are 12, but testing the first 3 should be plenty as a unit test, and - # it helps speed things up. - nimages = 3 - -datadir = os.path.join("tests/GalSim/tests", "Image_comparison_images") - - -@timer -def test_Image_basic(): - """Test that all supported types perform basic Image operations correctly""" - # Do all 10 types here, rather than just the 7 numpy types. i.e. Test the aliases. - for i in range(len(types)): - # Check basic constructor from ncol, nrow - array_type = types[i] - np_array_type = np_types[i] - print("array_type = ", array_type, np_array_type) - - # Check basic constructor from ncol, nrow - im1 = galsim.Image(ncol, nrow, dtype=array_type) - - # Check basic features of array built by Image - np.testing.assert_array_equal(im1.array, 0.0) - assert im1.array.shape == (nrow, ncol) - assert im1.array.dtype.type == np_array_type - # JAX specific modification - # ------------------------- - # These two lines are disabled because jnp.ndarrays do not have flags - # assert im1.array.flags.writeable == True - # assert im1.array.flags.c_contiguous == True - assert im1.dtype == np_array_type - - im1.fill(23) - np.testing.assert_array_equal(im1.array, 23.0) - - bounds = galsim.BoundsI(1, ncol, 1, nrow) - assert im1.xmin == 1 - assert im1.xmax == ncol - assert im1.ymin == 1 - assert im1.ymax == nrow - assert im1.bounds == bounds - assert im1.outer_bounds == galsim.BoundsD(0.5, ncol + 0.5, 0.5, nrow + 0.5) - - # Same thing if ncol,nrow are kwargs. Also can give init_value - im1b = galsim.Image(ncol=ncol, nrow=nrow, dtype=array_type, init_value=23) - np.testing.assert_array_equal(im1b.array, 23.0) - assert im1 == im1b - - # Adding on xmin, ymin allows you to set an origin other than (1,1) - im1a = galsim.Image(ncol, nrow, dtype=array_type, xmin=4, ymin=7) - im1b = galsim.Image(ncol=ncol, nrow=nrow, dtype=array_type, xmin=0, ymin=0) - assert im1a.xmin == 4 - assert im1a.xmax == ncol + 3 - assert im1a.ymin == 7 - assert im1a.ymax == nrow + 6 - assert im1a.bounds == galsim.BoundsI(4, ncol + 3, 7, nrow + 6) - assert im1a.outer_bounds == galsim.BoundsD(3.5, ncol + 3.5, 6.5, nrow + 6.5) - assert im1b.xmin == 0 - assert im1b.xmax == ncol - 1 - assert im1b.ymin == 0 - assert im1b.ymax == nrow - 1 - assert im1b.bounds == galsim.BoundsI(0, ncol - 1, 0, nrow - 1) - assert im1b.outer_bounds == galsim.BoundsD(-0.5, ncol - 0.5, -0.5, nrow - 0.5) - - # Also test alternate name of image type: ImageD, ImageF, etc. - image_type = eval( - "galsim.Image" + tchar[i] - ) # Use handy eval() mimics use of ImageSIFD - im2 = image_type(bounds, init_value=23) - im2_view = im2.view() - im2_cview = im2.view(make_const=True) - im2_conj = im2.conjugate - - assert im2_view.xmin == 1 - assert im2_view.xmax == ncol - assert im2_view.ymin == 1 - assert im2_view.ymax == nrow - assert im2_view.bounds == bounds - assert im2_view.array.dtype.type == np_array_type - assert im2_view.dtype == np_array_type - - assert im2_cview.xmin == 1 - assert im2_cview.xmax == ncol - assert im2_cview.ymin == 1 - assert im2_cview.ymax == nrow - assert im2_cview.bounds == bounds - assert im2_cview.array.dtype.type == np_array_type - assert im2_cview.dtype == np_array_type - - assert im1.real.bounds == bounds - assert im1.imag.bounds == bounds - assert im2.real.bounds == bounds - assert im2.imag.bounds == bounds - assert im2_view.real.bounds == bounds - assert im2_view.imag.bounds == bounds - assert im2_cview.real.bounds == bounds - assert im2_cview.imag.bounds == bounds - if tchar[i] == "CF": - assert im1.real.dtype == np.float32 - assert im1.imag.dtype == np.float32 - elif tchar[i] == "CD": - assert im1.real.dtype == np.float64 - assert im1.imag.dtype == np.float64 - else: - assert im1.real.dtype == np_array_type - assert im1.imag.dtype == np_array_type - - # Check various ways to set and get values - for y in range(1, nrow + 1): - for x in range(1, ncol + 1): - im1.setValue(x, y, 100 + 10 * x + y) - im1a.setValue(x + 3, y + 6, 100 + 10 * x + y) - im1b.setValue(x=x - 1, y=y - 1, value=100 + 10 * x + y) - # JAX specific modification - # ------------------------- - # im2_view and im2 do not share the same underlying - # numpy array, so we also update the value of im2 - # im2_view._setValue(x, y, 100 + 10*x) - # im2_view._addValue(x, y, y) - im2._setValue(x, y, 100 + 10 * x) - im2._addValue(x, y, y) - # And we recreate the view and conjugates from the modified - # array - im2_view = im2.view() - im2_cview = im2.view(make_const=True) - im2_conj = im2.conjugate - - for y in range(1, nrow + 1): - for x in range(1, ncol + 1): - value = 100 + 10 * x + y - assert im1(x, y) == value - assert im1(galsim.PositionI(x, y)) == value - assert im1a(x + 3, y + 6) == value - assert im1b(x - 1, y - 1) == value - assert im1.view()(x, y) == value - assert im1.view()(galsim.PositionI(x, y)) == value - assert im1.view(make_const=True)(x, y) == value - assert im2(x, y) == value - assert im2_view(x, y) == value - assert im2_cview(x, y) == value - assert im1.conjugate(x, y) == value - # JAX specific modification - # ------------------------- - # We have actually redefined im2_conj, so we can use this line - # if tchar[i][0] == 'C': - # # complex conjugate is not a view into the original. - # assert im2_conj(x,y) == 23 - # assert im2.conjugate(x,y) == value - # else: - # assert im2_conj(x,y) == value - assert im2_conj(x, y) == value - - value2 = 53 + 12 * x - 19 * y - if tchar[i] in ["US", "UI"]: - value2 = abs(value2) - im1[x, y] = value2 - im2_view[galsim.PositionI(x, y)] = value2 - # JAX specific modification - # ------------------------- - # Also updating the value of im2 and the cview - im2[galsim.PositionI(x, y)] = value2 - im2_cview[galsim.PositionI(x, y)] = value2 - assert im1.getValue(x, y) == value2 - assert im1.view().getValue(x=x, y=y) == value2 - assert im1.view(make_const=True).getValue(x, y) == value2 - assert im2.getValue(x=x, y=y) == value2 - assert im2_view.getValue(x, y) == value2 - assert im2_cview._getValue(x, y) == value2 - - assert im1.real(x, y) == value2 - assert im1.view().real(x, y) == value2 - assert im1.view(make_const=True).real(x, y) == value2.real - assert im2.real(x, y) == value2.real - assert im2_view.real(x, y) == value2.real - assert im2_cview.real(x, y) == value2.real - assert im1.imag(x, y) == 0 - assert im1.view().imag(x, y) == 0 - assert im1.view(make_const=True).imag(x, y) == 0 - assert im2.imag(x, y) == 0 - assert im2_view.imag(x, y) == 0 - assert im2_cview.imag(x, y) == 0 - - value3 = 10 * x + y - im1.addValue(x, y, value3 - value2) - im2_view[x, y] += value3 - value2 - # JAX specific modification - # ------------------------- - # Also updating the value of im2 and the cview - im2[galsim.PositionI(x, y)] += value3 - value2 - im2_cview[galsim.PositionI(x, y)] += value3 - value2 - assert im1[galsim.PositionI(x, y)] == value3 - assert im1.view()[x, y] == value3 - assert im1.view(make_const=True)[galsim.PositionI(x, y)] == value3 - assert im2[x, y] == value3 - assert im2_view[galsim.PositionI(x, y)] == value3 - assert im2_cview[x, y] == value3 - - # Setting or getting the value outside the bounds should throw an exception. - assert_raises(galsim.GalSimBoundsError, im1.setValue, 0, 0, 1) - assert_raises(galsim.GalSimBoundsError, im1.addValue, 0, 0, 1) - assert_raises(galsim.GalSimBoundsError, im1.__call__, 0, 0) - assert_raises(galsim.GalSimBoundsError, im1.__getitem__, 0, 0) - assert_raises(galsim.GalSimBoundsError, im1.__setitem__, 0, 0, 1) - assert_raises(galsim.GalSimBoundsError, im1.view().setValue, 0, 0, 1) - assert_raises(galsim.GalSimBoundsError, im1.view().__call__, 0, 0) - assert_raises(galsim.GalSimBoundsError, im1.view().__getitem__, 0, 0) - assert_raises(galsim.GalSimBoundsError, im1.view().__setitem__, 0, 0, 1) - - assert_raises(galsim.GalSimBoundsError, im1.setValue, ncol + 1, 0, 1) - assert_raises(galsim.GalSimBoundsError, im1.addValue, ncol + 1, 0, 1) - assert_raises(galsim.GalSimBoundsError, im1.__call__, ncol + 1, 0) - assert_raises(galsim.GalSimBoundsError, im1.view().setValue, ncol + 1, 0, 1) - assert_raises(galsim.GalSimBoundsError, im1.view().__call__, ncol + 1, 0) - - assert_raises(galsim.GalSimBoundsError, im1.setValue, 0, nrow + 1, 1) - assert_raises(galsim.GalSimBoundsError, im1.addValue, 0, nrow + 1, 1) - assert_raises(galsim.GalSimBoundsError, im1.__call__, 0, nrow + 1) - assert_raises(galsim.GalSimBoundsError, im1.view().setValue, 0, nrow + 1, 1) - assert_raises(galsim.GalSimBoundsError, im1.view().__call__, 0, nrow + 1) - - assert_raises(galsim.GalSimBoundsError, im1.setValue, ncol + 1, nrow + 1, 1) - assert_raises(galsim.GalSimBoundsError, im1.addValue, ncol + 1, nrow + 1, 1) - assert_raises(galsim.GalSimBoundsError, im1.__call__, ncol + 1, nrow + 1) - assert_raises( - galsim.GalSimBoundsError, im1.view().setValue, ncol + 1, nrow + 1, 1 - ) - assert_raises(galsim.GalSimBoundsError, im1.view().__call__, ncol + 1, nrow + 1) - - assert_raises( - galsim.GalSimBoundsError, im1.__getitem__, galsim.BoundsI(0, ncol, 1, nrow) - ) - assert_raises( - galsim.GalSimBoundsError, im1.__getitem__, galsim.BoundsI(1, ncol, 0, nrow) - ) - assert_raises( - galsim.GalSimBoundsError, - im1.__getitem__, - galsim.BoundsI(1, ncol + 1, 1, nrow), - ) - assert_raises( - galsim.GalSimBoundsError, - im1.__getitem__, - galsim.BoundsI(1, ncol, 1, nrow + 1), - ) - assert_raises( - galsim.GalSimBoundsError, - im1.__getitem__, - galsim.BoundsI(0, ncol + 1, 0, nrow + 1), - ) - assert_raises( - galsim.GalSimBoundsError, im1.subImage, galsim.BoundsI(0, ncol, 1, nrow) - ) - assert_raises( - galsim.GalSimBoundsError, im1.subImage, galsim.BoundsI(1, ncol, 0, nrow) - ) - assert_raises( - galsim.GalSimBoundsError, im1.subImage, galsim.BoundsI(1, ncol + 1, 1, nrow) - ) - assert_raises( - galsim.GalSimBoundsError, im1.subImage, galsim.BoundsI(1, ncol, 1, nrow + 1) - ) - assert_raises( - galsim.GalSimBoundsError, - im1.subImage, - galsim.BoundsI(0, ncol + 1, 0, nrow + 1), - ) - - assert_raises( - galsim.GalSimBoundsError, - im1.setSubImage, - galsim.BoundsI(0, ncol, 1, nrow), - galsim.Image(ncol + 1, nrow, init_value=10), - ) - assert_raises( - galsim.GalSimBoundsError, - im1.setSubImage, - galsim.BoundsI(1, ncol, 0, nrow), - galsim.Image(ncol + 1, nrow, init_value=10), - ) - assert_raises( - galsim.GalSimBoundsError, - im1.setSubImage, - galsim.BoundsI(1, ncol + 1, 1, nrow), - galsim.Image(ncol + 1, nrow, init_value=10), - ) - assert_raises( - galsim.GalSimBoundsError, - im1.setSubImage, - galsim.BoundsI(1, ncol, 1, nrow + 1), - galsim.Image(ncol + 1, nrow, init_value=10), - ) - assert_raises( - galsim.GalSimBoundsError, - im1.setSubImage, - galsim.BoundsI(0, ncol + 1, 0, nrow + 1), - galsim.Image(ncol + 2, nrow + 2, init_value=10), - ) - assert_raises( - galsim.GalSimBoundsError, - im1.__setitem__, - galsim.BoundsI(0, ncol, 1, nrow), - galsim.Image(ncol + 1, nrow, init_value=10), - ) - assert_raises( - galsim.GalSimBoundsError, - im1.__setitem__, - galsim.BoundsI(1, ncol, 0, nrow), - galsim.Image(ncol + 1, nrow, init_value=10), - ) - assert_raises( - galsim.GalSimBoundsError, - im1.__setitem__, - galsim.BoundsI(1, ncol + 1, 1, nrow), - galsim.Image(ncol + 1, nrow, init_value=10), - ) - assert_raises( - galsim.GalSimBoundsError, - im1.__setitem__, - galsim.BoundsI(1, ncol, 1, nrow + 1), - galsim.Image(ncol + 1, nrow, init_value=10), - ) - assert_raises( - galsim.GalSimBoundsError, - im1.__setitem__, - galsim.BoundsI(0, ncol + 1, 0, nrow + 1), - galsim.Image(ncol + 2, nrow + 2, init_value=10), - ) - - # JAX specific modification - # ------------------------- - # This doesn't raise an error because all Images are similarly immutable - # Also, setting values in something that should be const - # assert_raises(galsim.GalSimImmutableError,im1.view(make_const=True).setValue,1,1,1) - # assert_raises(galsim.GalSimImmutableError,im1.view(make_const=True).real.setValue,1,1,1) - # assert_raises(galsim.GalSimImmutableError,im1.view(make_const=True).imag.setValue,1,1,1) - # if tchar[i][0] != 'C': - # assert_raises(galsim.GalSimImmutableError,im1.imag.setValue,1,1,1) - - # Finally check for the wrong number of arguments in get/setitem - assert_raises(TypeError, im1.__getitem__, 1) - assert_raises(TypeError, im1.__setitem__, 1, 1) - assert_raises(TypeError, im1.__getitem__, 1, 2, 3) - assert_raises(TypeError, im1.__setitem__, 1, 2, 3, 4) - - # Check view of given data - im3_view = galsim.Image(ref_array.astype(np_array_type)) - slice_array = large_array.astype(np_array_type)[::3, ::2] - im4_view = galsim.Image(slice_array) - im5_view = galsim.Image( - ref_array.astype(np_array_type).tolist(), dtype=array_type - ) - im6_view = galsim.Image(ref_array.astype(np_array_type), xmin=4, ymin=7) - im7_view = galsim.Image(ref_array.astype(np_array_type), xmin=0, ymin=0) - im8_view = galsim.Image(ref_array).view(dtype=np_array_type) - for y in range(1, nrow + 1): - for x in range(1, ncol + 1): - value3 = 10 * x + y - assert im3_view(x, y) == value3 - assert im4_view(x, y) == value3 - assert im5_view(x, y) == value3 - assert im6_view(x + 3, y + 6) == value3 - assert im7_view(x - 1, y - 1) == value3 - assert im8_view(x, y) == value3 - - # Check shift ops - im1_view = im1.view() # View with old bounds - dx = 31 - dy = 16 - im1.shift(dx, dy) - im2_view.setOrigin(1 + dx, 1 + dy) - im3_view.setCenter((ncol + 1) / 2 + dx, (nrow + 1) / 2 + dy) - shifted_bounds = galsim.BoundsI(1 + dx, ncol + dx, 1 + dy, nrow + dy) - - assert im1.bounds == shifted_bounds - assert im2_view.bounds == shifted_bounds - assert im3_view.bounds == shifted_bounds - # Others shouldn't have changed - assert im1_view.bounds == bounds - assert im2.bounds == bounds - for y in range(1, nrow + 1): - for x in range(1, ncol + 1): - value3 = 10 * x + y - assert im1(x + dx, y + dy) == value3 - assert im1_view(x, y) == value3 - assert im2(x, y) == value3 - assert im2_view(x + dx, y + dy) == value3 - assert im3_view(x + dx, y + dy) == value3 - - assert_raises(TypeError, im1.shift, dx) - assert_raises(TypeError, im1.shift, dx=dx) - assert_raises(TypeError, im1.shift, x=dx, y=dy) - assert_raises(TypeError, im1.shift, dx, dy=dy) - assert_raises(TypeError, im1.shift, dx, dy, dy) - assert_raises(TypeError, im1.shift, dx, dy, invalid=True) - - # JAX specific modification - # ------------------------- - # We will not be doing pickles - # Check picklability - # check_pickle(im1) - # check_pickle(im1_view) - # check_pickle(im2) - # check_pickle(im2_view) - # check_pickle(im2_cview) - # check_pickle(im3_view) - # check_pickle(im4_view) - - # JAX specific modification - # ------------------------- - # We will not be doing pickles - # Also check picklability of Bounds, Position here. - # check_pickle(galsim.PositionI(2,3)) - # check_pickle(galsim.PositionD(2.2,3.3)) - # check_pickle(galsim.BoundsI(2,3,7,8)) - # check_pickle(galsim.BoundsD(2.1, 4.3, 6.5, 9.1)) - - -@timer -def test_undefined_image(): - """Test various ways to construct an image with undefined bounds""" - for i in range(len(types)): - im1 = galsim.Image(dtype=types[i]) - assert not im1.bounds.isDefined() - assert im1.array.shape == (1, 1) - assert im1 == im1 - - im2 = galsim.Image() - assert not im2.bounds.isDefined() - assert im2.array.shape == (1, 1) - assert im2 == im2 - if types[i] == np.float32: - assert im2 == im1 - - im3 = galsim.Image(array=np.array([[]], dtype=types[i])) - assert not im3.bounds.isDefined() - assert im3.array.shape == (1, 1) - assert im3 == im1 - - im4 = galsim.Image(array=np.array([[]]), dtype=types[i]) - assert not im4.bounds.isDefined() - assert im4.array.shape == (1, 1) - assert im4 == im1 - - im5 = galsim.Image( - array=np.array([[1]]), dtype=types[i], bounds=galsim.BoundsI() - ) - assert not im5.bounds.isDefined() - assert im5.array.shape == (1, 1) - assert im5 == im1 - - im6 = galsim.Image( - array=np.array([[1]], dtype=types[i]), bounds=galsim.BoundsI() - ) - assert not im6.bounds.isDefined() - assert im6.array.shape == (1, 1) - assert im6 == im1 - - im7 = 1.0 * im1 - assert not im7.bounds.isDefined() - assert im7.array.shape == (1, 1) - if types[i] == np.float64: - assert im7 == im1 - - im8 = im1 + 1j * im3 - assert not im8.bounds.isDefined() - assert im8.array.shape == (1, 1) - if types[i] == np.complex128: - assert im8 == im1 - - # JAX specific modification - # ------------------------- - # We don't handle empty images - # im9 = galsim.Image(0, 0) - # assert im9.array.shape == (1,1) - # assert im9 == im1 - # - # im10 = galsim.Image(10, 0) - # assert im10.array.shape == (1,1) - # assert im10 == im1 - # - # im11 = galsim.Image(0, 19) - # assert im11.array.shape == (1,1) - # assert im11 == im1 - - assert_raises(galsim.GalSimUndefinedBoundsError, im1.setValue, 0, 0, 1) - assert_raises(galsim.GalSimUndefinedBoundsError, im1.__call__, 0, 0) - assert_raises(galsim.GalSimUndefinedBoundsError, im1.view().setValue, 0, 0, 1) - assert_raises(galsim.GalSimUndefinedBoundsError, im1.view().__call__, 0, 0) - assert_raises(galsim.GalSimUndefinedBoundsError, im1.view().addValue, 0, 0, 1) - assert_raises(galsim.GalSimUndefinedBoundsError, im1.fill, 3) - assert_raises(galsim.GalSimUndefinedBoundsError, im1.view().fill, 3) - assert_raises(galsim.GalSimUndefinedBoundsError, im1.invertSelf) - - assert_raises( - galsim.GalSimUndefinedBoundsError, - im1.__getitem__, - galsim.BoundsI(1, 2, 1, 2), - ) - assert_raises( - galsim.GalSimUndefinedBoundsError, im1.subImage, galsim.BoundsI(1, 2, 1, 2) - ) - - assert_raises( - galsim.GalSimUndefinedBoundsError, - im1.setSubImage, - galsim.BoundsI(1, 2, 1, 2), - galsim.Image(2, 2, init_value=10), - ) - assert_raises( - galsim.GalSimUndefinedBoundsError, - im1.__setitem__, - galsim.BoundsI(1, 2, 1, 2), - galsim.Image(2, 2, init_value=10), - ) - - im1.scale = 1.0 - assert_raises(galsim.GalSimUndefinedBoundsError, im1.calculate_fft) - assert_raises(galsim.GalSimUndefinedBoundsError, im1.calculate_inverse_fft) - - # JAX specific modification - # ------------------------- - # We will not be doing pickles - # check_pickle(im1.bounds) - # check_pickle(im1) - # check_pickle(im1.view()) - # check_pickle(im1.view(make_const=True)) - - -@timer -def test_Image_FITS_IO(): - """Test that all six FITS reference images are correctly read in by both PyFITS and our Image - wrappers. - """ - for i in range(ntypes): - array_type = types[i] - - if tchar[i][0] == "C": - # Cannot write complex Images to fits. Check for an exception and continue. - ref_image = galsim.Image(ref_array.astype(array_type)) - test_file = os.path.join(datadir, "test" + tchar[i] + ".fits") - with assert_raises(ValueError): - ref_image.write(test_file) - continue - - # - # Test input from a single external FITS image - # - - # Read the reference image to from an externally-generated fits file - test_file = os.path.join(datadir, "test" + tchar[i] + ".fits") - # Check pyfits read for sanity - with pyfits.open(test_file) as fits: - test_array = fits[0].data - np.testing.assert_array_equal( - ref_array.astype(types[i]), - test_array, - err_msg="PyFITS failing to read reference image.", - ) - - # Then use galsim fits.read function - # First version: use pyfits HDUList - with pyfits.open(test_file) as hdu: - test_image = galsim.fits.read(hdu_list=hdu) - np.testing.assert_array_equal( - ref_array.astype(types[i]), - test_image.array, - err_msg="Failed reading from PyFITS PrimaryHDU input.", - ) - - # Second version: use file name - test_image = galsim.fits.read(test_file) - np.testing.assert_array_equal( - ref_array.astype(types[i]), - test_image.array, - err_msg="Image" + tchar[i] + " read failed reading from filename input.", - ) - - # - # Test full I/O on a single internally-generated FITS image - # - - # Write the reference image to a fits file - ref_image = galsim.Image(ref_array.astype(array_type)) - test_file = os.path.join(datadir, "test" + tchar[i] + "_internal.fits") - ref_image.write(test_file) - - # Check pyfits read for sanity - with pyfits.open(test_file) as fits: - test_array = fits[0].data - np.testing.assert_array_equal( - ref_array.astype(types[i]), - test_array, - err_msg="Image" + tchar[i] + " write failed.", - ) - - # Then use galsim fits.read function - # First version: use pyfits HDUList - with pyfits.open(test_file) as hdu: - test_image = galsim.fits.read(hdu_list=hdu) - np.testing.assert_array_equal( - ref_array.astype(types[i]), - test_image.array, - err_msg="Failed reading from PyFITS PrimaryHDU input.", - ) - - # Second version: use file name - test_image = galsim.fits.read(test_file) - np.testing.assert_array_equal( - ref_array.astype(types[i]), - test_image.array, - err_msg="Image" + tchar[i] + " read failed reading from filename input.", - ) - - assert_raises(ValueError, galsim.fits.read, test_file, compression="invalid") - assert_raises(ValueError, ref_image.write, test_file, compression="invalid") - assert_raises(OSError, galsim.fits.read, test_file, compression="rice") - assert_raises(OSError, galsim.fits.read, "invalid.fits") - assert_raises(OSError, galsim.fits.read, "config_input/catalog.fits", hdu=1) - - assert_raises(TypeError, galsim.fits.read) - assert_raises(TypeError, galsim.fits.read, test_file, hdu_list=hdu) - assert_raises(TypeError, ref_image.write) - assert_raises(TypeError, ref_image.write, file_name=test_file, hdu_list=hdu) - - # If clobbert = False, then trying to overwrite will raise an OSError - assert_raises(OSError, ref_image.write, test_file, clobber=False) - - # - # Test various compression schemes - # - - # These tests are a bit slow, so we only bother to run them for the first dtype - # when doing the regular unit tests. When running python test_image.py, all of them - # will run, so when working on the code, it is a good idea to run the tests that way. - if i > 0 and __name__ != "__main__": - continue - - test_file0 = test_file # Save the name of the uncompressed file. - - # Test full-file gzip - test_file = os.path.join(datadir, "test" + tchar[i] + ".fits.gz") - test_image = galsim.fits.read(test_file, compression="gzip") - np.testing.assert_array_equal( - ref_array.astype(types[i]), - test_image.array, - err_msg="Image" + tchar[i] + " read failed for explicit full-file gzip", - ) - - test_image = galsim.fits.read(test_file) - np.testing.assert_array_equal( - ref_array.astype(types[i]), - test_image.array, - err_msg="Image" + tchar[i] + " read failed for auto full-file gzip", - ) - - test_file = os.path.join(datadir, "test" + tchar[i] + "_internal.fits.gz") - ref_image.write(test_file, compression="gzip") - test_image = galsim.fits.read(test_file) - np.testing.assert_array_equal( - ref_array.astype(types[i]), - test_image.array, - err_msg="Image" + tchar[i] + " write failed for explicit full-file gzip", - ) - - ref_image.write(test_file) - test_image = galsim.fits.read(test_file) - np.testing.assert_array_equal( - ref_array.astype(types[i]), - test_image.array, - err_msg="Image" + tchar[i] + " write failed for auto full-file gzip", - ) - - # With compression = None or 'none', astropy automatically figures it out anyway. - test_image = galsim.fits.read(test_file, compression=None) - np.testing.assert_array_equal( - ref_array.astype(types[i]), - test_image.array, - err_msg="Image" + tchar[i] + " write failed for auto full-file gzip", - ) - - assert_raises(OSError, galsim.fits.read, test_file0, compression="gzip") - - # Test full-file bzip2 - test_file = os.path.join(datadir, "test" + tchar[i] + ".fits.bz2") - test_image = galsim.fits.read(test_file, compression="bzip2") - np.testing.assert_array_equal( - ref_array.astype(types[i]), - test_image.array, - err_msg="Image" + tchar[i] + " read failed for explicit full-file bzip2", - ) - - test_image = galsim.fits.read(test_file) - np.testing.assert_array_equal( - ref_array.astype(types[i]), - test_image.array, - err_msg="Image" + tchar[i] + " read failed for auto full-file bzip2", - ) - - test_file = os.path.join(datadir, "test" + tchar[i] + "_internal.fits.bz2") - ref_image.write(test_file, compression="bzip2") - test_image = galsim.fits.read(test_file) - np.testing.assert_array_equal( - ref_array.astype(types[i]), - test_image.array, - err_msg="Image" + tchar[i] + " write failed for explicit full-file bzip2", - ) - - ref_image.write(test_file) - test_image = galsim.fits.read(test_file) - np.testing.assert_array_equal( - ref_array.astype(types[i]), - test_image.array, - err_msg="Image" + tchar[i] + " write failed for auto full-file bzip2", - ) - - # With compression = None or 'none', astropy automatically figures it out anyway. - test_image = galsim.fits.read(test_file, compression=None) - np.testing.assert_array_equal( - ref_array.astype(types[i]), - test_image.array, - err_msg="Image" + tchar[i] + " write failed for auto full-file gzip", - ) - - assert_raises(OSError, galsim.fits.read, test_file0, compression="bzip2") - - # Test rice - test_file = os.path.join(datadir, "test" + tchar[i] + ".fits.fz") - test_image = galsim.fits.read(test_file, compression="rice") - np.testing.assert_array_equal( - ref_array.astype(types[i]), - test_image.array, - err_msg="Image" + tchar[i] + " read failed for explicit rice", - ) - - test_image = galsim.fits.read(test_file) - np.testing.assert_array_equal( - ref_array.astype(types[i]), - test_image.array, - err_msg="Image" + tchar[i] + " read failed for auto rice", - ) - - test_file = os.path.join(datadir, "test" + tchar[i] + "_internal.fits.fz") - ref_image.write(test_file, compression="rice") - test_image = galsim.fits.read(test_file) - np.testing.assert_array_equal( - ref_array.astype(types[i]), - test_image.array, - err_msg="Image" + tchar[i] + " write failed for explicit rice", - ) - - ref_image.write(test_file) - test_image = galsim.fits.read(test_file) - np.testing.assert_array_equal( - ref_array.astype(types[i]), - test_image.array, - err_msg="Image" + tchar[i] + " write failed for auto rice", - ) - - assert_raises(OSError, galsim.fits.read, test_file0, compression="rice") - assert_raises(OSError, galsim.fits.read, test_file, compression="none") - - # Test gzip_tile - test_file = os.path.join(datadir, "test" + tchar[i] + "_internal.fits.gzt") - ref_image.write(test_file, compression="gzip_tile") - test_image = galsim.fits.read(test_file, compression="gzip_tile") - np.testing.assert_array_equal( - ref_array.astype(types[i]), - test_image.array, - err_msg="Image" + tchar[i] + " write failed for gzip_tile", - ) - - assert_raises(OSError, galsim.fits.read, test_file0, compression="gzip_tile") - assert_raises(OSError, galsim.fits.read, test_file, compression="none") - - # Test hcompress - # Note: hcompress is a lossy algorithm, and starting with astropy 2.0.5, - # the fidelity of the reconstruction is really quite poor, so only test with - # rtol=0.1. I'm not sure if this is a bug in astropy, or just the nature - # of the hcompress algorithm. But I'm ignoring it for now, since I don't - # think too many people use hcompress anyway. - test_file = os.path.join(datadir, "test" + tchar[i] + "_internal.fits.hc") - ref_image.write(test_file, compression="hcompress") - test_image = galsim.fits.read(test_file, compression="hcompress") - np.testing.assert_allclose( - ref_array.astype(types[i]), - test_image.array, - rtol=0.1, - err_msg="Image" + tchar[i] + " write failed for hcompress", - ) - - assert_raises(OSError, galsim.fits.read, test_file0, compression="hcompress") - assert_raises(OSError, galsim.fits.read, test_file, compression="none") - - # Test plio (only valid on positive integer values) - if tchar[i] in ["S", "I"]: - test_file = os.path.join(datadir, "test" + tchar[i] + "_internal.fits.plio") - ref_image.write(test_file, compression="plio") - test_image = galsim.fits.read(test_file, compression="plio") - np.testing.assert_array_equal( - ref_array.astype(types[i]), - test_image.array, - err_msg="Image" + tchar[i] + " write failed for plio", - ) - - assert_raises(OSError, galsim.fits.read, test_file0, compression="plio") - assert_raises(OSError, galsim.fits.read, test_file, compression="none") - - # Check a file with no WCS information - nowcs_file = os.path.join( - os.path.dirname(__file__), "../../GalSim/tests", "fits_files", "blankimg.fits" - ) - im = galsim.fits.read(nowcs_file) - assert im.wcs == galsim.PixelScale(1.0) - - # If desired, can get a warning about this - with assert_warns(galsim.GalSimWarning): - im = galsim.fits.read(nowcs_file, suppress_warning=False) - assert im.wcs == galsim.PixelScale(1.0) - - -@timer -def test_Image_MultiFITS_IO(): - """Test that all six FITS reference images are correctly read in by both PyFITS and our Image - wrappers. - """ - for i in range(ntypes): - array_type = types[i] - - if tchar[i][0] == "C": - # Cannot write complex Images to fits. Check for an exception and continue. - ref_image = galsim.Image(ref_array.astype(array_type)) - image_list = [] - for k in range(nimages): - image_list.append(ref_image + k) - test_multi_file = os.path.join(datadir, "test_multi" + tchar[i] + ".fits") - with assert_raises(ValueError): - galsim.fits.writeMulti(image_list, test_multi_file) - continue - - # - # Test input from an external multi-extension fits file - # - - test_multi_file = os.path.join(datadir, "test_multi" + tchar[i] + ".fits") - # Check pyfits read for sanity - with pyfits.open(test_multi_file) as fits: - test_array = fits[0].data - np.testing.assert_array_equal( - ref_array.astype(types[i]), - test_array, - err_msg="PyFITS failing to read multi file.", - ) - - # Then use galsim fits.readMulti function - # First version: use pyfits HDUList - with pyfits.open(test_multi_file) as hdu: - test_image_list = galsim.fits.readMulti(hdu_list=hdu) - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - err_msg="Failed reading from PyFITS PrimaryHDU input.", - ) - - # Second version: use file name - test_image_list = galsim.fits.readMulti(test_multi_file) - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - err_msg="Image" - + tchar[i] - + " readMulti failed reading from filename input.", - ) - - # - # Test full I/O for an internally-generated multi-extension fits file - # - - # Build a list of images with different values - ref_image = galsim.Image(ref_array.astype(array_type)) - image_list = [] - for k in range(nimages): - image_list.append(ref_image + k) - - # Write the list to a multi-extension fits file - test_multi_file = os.path.join( - datadir, "test_multi" + tchar[i] + "_internal.fits" - ) - galsim.fits.writeMulti(image_list, test_multi_file) - - # Check pyfits read for sanity - with pyfits.open(test_multi_file) as fits: - test_array = fits[0].data - np.testing.assert_array_equal( - ref_array.astype(types[i]), - test_array, - err_msg="PyFITS failing to read multi file.", - ) - - # Then use galsim fits.readMulti function - # First version: use pyfits HDUList - with pyfits.open(test_multi_file) as hdu: - test_image_list = galsim.fits.readMulti(hdu_list=hdu) - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - err_msg="Failed reading from PyFITS PrimaryHDU input.", - ) - - # Second version: use file name - test_image_list = galsim.fits.readMulti(test_multi_file) - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - err_msg="Image" - + tchar[i] - + " readMulti failed reading from filename input.", - ) - - # - # Test writing to hdu_list directly and then writing to file. - # - - # Start with empty hdu_list - hdu_list = pyfits.HDUList() - - # Add each image one at a time - for k in range(nimages): - image = ref_image + k - galsim.fits.write(image, hdu_list=hdu_list) - - # Write it out with writeFile - galsim.fits.writeFile(test_multi_file, hdu_list) - - # Check that reading it back in gives the same values - test_image_list = galsim.fits.readMulti(test_multi_file) - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - err_msg="Image" + tchar[i] + " readMulti failed after using writeFile", - ) - - # Can also use writeMulti to write directly to an hdu list - hdu_list = pyfits.HDUList() - galsim.fits.writeMulti(image_list, hdu_list=hdu_list) - galsim.fits.writeFile(test_multi_file, hdu_list) - test_image_list = galsim.fits.readMulti(test_multi_file) - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - err_msg="Image" + tchar[i] + " readMulti failed after using writeFile", - ) - - assert_raises( - ValueError, galsim.fits.readMulti, test_multi_file, compression="invalid" - ) - assert_raises( - ValueError, - galsim.fits.writeMulti, - image_list, - test_multi_file, - compression="invalid", - ) - assert_raises( - ValueError, - galsim.fits.writeFile, - image_list, - test_multi_file, - compression="invalid", - ) - assert_raises( - OSError, galsim.fits.readMulti, test_multi_file, compression="rice" - ) - assert_raises( - OSError, galsim.fits.readFile, test_multi_file, compression="rice" - ) - assert_raises(OSError, galsim.fits.readMulti, hdu_list=pyfits.HDUList()) - assert_raises( - OSError, - galsim.fits.readMulti, - hdu_list=pyfits.HDUList(), - compression="rice", - ) - assert_raises(OSError, galsim.fits.readMulti, "invalid.fits") - assert_raises(OSError, galsim.fits.readFile, "invalid.fits") - - assert_raises(TypeError, galsim.fits.readMulti) - assert_raises(TypeError, galsim.fits.readMulti, test_multi_file, hdu_list=hdu) - assert_raises(TypeError, galsim.fits.readMulti, hdu_list=test_multi_file) - assert_raises(TypeError, galsim.fits.writeMulti) - assert_raises(TypeError, galsim.fits.writeMulti, image_list) - assert_raises( - TypeError, - galsim.fits.writeMulti, - image_list, - file_name=test_multi_file, - hdu_list=hdu, - ) - - assert_raises( - OSError, galsim.fits.writeMulti, image_list, test_multi_file, clobber=False - ) - - assert_raises(TypeError, galsim.fits.writeFile) - assert_raises(TypeError, galsim.fits.writeFile, image_list) - assert_raises( - ValueError, - galsim.fits.writeFile, - test_multi_file, - image_list, - compression="invalid", - ) - assert_raises( - ValueError, - galsim.fits.writeFile, - test_multi_file, - image_list, - compression="rice", - ) - assert_raises( - ValueError, - galsim.fits.writeFile, - test_multi_file, - image_list, - compression="gzip_tile", - ) - assert_raises( - ValueError, - galsim.fits.writeFile, - test_multi_file, - image_list, - compression="hcompress", - ) - assert_raises( - ValueError, - galsim.fits.writeFile, - test_multi_file, - image_list, - compression="plio", - ) - - galsim.fits.writeFile(test_multi_file, hdu_list) - assert_raises( - OSError, galsim.fits.writeFile, test_multi_file, image_list, clobber=False - ) - - # - # Test various compression schemes - # - - # These tests are a bit slow, so we only bother to run them for the first dtype - # when doing the regular unit tests. When running python test_image.py, all of them - # will run, so when working on the code, it is a good idea to run the tests that way. - if i > 0 and __name__ != "__main__": - continue - - test_multi_file0 = test_multi_file - - # Test full-file gzip - test_multi_file = os.path.join(datadir, "test_multi" + tchar[i] + ".fits.gz") - test_image_list = galsim.fits.readMulti(test_multi_file, compression="gzip") - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - err_msg="Image" - + tchar[i] - + " readMulti failed for explicit full-file gzip", - ) - - test_image_list = galsim.fits.readMulti(test_multi_file) - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - err_msg="Image" - + tchar[i] - + " readMulti failed for auto full-file gzip", - ) - - test_multi_file = os.path.join( - datadir, "test_multi" + tchar[i] + "_internal.fits.gz" - ) - galsim.fits.writeMulti(image_list, test_multi_file, compression="gzip") - test_image_list = galsim.fits.readMulti(test_multi_file) - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - err_msg="Image" - + tchar[i] - + " writeMulti failed for explicit full-file gzip", - ) - - galsim.fits.writeMulti(image_list, test_multi_file) - test_image_list = galsim.fits.readMulti(test_multi_file) - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - err_msg="Image" - + tchar[i] - + " writeMulti failed for auto full-file gzip", - ) - - # With compression = None or 'none', astropy automatically figures it out anyway. - test_image_list = galsim.fits.readMulti(test_multi_file, compression=None) - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - err_msg="Image" - + tchar[i] - + " writeMulti failed for auto full-file gzip", - ) - - assert_raises( - OSError, galsim.fits.readMulti, test_multi_file0, compression="gzip" - ) - - # Test full-file bzip2 - test_multi_file = os.path.join(datadir, "test_multi" + tchar[i] + ".fits.bz2") - test_image_list = galsim.fits.readMulti(test_multi_file, compression="bzip2") - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - err_msg="Image" - + tchar[i] - + " readMulti failed for explicit full-file bzip2", - ) - - test_image_list = galsim.fits.readMulti(test_multi_file) - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - err_msg="Image" - + tchar[i] - + " readMulti failed for auto full-file bzip2", - ) - - test_multi_file = os.path.join( - datadir, "test_multi" + tchar[i] + "_internal.fits.bz2" - ) - galsim.fits.writeMulti(image_list, test_multi_file, compression="bzip2") - test_image_list = galsim.fits.readMulti(test_multi_file) - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - err_msg="Image" - + tchar[i] - + " writeMulti failed for explicit full-file bzip2", - ) - - galsim.fits.writeMulti(image_list, test_multi_file) - test_image_list = galsim.fits.readMulti(test_multi_file) - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - err_msg="Image" - + tchar[i] - + " writeMulti failed for auto full-file bzip2", - ) - - # With compression = None or 'none', astropy automatically figures it out anyway. - test_image_list = galsim.fits.readMulti(test_multi_file, compression=None) - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - err_msg="Image" - + tchar[i] - + " writeMulti failed for auto full-file gzip", - ) - - assert_raises( - OSError, galsim.fits.readMulti, test_multi_file0, compression="bzip2" - ) - - # Test rice - test_multi_file = os.path.join(datadir, "test_multi" + tchar[i] + ".fits.fz") - test_image_list = galsim.fits.readMulti(test_multi_file, compression="rice") - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - err_msg="Image" + tchar[i] + " readMulti failed for explicit rice", - ) - - test_image_list = galsim.fits.readMulti(test_multi_file) - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - err_msg="Image" + tchar[i] + " readMulti failed for auto rice", - ) - - test_multi_file = os.path.join( - datadir, "test_multi" + tchar[i] + "_internal.fits.fz" - ) - galsim.fits.writeMulti(image_list, test_multi_file, compression="rice") - test_image_list = galsim.fits.readMulti(test_multi_file) - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - err_msg="Image" + tchar[i] + " writeMulti failed for explicit rice", - ) - - galsim.fits.writeMulti(image_list, test_multi_file) - test_image_list = galsim.fits.readMulti(test_multi_file) - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - err_msg="Image" + tchar[i] + " writeMulti failed for auto rice", - ) - - assert_raises( - OSError, galsim.fits.readMulti, test_multi_file0, compression="rice" - ) - assert_raises( - OSError, galsim.fits.readMulti, test_multi_file, compression="none" - ) - - # Test gzip_tile - test_multi_file = os.path.join( - datadir, "test_multi" + tchar[i] + "_internal.fits.gzt" - ) - galsim.fits.writeMulti(image_list, test_multi_file, compression="gzip_tile") - test_image_list = galsim.fits.readMulti( - test_multi_file, compression="gzip_tile" - ) - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - err_msg="Image" + tchar[i] + " writeMulti failed for gzip_tile", - ) - - assert_raises( - OSError, galsim.fits.readMulti, test_multi_file0, compression="gzip_tile" - ) - assert_raises( - OSError, galsim.fits.readMulti, test_multi_file, compression="none" - ) - - # Test hcompress - test_multi_file = os.path.join( - datadir, "test_multi" + tchar[i] + "_internal.fits.hc" - ) - galsim.fits.writeMulti(image_list, test_multi_file, compression="hcompress") - test_image_list = galsim.fits.readMulti( - test_multi_file, compression="hcompress" - ) - for k in range(nimages): - np.testing.assert_allclose( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - rtol=0.1, - err_msg="Image" + tchar[i] + " writeMulti failed for hcompress", - ) - - assert_raises( - OSError, galsim.fits.readMulti, test_multi_file0, compression="hcompress" - ) - assert_raises( - OSError, galsim.fits.readMulti, test_multi_file, compression="none" - ) - - # Test plio (only valid on positive integer values) - if tchar[i] in ["S", "I"]: - test_multi_file = os.path.join( - datadir, "test_multi" + tchar[i] + "_internal.fits.plio" - ) - galsim.fits.writeMulti(image_list, test_multi_file, compression="plio") - test_image_list = galsim.fits.readMulti(test_multi_file, compression="plio") - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - err_msg="Image" + tchar[i] + " writeMulti failed for plio", - ) - - assert_raises( - OSError, galsim.fits.readMulti, test_multi_file0, compression="plio" - ) - assert_raises( - OSError, galsim.fits.readMulti, test_multi_file, compression="none" - ) - - # Check a file with no WCS information - nowcs_file = os.path.join( - os.path.dirname(__file__), "../../GalSim/tests", "fits_files", "blankimg.fits" - ) - ims = galsim.fits.readMulti(nowcs_file) - assert ims[0].wcs == galsim.PixelScale(1.0) - - # If desired, can get a warning about this - with assert_warns(galsim.GalSimWarning): - ims = galsim.fits.readMulti(nowcs_file, suppress_warning=False) - assert ims[0].wcs == galsim.PixelScale(1.0) - - -@timer -def test_Image_CubeFITS_IO(): - """Test that all six FITS reference images are correctly read in by both PyFITS and our Image - wrappers. - """ - for i in range(ntypes): - array_type = types[i] - - if tchar[i][0] == "C": - # Cannot write complex Images to fits. Check for an exception and continue. - ref_image = galsim.Image(ref_array.astype(array_type)) - image_list = [] - for k in range(nimages): - image_list.append(ref_image + k) - test_cube_file = os.path.join(datadir, "test_cube" + tchar[i] + ".fits") - with assert_raises(ValueError): - galsim.fits.writeCube(image_list, test_cube_file) - array_list = [im.array for im in image_list] - with assert_raises(ValueError): - galsim.fits.writeCube(array_list, test_cube_file) - one_array = np.asarray(array_list) - with assert_raises(ValueError): - galsim.fits.writeCube(one_array, test_cube_file) - continue - - # - # Test input from an external fits data cube - # - - test_cube_file = os.path.join(datadir, "test_cube" + tchar[i] + ".fits") - # Check pyfits read for sanity - with pyfits.open(test_cube_file) as fits: - test_array = fits[0].data - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_array[k, :, :], - err_msg="PyFITS failing to read cube file.", - ) - - # Then use galsim fits.readCube function - # First version: use pyfits HDUList - with pyfits.open(test_cube_file) as hdu: - test_image_list = galsim.fits.readCube(hdu_list=hdu) - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - err_msg="Failed reading from PyFITS PrimaryHDU input.", - ) - - # Second version: use file name - test_image_list = galsim.fits.readCube(test_cube_file) - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - err_msg="Image" - + tchar[i] - + " readCube failed reading from filename input.", - ) - - # - # Test full I/O for an internally-generated fits data cube - # - - # Build a list of images with different values - ref_image = galsim.Image(ref_array.astype(array_type)) - image_list = [] - for k in range(nimages): - image_list.append(ref_image + k) - - # Write the list to a fits data cube - test_cube_file = os.path.join( - datadir, "test_cube" + tchar[i] + "_internal.fits" - ) - galsim.fits.writeCube(image_list, test_cube_file) - - # Check pyfits read for sanity - with pyfits.open(test_cube_file) as fits: - test_array = fits[0].data - - wrong_type_error_msg = "%s != %s" % (test_array.dtype.type, types[i]) - assert test_array.dtype.type == types[i], wrong_type_error_msg - - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_array[k, :, :], - err_msg="PyFITS failing to read cube file.", - ) - - # Then use galsim fits.readCube function - # First version: use pyfits HDUList - with pyfits.open(test_cube_file) as hdu: - test_image_list = galsim.fits.readCube(hdu_list=hdu) - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - err_msg="Failed reading from PyFITS PrimaryHDU input.", - ) - - # Second version: use file name - test_image_list = galsim.fits.readCube(test_cube_file) - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - err_msg="Image" - + tchar[i] - + " readCube failed reading from filename input.", - ) - - # - # Test writeCube with arrays, rather than images. - # - - array_list = [im.array for im in image_list] - galsim.fits.writeCube(array_list, test_cube_file) - test_image_list = galsim.fits.readCube(test_cube_file) - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - err_msg="Image" - + tchar[i] - + " write/readCube failed with list of numpy arrays.", - ) - - one_array = np.asarray(array_list) - galsim.fits.writeCube(one_array, test_cube_file) - test_image_list = galsim.fits.readCube(test_cube_file) - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - err_msg="Image" - + tchar[i] - + " write/readCube failed with single 3D numpy array.", - ) - - # - # Test writing to hdu_list directly and then writing to file. - # - - # Start with empty hdu_list - hdu_list = pyfits.HDUList() - - # Write the images to the hdu_list - galsim.fits.writeCube(image_list, hdu_list=hdu_list) - - # Write it out with writeFile - galsim.fits.writeFile(test_cube_file, hdu_list) - - # Check that reading it back in gives the same values - test_image_list = galsim.fits.readCube(test_cube_file) - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - err_msg="Image" + tchar[i] + " readCube failed after using writeFile", - ) - - assert_raises( - ValueError, galsim.fits.readCube, test_cube_file, compression="invalid" - ) - assert_raises( - ValueError, - galsim.fits.writeCube, - image_list, - test_cube_file, - compression="invalid", - ) - assert_raises( - ValueError, - galsim.fits.writeFile, - image_list, - test_cube_file, - compression="invalid", - ) - assert_raises(OSError, galsim.fits.readCube, test_cube_file, compression="rice") - assert_raises(OSError, galsim.fits.readCube, "invalid.fits") - - assert_raises(TypeError, galsim.fits.readCube) - assert_raises(TypeError, galsim.fits.readCube, test_cube_file, hdu_list=hdu) - assert_raises(TypeError, galsim.fits.readCube, hdu_list=test_cube_file) - assert_raises(TypeError, galsim.fits.writeCube) - assert_raises(TypeError, galsim.fits.writeCube, image_list) - assert_raises( - TypeError, - galsim.fits.writeCube, - image_list, - file_name=test_cube_file, - hdu_list=hdu_list, - ) - - assert_raises( - OSError, galsim.fits.writeCube, image_list, test_cube_file, clobber=False - ) - - assert_raises(ValueError, galsim.fits.writeCube, image_list[:0], test_cube_file) - assert_raises( - ValueError, - galsim.fits.writeCube, - [image_list[0], image_list[1].subImage(galsim.BoundsI(1, 4, 1, 4))], - test_cube_file, - ) - - # - # Test various compression schemes - # - - # These tests are a bit slow, so we only bother to run them for the first dtype - # when doing the regular unit tests. When running python test_image.py, all of them - # will run, so when working on the code, it is a good idea to run the tests that way. - if i > 0 and __name__ != "__main__": - continue - - test_cube_file0 = test_cube_file - - # Test full-file gzip - test_cube_file = os.path.join(datadir, "test_cube" + tchar[i] + ".fits.gz") - test_image_list = galsim.fits.readCube(test_cube_file, compression="gzip") - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - err_msg="Image" - + tchar[i] - + " readCube failed for explicit full-file gzip", - ) - - test_image_list = galsim.fits.readCube(test_cube_file) - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - err_msg="Image" + tchar[i] + " readCube failed for auto full-file gzip", - ) - - test_cube_file = os.path.join( - datadir, "test_cube" + tchar[i] + "_internal.fits.gz" - ) - galsim.fits.writeCube(image_list, test_cube_file, compression="gzip") - test_image_list = galsim.fits.readCube(test_cube_file) - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - err_msg="Image" - + tchar[i] - + " writeCube failed for explicit full-file gzip", - ) - - galsim.fits.writeCube(image_list, test_cube_file) - test_image_list = galsim.fits.readCube(test_cube_file) - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - err_msg="Image" - + tchar[i] - + " writeCube failed for auto full-file gzip", - ) - - # With compression = None or 'none', astropy automatically figures it out anyway. - test_image_list = galsim.fits.readCube(test_cube_file, compression=None) - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - err_msg="Image" - + tchar[i] - + " writeCube failed for auto full-file gzip", - ) - - assert_raises( - OSError, galsim.fits.readCube, test_cube_file0, compression="gzip" - ) - - # Test full-file bzip2 - test_cube_file = os.path.join(datadir, "test_cube" + tchar[i] + ".fits.bz2") - test_image_list = galsim.fits.readCube(test_cube_file, compression="bzip2") - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - err_msg="Image" - + tchar[i] - + " readCube failed for explicit full-file bzip2", - ) - - test_image_list = galsim.fits.readCube(test_cube_file) - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - err_msg="Image" - + tchar[i] - + " readCube failed for auto full-file bzip2", - ) - - test_cube_file = os.path.join( - datadir, "test_cube" + tchar[i] + "_internal.fits.bz2" - ) - galsim.fits.writeCube(image_list, test_cube_file, compression="bzip2") - test_image_list = galsim.fits.readCube(test_cube_file) - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - err_msg="Image" - + tchar[i] - + " writeCube failed for explicit full-file bzip2", - ) - - galsim.fits.writeCube(image_list, test_cube_file) - test_image_list = galsim.fits.readCube(test_cube_file) - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - err_msg="Image" - + tchar[i] - + " writeCube failed for auto full-file bzip2", - ) - - # With compression = None or 'none', astropy automatically figures it out anyway. - test_image_list = galsim.fits.readCube(test_cube_file, compression=None) - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - err_msg="Image" - + tchar[i] - + " writeCube failed for auto full-file gzip", - ) - - assert_raises( - OSError, galsim.fits.readCube, test_cube_file0, compression="bzip2" - ) - - # Test rice - test_cube_file = os.path.join(datadir, "test_cube" + tchar[i] + ".fits.fz") - test_image_list = galsim.fits.readCube(test_cube_file, compression="rice") - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - err_msg="Image" + tchar[i] + " readCube failed for explicit rice", - ) - - test_image_list = galsim.fits.readCube(test_cube_file) - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - err_msg="Image" + tchar[i] + " readCube failed for auto rice", - ) - - test_cube_file = os.path.join( - datadir, "test_cube" + tchar[i] + "_internal.fits.fz" - ) - galsim.fits.writeCube(image_list, test_cube_file, compression="rice") - test_image_list = galsim.fits.readCube(test_cube_file) - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - err_msg="Image" + tchar[i] + " writeCube failed for explicit rice", - ) - - galsim.fits.writeCube(image_list, test_cube_file) - test_image_list = galsim.fits.readCube(test_cube_file) - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - err_msg="Image" + tchar[i] + " writeCube failed for auto rice", - ) - - assert_raises( - OSError, galsim.fits.readCube, test_cube_file0, compression="rice" - ) - assert_raises(OSError, galsim.fits.readCube, test_cube_file, compression="none") - - # Test gzip_tile - test_cube_file = os.path.join( - datadir, "test_cube" + tchar[i] + "_internal.fits.gzt" - ) - galsim.fits.writeCube(image_list, test_cube_file, compression="gzip_tile") - test_image_list = galsim.fits.readCube(test_cube_file, compression="gzip_tile") - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - err_msg="Image" + tchar[i] + " writeCube failed for gzip_tile", - ) - - assert_raises( - OSError, galsim.fits.readCube, test_cube_file0, compression="gzip_tile" - ) - assert_raises(OSError, galsim.fits.readCube, test_cube_file, compression="none") - - # Test hcompress - test_cube_file = os.path.join( - datadir, "test_cube" + tchar[i] + "_internal.fits.hc" - ) - galsim.fits.writeCube(image_list, test_cube_file, compression="hcompress") - test_image_list = galsim.fits.readCube(test_cube_file, compression="hcompress") - for k in range(nimages): - np.testing.assert_allclose( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - rtol=0.1, - err_msg="Image" + tchar[i] + " writeCube failed for hcompress", - ) - - assert_raises( - OSError, galsim.fits.readCube, test_cube_file0, compression="hcompress" - ) - assert_raises(OSError, galsim.fits.readCube, test_cube_file, compression="none") - - # Test plio (only valid on positive integer values) - if tchar[i] in ["S", "I"]: - test_cube_file = os.path.join( - datadir, "test_cube" + tchar[i] + "_internal.fits.plio" - ) - galsim.fits.writeCube(image_list, test_cube_file, compression="plio") - test_image_list = galsim.fits.readCube(test_cube_file, compression="plio") - for k in range(nimages): - np.testing.assert_array_equal( - (ref_array + k).astype(types[i]), - test_image_list[k].array, - err_msg="Image" + tchar[i] + " writeCube failed for plio", - ) - - assert_raises( - OSError, galsim.fits.readCube, test_cube_file0, compression="plio" - ) - assert_raises(OSError, galsim.fits.readCube, test_cube_file, compression="none") - - # Check a file with no WCS information - nowcs_file = os.path.join( - os.path.dirname(__file__), "../../GalSim/tests", "fits_files", "blankimg.fits" - ) - ims = galsim.fits.readCube(nowcs_file) - assert ims[0].wcs == galsim.PixelScale(1.0) - - # If desired, can get a warning about this - with assert_warns(galsim.GalSimWarning): - ims = galsim.fits.readCube(nowcs_file, suppress_warning=False) - assert ims[0].wcs == galsim.PixelScale(1.0) - - -@timer -def test_Image_array_view(): - """Test that all six types of supported Images correctly provide a view on an input array.""" - for i in range(ntypes): - # First try using the dictionary-type Image init - image = galsim.Image(ref_array.astype(types[i])) - np.testing.assert_array_equal( - ref_array.astype(types[i]), - image.array, - err_msg="Array look into Image class does not match input for dtype = " - + str(types[i]), - ) - - # Then try using the eval command to mimic use via ImageD, ImageF etc. - image_init_func = eval("galsim.Image" + tchar[i]) - image = image_init_func(ref_array.astype(types[i])) - np.testing.assert_array_equal( - ref_array.astype(types[i]), - image.array, - err_msg="Array look into Image class does not match input for dtype = " - + str(types[i]), - ) - - -@timer -def test_Image_binary_add(): - """Test that all six types of supported Images add correctly.""" - for i in range(ntypes): - # First try using the dictionary-type Image init - image1 = galsim.Image(ref_array.astype(types[i])) - image2 = galsim.Image((2 * ref_array).astype(types[i])) - image3 = image1 + image2 - np.testing.assert_array_equal( - (3 * ref_array).astype(types[i]), - image3.array, - err_msg="Binary add in Image class does not match reference for dtype = " - + str(types[i]), - ) - - # Then try using the eval command to mimic use via ImageD, ImageF etc. - image_init_func = eval("galsim.Image" + tchar[i]) - slice_array = large_array.astype(types[i])[::3, ::2] - image1 = image_init_func(slice_array) - image2 = image_init_func((2 * ref_array).astype(types[i])) - image3 = image1 + image2 - np.testing.assert_array_equal( - (3 * ref_array).astype(types[i]), - image3.array, - err_msg="Binary add in Image class does not match reference for dtype = " - + str(types[i]), - ) - - for j in range(ntypes): - image2_init_func = eval("galsim.Image" + tchar[j]) - image2 = image2_init_func((2 * ref_array).astype(types[j])) - image3 = image1 + image2 - type3 = image3.array.dtype.type - np.testing.assert_array_equal( - (3 * ref_array).astype(type3), - image3.array, - err_msg="Inplace add in Image class does not match reference for dtypes = " - + str(types[i]) - + " and " - + str(types[j]), - ) - - # Check for exceptions if we try to do this operation for images without matching - # shape. Note that this test is only included here (not in the unit tests for all - # other operations) because all operations have the same error-checking code, so it should - # only be necessary to check once. - with assert_raises(ValueError): - image1 + image1.subImage(galsim.BoundsI(1, 3, 1, 3)) - - -@timer -def test_Image_binary_subtract(): - """Test that all six types of supported Images subtract correctly.""" - for i in range(ntypes): - # First try using the dictionary-type Image init - image1 = galsim.Image(ref_array.astype(types[i])) - image2 = galsim.Image((2 * ref_array).astype(types[i])) - image3 = image2 - image1 - np.testing.assert_array_equal( - ref_array.astype(types[i]), - image3.array, - err_msg="Binary subtract in Image class (dictionary call) does" - + " not match reference for dtype = " - + str(types[i]), - ) - - # Then try using the eval command to mimic use via ImageD, ImageF etc. - image_init_func = eval("galsim.Image" + tchar[i]) - slice_array = large_array.astype(types[i])[::3, ::2] - image1 = image_init_func(slice_array) - image2 = image_init_func((2 * ref_array).astype(types[i])) - image3 = image2 - image1 - np.testing.assert_array_equal( - ref_array.astype(types[i]), - image3.array, - err_msg="Binary subtract in Image class does not match reference for dtype = " - + str(types[i]), - ) - - for j in range(ntypes): - image2_init_func = eval("galsim.Image" + tchar[j]) - image2 = image2_init_func((2 * ref_array).astype(types[j])) - image3 = image2 - image1 - type3 = image3.array.dtype.type - np.testing.assert_array_equal( - ref_array.astype(type3), - image3.array, - err_msg="Inplace add in Image class does not match reference for dtypes = " - + str(types[i]) - + " and " - + str(types[j]), - ) - - with assert_raises(ValueError): - image1 - image1.subImage(galsim.BoundsI(0, 4, 0, 4)) - - -@timer -def test_Image_binary_multiply(): - """Test that all six types of supported Images multiply correctly.""" - for i in range(ntypes): - # First try using the dictionary-type Image init - image1 = galsim.Image(ref_array.astype(types[i])) - image2 = galsim.Image((2 * ref_array).astype(types[i])) - image3 = image1 * image2 - np.testing.assert_array_equal( - (2 * ref_array**2).astype(types[i]), - image3.array, - err_msg="Binary multiply in Image class (dictionary call) does" - + " not match reference for dtype = " - + str(types[i]), - ) - - # Then try using the eval command to mimic use via ImageD, ImageF etc. - image_init_func = eval("galsim.Image" + tchar[i]) - slice_array = large_array.astype(types[i])[::3, ::2] - image1 = image_init_func(slice_array) - image2 = image_init_func((2 * ref_array).astype(types[i])) - image3 = image1 * image2 - np.testing.assert_array_equal( - (2 * ref_array**2).astype(types[i]), - image3.array, - err_msg="Binary multiply in Image class does not match reference for dtype = " - + str(types[i]), - ) - - # Check unary - - image1 = galsim.Image(ref_array.astype(types[i])) - image3 = -image1 - np.testing.assert_array_equal( - image3.array, - (-1 * ref_array).astype(types[i]), - err_msg="Unary - in Image class (dictionary call) does" - + " not match reference for dtype = " - + str(types[i]), - ) - - for j in range(ntypes): - image2_init_func = eval("galsim.Image" + tchar[j]) - image2 = image2_init_func((2 * ref_array).astype(types[j])) - image3 = image2 * image1 - type3 = image3.array.dtype.type - np.testing.assert_array_equal( - (2 * ref_array**2).astype(type3), - image3.array, - err_msg="Inplace add in Image class does not match reference for dtypes = " - + str(types[i]) - + " and " - + str(types[j]), - ) - - with assert_raises(ValueError): - image1 * image1.subImage(galsim.BoundsI(0, 4, 0, 4)) - - -@timer -def test_Image_binary_divide(): - """Test that all six types of supported Images divide correctly.""" - # Note: tests here are not precisely equal, since division can have rounding errors for - # some elements. In particular when dividing by complex, where there is a bit more to the - # generic calculation (even though the imaginary parts are zero here). - # So check that they are *almost* equal to 12 digits of precision (or 4 for complex64). - for i in range(ntypes): - # JAX specific modification - # ------------------------- - # Decimals adjusted for float32 because computation on gpu is different than cpu - decimal = 4 if (types[i] == np.complex64 or types[i] == np.float32) else 12 - # First try using the dictionary-type Image init - # Note that I am using refarray + 1 to avoid divide-by-zero. - image1 = galsim.Image((ref_array + 1).astype(types[i])) - image2 = galsim.Image((3 * (ref_array + 1) ** 2).astype(types[i])) - image3 = image2 / image1 - np.testing.assert_almost_equal( - (3 * (ref_array + 1)).astype(types[i]), - image3.array, - decimal=decimal, - err_msg="Binary divide in Image class (dictionary call) does" - + " not match reference for dtype = " - + str(types[i]), - ) - - # Then try using the eval command to mimic use via ImageD, ImageF etc. - image_init_func = eval("galsim.Image" + tchar[i]) - slice_array = (large_array + 1).astype(types[i])[::3, ::2] - image1 = image_init_func(slice_array) - image2 = image_init_func((3 * (ref_array + 1) ** 2).astype(types[i])) - image3 = image2 / image1 - np.testing.assert_almost_equal( - (3 * (ref_array + 1)).astype(types[i]), - image3.array, - decimal=decimal, - err_msg="Binary divide in Image class does not match reference for dtype = " - + str(types[i]), - ) - - for j in range(ntypes): - # JAX specific modification - # ------------------------- - # Decimals adjusted for float32 because computation on gpu is different than cpu - decimal = ( - 4 - if ( - types[i] == np.complex64 - or types[i] == np.float32 - or types[j] == np.complex64 - or types[j] == np.float32 - ) - else 12 - ) - image2_init_func = eval("galsim.Image" + tchar[j]) - image2 = image2_init_func((3 * (ref_array + 1) ** 2).astype(types[j])) - image3 = image2 / image1 - type3 = image3.array.dtype.type - np.testing.assert_almost_equal( - (3 * (ref_array + 1)).astype(type3), - image3.array, - decimal=decimal, - err_msg="Inplace divide in Image class does not match reference for dtypes = " - + str(types[i]) - + " and " - + str(types[j]), - ) - - with assert_raises(ValueError): - image1 / image1.subImage(galsim.BoundsI(0, 4, 0, 4)) - - -@timer -def test_Image_binary_scalar_add(): - """Test that all six types of supported Images add scalars correctly.""" - for i in range(ntypes): - # First try using the dictionary-type Image init - image1 = galsim.Image(ref_array.astype(types[i])) - image2 = image1 + 3 - np.testing.assert_array_equal( - (ref_array + 3).astype(types[i]), - image2.array, - err_msg="Binary add scalar in Image class (dictionary call) does" - + " not match reference for dtype = " - + str(types[i]), - ) - image2 = 3 + image1 - np.testing.assert_array_equal( - (ref_array + 3).astype(types[i]), - image2.array, - err_msg="Binary radd scalar in Image class (dictionary call) does" - + " not match reference for dtype = " - + str(types[i]), - ) - # Then try using the eval command to mimic use via ImageD, ImageF etc. - image_init_func = eval("galsim.Image" + tchar[i]) - slice_array = large_array.astype(types[i])[::3, ::2] - image1 = image_init_func(slice_array) - image2 = image1 + 3 - np.testing.assert_array_equal( - (ref_array + 3).astype(types[i]), - image2.array, - err_msg="Binary add scalar in Image class does not match reference for dtype = " - + str(types[i]), - ) - image2 = 3 + image1 - np.testing.assert_array_equal( - (ref_array + 3).astype(types[i]), - image2.array, - err_msg="Binary radd scalar in Image class does not match reference for dtype = " - + str(types[i]), - ) - - -@timer -def test_Image_binary_scalar_subtract(): - """Test that all six types of supported Images binary scalar subtract correctly.""" - for i in range(ntypes): - # First try using the dictionary-type Image init - image1 = galsim.Image(ref_array.astype(types[i])) - image2 = image1 - 3 - np.testing.assert_array_equal( - (ref_array - 3).astype(types[i]), - image2.array, - err_msg="Binary add scalar in Image class (dictionary call) does" - + " not match reference for dtype = " - + str(types[i]), - ) - image2 = 3 - image1 - np.testing.assert_array_equal( - (3 - ref_array).astype(types[i]), - image2.array, - err_msg="Binary add scalar in Image class (dictionary call) does" - + " not match reference for dtype = " - + str(types[i]), - ) - # Then try using the eval command to mimic use via ImageD, ImageF etc. - image_init_func = eval("galsim.Image" + tchar[i]) - slice_array = large_array.astype(types[i])[::3, ::2] - image1 = image_init_func(slice_array) - image2 = image1 - 3 - np.testing.assert_array_equal( - (ref_array - 3).astype(types[i]), - image2.array, - err_msg="Binary add scalar in Image class does not match reference for dtype = " - + str(types[i]), - ) - image2 = 3 - image1 - np.testing.assert_array_equal( - (3 - ref_array).astype(types[i]), - image2.array, - err_msg="Binary add scalar in Image class does not match reference for dtype = " - + str(types[i]), - ) - - -@timer -def test_Image_binary_scalar_multiply(): - """Test that all six types of supported Images binary scalar multiply correctly.""" - for i in range(ntypes): - # First try using the dictionary-type Image init - image1 = galsim.Image(ref_array.astype(types[i])) - image2 = image1 * 3 - np.testing.assert_array_equal( - (ref_array * 3).astype(types[i]), - image2.array, - err_msg="Binary multiply scalar in Image class (dictionary call) does" - + " not match reference for dtype = " - + str(types[i]), - ) - image2 = 3 * image1 - np.testing.assert_array_equal( - (ref_array * 3).astype(types[i]), - image2.array, - err_msg="Binary rmultiply scalar in Image class (dictionary call) does" - + " not match reference for dtype = " - + str(types[i]), - ) - # Then try using the eval command to mimic use via ImageD, ImageF etc. - image_init_func = eval("galsim.Image" + tchar[i]) - slice_array = large_array.astype(types[i])[::3, ::2] - image1 = image_init_func(slice_array) - image2 = image1 * 3 - np.testing.assert_array_equal( - (ref_array * 3).astype(types[i]), - image2.array, - err_msg="Binary multiply scalar in Image class does" - + " not match reference for dtype = " - + str(types[i]), - ) - image2 = 3 * image1 - np.testing.assert_array_equal( - (ref_array * 3).astype(types[i]), - image2.array, - err_msg="Binary rmultiply scalar in Image class does" - + " not match reference for dtype = " - + str(types[i]), - ) - - -@timer -def test_Image_binary_scalar_divide(): - """Test that all six types of supported Images binary scalar divide correctly.""" - for i in range(ntypes): - # First try using the dictionary-type Image init - image1 = galsim.Image((3 * ref_array).astype(types[i])) - image2 = image1 / 3 - np.testing.assert_array_equal( - ref_array.astype(types[i]), - image2.array, - err_msg="Binary divide scalar in Image class (dictionary call) does" - + " not match reference for dtype = " - + str(types[i]), - ) - # Then try using the eval command to mimic use via ImageD, ImageF etc. - image_init_func = eval("galsim.Image" + tchar[i]) - slice_array = (3 * large_array).astype(types[i])[::3, ::2] - image1 = image_init_func(slice_array) - image2 = image1 / 3 - np.testing.assert_array_equal( - ref_array.astype(types[i]), - image2.array, - err_msg="Binary divide scalar in Image class does" - + " not match reference for dtype = " - + str(types[i]), - ) - - -@timer -def test_Image_binary_scalar_pow(): - """Test that all six types of supported Images can be raised to a power (scalar) correctly.""" - for i in range(ntypes): - # First try using the dictionary-type Image init - image1 = galsim.Image(ref_array.astype(types[i])) - image2 = galsim.Image((ref_array**2).astype(types[i])) - image3 = image1**2 - # Note: unlike for the tests above with multiplication, the test fails if I use - # assert_array_equal. I have to use assert_array_almost_equal to avoid failure due to - # small numerical issues. - np.testing.assert_array_almost_equal( - image3.array, - image2.array, - decimal=4, - err_msg="Binary pow scalar in Image class (dictionary call) does" - + " not match reference for dtype = " - + str(types[i]), - ) - - # Then try using the eval command to mimic use via ImageD, ImageF etc. - image_init_func = eval("galsim.Image" + tchar[i]) - slice_array = large_array.astype(types[i])[::3, ::2] - image1 = image_init_func(slice_array) - image2 = image_init_func(ref_array.astype(types[i])) - image2 **= 2 - image3 = image1**2 - np.testing.assert_array_equal( - image3.array, - image2.array, - err_msg="Binary pow scalar in Image class does" - + " not match reference for dtype = " - + str(types[i]), - ) - - # float types can also be taken to a float power - if types[i] in [np.float32, np.float64]: - image2 = image_init_func((ref_array ** (1 / 1.3)).astype(types[i])) - image3 = image2**1.3 - # Note: unlike for the tests above with multiplication/division, the test fails if I use - # assert_array_equal. I have to use assert_array_almost_equal to avoid failure due to - # small numerical issues. - np.testing.assert_array_almost_equal( - ref_array.astype(types[i]), - image3.array, - decimal=4, - err_msg="Binary pow scalar in Image class (dictionary call) does" - + " not match reference for dtype = " - + str(types[i]), - ) - - with assert_raises(TypeError): - image1**image2 - - -@timer -def test_Image_inplace_add(): - """Test that all six types of supported Images inplace add correctly.""" - for i in range(ntypes): - # First try using the dictionary-type Image init - image1 = galsim.Image(ref_array.astype(types[i])) - image2 = galsim.Image((2 * ref_array).astype(types[i])) - image1 += image2 - np.testing.assert_array_equal( - (3 * ref_array).astype(types[i]), - image1.array, - err_msg="Inplace add in Image class (dictionary call) does" - + " not match reference for dtype = " - + str(types[i]), - ) - - # Then try using the eval command to mimic use via ImageD, ImageF etc. - image_init_func = eval("galsim.Image" + tchar[i]) - slice_array = large_array.copy().astype(types[i])[::3, ::2] - image1 = image_init_func(slice_array) - image2 = image_init_func((2 * ref_array).astype(types[i])) - image1 += image2 - np.testing.assert_array_equal( - (3 * ref_array).astype(types[i]), - image1.array, - err_msg="Inplace add in Image class does not match reference for dtype = " - + str(types[i]), - ) - - for j in range(i): # Only add simpler types to this one. - image2_init_func = eval("galsim.Image" + tchar[j]) - slice_array = large_array.copy().astype(types[i])[::3, ::2] - image1 = image_init_func(slice_array) - image2 = image2_init_func((2 * ref_array).astype(types[j])) - image1 += image2 - np.testing.assert_array_equal( - (3 * ref_array).astype(types[i]), - image1.array, - err_msg="Inplace add in Image class does not match reference for dtypes = " - + str(types[i]) - + " and " - + str(types[j]), - ) - - with assert_raises(ValueError): - image1 += image1.subImage(galsim.BoundsI(0, 4, 0, 4)) - - -@timer -def test_Image_inplace_subtract(): - """Test that all six types of supported Images inplace subtract correctly.""" - for i in range(ntypes): - # First try using the dictionary-type Image init - image1 = galsim.Image((2 * ref_array).astype(types[i])) - image2 = galsim.Image(ref_array.astype(types[i])) - image1 -= image2 - np.testing.assert_array_equal( - ref_array.astype(types[i]), - image1.array, - err_msg="Inplace subtract in Image class (dictionary call) does" - + " not match reference for dtype = " - + str(types[i]), - ) - - # Then try using the eval command to mimic use via ImageD, ImageF etc. - image_init_func = eval("galsim.Image" + tchar[i]) - slice_array = (2 * large_array).astype(types[i])[::3, ::2] - image1 = image_init_func(slice_array) - image2 = image_init_func(ref_array.astype(types[i])) - image1 -= image2 - np.testing.assert_array_equal( - ref_array.astype(types[i]), - image1.array, - err_msg="Inplace subtract in Image class does" - + " not match reference for dtype = " - + str(types[i]), - ) - - for j in range(i): # Only subtract simpler types from this one. - image2_init_func = eval("galsim.Image" + tchar[j]) - slice_array = (2 * large_array).astype(types[i])[::3, ::2] - image1 = image_init_func(slice_array) - image2 = image2_init_func(ref_array.astype(types[j])) - image1 -= image2 - np.testing.assert_array_equal( - ref_array.astype(types[i]), - image1.array, - err_msg="Inplace subtract in Image class does not match reference for dtypes = " - + str(types[i]) - + " and " - + str(types[j]), - ) - - with assert_raises(ValueError): - image1 -= image1.subImage(galsim.BoundsI(0, 4, 0, 4)) - - -@timer -def test_Image_inplace_multiply(): - """Test that all six types of supported Images inplace multiply correctly.""" - for i in range(ntypes): - # First try using the dictionary-type Image init - image1 = galsim.Image(ref_array.astype(types[i])) - image2 = galsim.Image((2 * ref_array).astype(types[i])) - image1 *= image2 - np.testing.assert_array_equal( - (2 * ref_array**2).astype(types[i]), - image1.array, - err_msg="Inplace multiply in Image class (dictionary call) does" - + " not match reference for dtype = " - + str(types[i]), - ) - - # Then try using the eval command to mimic use via ImageD, ImageF etc. - image_init_func = eval("galsim.Image" + tchar[i]) - slice_array = large_array.copy().astype(types[i])[::3, ::2] - image1 = image_init_func(slice_array) - image2 = image_init_func((2 * ref_array).astype(types[i])) - image1 *= image2 - np.testing.assert_array_equal( - (2 * ref_array**2).astype(types[i]), - image1.array, - err_msg="Inplace multiply in Image class does not match reference for dtype = " - + str(types[i]), - ) - - for j in range(i): # Only multiply simpler types to this one. - image2_init_func = eval("galsim.Image" + tchar[j]) - slice_array = large_array.copy().astype(types[i])[::3, ::2] - image1 = image_init_func(slice_array) - image2 = image2_init_func((2 * ref_array).astype(types[j])) - image1 *= image2 - np.testing.assert_array_equal( - (2 * ref_array**2).astype(types[i]), - image1.array, - err_msg="Inplace multiply in Image class does not match reference for dtypes = " - + str(types[i]) - + " and " - + str(types[j]), - ) - - with assert_raises(ValueError): - image1 *= image1.subImage(galsim.BoundsI(0, 4, 0, 4)) - - -@timer -def test_Image_inplace_divide(): - """Test that all six types of supported Images inplace divide correctly.""" - for i in range(ntypes): - # JAX specific modification - # ------------------------- - # Decimals adjusted for float32 because computation on gpu is different than cpu - decimal = 5 if (types[i] == np.complex64 or types[i] == np.float32) else 12 - # First try using the dictionary-type Image init - image1 = galsim.Image((2 * (ref_array + 1) ** 2).astype(types[i])) - image2 = galsim.Image((ref_array + 1).astype(types[i])) - image1 /= image2 - np.testing.assert_almost_equal( - (2 * (ref_array + 1)).astype(types[i]), - image1.array, - decimal=decimal, - err_msg="Inplace divide in Image class (dictionary call) does" - + " not match reference for dtype = " - + str(types[i]), - ) - - # Then try using the eval command to mimic use via ImageD, ImageF etc. - image_init_func = eval("galsim.Image" + tchar[i]) - slice_array = (2 * (large_array + 1) ** 2).astype(types[i])[::3, ::2] - image1 = image_init_func(slice_array) - image2 = image_init_func((ref_array + 1).astype(types[i])) - image1 /= image2 - np.testing.assert_almost_equal( - (2 * (ref_array + 1)).astype(types[i]), - image1.array, - decimal=decimal, - err_msg="Inplace divide in Image class does not match reference for dtype = " - + str(types[i]), - ) - - # Test image.invertSelf() - # Intentionally make some elements zero, so we test that 1/0 -> 0. - image1 = galsim.Image((ref_array // 11 - 3).astype(types[i])) - image2 = image1.copy() - mask1 = image1.array == 0 - mask2 = image1.array != 0 - image2.invertSelf() - np.testing.assert_array_equal( - image2.array[mask1], 0, err_msg="invertSelf did not do 1/0 -> 0." - ) - np.testing.assert_array_equal( - image2.array[mask2], - (1.0 / image1.array[mask2]).astype(types[i]), - err_msg="invertSelf gave wrong answer for non-zero elements", - ) - - for j in range(i): # Only divide simpler types into this one. - # JAX specific modification - # ------------------------- - # Decimals adjusted for float32 because computation on gpu is different than cpu - decimal = ( - 5 - if ( - types[i] == np.complex64 - or types[j] == np.complex64 - or types[i] == np.float32 - or types[j] == np.float32 - ) - else 12 - ) - image2_init_func = eval("galsim.Image" + tchar[j]) - slice_array = (2 * (large_array + 1) ** 2).astype(types[i])[::3, ::2] - image1 = image_init_func(slice_array) - image2 = image2_init_func((ref_array + 1).astype(types[j])) - image1 /= image2 - np.testing.assert_almost_equal( - (2 * (ref_array + 1)).astype(types[i]), - image1.array, - decimal=decimal, - err_msg="Inplace divide in Image class does not match reference for dtypes = " - + str(types[i]) - + " and " - + str(types[j]), - ) - - with assert_raises(ValueError): - image1 /= image1.subImage(galsim.BoundsI(0, 4, 0, 4)) - - -@timer -def test_Image_inplace_scalar_add(): - """Test that all six types of supported Images inplace scalar add correctly.""" - for i in range(ntypes): - # First try using the dictionary-type Image init - image1 = galsim.Image(ref_array.astype(types[i])) - image1 += 1 - np.testing.assert_array_equal( - (ref_array + 1).astype(types[i]), - image1.array, - err_msg="Inplace scalar add in Image class (dictionary " - + "call) does not match reference for dtype = " - + str(types[i]), - ) - # Then try using the eval command to mimic use via ImageD, ImageF etc. - image_init_func = eval("galsim.Image" + tchar[i]) - slice_array = large_array.copy().astype(types[i])[::3, ::2] - image1 = image_init_func(slice_array) - image1 += 1 - np.testing.assert_array_equal( - (ref_array + 1).astype(types[i]), - image1.array, - err_msg="Inplace scalar add in Image class does not match reference for dtype = " - + str(types[i]), - ) - - -@timer -def test_Image_inplace_scalar_subtract(): - """Test that all six types of supported Images inplace scalar subtract correctly.""" - for i in range(ntypes): - # First try using the dictionary-type Image init - image1 = galsim.Image(ref_array.astype(types[i])) - image1 -= 1 - np.testing.assert_array_equal( - (ref_array - 1).astype(types[i]), - image1.array, - err_msg="Inplace scalar subtract in Image class (dictionary " - + "call) does not match reference for dtype = " - + str(types[i]), - ) - # Then try using the eval command to mimic use via ImageD, ImageF etc. - image_init_func = eval("galsim.Image" + tchar[i]) - slice_array = large_array.copy().astype(types[i])[::3, ::2] - image1 = image_init_func(slice_array) - image1 -= 1 - np.testing.assert_array_equal( - (ref_array - 1).astype(types[i]), - image1.array, - err_msg="Inplace scalar subtract in Image class does" - + " not match reference for dtype = " - + str(types[i]), - ) - - -@timer -def test_Image_inplace_scalar_multiply(): - """Test that all six types of supported Images inplace scalar multiply correctly.""" - for i in range(ntypes): - # First try using the dictionary-type Image init - image1 = galsim.Image(ref_array.astype(types[i])) - image2 = galsim.Image((2 * ref_array).astype(types[i])) - image1 *= 2 - np.testing.assert_array_equal( - image1.array, - image2.array, - err_msg="Inplace scalar multiply in Image class (dictionary " - + "call) does not match reference for dtype = " - + str(types[i]), - ) - # Then try using the eval command to mimic use via ImageD, ImageF etc. - image_init_func = eval("galsim.Image" + tchar[i]) - slice_array = large_array.copy().astype(types[i])[::3, ::2] - image1 = image_init_func(slice_array) - image2 = image_init_func((2 * ref_array).astype(types[i])) - image1 *= 2 - np.testing.assert_array_equal( - image1.array, - image2.array, - err_msg="Inplace scalar multiply in Image class does" - + " not match reference for dtype = " - + str(types[i]), - ) - - -@timer -def test_Image_inplace_scalar_divide(): - """Test that all six types of supported Images inplace scalar divide correctly.""" - for i in range(ntypes): - # First try using the dictionary-type Image init - image1 = galsim.Image(ref_array.astype(types[i])) - image2 = galsim.Image((2 * ref_array).astype(types[i])) - image2 /= 2 - np.testing.assert_array_equal( - image1.array, - image2.array, - err_msg="Inplace scalar divide in Image class (dictionary " - + "call) does not match reference for dtype = " - + str(types[i]), - ) - # Then try using the eval command to mimic use via ImageD, ImageF etc. - image_init_func = eval("galsim.Image" + tchar[i]) - slice_array = (2 * large_array).astype(types[i])[::3, ::2] - image1 = image_init_func(slice_array) - image1 /= 2 - np.testing.assert_array_equal( - ref_array.astype(types[i]), - image1.array, - err_msg="Inplace scalar divide in Image class does" - + " not match reference for dtype = " - + str(types[i]), - ) - - -@timer -def test_Image_inplace_scalar_pow(): - """Test that all six types of supported Images can be raised (in-place) to a scalar correctly.""" - for i in range(ntypes): - # First try using the dictionary-type Image init - image1 = galsim.Image((ref_array**2).astype(types[i])) - image2 = galsim.Image(ref_array.astype(types[i])) - image2 **= 2 - np.testing.assert_array_almost_equal( - image1.array, - image2.array, - decimal=4, - err_msg="Inplace scalar pow in Image class (dictionary " - + "call) does not match reference for dtype = " - + str(types[i]), - ) - - # Then try using the eval command to mimic use via ImageD, ImageF etc. - image_init_func = eval("galsim.Image" + tchar[i]) - slice_array = large_array.copy().astype(types[i])[::3, ::2] - image1 = image_init_func(slice_array) - image2 = image_init_func((ref_array.astype(types[i])) ** 2) - image1 **= 2 - np.testing.assert_array_equal( - image1.array, - image2.array, - err_msg="Inplace scalar pow in Image class does" - + " not match reference for dtype = " - + str(types[i]), - ) - - # float types can also be taken to a float power - if types[i] in [np.float32, np.float64]: - # First try using the dictionary-type Image init - image1 = galsim.Image(ref_array.astype(types[i])) - image2 = galsim.Image((ref_array ** (1.0 / 1.3)).astype(types[i])) - image2 **= 1.3 - np.testing.assert_array_almost_equal( - image1.array, - image2.array, - decimal=4, - err_msg="Inplace scalar pow in Image class (dictionary " - + "call) does not match reference for dtype = " - + str(types[i]), - ) - - with assert_raises(TypeError): - image1 **= image2 - - -@timer -def test_Image_subImage(): - """Test that subImages are accessed and written correctly.""" - for i in range(ntypes): - image = galsim.Image(ref_array.astype(types[i])) - bounds = galsim.BoundsI(3, 4, 2, 3) - sub_array = np.array([[32, 42], [33, 43]]).astype(types[i]) - np.testing.assert_array_equal( - image.subImage(bounds).array, - sub_array, - err_msg="image.subImage(bounds) does not match reference for dtype = " - + str(types[i]), - ) - np.testing.assert_array_equal( - image[bounds].array, - sub_array, - err_msg="image[bounds] does not match reference for dtype = " - + str(types[i]), - ) - image[bounds] = galsim.Image(sub_array + 100) - np.testing.assert_array_equal( - image[bounds].array, - (sub_array + 100), - err_msg="image[bounds] = im2 does not set correctly for dtype = " - + str(types[i]), - ) - for xpos in range(1, test_shape[0] + 1): - for ypos in range(1, test_shape[1] + 1): - if ( - xpos >= bounds.xmin - and xpos <= bounds.xmax - and ypos >= bounds.ymin - and ypos <= bounds.ymax - ): - value = ref_array[ypos - 1, xpos - 1] + 100 - else: - value = ref_array[ypos - 1, xpos - 1] - assert ( - image(xpos, ypos) == value - ), "image[bounds] = im2 set wrong locations for dtype = " + str( - types[i] - ) - - image = galsim.Image(ref_array.astype(types[i])) - image[bounds] += 100 - np.testing.assert_array_equal( - image[bounds].array, - (sub_array + 100), - err_msg="image[bounds] += 100 does not set correctly for dtype = " - + str(types[i]), - ) - image[bounds] = galsim.Image(sub_array) - np.testing.assert_array_equal( - image.array, - ref_array, - err_msg="image[bounds] += 100 set wrong locations for dtype = " - + str(types[i]), - ) - - image = galsim.Image(ref_array.astype(types[i])) - image[bounds] -= 100 - np.testing.assert_array_equal( - image[bounds].array, - (sub_array - 100), - err_msg="image[bounds] -= 100 does not set correctly for dtype = " - + str(types[i]), - ) - image[bounds] = galsim.Image(sub_array) - np.testing.assert_array_equal( - image.array, - ref_array, - err_msg="image[bounds] -= 100 set wrong locations for dtype = " - + str(types[i]), - ) - - image = galsim.Image(ref_array.astype(types[i])) - image[bounds] *= 100 - np.testing.assert_array_equal( - image[bounds].array, - (sub_array * 100), - err_msg="image[bounds] *= 100 does not set correctly for dtype = " - + str(types[i]), - ) - image[bounds] = galsim.Image(sub_array) - np.testing.assert_array_equal( - image.array, - ref_array, - err_msg="image[bounds] *= 100 set wrong locations for dtype = " - + str(types[i]), - ) - - image = galsim.Image((100 * ref_array).astype(types[i])) - image[bounds] /= 100 - np.testing.assert_array_equal( - image[bounds].array, - (sub_array), - err_msg="image[bounds] /= 100 does not set correctly for dtype = " - + str(types[i]), - ) - image[bounds] = galsim.Image((100 * sub_array).astype(types[i])) - np.testing.assert_array_equal( - image.array, - (100 * ref_array), - err_msg="image[bounds] /= 100 set wrong locations for dtype = " - + str(types[i]), - ) - - im2 = galsim.Image(sub_array) - image = galsim.Image(ref_array.astype(types[i])) - image[bounds] += im2 - np.testing.assert_array_equal( - image[bounds].array, - (2 * sub_array), - err_msg="image[bounds] += im2 does not set correctly for dtype = " - + str(types[i]), - ) - image[bounds] = galsim.Image(sub_array) - np.testing.assert_array_equal( - image.array, - ref_array, - err_msg="image[bounds] += im2 set wrong locations for dtype = " - + str(types[i]), - ) - - image = galsim.Image(2 * ref_array.astype(types[i])) - image[bounds] -= im2 - np.testing.assert_array_equal( - image[bounds].array, - sub_array, - err_msg="image[bounds] -= im2 does not set correctly for dtype = " - + str(types[i]), - ) - image[bounds] = galsim.Image((2 * sub_array).astype(types[i])) - np.testing.assert_array_equal( - image.array, - (2 * ref_array), - err_msg="image[bounds] -= im2 set wrong locations for dtype = " - + str(types[i]), - ) - - image = galsim.Image(ref_array.astype(types[i])) - image[bounds] *= im2 - np.testing.assert_array_equal( - image[bounds].array, - (sub_array**2), - err_msg="image[bounds] *= im2 does not set correctly for dtype = " - + str(types[i]), - ) - image[bounds] = galsim.Image(sub_array) - np.testing.assert_array_equal( - image.array, - ref_array, - err_msg="image[bounds] *= im2 set wrong locations for dtype = " - + str(types[i]), - ) - - image = galsim.Image((2 * ref_array**2).astype(types[i])) - image[bounds] /= im2 - np.testing.assert_array_equal( - image[bounds].array, - (2 * sub_array), - err_msg="image[bounds] /= im2 does not set correctly for dtype = " - + str(types[i]), - ) - image[bounds] = galsim.Image((2 * sub_array**2).astype(types[i])) - np.testing.assert_array_equal( - image.array, - (2 * ref_array**2), - err_msg="image[bounds] /= im2 set wrong locations for dtype = " - + str(types[i]), - ) - - # JAX specific modification - # ------------------------- - # We won't do any pickling - # check_pickle(image) - - assert_raises(TypeError, image.subImage, bounds=None) - assert_raises(TypeError, image.subImage, bounds=galsim.BoundsD(0, 4, 0, 4)) - - -def make_subImage(file_name, bounds): - """Helper function for test_subImage_persistence""" - full_im = galsim.fits.read(file_name) - stamp = full_im.subImage(bounds) - return stamp - - -@timer -def test_subImage_persistence(): - """Test that a subimage is properly accessible even if the original image has gone out - of scope. - """ - file_name = os.path.join( - os.path.dirname(__file__), "../../GalSim/tests", "fits_files", "tpv.fits" - ) - bounds = galsim.BoundsI(123, 133, 45, 55) # Something random - - # In this case, the original image has gone out of scope. At least on some systems, - # this used to caus a seg fault when accessing stamp1.array. (BAD!) - stamp1 = make_subImage(file_name, bounds) - print("stamp1 = ", stamp1.array) - - full_im = galsim.fits.read(file_name) - stamp2 = full_im.subImage(bounds) - print("stamp2 = ", stamp2.array) - - np.testing.assert_array_equal(stamp1.array, stamp2.array) - - -@timer -def test_Image_resize(): - """Test that the Image resize function works correctly.""" - # Use a random number generator for some values here. - ud = galsim.UniformDeviate(515324) - - for i in range(ntypes): - # Resize to a bunch of different shapes (larger and smaller) to test reallocations - for shape in [(10, 10), (3, 20), (21, 8), (1, 3), (13, 30)]: - # im1 starts with basic constructor with a given size - array_type = types[i] - im1 = galsim.Image(5, 5, dtype=array_type, scale=0.1) - - # im2 stars with null constructor - im2 = galsim.Image(dtype=array_type, scale=0.2) - - # im3 is a view into a larger image - im3_full = galsim.Image(10, 10, dtype=array_type, init_value=23, scale=0.3) - im3 = im3_full.subImage(galsim.BoundsI(1, 6, 1, 6)) - - # Make sure at least one of the _arrays is instantiated. This isn't required, - # but we used to have bugs if the array was instantiated before resizing. - # So test im1 and im3 being instantiated and im2 not instantiated. - np.testing.assert_array_equal(im1.array, 0, "im1 is not initially all 0.") - np.testing.assert_array_equal(im3.array, 23, "im3 is not initially all 23.") - - # Have the xmin, ymin value be random numbers from -100..100: - xmin = int(ud() * 200) - 100 - ymin = int(ud() * 200) - 100 - xmax = xmin + shape[1] - 1 - ymax = ymin + shape[0] - 1 - b = galsim.BoundsI(xmin, xmax, ymin, ymax) - im1.resize(b) - im2.resize(b) - im3.resize(b, wcs=galsim.PixelScale(0.33)) - - np.testing.assert_equal( - b, im1.bounds, err_msg="im1 has wrong bounds after resize to b = %s" % b - ) - np.testing.assert_equal( - b, im2.bounds, err_msg="im2 has wrong bounds after resize to b = %s" % b - ) - np.testing.assert_equal( - b, im3.bounds, err_msg="im3 has wrong bounds after resize to b = %s" % b - ) - np.testing.assert_array_equal( - im1.array.shape, shape, err_msg="im1.array.shape wrong after resize" - ) - np.testing.assert_array_equal( - im2.array.shape, shape, err_msg="im2.array.shape wrong after resize" - ) - np.testing.assert_array_equal( - im3.array.shape, shape, err_msg="im3.array.shape wrong after resize" - ) - np.testing.assert_equal( - im1.scale, 0.1, err_msg="im1 has wrong scale after resize to b = %s" % b - ) - np.testing.assert_equal( - im2.scale, 0.2, err_msg="im2 has wrong scale after resize to b = %s" % b - ) - np.testing.assert_equal( - im3.scale, - 0.33, - err_msg="im3 has wrong scale after resize to b = %s" % b, - ) - - # Fill the images with random numbers - for x in range(xmin, xmax + 1): - for y in range(ymin, ymax + 1): - val = simple_types[i](ud() * 500) - im1.setValue(x, y, val) - im2._setValue(x, y, val) - im3.setValue(x, y, val) - - # They should be equal now. This doesn't completely guarantee that nothing is - # wrong, but hopefully if we are misallocating memory here, something will be - # clobbered or we will get a seg fault. - np.testing.assert_array_equal( - im1.array, im2.array, err_msg="im1 != im2 after resize to b = %s" % b - ) - np.testing.assert_array_equal( - im1.array, im3.array, err_msg="im1 != im3 after resize to b = %s" % b - ) - np.testing.assert_array_equal( - im2.array, im3.array, err_msg="im2 != im3 after resize to b = %s" % b - ) - - # Also, since the view was resized, it should no longer be coupled to the original. - np.testing.assert_array_equal( - im3_full.array, 23, err_msg="im3_full changed" - ) - - check_pickle(im1) - check_pickle(im2) - check_pickle(im3) - - assert_raises(TypeError, im1.resize, bounds=None) - assert_raises(TypeError, im1.resize, bounds=galsim.BoundsD(0, 5, 0, 5)) - - -# JAX specific modification -# ------------------------- -# We do not have the concept of contanst images in JAX, so we skip this test. -# @timer -# def test_ConstImage_array_constness(): -# """Test that Image instances with make_const=True cannot be modified via their .array -# attributes, and that if this is attempted a GalSimImmutableError is raised. -# """ -# for i in range(ntypes): -# image = galsim.Image(ref_array.astype(types[i]), make_const=True) -# # Apparently older numpy versions might raise a RuntimeError, a ValueError, or a TypeError -# # when trying to write to arrays that have writeable=False. -# # From the numpy 1.7.0 release notes: -# # Attempting to write to a read-only array (one with -# # ``arr.flags.writeable`` set to ``False``) used to raise either a -# # RuntimeError, ValueError, or TypeError inconsistently, depending on -# # which code path was taken. It now consistently raises a ValueError. -# with assert_raises((RuntimeError, ValueError, TypeError)): -# image.array[1, 2] = 666 - -# # Native image operations that are invalid just raise GalSimImmutableError -# with assert_raises(galsim.GalSimImmutableError): -# image[1, 2] = 666 - -# with assert_raises(galsim.GalSimImmutableError): -# image.setValue(1,2,666) - -# with assert_raises(galsim.GalSimImmutableError): -# image[image.bounds] = image - -# # The rest are functions, so just use assert_raises. -# assert_raises(galsim.GalSimImmutableError, image.setValue, 1, 2, 666) -# assert_raises(galsim.GalSimImmutableError, image.setSubImage, image.bounds, image) -# assert_raises(galsim.GalSimImmutableError, image.addValue, 1, 2, 666) -# assert_raises(galsim.GalSimImmutableError, image.copyFrom, image) -# assert_raises(galsim.GalSimImmutableError, image.resize, image.bounds) -# assert_raises(galsim.GalSimImmutableError, image.fill, 666) -# assert_raises(galsim.GalSimImmutableError, image.setZero) -# assert_raises(galsim.GalSimImmutableError, image.invertSelf) - -# check_pickle(image) - - -@timer -def test_BoundsI_init_with_non_pure_ints(): - """Test that BoundsI converts its input args to int variables on input.""" - ref_bound_vals = (5, 35, 1, 48) - xmin_test, xmax_test, ymin_test, ymax_test = ref_bound_vals - ref_bounds = galsim.BoundsI(xmin_test, xmax_test, ymin_test, ymax_test) - bound_arr_int = np.asarray(ref_bound_vals, dtype=int) - bound_arr_flt = np.asarray(ref_bound_vals, dtype=float) - bound_arr_flt_nonint = bound_arr_flt + 0.3 - - # Check kwarg initialization: - assert ref_bounds == galsim.BoundsI( - xmin=bound_arr_int[0], - xmax=bound_arr_int[1], - ymin=bound_arr_int[2], - ymax=bound_arr_int[3], - ), "Cannot initialize a BoundI with int array elements" - assert ref_bounds == galsim.BoundsI( - xmin=bound_arr_flt[0], - xmax=bound_arr_flt[1], - ymin=bound_arr_flt[2], - ymax=bound_arr_flt[3], - ), "Cannot initialize a BoundI with float array elements" - - # Check arg initialization: - assert ref_bounds == galsim.BoundsI( - *bound_arr_int - ), "Cannot initialize a BoundI with int array elements" - assert ref_bounds == galsim.BoundsI( - *bound_arr_flt - ), "Cannot initialize a BoundI with float array elements" - - # JAX specific modification - # ------------------------- - # No type checking for JAX bounds inputs - # assert_raises(TypeError, galsim.BoundsI, *bound_arr_flt_nonint) - # assert_raises(TypeError, galsim.BoundsI, - # xmin=bound_arr_flt_nonint[0], xmax=bound_arr_flt_nonint[1], - # ymin=bound_arr_flt_nonint[2], ymax=bound_arr_flt_nonint[3]) - - -@timer -def test_Image_constructor(): - """Check that the Image constructor that takes NumPy array does not mangle input.""" - # Loop over types. - for i in range(ntypes): - array_dtype = np.dtype(types[i]) - - # Make a NumPy array directly, with non-trivially interesting values. - test_arr = np.ones((3, 4), dtype=types[i]) - test_arr[1, 3] = -5 - test_arr[2, 2] = 7 - # Initialize the Image from it. - test_im = galsim.Image(test_arr) - # Check that the image.array attribute matches the original. - np.testing.assert_array_equal( - test_arr, - test_im.array, - err_msg="Image constructor mangled input NumPy array.", - ) - - # Now make an opposite-endian Numpy array, to initialize the Image. - new_type = array_dtype.newbyteorder("S") - test_arr = np.ones((3, 4), dtype=new_type) - test_arr[1, 3] = -5 - test_arr[2, 2] = 7 - # Initialize the Image from it. - test_im = galsim.Image(test_arr) - # Check that the image.array attribute matches the original. - np.testing.assert_array_equal( - test_arr, - test_im.array, - err_msg="Image constructor mangled input NumPy array (endian issues).", - ) - - # JAX specific modification - # ------------------------- - # No picklibng for JAX images - # check_pickle(test_im) - - # Check that some invalid sets of construction args raise the appropriate errors - # Invalid args - assert_raises(TypeError, galsim.Image, 1, 2, 3) - assert_raises(TypeError, galsim.Image, 128) - assert_raises(TypeError, galsim.Image, 1.8) - assert_raises(TypeError, galsim.Image, 1.3, 2.7) - # Invalid array kwarg - assert_raises(TypeError, galsim.Image, array=5) - assert_raises(TypeError, galsim.Image, array=test_im) - # Invalid image kwarg - assert_raises(TypeError, galsim.Image, image=5) - assert_raises(TypeError, galsim.Image, image=test_arr) - # Invalid bounds - assert_raises(TypeError, galsim.Image, bounds=(1, 4, 1, 3)) - assert_raises(TypeError, galsim.Image, bounds=galsim.BoundsD(1, 4, 1, 3)) - assert_raises(TypeError, galsim.Image, array=test_arr, bounds=(1, 4, 1, 3)) - assert_raises( - ValueError, galsim.Image, array=test_arr, bounds=galsim.BoundsI(1, 3, 1, 4) - ) - assert_raises( - ValueError, galsim.Image, array=test_arr, bounds=galsim.BoundsI(1, 4, 1, 1) - ) - # Invalid ncol, nrow - assert_raises(TypeError, galsim.Image, ncol=1.2, nrow=3) - assert_raises(TypeError, galsim.Image, ncol=2, nrow=3.4) - assert_raises(ValueError, galsim.Image, ncol="four", nrow="three") - # Invalid dtype - assert_raises(ValueError, galsim.Image, array=test_arr, dtype=bool) - assert_raises(ValueError, galsim.Image, array=test_arr.astype(bool)) - # Invalid scale - assert_raises(ValueError, galsim.Image, 4, 3, scale="invalid") - # Invalid wcs - assert_raises(TypeError, galsim.Image, 4, 3, wcs="invalid") - # Disallowed combinations - assert_raises( - TypeError, galsim.Image, ncol=4, nrow=3, bounds=galsim.BoundsI(1, 4, 1, 3) - ) - assert_raises(TypeError, galsim.Image, ncol=4, nrow=3, array=test_arr) - assert_raises(TypeError, galsim.Image, ncol=4, nrow=3, image=test_im) - assert_raises(TypeError, galsim.Image, ncol=4) - assert_raises(TypeError, galsim.Image, nrow=3) - assert_raises( - ValueError, galsim.Image, test_arr, bounds=galsim.BoundsI(1, 2, 1, 3) - ) - assert_raises( - ValueError, galsim.Image, array=test_arr, bounds=galsim.BoundsI(1, 2, 1, 3) - ) - assert_raises( - ValueError, galsim.Image, [[1, 2]], bounds=galsim.BoundsI(1, 2, 1, 3) - ) - assert_raises(TypeError, galsim.Image, test_arr, init_value=3) - assert_raises(TypeError, galsim.Image, array=test_arr, init_value=3) - assert_raises(TypeError, galsim.Image, test_im, init_value=3) - assert_raises(TypeError, galsim.Image, image=test_im, init_value=3) - assert_raises(TypeError, galsim.Image, dtype=float, init_value=3) - assert_raises( - TypeError, galsim.Image, test_im, scale=3, wcs=galsim.PixelScale(3) - ) - # Extra kwargs - assert_raises(TypeError, galsim.Image, image=test_im, name="invalid") - - -@timer -def test_Image_view(): - """Test the functionality of image.view(...)""" - im = galsim.ImageI( - 25, - 25, - wcs=galsim.AffineTransform(0.23, 0.01, -0.02, 0.22, galsim.PositionI(13, 13)), - ) - im._fill(17) - assert im.wcs == galsim.AffineTransform( - 0.23, 0.01, -0.02, 0.22, galsim.PositionI(13, 13) - ) - assert im.bounds == galsim.BoundsI(1, 25, 1, 25) - assert im(11, 19) == 17 # I'll keep editing this pixel to new values. - - # JAX specific modification - # ------------------------- - # No picklibng for JAX images - # check_pickle(im) - - # Test view with no arguments - imv = im.view() - assert imv.wcs == im.wcs - assert imv.bounds == im.bounds - imv.setValue(11, 19, 20) - # JAX specific modification - # ------------------------- - # the original image is left unmodified by the view - im.setValue(11, 19, 20) - assert imv(11, 19) == 20 - assert im(11, 19) == 20 - # JAX specific modification - # ------------------------- - # No picklibng for JAX images - # check_pickle(im) - # check_pickle(imv) - - # Test view with new origin - imv = im.view(origin=(0, 0)) - assert im.wcs == galsim.AffineTransform( - 0.23, 0.01, -0.02, 0.22, galsim.PositionI(13, 13) - ) - assert imv.wcs == galsim.AffineTransform( - 0.23, 0.01, -0.02, 0.22, galsim.PositionI(12, 12) - ) - assert im.bounds == galsim.BoundsI(1, 25, 1, 25) - assert imv.bounds == galsim.BoundsI(0, 24, 0, 24) - imv.setValue(10, 18, 30) - assert imv(10, 18) == 30 - # JAX specific modification - # ------------------------- - # the original image is left unmodified by the view - # assert im(11,19) == 30 - imv2 = im.view() - imv2.setOrigin(0, 0) - assert imv.bounds == imv2.bounds - assert imv.wcs == imv2.wcs - # JAX specific modification - # ------------------------- - # No picklibng for JAX images - # check_pickle(imv) - # check_pickle(imv2) - - # Test view with new center - imv = im.view(center=(0, 0)) - assert im.wcs == galsim.AffineTransform( - 0.23, 0.01, -0.02, 0.22, galsim.PositionI(13, 13) - ) - assert imv.wcs == galsim.AffineTransform( - 0.23, 0.01, -0.02, 0.22, galsim.PositionI(0, 0) - ) - assert im.bounds == galsim.BoundsI(1, 25, 1, 25) - assert imv.bounds == galsim.BoundsI(-12, 12, -12, 12) - imv.setValue(-2, 6, 40) - assert imv(-2, 6) == 40 - # JAX specific modification - # ------------------------- - # the original image is left unmodified by the view - # assert im(11,19) == 40 - imv2 = im.view() - imv2.setCenter(0, 0) - assert imv.bounds == imv2.bounds - assert imv.wcs == imv2.wcs - with assert_raises(galsim.GalSimError): - imv.scale # scale is invalid if wcs is not a PixelScale - # JAX specific modification - # ------------------------- - # No picklibng for JAX images - # check_pickle(imv) - # check_pickle(imv2) - - # Test view with new scale - imv = im.view(scale=0.17) - assert im.wcs == galsim.AffineTransform( - 0.23, 0.01, -0.02, 0.22, galsim.PositionI(13, 13) - ) - assert imv.wcs == galsim.PixelScale(0.17) - assert imv.bounds == im.bounds - imv.setValue(11, 19, 50) - assert imv(11, 19) == 50 - # JAX specific modification - # ------------------------- - # the original image is left unmodified by the view - # assert im(11,19) == 50 - imv2 = im.view() - with assert_raises(galsim.GalSimError): - imv2.scale = 0.17 # Invalid if wcs is not PixelScale - imv2.wcs = None - imv2.scale = 0.17 - assert imv.bounds == imv2.bounds - assert imv.wcs == imv2.wcs - # JAX specific modification - # ------------------------- - # No picklibng for JAX images - # check_pickle(imv) - # check_pickle(imv2) - - # Test view with new wcs - imv = im.view(wcs=galsim.JacobianWCS(0.0, 0.23, -0.23, 0.0)) - assert im.wcs == galsim.AffineTransform( - 0.23, 0.01, -0.02, 0.22, galsim.PositionI(13, 13) - ) - assert imv.wcs == galsim.JacobianWCS(0.0, 0.23, -0.23, 0.0) - assert imv.bounds == im.bounds - imv.setValue(11, 19, 60) - assert imv(11, 19) == 60 - # JAX specific modification - # ------------------------- - # the original image is left unmodified by the view - # assert im(11,19) == 60 - imv2 = im.view() - imv2.wcs = galsim.JacobianWCS(0.0, 0.23, -0.23, 0.0) - assert imv.bounds == imv2.bounds - assert imv.wcs == imv2.wcs - # JAX specific modification - # ------------------------- - # No picklibng for JAX images - # check_pickle(imv) - # check_pickle(imv2) - - # Go back to original value for that pixel and make sure all are still equal to 17 - im.setValue(11, 19, 17) - assert im.array.min() == 17 - assert im.array.max() == 17 - - assert_raises(TypeError, im.view, origin=(0, 0), center=(0, 0)) - assert_raises( - TypeError, im.view, scale=0.3, wcs=galsim.JacobianWCS(1.1, 0.1, 0.1, 1.0) - ) - # JAX specific modification - # ------------------------- - # this test cannot work in JAX because during the tracing process - # a PixelScale object may be given to the WCS init function - # assert_raises(TypeError, im.view, scale=galsim.PixelScale(0.3)) - assert_raises(TypeError, im.view, wcs=0.3) - - -@timer -def test_Image_writeheader(): - """Test the functionality of image.write(...) for images that have header attributes""" - # First check: if we have an image.header attribute, it gets written to file. - im_test = galsim.Image(10, 10) - key_name = "test_key" - im_test.header = galsim.FitsHeader(header={key_name: "test_val"}) - test_file = os.path.join(datadir, "test_header.fits") - im_test.write(test_file) - new_header = galsim.FitsHeader(test_file) - assert key_name.upper() in new_header.keys() - - # Second check: if we have an image.header attribute that modifies some keywords used by the - # WCS, then make sure it doesn't overwrite the WCS. - im_test.wcs = galsim.JacobianWCS(0.0, 0.23, -0.23, 0.0) - im_test.header = galsim.FitsHeader(header={"CD1_1": 10.0, key_name: "test_val"}) - im_test.write(test_file) - new_header = galsim.FitsHeader(test_file) - assert key_name.upper() in new_header.keys() - assert new_header["CD1_1"] == 0.0 - - # If clobbert = False, then trying to overwrite will raise an OSError - assert_raises(OSError, im_test.write, test_file, clobber=False) - - -@timer -def test_ne(): - """Check that inequality works as expected.""" - array1 = np.arange(32 * 32).reshape(32, 32).astype(float) - array2 = array1.copy() - array2[15, 15] += 2 - - objs = [ - galsim.ImageD(array1), - galsim.ImageD(array2), - # JAX specific modification - # ------------------------- - # make_const is not allowed - # galsim.ImageD(array2, make_const=True), - galsim.ImageD(array1, wcs=galsim.PixelScale(0.2)), - galsim.ImageD(array1, xmin=2), - ] - check_all_diff(objs) - - -@timer -def test_copy(): - """Test different ways to copy an Image.""" - wcs = galsim.AffineTransform(0.23, 0.01, -0.02, 0.22, galsim.PositionI(13, 13)) - im = galsim.Image(25, 25, wcs=wcs) - gn = galsim.GaussianNoise(sigma=1.7) - im.addNoise(gn) - - assert im.wcs == galsim.AffineTransform( - 0.23, 0.01, -0.02, 0.22, galsim.PositionI(13, 13) - ) - assert im.bounds == galsim.BoundsI(1, 25, 1, 25) - - # Simplest way to copy is copy() - im2 = im.copy() - assert im2.wcs == im.wcs - assert im2.bounds == im.bounds - np.testing.assert_array_equal(im2.array, im.array) - - # Make sure it actually copied the array, not just made a view of it. - im2.setValue(3, 8, 11.0) - assert im(3, 8) != 11.0 - - # Can also use constructor to "copy" - im3 = galsim.Image(im) - assert im3.wcs == im.wcs - assert im3.bounds == im.bounds - np.testing.assert_array_equal(im3.array, im.array) - im3.setValue(3, 8, 11.0) - assert im(3, 8) != 11.0 - - # JAX always copies so remove this test - # # If copy=False is specified, then it shares the same array - # im3b = galsim.Image(im, copy=False) - # assert im3b.wcs == im.wcs - # assert im3b.bounds == im.bounds - # np.testing.assert_array_equal(im3b.array, im.array) - # im3b.setValue(2, 3, 2.0) - # assert im3b(2, 3) == 2.0 - # assert im(2, 3) == 2.0 - - # Constructor can change the wcs - im4 = galsim.Image(im, scale=0.6) - assert im4.wcs != im.wcs # wcs is not equal this time. - assert im4.bounds == im.bounds - np.testing.assert_array_equal(im4.array, im.array) - im4.setValue(3, 8, 11.0) - assert im(3, 8) != 11.0 - - im5 = galsim.Image(im, wcs=galsim.PixelScale(1.4)) - assert im5.wcs != im.wcs # wcs is not equal this time. - assert im5.bounds == im.bounds - np.testing.assert_array_equal(im5.array, im.array) - im5.setValue(3, 8, 11.0) - assert im(3, 8) != 11.0 - - im6 = galsim.Image(im, wcs=wcs) - assert im6.wcs == im.wcs # This is the same wcs now. - assert im6.bounds == im.bounds - np.testing.assert_array_equal(im6.array, im.array) - im6.setValue(3, 8, 11.0) - assert im(3, 8) != 11.0 - - # Can also change the dtype - im7 = galsim.Image(im, dtype=float) - assert im7.wcs == im.wcs - assert im7.bounds == im.bounds - np.testing.assert_array_equal(im7.array, im.array) - im7.setValue(3, 8, 11.0) - assert im(3, 8) != 11.0 - - im8 = galsim.Image(im, wcs=wcs, dtype=float) - assert im8.wcs == im.wcs # This is the same wcs now. - assert im8.bounds == im.bounds - np.testing.assert_array_equal(im8.array, im.array) - im8.setValue(3, 8, 11.0) - assert im(3, 8) != 11.0 - - # Check that a slice image copies correctly - slice_array = large_array.astype(complex)[::3, ::2] - im_slice = galsim.Image(slice_array, wcs=wcs) - im9 = im_slice.copy() - assert im9.wcs == im_slice.wcs - assert im9.bounds == im_slice.bounds - np.testing.assert_array_equal(im9.array, im_slice.array) - im9.setValue(2, 3, 11.0) - assert im9(2, 3) == 11.0 - assert im_slice(2, 3) != 11.0 - - # JAX always copies so remove this test - # # Can also copy by giving the array and specify copy=True - # im10 = galsim.Image(im.array, bounds=im.bounds, wcs=im.wcs, copy=False) - # assert im10.wcs == im.wcs - # assert im10.bounds == im.bounds - # np.testing.assert_array_equal(im10.array, im.array) - # im10[2, 3] = 17 - # assert im10(2, 3) == 17.0 - # assert im(2, 3) == 17.0 - - im10b = galsim.Image(im.array, bounds=im.bounds, wcs=im.wcs, copy=True) - assert im10b.wcs == im.wcs - assert im10b.bounds == im.bounds - np.testing.assert_array_equal(im10b.array, im.array) - im10b[2, 3] = 27 - assert im10b(2, 3) == 27.0 - assert im(2, 3) != 27.0 - - # copyFrom copies the data only. - im5.copyFrom(im8) - assert im5.wcs != im.wcs # im5 had a different wcs. Should still have it. - assert im5.bounds == im8.bounds - np.testing.assert_array_equal(im5.array, im8.array) - assert im5(3, 8) == 11.0 - im8[3, 8] = 15 - assert im5(3, 8) == 11.0 - - assert_raises(TypeError, im5.copyFrom, im8.array) - im9 = galsim.Image(5, 5, init_value=3) - assert_raises(ValueError, im5.copyFrom, im9) - - -@timer -def test_complex_image(): - """Additional tests that are relevant for complex Image types""" - - for dtype in [np.complex64, np.complex128]: - # Some complex modifications to tests in test_Image_basic - im1 = galsim.Image(ncol, nrow, dtype=dtype) - im1_view = im1.view() - im1_cview = im1.view(make_const=True) - im2 = galsim.Image(ncol, nrow, init_value=23, dtype=dtype) - im2_view = im2.view() - im2_cview = im2.view(make_const=True) - im2_conj = im2.conjugate - - # Check various ways to set and get values - for y in range(1, nrow + 1): - for x in range(1, ncol + 1): - im1.setValue(x, y, 100 + 10 * x + y + 13j * x + 23j * y) - im2_view.setValue(x, y, 100 + 10 * x + y + 13j * x + 23j * y) - # JAX specific modification - # ------------------------- - # the view does not modify the parent array - im2.setValue(x, y, 100 + 10 * x + y + 13j * x + 23j * y) - im2_cview.setValue(x, y, 100 + 10 * x + y + 13j * x + 23j * y) - - for y in range(1, nrow + 1): - for x in range(1, ncol + 1): - value = 100 + 10 * x + y + 13j * x + 23j * y - assert im1(x, y) == value - assert im1.view()(x, y) == value - assert im1.view(make_const=True)(x, y) == value - assert im2(x, y) == value - assert im2_view(x, y) == value - assert im2_cview(x, y) == value - assert im1.conjugate(x, y) == np.conjugate(value) - - # complex conjugate is not a view into the original. - assert im2_conj(x, y) == 23 - assert im2.conjugate(x, y) == np.conjugate(value) - - value2 = 10 * x + y + 20j * x + 2j * y - im1.setValue(x, y, value2) - im2_view.setValue(x=x, y=y, value=value2) - # JAX specific modification - # ------------------------- - # the view does not modify the parent array - im2.setValue(x=x, y=y, value=value2) - im2_cview.setValue(x=x, y=y, value=value2) - - assert im1(x, y) == value2 - assert im1.view()(x, y) == value2 - assert im1.view(make_const=True)(x, y) == value2 - assert im2(x, y) == value2 - assert im2_view(x, y) == value2 - assert im2_cview(x, y) == value2 - - assert im1.real(x, y) == value2.real - assert im1.view().real(x, y) == value2.real - assert im1.view(make_const=True).real(x, y) == value2.real - assert im2.real(x, y) == value2.real - assert im2_view.real(x, y) == value2.real - assert im2_cview.real(x, y) == value2.real - assert im1.imag(x, y) == value2.imag - assert im1.view().imag(x, y) == value2.imag - assert im1.view(make_const=True).imag(x, y) == value2.imag - assert im2.imag(x, y) == value2.imag - assert im2_view.imag(x, y) == value2.imag - assert im2_cview.imag(x, y) == value2.imag - assert im1.conjugate(x, y) == np.conjugate(value2) - assert im2.conjugate(x, y) == np.conjugate(value2) - - # JAX specific modification - # ------------------------- - # Assigning the value of real and imag would not work - # under our assumptions. - # rvalue3 = 12*x + y - # ivalue3 = x + 21*y - # value3 = rvalue3 + 1j * ivalue3 - # im1.real.setValue(x,y, rvalue3) - # im1.imag.setValue(x,y, ivalue3) - # im2_view.real.setValue(x,y, rvalue3) - # im2_view.imag.setValue(x,y, ivalue3) - # assert im1(x,y) == value3 - # assert im1.view()(x,y) == value3 - # assert im1.view(make_const=True)(x,y) == value3 - # assert im2(x,y) == value3 - # assert im2_view(x,y) == value3 - # assert im2_cview(x,y) == value3 - # assert im1.conjugate(x,y) == np.conjugate(value3) - # assert im2.conjugate(x,y) == np.conjugate(value3) - - # Check view of given data - im3_view = galsim.Image((1 + 2j) * ref_array.astype(complex)) - slice_array = (large_array * (1 + 2j)).astype(complex)[::3, ::2] - im4_view = galsim.Image(slice_array) - for y in range(1, nrow + 1): - for x in range(1, ncol + 1): - assert im3_view(x, y) == 10 * x + y + 20j * x + 2j * y - assert im4_view(x, y) == 10 * x + y + 20j * x + 2j * y - - # JAX specific modification - # ------------------------- - # No picklibng for JAX images - # Check picklability - # check_pickle(im1) - # check_pickle(im1_view) - # check_pickle(im1_cview) - # check_pickle(im2) - # check_pickle(im2_view) - # check_pickle(im3_view) - # check_pickle(im4_view) - - -@timer -def test_complex_image_arith(): - """Additional arithmetic tests that are relevant for complex Image types""" - image1 = galsim.ImageD(ref_array) - - # Binary ImageD op complex scalar - image2 = image1 + (2 + 5j) - np.testing.assert_array_equal( - image2.array, ref_array + (2 + 5j), err_msg="ImageD + complex is not correct" - ) - image2 = image1 - (2 + 5j) - np.testing.assert_array_equal( - image2.array, ref_array - (2 + 5j), err_msg="ImageD - complex is not correct" - ) - image2 = image1 * (2 + 5j) - np.testing.assert_array_equal( - image2.array, ref_array * (2 + 5j), err_msg="ImageD * complex is not correct" - ) - image2 = image1 / (2 + 5j) - np.testing.assert_almost_equal( - image2.array, - ref_array / (2 + 5j), - decimal=12, - err_msg="ImageD / complex is not correct", - ) - - # Binary complex scalar op ImageD - image2 = (2 + 5j) + image1 - np.testing.assert_array_equal( - image2.array, ref_array + (2 + 5j), err_msg="complex + ImageD is not correct" - ) - image2 = (2 + 5j) - image1 - np.testing.assert_array_equal( - image2.array, -ref_array + (2 + 5j), err_msg="complex - ImageD is not correct" - ) - image2 = (2 + 5j) * image1 - np.testing.assert_array_equal( - image2.array, ref_array * (2 + 5j), err_msg="complex * ImageD is not correct" - ) - image2 = (2 + 5j) / image1 - np.testing.assert_almost_equal( - image2.array, - (2 + 5j) / ref_array.astype(float), - decimal=12, - err_msg="complex / ImageD is not correct", - ) - - image2 = image1 * (3 + 1j) - - # Binary ImageCD op complex scalar - image3 = image2 + (2 + 5j) - np.testing.assert_array_equal( - image3.array, - (3 + 1j) * ref_array + (2 + 5j), - err_msg="ImageCD + complex is not correct", - ) - image3 = image2 - (2 + 5j) - np.testing.assert_array_equal( - image3.array, - (3 + 1j) * ref_array - (2 + 5j), - err_msg="ImageCD - complex is not correct", - ) - image3 = image2 * (2 + 5j) - np.testing.assert_array_equal( - image3.array, - (3 + 1j) * ref_array * (2 + 5j), - err_msg="ImageCD * complex is not correct", - ) - image3 = image2 / (2 + 5j) - np.testing.assert_almost_equal( - image3.array, - (3 + 1j) * ref_array / (2 + 5j), - decimal=12, - err_msg="ImageCD / complex is not correct", - ) - - # Binary complex scalar op ImageCD - image3 = (2 + 5j) + image2 - np.testing.assert_array_equal( - image3.array, - (3 + 1j) * ref_array + (2 + 5j), - err_msg="complex + ImageCD is not correct", - ) - image3 = (2 + 5j) - image2 - np.testing.assert_array_equal( - image3.array, - (-3 - 1j) * ref_array + (2 + 5j), - err_msg="complex - ImageCD is not correct", - ) - image3 = (2 + 5j) * image2 - np.testing.assert_array_equal( - image3.array, - (3 + 1j) * ref_array * (2 + 5j), - err_msg="complex * ImageCD is not correct", - ) - image3 = (2 + 5j) / image2 - np.testing.assert_almost_equal( - image3.array, - (2 + 5j) / ((3 + 1j) * ref_array), - decimal=12, - err_msg="complex / ImageCD is not correct", - ) - - # Binary ImageD op ImageCD - image3 = image1 + image2 - np.testing.assert_array_equal( - image3.array, (4 + 1j) * ref_array, err_msg="ImageD + ImageCD is not correct" - ) - image3 = image1 - image2 - np.testing.assert_array_equal( - image3.array, (-2 - 1j) * ref_array, err_msg="ImageD - ImageCD is not correct" - ) - image3 = image1 * image2 - np.testing.assert_array_equal( - image3.array, - (3 + 1j) * ref_array**2, - err_msg="ImageD * ImageCD is not correct", - ) - image3 = image1 / image2 - np.testing.assert_almost_equal( - image3.array, - 1.0 / (3 + 1j) * np.ones_like(ref_array), - decimal=12, - err_msg="ImageD / ImageCD is not correct", - ) - - # Binary ImageCD op ImageD - image3 = image2 + image1 - np.testing.assert_array_equal( - image3.array, (4 + 1j) * ref_array, err_msg="ImageD + ImageCD is not correct" - ) - image3 = image2 - image1 - np.testing.assert_array_equal( - image3.array, (2 + 1j) * ref_array, err_msg="ImageD - ImageCD is not correct" - ) - image3 = image2 * image1 - np.testing.assert_array_equal( - image3.array, - (3 + 1j) * ref_array**2, - err_msg="ImageD * ImageCD is not correct", - ) - image3 = image2 / image1 - np.testing.assert_almost_equal( - image3.array, - (3 + 1j) * np.ones_like(ref_array), - decimal=12, - err_msg="ImageD / ImageCD is not correct", - ) - - # Binary ImageCD op ImageCD - image3 = (4 - 3j) * image1 - image4 = image2 + image3 - np.testing.assert_array_equal( - image4.array, (7 - 2j) * ref_array, err_msg="ImageCD + ImageCD is not correct" - ) - image4 = image2 - image3 - np.testing.assert_array_equal( - image4.array, (-1 + 4j) * ref_array, err_msg="ImageCD - ImageCD is not correct" - ) - image4 = image2 * image3 - np.testing.assert_array_equal( - image4.array, - (15 - 5j) * ref_array**2, - err_msg="ImageCD * ImageCD is not correct", - ) - image4 = image2 / image3 - np.testing.assert_almost_equal( - image4.array, - (9 + 13j) / 25.0 * np.ones_like(ref_array), - decimal=12, - err_msg="ImageCD / ImageCD is not correct", - ) - - # In place ImageCD op complex scalar - image4 = image2.copy() - image4 += 2 + 5j - np.testing.assert_array_equal( - image4.array, - (3 + 1j) * ref_array + (2 + 5j), - err_msg="ImageCD + complex is not correct", - ) - image4 = image2.copy() - image4 -= 2 + 5j - np.testing.assert_array_equal( - image4.array, - (3 + 1j) * ref_array - (2 + 5j), - err_msg="ImageCD - complex is not correct", - ) - image4 = image2.copy() - image4 *= 2 + 5j - np.testing.assert_array_equal( - image4.array, - (3 + 1j) * ref_array * (2 + 5j), - err_msg="ImageCD * complex is not correct", - ) - image4 = image2.copy() - image4 /= 2 + 5j - np.testing.assert_almost_equal( - image4.array, - (3 + 1j) * ref_array / (2 + 5j), - decimal=12, - err_msg="ImageCD / complex is not correct", - ) - - # In place ImageCD op ImageD - image4 = image2.copy() - image4 += image1 - np.testing.assert_array_equal( - image4.array, (4 + 1j) * ref_array, err_msg="ImageD + ImageCD is not correct" - ) - image4 = image2.copy() - image4 -= image1 - np.testing.assert_array_equal( - image4.array, (2 + 1j) * ref_array, err_msg="ImageD - ImageCD is not correct" - ) - image4 = image2.copy() - image4 *= image1 - np.testing.assert_array_equal( - image4.array, - (3 + 1j) * ref_array**2, - err_msg="ImageD * ImageCD is not correct", - ) - image4 = image2.copy() - image4 /= image1 - np.testing.assert_almost_equal( - image4.array, - (3 + 1j) * np.ones_like(ref_array), - decimal=12, - err_msg="ImageD / ImageCD is not correct", - ) - - # In place ImageCD op ImageCD - image4 = image2.copy() - image4 += image3 - np.testing.assert_array_equal( - image4.array, (7 - 2j) * ref_array, err_msg="ImageCD + ImageCD is not correct" - ) - image4 = image2.copy() - image4 -= image3 - np.testing.assert_array_equal( - image4.array, (-1 + 4j) * ref_array, err_msg="ImageCD - ImageCD is not correct" - ) - image4 = image2.copy() - image4 *= image3 - np.testing.assert_array_equal( - image4.array, - (15 - 5j) * ref_array**2, - err_msg="ImageCD * ImageCD is not correct", - ) - image4 = image2.copy() - image4 /= image3 - np.testing.assert_almost_equal( - image4.array, - (9 + 13j) / 25.0 * np.ones_like(ref_array), - decimal=12, - err_msg="ImageCD / ImageCD is not correct", - ) - - -@timer -def test_int_image_arith(): - """Additional arithmetic tests that are relevant for integer Image types""" - for i in range(int_ntypes): - full = galsim.Image(ref_array.astype(types[i])) - hi = (full // 8) * 8 - lo = full % 8 - - # - # Tests of __and__ and __iand__ operators: - # - - # lo & hi = 0 - test = lo & hi - np.testing.assert_array_equal( - test.array, 0, err_msg="& failed for Images with dtype = %s." % types[i] - ) - - # full & lo = lo - test = full & lo - np.testing.assert_array_equal( - test.array, - lo.array, - err_msg="& failed for Images with dtype = %s." % types[i], - ) - - # fullo & 0 = 0 - test = full & 0 - np.testing.assert_array_equal( - test.array, 0, err_msg="& failed for Images with dtype = %s." % types[i] - ) - - # lo & 24 = 0 - test = lo & 24 - np.testing.assert_array_equal( - test.array, 0, err_msg="& failed for Images with dtype = %s." % types[i] - ) - - # 7 & hi = 0 - test = 7 & hi - np.testing.assert_array_equal( - test.array, 0, err_msg="& failed for Images with dtype = %s." % types[i] - ) - - # full & hi = hi - test = full & hi - np.testing.assert_array_equal( - test.array, - hi.array, - err_msg="& failed for Images with dtype = %s." % types[i], - ) - - # hi &= full => hi - test &= full - np.testing.assert_array_equal( - test.array, - hi.array, - err_msg="&= failed for Images with dtype = %s." % types[i], - ) - - # hi &= 8 => (hi & 8) - test &= 8 - np.testing.assert_array_equal( - test.array, - (hi.array & 8), - err_msg="&= failed for Images with dtype = %s." % types[i], - ) - - # (hi & 8) &= hi => (hi & 8) - test &= hi - np.testing.assert_array_equal( - test.array, - (hi.array & 8), - err_msg="&= failed for Images with dtype = %s." % types[i], - ) - - # - # Tests of __or__ and __ior__ operators: - # - - # lo | hi = full - test = lo | hi - np.testing.assert_array_equal( - test.array, - full.array, - err_msg="| failed for Images with dtype = %s." % types[i], - ) - - # lo | lo = lo - test = lo | lo - np.testing.assert_array_equal( - test.array, - lo.array, - err_msg="| failed for Images with dtype = %s." % types[i], - ) - - # lo | 8 = lo + 8 - test = lo | 8 - np.testing.assert_array_equal( - test.array, - lo.array + 8, - err_msg="| failed for Images with dtype = %s." % types[i], - ) - - # 7 | hi = hi + 7 - test = 7 | hi - np.testing.assert_array_equal( - test.array, - hi.array + 7, - err_msg="| failed for Images with dtype = %s." % types[i], - ) - - # hi | 0 = hi - test = hi | 0 - np.testing.assert_array_equal( - test.array, - hi.array, - err_msg="| failed for Images with dtype = %s." % types[i], - ) - - # hi |= hi => hi - test |= hi - np.testing.assert_array_equal( - test.array, - hi.array, - err_msg="|= failed for Images with dtype = %s." % types[i], - ) - - # hi |= 3 => hi + 3 - test |= 3 - np.testing.assert_array_equal( - test.array, - hi.array + 3, - err_msg="|= failed for Images with dtype = %s." % types[i], - ) - - # - # Tests of __xor__ and __ixor__ operators: - # - - # lo ^ hi = full - test = lo ^ hi - np.testing.assert_array_equal( - test.array, - full.array, - err_msg="^ failed for Images with dtype = %s." % types[i], - ) - - # lo ^ full = hi - test = lo ^ full - np.testing.assert_array_equal( - test.array, - hi.array, - err_msg="^ failed for Images with dtype = %s." % types[i], - ) - - # lo ^ 40 = lo + 40 - test = lo ^ 40 - np.testing.assert_array_equal( - test.array, - lo.array + 40, - err_msg="^ failed for Images with dtype = %s." % types[i], - ) - - # 5 ^ hi = hi + 5 - test = 5 ^ hi - np.testing.assert_array_equal( - test.array, - hi.array + 5, - err_msg="^ failed for Images with dtype = %s." % types[i], - ) - - # full ^ hi = lo - test = full ^ hi - np.testing.assert_array_equal( - test.array, - lo.array, - err_msg="^ failed for Images with dtype = %s." % types[i], - ) - - # lo ^= hi => full - test ^= hi - np.testing.assert_array_equal( - test.array, - full.array, - err_msg="^= failed for Images with dtype = %s." % types[i], - ) - - # full ^= 111 (x2) => full - test ^= 111 - test ^= 111 - np.testing.assert_array_equal( - test.array, - full.array, - err_msg="^= failed for Images with dtype = %s." % types[i], - ) - - # full ^= lo => hi - test ^= lo - np.testing.assert_array_equal( - test.array, - hi.array, - err_msg="^= failed for Images with dtype = %s." % types[i], - ) - - # - # Tests of __mod__ and __floordiv__ operators: - # - - # lo // hi = 0 - test = lo // hi - np.testing.assert_array_equal( - test.array, 0, err_msg="// failed for Images with dtype = %s." % types[i] - ) - - # lo // 8 = 0 - test = lo // 8 - np.testing.assert_array_equal( - test.array, 0, err_msg="// failed for Images with dtype = %s." % types[i] - ) - - # lo % 8 = lo - test = lo % 8 - np.testing.assert_array_equal( - test.array, - lo.array, - err_msg="%% failed for Images with dtype = %s." % types[i], - ) - - # hi % 2 = hi & 1 - test = hi % 2 - np.testing.assert_array_equal( - test.array, - (hi & 1).array, - err_msg="%% failed for Images with dtype = %s." % types[i], - ) - - # lo % hi = lo - test = lo % hi - np.testing.assert_array_equal( - test.array, - lo.array, - err_msg="%% failed for Images with dtype = %s." % types[i], - ) - - # lo %= hi => lo - test %= hi - np.testing.assert_array_equal( - test.array, - lo.array, - err_msg="%%= failed for Images with dtype = %s." % types[i], - ) - - # lo %= 8 => lo - test %= 8 - np.testing.assert_array_equal( - test.array, - lo.array, - err_msg="%%= failed for Images with dtype = %s." % types[i], - ) - - # lo //= hi => 0 - test //= hi - np.testing.assert_array_equal( - test.array, 0, err_msg="//= failed for Images with dtype = %s." % types[i] - ) - - # 7 // hi = 0 - test = 7 // hi - np.testing.assert_array_equal( - test.array, 0, err_msg="// failed for Images with dtype = %s." % types[i] - ) - - # 7 % hi = 7 - test = 7 % hi - np.testing.assert_array_equal( - test.array, 7, err_msg="%% failed for Images with dtype = %s." % types[i] - ) - - # 7 //= 2 => 3 - test //= 2 - np.testing.assert_array_equal( - test.array, 3, err_msg="%%= failed for Images with dtype = %s." % types[i] - ) - - # 3 //= hi => 0 - test //= hi - np.testing.assert_array_equal( - test.array, 0, err_msg="//= failed for Images with dtype = %s." % types[i] - ) - - # A subset of the above for cross-type checks. - for j in range(i): - full2 = galsim.Image(ref_array.astype(types[j])) - hi2 = (full2 // 8) * 8 - lo2 = full2 % 8 - - # full & hi = hi - test = full & hi2 - np.testing.assert_array_equal( - test.array, - hi.array, - err_msg="& failed for Images with dtypes = %s, %s." - % (types[i], types[j]), - ) - # hi &= full => hi - test &= full2 - np.testing.assert_array_equal( - test.array, - hi.array, - err_msg="&= failed for Images with dtypes = %s, %s." - % (types[i], types[j]), - ) - - # lo | lo = lo - test = lo | lo2 - np.testing.assert_array_equal( - test.array, - lo.array, - err_msg="| failed for Images with dtypes = %s, %s." - % (types[i], types[j]), - ) - - # lo |= hi => full - test |= hi2 - np.testing.assert_array_equal( - test.array, - full.array, - err_msg="|= failed for Images with dtypes = %s, %s." - % (types[i], types[j]), - ) - - # lo ^ hi = full - test = lo ^ hi2 - np.testing.assert_array_equal( - test.array, - full.array, - err_msg="^ failed for Images with dtypes = %s, %s." - % (types[i], types[j]), - ) - - # full ^= lo => hi - test ^= lo2 - np.testing.assert_array_equal( - test.array, - hi.array, - err_msg="^= failed for Images with dtypes = %s, %s." - % (types[i], types[j]), - ) - - # lo // hi = 0 - test = lo // hi2 - np.testing.assert_array_equal( - test.array, - 0, - err_msg="// failed for Images with dtype = %s." % types[i], - ) - - # lo % hi = lo - test = lo % hi2 - np.testing.assert_array_equal( - test.array, - lo.array, - err_msg="%% failed for Images with dtype = %s." % types[i], - ) - - # lo %= hi => lo - test %= hi2 - np.testing.assert_array_equal( - test.array, - lo.array, - err_msg="%%= failed for Images with dtype = %s." % types[i], - ) - - # lo //= hi => 0 - test //= hi2 - np.testing.assert_array_equal( - test.array, - 0, - err_msg="//= failed for Images with dtype = %s." % types[i], - ) - - with assert_raises(ValueError): - full & full.subImage(galsim.BoundsI(0, 4, 0, 4)) - with assert_raises(ValueError): - full | full.subImage(galsim.BoundsI(0, 4, 0, 4)) - with assert_raises(ValueError): - full ^ full.subImage(galsim.BoundsI(0, 4, 0, 4)) - with assert_raises(ValueError): - full // full.subImage(galsim.BoundsI(0, 4, 0, 4)) - with assert_raises(ValueError): - full % full.subImage(galsim.BoundsI(0, 4, 0, 4)) - with assert_raises(ValueError): - full &= full.subImage(galsim.BoundsI(0, 4, 0, 4)) - with assert_raises(ValueError): - full |= full.subImage(galsim.BoundsI(0, 4, 0, 4)) - with assert_raises(ValueError): - full ^= full.subImage(galsim.BoundsI(0, 4, 0, 4)) - with assert_raises(ValueError): - full //= full.subImage(galsim.BoundsI(0, 4, 0, 4)) - with assert_raises(ValueError): - full %= full.subImage(galsim.BoundsI(0, 4, 0, 4)) - - imd = galsim.ImageD(ref_array) - with assert_raises(ValueError): - imd & full - with assert_raises(ValueError): - imd | full - with assert_raises(ValueError): - imd ^ full - with assert_raises(ValueError): - imd // full - with assert_raises(ValueError): - imd % full - with assert_raises(ValueError): - imd &= full - with assert_raises(ValueError): - imd |= full - with assert_raises(ValueError): - imd ^= full - with assert_raises(ValueError): - imd //= full - with assert_raises(ValueError): - imd %= full - - with assert_raises(ValueError): - full & imd - with assert_raises(ValueError): - full | imd - with assert_raises(ValueError): - full ^ imd - with assert_raises(ValueError): - full // imd - with assert_raises(ValueError): - full % imd - with assert_raises(ValueError): - full &= imd - with assert_raises(ValueError): - full |= imd - with assert_raises(ValueError): - full ^= imd - with assert_raises(ValueError): - full //= imd - with assert_raises(ValueError): - full %= imd - - -@timer -def test_wrap(): - """Test the image.wrap() function.""" - # Start with a fairly simple test where the image is 4 copies of the same data: - im_orig = galsim.Image( - [ - [11.0, 12.0, 13.0, 14.0, 11.0, 12.0, 13.0, 14.0], - [21.0, 22.0, 23.0, 24.0, 21.0, 22.0, 23.0, 24.0], - [31.0, 32.0, 33.0, 34.0, 31.0, 32.0, 33.0, 34.0], - [41.0, 42.0, 43.0, 44.0, 41.0, 42.0, 43.0, 44.0], - [11.0, 12.0, 13.0, 14.0, 11.0, 12.0, 13.0, 14.0], - [21.0, 22.0, 23.0, 24.0, 21.0, 22.0, 23.0, 24.0], - [31.0, 32.0, 33.0, 34.0, 31.0, 32.0, 33.0, 34.0], - [41.0, 42.0, 43.0, 44.0, 41.0, 42.0, 43.0, 44.0], - ] - ) - im = im_orig.copy() - b = galsim.BoundsI(1, 4, 1, 4) - im_quad = im_orig[b] - im_wrap = im.wrap(b) - np.testing.assert_array_almost_equal( - im_wrap.array, - 4.0 * im_quad.array, - 12, - "image.wrap() into first quadrant did not match expectation", - ) - - # The same thing should work no matter where the lower left corner is: - for xmin, ymin in ((1, 5), (5, 1), (5, 5), (2, 3), (4, 1)): - b = galsim.BoundsI(xmin, xmin + 3, ymin, ymin + 3) - im_quad = im_orig[b] - im = im_orig.copy() - im_wrap = im.wrap(b) - np.testing.assert_array_almost_equal( - im_wrap.array, - 4.0 * im_quad.array, - 12, - "image.wrap(%s) did not match expectation" % b, - ) - np.testing.assert_array_almost_equal( - im_wrap.array, - im[b].array, - 12, - "image.wrap(%s) did not return the right subimage" % b, - ) - im[b].fill(0) - np.testing.assert_array_almost_equal( - im_wrap.array, - im[b].array, - 12, - "image.wrap(%s) did not return a view of the original" % b, - ) - - # Now test where the subimage is not a simple fraction of the original, and all the - # sizes are different. - im = galsim.ImageD(17, 23, xmin=0, ymin=0) - b = galsim.BoundsI(7, 9, 11, 18) - im_test = galsim.ImageD(b, init_value=0) - for i in range(17): - for j in range(23): - val = np.exp(i / 7.3) + (j / 12.9) ** 3 # Something randomly complicated... - im[i, j] = val - # Find the location in the sub-image for this point. - ii = (i - b.xmin) % (b.xmax - b.xmin + 1) + b.xmin - jj = (j - b.ymin) % (b.ymax - b.ymin + 1) + b.ymin - im_test.addValue(ii, jj, val) - im_wrap = im.wrap(b) - np.testing.assert_array_almost_equal( - im_wrap.array, im_test.array, 12, "image.wrap(%s) did not match expectation" % b - ) - np.testing.assert_array_equal( - im_wrap.array, im[b].array, "image.wrap(%s) did not return the right subimage" - ) - np.testing.assert_equal( - im_wrap.bounds, b, "image.wrap(%s) does not have the correct bounds" - ) - - # For complex images (in particular k-space images), we often want the image to be implicitly - # Hermitian, so we only need to keep around half of it. - M = 38 - N = 25 - K = 8 - L = 5 - im = galsim.ImageCD(2 * M + 1, 2 * N + 1, xmin=-M, ymin=-N) # Explicitly Hermitian - im2 = galsim.ImageCD( - 2 * M + 1, N + 1, xmin=-M, ymin=0 - ) # Implicitly Hermitian across y axis - im3 = galsim.ImageCD( - M + 1, 2 * N + 1, xmin=0, ymin=-N - ) # Implicitly Hermitian across x axis - # print('im = ',im) - # print('im2 = ',im2) - # print('im3 = ',im3) - b = galsim.BoundsI(-K + 1, K, -L + 1, L) - b2 = galsim.BoundsI(-K + 1, K, 0, L) - b3 = galsim.BoundsI(0, K, -L + 1, L) - im_test = galsim.ImageCD(b, init_value=0) - for i in range(-M, M + 1): - for j in range(-N, N + 1): - # An arbitrary, complicated Hermitian function. - val = ( - np.exp((i / (2.3 * M)) ** 2 + 1j * (2.8 * i - 1.3 * j)) - + ((2 + 3j * j) / (1.9 * N)) ** 3 - ) - # val = 2*(i-j)**2 + 3j*(i+j) - - im[i, j] = val - if j >= 0: - im2[i, j] = val - if i >= 0: - im3[i, j] = val - - ii = (i - b.xmin) % (b.xmax - b.xmin + 1) + b.xmin - jj = (j - b.ymin) % (b.ymax - b.ymin + 1) + b.ymin - im_test.addValue(ii, jj, val) - # print("im = ",im.array) - - # Confirm that the image is Hermitian. - for i in range(-M, M + 1): - for j in range(-N, N + 1): - assert im(i, j) == im(-i, -j).conjugate() - - im_wrap = im.wrap(b) - # print("im_wrap = ",im_wrap.array) - np.testing.assert_array_almost_equal( - im_wrap.array, im_test.array, 12, "image.wrap(%s) did not match expectation" % b - ) - np.testing.assert_array_equal( - im_wrap.array, im[b].array, "image.wrap(%s) did not return the right subimage" - ) - np.testing.assert_equal( - im_wrap.bounds, b, "image.wrap(%s) does not have the correct bounds" - ) - - im2_wrap = im2.wrap(b2, hermitian="y") - # print('im_test = ',im_test[b2].array) - # print('im2_wrap = ',im2_wrap.array) - # print('diff = ',im2_wrap.array-im_test[b2].array) - np.testing.assert_array_almost_equal( - im2_wrap.array, - im_test[b2].array, - 12, - "image.wrap(%s) did not match expectation" % b, - ) - np.testing.assert_array_equal( - im2_wrap.array, - im2[b2].array, - "image.wrap(%s) did not return the right subimage", - ) - np.testing.assert_equal( - im2_wrap.bounds, b2, "image.wrap(%s) does not have the correct bounds" - ) - - im3_wrap = im3.wrap(b3, hermitian="x") - # print('im_test = ',im_test[b3].array) - # print('im3_wrap = ',im3_wrap.array) - # print('diff = ',im3_wrap.array-im_test[b3].array) - np.testing.assert_array_almost_equal( - im3_wrap.array, - im_test[b3].array, - 12, - "image.wrap(%s) did not match expectation" % b, - ) - np.testing.assert_array_equal( - im3_wrap.array, - im3[b3].array, - "image.wrap(%s) did not return the right subimage", - ) - np.testing.assert_equal( - im3_wrap.bounds, b3, "image.wrap(%s) does not have the correct bounds" - ) - - b = galsim.BoundsI(-K + 1, K, -L + 1, L) - b2 = galsim.BoundsI(-K + 1, K, 0, L) - b3 = galsim.BoundsI(0, K, -L + 1, L) - assert_raises(TypeError, im.wrap, bounds=None) - assert_raises(ValueError, im3.wrap, b, hermitian="x") - assert_raises(ValueError, im3.wrap, b2, hermitian="x") - assert_raises(ValueError, im.wrap, b3, hermitian="x") - assert_raises(ValueError, im2.wrap, b, hermitian="y") - assert_raises(ValueError, im2.wrap, b3, hermitian="y") - assert_raises(ValueError, im.wrap, b2, hermitian="y") - assert_raises(ValueError, im.wrap, b, hermitian="invalid") - assert_raises(ValueError, im2.wrap, b2, hermitian="invalid") - assert_raises(ValueError, im3.wrap, b3, hermitian="invalid") - - -@timer -def test_FITS_bad_type(): - """Test that reading FITS files with an invalid data type succeeds by converting the - type to float64. - """ - # We check this by monkey patching the Image.valid_types list to not include int16 - # and see if it reads properly and raises the appropriate warning. - orig_dtypes = galsim.Image.valid_dtypes - setattr(galsim.Image, "valid_dtypes", (np.int32, np.float32, np.float64)) - - testS_file = os.path.join(datadir, "testS.fits") - testMultiS_file = os.path.join(datadir, "test_multiS.fits") - testCubeS_file = os.path.join(datadir, "test_cubeS.fits") - with assert_warns(galsim.GalSimWarning): - testS_image = galsim.fits.read(testS_file) - with assert_warns(galsim.GalSimWarning): - testMultiS_image_list = galsim.fits.readMulti(testMultiS_file) - with assert_warns(galsim.GalSimWarning): - testCubeS_image_list = galsim.fits.readCube(testCubeS_file) - - np.testing.assert_equal(np.float64, testS_image.array.dtype.type) - np.testing.assert_array_equal( - ref_array.astype(float), - testS_image.array, - err_msg="ImageS read failed reading when int16 not a valid image type", - ) - for k in range(nimages): - np.testing.assert_equal(np.float64, testMultiS_image_list[k].array.dtype.type) - np.testing.assert_equal(np.float64, testCubeS_image_list[k].array.dtype.type) - np.testing.assert_array_equal( - (ref_array + k).astype(float), - testMultiS_image_list[k].array, - err_msg="ImageS readMulti failed reading when int16 not a valid image type", - ) - np.testing.assert_array_equal( - (ref_array + k).astype(float), - testCubeS_image_list[k].array, - err_msg="ImageS readCube failed reading when int16 not a valid image type", - ) - - # Don't forget to set it back to the original. - setattr(galsim.Image, "valid_dtypes", orig_dtypes) - - -@timer -def test_bin(): - """Test the bin and subsample methods""" - - # Start with a relatively simple case of 2x2 binning with no partial bins to worry about. - obj = galsim.Gaussian(sigma=2, flux=17).shear(g1=0.1, g2=0.3) - im1 = obj.drawImage(nx=10, ny=14, scale=0.6, dtype=float) - - im2 = obj.drawImage(nx=20, ny=28, scale=0.3, dtype=float) - im3 = im2.bin(2, 2) - ar2 = im2.array - ar3b = ar2[0::2, 0::2] + ar2[0::2, 1::2] + ar2[1::2, 0::2] + ar2[1::2, 1::2] - - np.testing.assert_almost_equal( - ar3b.sum(), im2.array.sum(), 6, "direct binning didn't perserve total flux" - ) - np.testing.assert_almost_equal( - ar3b, im3.array, 6, "direct binning didn't match bin function." - ) - np.testing.assert_almost_equal( - im3.array.sum(), im2.array.sum(), 6, "bin didn't preserve the total flux" - ) - np.testing.assert_almost_equal( - im3.array, - im1.array, - 6, - "2x2 binned image doesn't match image with 2x2 larger pixels", - ) - np.testing.assert_almost_equal( - im3.scale, im1.scale, 6, "bin resulted in wrong scale" - ) - - im4 = im2.subsample(2, 2) - np.testing.assert_almost_equal( - im4.array.sum(), im2.array.sum(), 6, "subsample didn't preserve the total flux" - ) - np.testing.assert_almost_equal( - im4.scale, im2.scale / 2.0, 6, "subsample resulted in wrong scale" - ) - im5 = im4.bin(2, 2) - np.testing.assert_almost_equal( - im5.array, - im2.array, - 6, - "Round trip subsample then bin 2x2 doesn't match original", - ) - np.testing.assert_almost_equal( - im5.scale, im2.scale, 6, "round trip resulted in wrong scale" - ) - - # Next do nx != ny. And wcs = JacobianWCS - wcs1 = galsim.JacobianWCS(0.6, 0.14, 0.15, 0.7) - im1 = obj.drawImage(nx=11, ny=15, wcs=wcs1, dtype=float) - im1.wcs = im1.wcs.withOrigin(im1.true_center, galsim.PositionD(200, 300)) - center1 = im1.wcs.toWorld(im1.true_center) - print("center1 = ", center1) - - wcs2 = galsim.JacobianWCS(0.2, 0.07, 0.05, 0.35) - im2 = obj.drawImage(nx=33, ny=30, wcs=wcs2, dtype=float) - im2.wcs = im2.wcs.withOrigin(im2.true_center, galsim.PositionD(200, 300)) - center2 = im2.wcs.toWorld(im2.true_center) - print("center2 = ", center2) - im3 = im2.bin(3, 2) - center3 = im2.wcs.toWorld(im2.true_center) - print("center3 = ", center3) - ar2 = im2.array - ar3b = ( - ar2[0::2, 0::3] - + ar2[0::2, 1::3] - + ar2[0::2, 2::3] - + ar2[1::2, 0::3] - + ar2[1::2, 1::3] - + ar2[1::2, 2::3] - ) - - np.testing.assert_almost_equal( - ar3b.sum(), im2.array.sum(), 6, "direct binning didn't perserve total flux" - ) - np.testing.assert_almost_equal( - ar3b, im3.array, 6, "direct binning didn't match bin function." - ) - np.testing.assert_almost_equal( - im3.array.sum(), im2.array.sum(), 6, "bin didn't preserve the total flux" - ) - np.testing.assert_almost_equal( - im3.array, - im1.array, - 6, - "3x2 binned image doesn't match image with 3x2 larger pixels", - ) - np.testing.assert_almost_equal( - (center3.x, center3.y), - (center1.x, center1.y), - 6, - "3x2 binned image wcs is wrong", - ) - - im4 = im2.subsample(3, 2) - np.testing.assert_almost_equal( - im4.array.sum(), im2.array.sum(), 6, "subsample didn't preserve the total flux" - ) - center4 = im4.wcs.toWorld(im4.true_center) - print("center4 = ", center4) - np.testing.assert_almost_equal( - (center4.x, center4.y), - (center1.x, center1.y), - 6, - "3x2 subsampled image wcs is wrong", - ) - - im5 = im4.bin(3, 2) - np.testing.assert_almost_equal( - im5.array, - im2.array, - 6, - "Round trip subsample then bin 3x2 doesn't match original", - ) - center5 = im5.wcs.toWorld(im5.true_center) - print("center5 = ", center5) - np.testing.assert_almost_equal( - (center5.x, center5.y), - (center1.x, center1.y), - 6, - "Round trip 3x2 image wcs is wrong", - ) - - # If the initial wcs is None or not uniform, then the resulting wcs will be None. - im2.wcs = galsim.UVFunction("0.6 + np.sin(x*y)", "0.6 + np.cos(x+y)") - im3b = im2.bin(3, 2) - assert im3b.wcs is None - np.testing.assert_array_equal( - im3b.array, im3.array, "The wcs shouldn't affect what bin does to the array." - ) - im4b = im2.subsample(3, 2) - assert im4b.wcs is None - np.testing.assert_array_equal( - im4b.array, - im4.array, - "The wcs shouldn't affect what subsample does to the array.", - ) - - im2.wcs = None - im3c = im2.bin(3, 2) - assert im3c.wcs is None - np.testing.assert_array_equal( - im3c.array, im3.array, "The wcs shouldn't affect what bin does to the array." - ) - im4c = im2.subsample(3, 2) - assert im4c.wcs is None - np.testing.assert_array_equal( - im4c.array, - im4.array, - "The wcs shouldn't affect what subsample does to the array.", - ) - - # Finally, binning should still work, even if the number of pixels doesn't go evenly into - # the number of pixels/block specified. - im6 = im1.bin(8, 8) - ar6 = np.array( - [ - [im1.array[0:8, 0:8].sum(), im1.array[0:8, 8:].sum()], - [im1.array[8:, 0:8].sum(), im1.array[8:, 8:].sum()], - ] - ) - np.testing.assert_almost_equal( - im6.array, ar6, 6, "Binning past the edge of the image didn't work properly" - ) - # The center of this image doesn't correspond to the center of the original. - # But the lower left edge does. - origin1 = im1.wcs.toWorld(galsim.PositionD(im1.xmin - 0.5, im1.ymin - 0.5)) - origin6 = im6.wcs.toWorld(galsim.PositionD(im1.xmin - 0.5, im6.ymin - 0.5)) - print("origin1 = ", origin1) - print("origin6 = ", origin6) - np.testing.assert_almost_equal( - (origin6.x, origin6.y), - (origin1.x, origin1.y), - 6, - "Binning past the edge resulted in wrong wcs", - ) - - -@timer -def test_fpack(): - """Test the functionality that we advertise as being equivalent to fpack/funpack""" - from astropy.io import fits - - file_name0 = os.path.join( - "tests/GalSim/tests/des_data", "DECam_00158414_01.fits.fz" - ) - hdulist = fits.open(file_name0) - - # Remove a few invalid header keys in the DECam fits file - # The I/O works if we don't, but they complicate the later tests. - # Easier to just get rid of them. - hdulist.verify("silentfix") - for k in list(hdulist[1].header.keys()): - if k.startswith("G-") or k.startswith("time_recorded"): - print("remove header key ", k, hdulist[1].header[k]) - del hdulist[1].header[k] - - file_name1 = os.path.join("tests/output", "DECam_00158414_01_fix.fits.fz") - hdulist.writeto(file_name1, overwrite=True) - - file_name2 = os.path.join("tests/output", "DECam_00158414_01.fits") - file_name3 = os.path.join("tests/output", "DECam_00158414_01.fits.fz") - - # This line basically does funpack: - galsim.fits.writeMulti( - galsim.fits.readMulti(file_name1, read_headers=True), file_name2 - ) - - # This line basically does fpack: - galsim.fits.writeMulti( - galsim.fits.readMulti(file_name2, read_headers=True), file_name3 - ) - - # Check that the final file is essentially equivalent to the original. - imlist1 = galsim.fits.readMulti(file_name1, read_headers=True) - imlist3 = galsim.fits.readMulti(file_name3, read_headers=True) - - assert len(imlist1) == len(imlist3) - for im1, im3 in zip(imlist1, imlist3): - for key in im1.header.keys(): - if key in im3.header and key not in ["XTENSION", "COMMENT", "HISTORY", ""]: - if isinstance(im1.header[key], str): - assert im1.header[key] == im3.header[key] - else: - np.testing.assert_allclose(im1.header[key], im3.header[key]) - assert im1.bounds == im3.bounds - if isinstance(im1.wcs, galsim.GSFitsWCS): - assert im1.wcs.wcs_type == im3.wcs.wcs_type - np.testing.assert_array_equal(im1.wcs.crpix, im3.wcs.crpix) - np.testing.assert_array_equal(im1.wcs.cd, im3.wcs.cd) - np.testing.assert_array_equal(im1.wcs.center, im3.wcs.center) - # pv isn't identical through round trip, but vv close - np.testing.assert_allclose(im1.wcs.pv, im3.wcs.pv, rtol=1.0e-15) - else: - assert im1.wcs == im3.wcs - # array is not identical. Rice compression is lossy, so most pixel values change somewhat. - # Noise is ~100 ADU, so differences of < 1 ADU are small here. - np.testing.assert_allclose(im1.array, im3.array, atol=1) - - -if __name__ == "__main__": - testfns = [v for k, v in vars().items() if k[:5] == "test_" and callable(v)] - for testfn in testfns: - testfn() diff --git a/tests/jax/galsim/test_noise_jax.py b/tests/jax/galsim/test_noise_jax.py deleted file mode 100644 index c6eb3c97..00000000 --- a/tests/jax/galsim/test_noise_jax.py +++ /dev/null @@ -1,843 +0,0 @@ - -import numpy as np -import jax_galsim as galsim -from galsim_test_helpers import timer, assert_raises, check_pickle, drawNoise -import jax.numpy as jnp - -testseed = 1000 - -precision = 10 -# decimal point at which agreement is required for all double precision tests - -precisionD = precision -precisionF = 5 # precision=10 does not make sense at single precision -precisionS = 1 # "precision" also a silly concept for ints, but allows all 4 tests to run in one go -precisionI = 1 - - -@timer -def test_deviate_noise(): - """Test basic functionality of the DeviateNoise class - """ - u = galsim.UniformDeviate(testseed) - uResult = jnp.empty((10, 10)) - uResult = u.generate(uResult) - - noise = galsim.DeviateNoise(galsim.UniformDeviate(testseed)) - - # Test filling an image with random values - testimage = galsim.ImageD(10, 10) - testimage.addNoise(noise) - np.testing.assert_array_almost_equal( - testimage.array, uResult, precision, - err_msg='Wrong uniform random number sequence generated when applied to image.') - - # Test filling a single-precision image - noise.rng.seed(testseed) - testimage = galsim.ImageF(10, 10) - testimage.addNoise(noise) - np.testing.assert_array_almost_equal( - testimage.array, uResult, precisionF, - err_msg='Wrong uniform random number sequence generated when applied to ImageF.') - - # Test filling an image with Fortran ordering - noise.rng.seed(testseed) - testimage = galsim.ImageD(np.zeros((10, 10)).T) - testimage.addNoise(noise) - np.testing.assert_array_almost_equal( - testimage.array, uResult, precision, - err_msg="Wrong uniform randoms generated for Fortran-ordered Image") - - # Check picklability - check_pickle(noise, drawNoise) - check_pickle(noise) - - # Check copy, eq and ne - noise2 = galsim.DeviateNoise(noise.rng.duplicate()) # Separate but equivalent rng chain. - noise3 = noise.copy() # Always has exactly the same rng as noise. - noise4 = noise.copy(rng=galsim.BaseDeviate(11)) # Always has a different rng than noise - assert noise == noise2 - assert noise == noise3 - assert noise != noise4 - assert noise.rng() == noise2.rng() - assert noise == noise2 # Still equal because both chains incremented one place. - assert noise == noise3 - noise.rng() - assert noise2 != noise3 # This is no longer equal, since only noise.rng is incremented. - assert noise == noise3 - - assert_raises(TypeError, galsim.DeviateNoise, 53) - assert_raises(NotImplementedError, galsim.BaseNoise().getVariance) - assert_raises(NotImplementedError, galsim.BaseNoise().withVariance, 23) - assert_raises(NotImplementedError, galsim.BaseNoise().withScaledVariance, 23) - assert_raises(TypeError, noise.applyTo, 23) - assert_raises(NotImplementedError, galsim.BaseNoise().applyTo, testimage) - assert_raises(galsim.GalSimError, noise.getVariance) - assert_raises(galsim.GalSimError, noise.withVariance, 23) - assert_raises(galsim.GalSimError, noise.withScaledVariance, 23) - - -@timer -def test_gaussian_noise(): - """Test Gaussian random number generator - """ - gSigma = 17.23 - g = galsim.GaussianDeviate(testseed, sigma=gSigma) - gResult = np.empty((10, 10)) - gResult = g.generate(gResult) - noise = galsim.DeviateNoise(g) - - # Test filling an image - testimage = galsim.ImageD(10, 10) - noise.rng.seed(testseed) - testimage.addNoise(noise) - np.testing.assert_array_almost_equal( - testimage.array, gResult, precision, - err_msg='Wrong Gaussian random number sequence generated when applied to image.') - - # Test filling a single-precision image - noise.rng.seed(testseed) - testimage = galsim.ImageF(10, 10) - testimage.addNoise(noise) - np.testing.assert_array_almost_equal( - testimage.array, gResult, precisionF, - err_msg='Wrong Gaussian random number sequence generated when applied to ImageF.') - - # GaussianNoise is equivalent, but no mean allowed. - gn = galsim.GaussianNoise(galsim.BaseDeviate(testseed), sigma=gSigma) - testimage = galsim.ImageD(10, 10) - testimage.addNoise(gn) - np.testing.assert_array_almost_equal( - testimage.array, gResult, precision, - err_msg="GaussianNoise applied to Images does not reproduce expected sequence") - - # Test filling an image with Fortran ordering - gn.rng.seed(testseed) - testimage = galsim.ImageD(np.zeros((10, 10)).T) - testimage.addNoise(gn) - np.testing.assert_array_almost_equal( - testimage.array, gResult, precision, - err_msg="Wrong Gaussian noise generated for Fortran-ordered Image") - - # Check GaussianNoise variance: - np.testing.assert_almost_equal( - gn.getVariance(), gSigma**2, precision, - err_msg="GaussianNoise getVariance returns wrong variance") - np.testing.assert_almost_equal( - gn.sigma, gSigma, precision, - err_msg="GaussianNoise sigma returns wrong value") - - # Check that the noise model really does produce this variance. - big_im = galsim.Image(2048, 2048, dtype=float) - gn.rng.seed(testseed) - big_im.addNoise(gn) - var = np.var(big_im.array) - print('variance = ', var) - print('getVar = ', gn.getVariance()) - np.testing.assert_almost_equal( - var, gn.getVariance(), 1, - err_msg='Realized variance for GaussianNoise did not match getVariance()') - - # Check that GaussianNoise adds to the image, not overwrites the image. - gal = galsim.Exponential(half_light_radius=2.3, flux=1.e4) - gal.drawImage(image=big_im) - gn.rng.seed(testseed) - big_im.addNoise(gn) - gal.withFlux(-1.e4).drawImage(image=big_im, add_to_image=True) - var = np.var(big_im.array) - np.testing.assert_almost_equal( - var, gn.getVariance(), 1, - err_msg='GaussianNoise wrong when already an object drawn on the image') - - # Check that DeviateNoise adds to the image, not overwrites the image. - gal.drawImage(image=big_im) - gn.rng.seed(testseed) - big_im.addNoise(gn) - gal.withFlux(-1.e4).drawImage(image=big_im, add_to_image=True) - var = np.var(big_im.array) - np.testing.assert_almost_equal( - var, gn.getVariance(), 1, - err_msg='DeviateNoise wrong when already an object drawn on the image') - - # Check withVariance - gn = gn.withVariance(9.) - np.testing.assert_almost_equal( - gn.getVariance(), 9, precision, - err_msg="GaussianNoise withVariance results in wrong variance") - np.testing.assert_almost_equal( - gn.sigma, 3., precision, - err_msg="GaussianNoise withVariance results in wrong sigma") - - # Check withScaledVariance - gn = gn.withScaledVariance(4.) - np.testing.assert_almost_equal( - gn.getVariance(), 36., precision, - err_msg="GaussianNoise withScaledVariance results in wrong variance") - np.testing.assert_almost_equal( - gn.sigma, 6., precision, - err_msg="GaussianNoise withScaledVariance results in wrong sigma") - - # Check arithmetic - gn = gn.withVariance(0.5) - gn2 = gn * 3 - np.testing.assert_almost_equal( - gn2.getVariance(), 1.5, precision, - err_msg="GaussianNoise gn*3 results in wrong variance") - np.testing.assert_almost_equal( - gn.getVariance(), 0.5, precision, - err_msg="GaussianNoise gn*3 results in wrong variance for original gn") - gn2 = 5 * gn - np.testing.assert_almost_equal( - gn2.getVariance(), 2.5, precision, - err_msg="GaussianNoise 5*gn results in wrong variance") - np.testing.assert_almost_equal( - gn.getVariance(), 0.5, precision, - err_msg="GaussianNoise 5*gn results in wrong variance for original gn") - gn2 = gn / 2 - np.testing.assert_almost_equal( - gn2.getVariance(), 0.25, precision, - err_msg="GaussianNoise gn/2 results in wrong variance") - np.testing.assert_almost_equal( - gn.getVariance(), 0.5, precision, - err_msg="GaussianNoise 5*gn results in wrong variance for original gn") - gn *= 3 - np.testing.assert_almost_equal( - gn.getVariance(), 1.5, precision, - err_msg="GaussianNoise gn*=3 results in wrong variance") - gn /= 2 - np.testing.assert_almost_equal( - gn.getVariance(), 0.75, precision, - err_msg="GaussianNoise gn/=2 results in wrong variance") - - # Check starting with GaussianNoise() - gn2 = galsim.GaussianNoise() - gn2 = gn2.withVariance(9.) - np.testing.assert_almost_equal( - gn2.getVariance(), 9, precision, - err_msg="GaussianNoise().withVariance results in wrong variance") - np.testing.assert_almost_equal( - gn2.sigma, 3., precision, - err_msg="GaussianNoise().withVariance results in wrong sigma") - - gn2 = galsim.GaussianNoise() - gn2 = gn2.withScaledVariance(4.) - np.testing.assert_almost_equal( - gn2.getVariance(), 4., precision, - err_msg="GaussianNoise().withScaledVariance results in wrong variance") - np.testing.assert_almost_equal( - gn2.sigma, 2., precision, - err_msg="GaussianNoise().withScaledVariance results in wrong sigma") - - # Check picklability - check_pickle(gn, lambda x: (x.rng.serialize(), x.sigma)) - check_pickle(gn, drawNoise) - check_pickle(gn) - - # Check copy, eq and ne - gn = gn.withVariance(gSigma**2) - gn2 = galsim.GaussianNoise(gn.rng.duplicate(), gSigma) - gn3 = gn.copy() - gn4 = gn.copy(rng=galsim.BaseDeviate(11)) - gn5 = galsim.GaussianNoise(gn.rng, 2. * gSigma) - assert gn == gn2 - assert gn == gn3 - assert gn != gn4 - assert gn != gn5 - assert gn.rng.raw() == gn2.rng.raw() - assert gn == gn2 - assert gn == gn3 - gn.rng.raw() - assert gn != gn2 - assert gn == gn3 - - -@timer -def test_variable_gaussian_noise(): - """Test VariableGaussian random number generator - """ - # Make a checkerboard image with two values for the variance - gSigma1 = 17.23 - gSigma2 = 28.55 - var_image = galsim.ImageD(galsim.BoundsI(0, 9, 0, 9)) - coords = np.ogrid[0:10, 0:10] - var_image._array = var_image.array.at[(coords[0] + coords[1]) % 2 == 1].set(gSigma1**2) - var_image._array = var_image.array.at[(coords[0] + coords[1]) % 2 == 0].set(gSigma2**2) - print('var_image.array = ', var_image.array) - - g = galsim.GaussianDeviate(testseed, sigma=1.) - vgResult = np.empty((10, 10)) - vgResult = g.generate(vgResult) - vgResult *= np.sqrt(var_image.array) - - # Test filling an image - vgn = galsim.VariableGaussianNoise(galsim.BaseDeviate(testseed), var_image) - testimage = galsim.ImageD(10, 10) - testimage.addNoise(vgn) - np.testing.assert_array_almost_equal( - testimage.array, vgResult, precision, - err_msg="VariableGaussianNoise applied to Images does not reproduce expected sequence") - - # Test filling an image with Fortran ordering - vgn.rng.seed(testseed) - testimage = galsim.ImageD(np.zeros((10, 10)).T) - testimage.addNoise(vgn) - np.testing.assert_array_almost_equal( - testimage.array, vgResult, precision, - err_msg="Wrong VariableGaussian noise generated for Fortran-ordered Image") - - # Check var_image property - np.testing.assert_array_almost_equal( - vgn.var_image.array, var_image.array, precision, - err_msg="VariableGaussianNoise var_image returns wrong var_image") - - # Check that the noise model really does produce this variance. - big_var_image = galsim.ImageD(galsim.BoundsI(0, 2047, 0, 2047)) - big_coords = np.ogrid[0:2048, 0:2048] - mask1 = (big_coords[0] + big_coords[1]) % 2 == 0 - mask2 = (big_coords[0] + big_coords[1]) % 2 == 1 - big_var_image._array = big_var_image.array.at[mask1].set(gSigma1**2) - big_var_image._array = big_var_image.array.at[mask2].set(gSigma2**2) - big_vgn = galsim.VariableGaussianNoise(galsim.BaseDeviate(testseed), big_var_image) - - big_im = galsim.Image(2048, 2048, dtype=float) - big_im.addNoise(big_vgn) - var = np.var(big_im.array) - print('variance = ', var) - print('getVar = ', big_vgn.var_image.array.mean()) - # NOTE had to turn down precision to 0 due to different RNG - np.testing.assert_almost_equal( - var, big_vgn.var_image.array.mean(), 0, - err_msg='Realized variance for VariableGaussianNoise did not match var_image') - - # Check realized variance in each mask - print('rms1 = ', np.std(big_im.array[mask1])) - print('rms2 = ', np.std(big_im.array[mask2])) - np.testing.assert_almost_equal(np.std(big_im.array[mask1]), gSigma1, decimal=1) - np.testing.assert_almost_equal(np.std(big_im.array[mask2]), gSigma2, decimal=1) - - # Check that VariableGaussianNoise adds to the image, not overwrites the image. - gal = galsim.Exponential(half_light_radius=2.3, flux=1.e4) - gal.drawImage(image=big_im) - big_vgn.rng.seed(testseed) - big_im.addNoise(big_vgn) - gal.withFlux(-1.e4).drawImage(image=big_im, add_to_image=True) - var = np.var(big_im.array) - # NOTE had to turn down precision to 0 due to different RNG - np.testing.assert_almost_equal( - var, big_vgn.var_image.array.mean(), 0, - err_msg='VariableGaussianNoise wrong when already an object drawn on the image') - - # Check picklability - check_pickle(vgn, lambda x: (x.rng.serialize(), x.var_image)) - check_pickle(vgn, drawNoise) - check_pickle(vgn) - - # Check copy, eq and ne - vgn2 = galsim.VariableGaussianNoise(vgn.rng.duplicate(), var_image) - vgn3 = vgn.copy() - vgn4 = vgn.copy(rng=galsim.BaseDeviate(11)) - vgn5 = galsim.VariableGaussianNoise(vgn.rng, 2. * var_image) - assert vgn == vgn2 - assert vgn == vgn3 - assert vgn != vgn4 - assert vgn != vgn5 - assert vgn.rng.raw() == vgn2.rng.raw() - assert vgn == vgn2 - assert vgn == vgn3 - vgn.rng.raw() - assert vgn != vgn2 - assert vgn == vgn3 - - assert_raises(TypeError, vgn.applyTo, 23) - assert_raises(ValueError, vgn.applyTo, galsim.ImageF(3, 3)) - assert_raises(galsim.GalSimError, vgn.getVariance) - assert_raises(galsim.GalSimError, vgn.withVariance, 23) - assert_raises(galsim.GalSimError, vgn.withScaledVariance, 23) - - -@timer -def test_poisson_noise(): - """Test Poisson random number generator - """ - pMean = 17 - p = galsim.PoissonDeviate(testseed, mean=pMean) - pResult = np.empty((10, 10)) - pResult = p.generate(pResult) - noise = galsim.DeviateNoise(p) - - # Test filling an image - noise.rng.seed(testseed) - testimage = galsim.ImageI(10, 10) - # NOTE - this line changed since it appeared to be buggy in galsim - testimage.addNoise(noise) - np.testing.assert_array_equal( - testimage.array, pResult, - err_msg='Wrong poisson random number sequence generated when applied to image.') - - # The PoissonNoise version also subtracts off the mean value - pn = galsim.PoissonNoise(galsim.BaseDeviate(testseed), sky_level=pMean) - testimage.fill(0) - testimage.addNoise(pn) - np.testing.assert_array_equal( - testimage.array, pResult - pMean, - err_msg='Wrong poisson random number sequence generated using PoissonNoise') - - # Test filling a single-precision image - pn.rng.seed(testseed) - testimage = galsim.ImageF(10, 10) - testimage.addNoise(pn) - np.testing.assert_array_almost_equal( - testimage.array, pResult - pMean, precisionF, - err_msg='Wrong Poisson random number sequence generated when applied to ImageF.') - - # Test filling an image with Fortran ordering - pn.rng.seed(testseed) - testimage = galsim.ImageD(10, 10) - testimage.addNoise(pn) - np.testing.assert_array_almost_equal( - testimage.array, pResult - pMean, - err_msg="Wrong Poisson noise generated for Fortran-ordered Image") - - # Check PoissonNoise variance: - np.testing.assert_almost_equal( - pn.getVariance(), pMean, precision, - err_msg="PoissonNoise getVariance returns wrong variance") - np.testing.assert_almost_equal( - pn.sky_level, pMean, precision, - err_msg="PoissonNoise sky_level returns wrong value") - - # Check that the noise model really does produce this variance. - big_im = galsim.Image(2048, 2048, dtype=float) - big_im.addNoise(pn) - var = np.var(big_im.array) - print('variance = ', var) - print('getVar = ', pn.getVariance()) - np.testing.assert_almost_equal( - var, pn.getVariance(), 1, - err_msg='Realized variance for PoissonNoise did not match getVariance()') - - # Check that PoissonNoise adds to the image, not overwrites the image. - gal = galsim.Exponential(half_light_radius=2.3, flux=0.3) - # Note: in this case, flux/size^2 needs to be << sky_level or it will mess up the statistics. - gal.drawImage(image=big_im) - big_im.addNoise(pn) - gal.withFlux(-0.3).drawImage(image=big_im, add_to_image=True) - var = np.var(big_im.array) - np.testing.assert_almost_equal( - var, pn.getVariance(), 1, - err_msg='PoissonNoise wrong when already an object drawn on the image') - - # Check withVariance - pn = pn.withVariance(9.) - np.testing.assert_almost_equal( - pn.getVariance(), 9., precision, - err_msg="PoissonNoise withVariance results in wrong variance") - np.testing.assert_almost_equal( - pn.sky_level, 9., precision, - err_msg="PoissonNoise withVariance results in wrong sky_level") - - # Check withScaledVariance - pn = pn.withScaledVariance(4.) - np.testing.assert_almost_equal( - pn.getVariance(), 36, precision, - err_msg="PoissonNoise withScaledVariance results in wrong variance") - np.testing.assert_almost_equal( - pn.sky_level, 36., precision, - err_msg="PoissonNoise withScaledVariance results in wrong sky_level") - - # Check arithmetic - pn = pn.withVariance(0.5) - pn2 = pn * 3 - np.testing.assert_almost_equal( - pn2.getVariance(), 1.5, precision, - err_msg="PoissonNoise pn*3 results in wrong variance") - np.testing.assert_almost_equal( - pn.getVariance(), 0.5, precision, - err_msg="PoissonNoise pn*3 results in wrong variance for original pn") - pn2 = 5 * pn - np.testing.assert_almost_equal( - pn2.getVariance(), 2.5, precision, - err_msg="PoissonNoise 5*pn results in wrong variance") - np.testing.assert_almost_equal( - pn.getVariance(), 0.5, precision, - err_msg="PoissonNoise 5*pn results in wrong variance for original pn") - pn2 = pn / 2 - np.testing.assert_almost_equal( - pn2.getVariance(), 0.25, precision, - err_msg="PoissonNoise pn/2 results in wrong variance") - np.testing.assert_almost_equal( - pn.getVariance(), 0.5, precision, - err_msg="PoissonNoise 5*pn results in wrong variance for original pn") - pn *= 3 - np.testing.assert_almost_equal( - pn.getVariance(), 1.5, precision, - err_msg="PoissonNoise pn*=3 results in wrong variance") - pn /= 2 - np.testing.assert_almost_equal( - pn.getVariance(), 0.75, precision, - err_msg="PoissonNoise pn/=2 results in wrong variance") - - # Check starting with PoissonNoise() - pn = galsim.PoissonNoise() - pn = pn.withVariance(9.) - np.testing.assert_almost_equal( - pn.getVariance(), 9., precision, - err_msg="PoissonNoise().withVariance results in wrong variance") - np.testing.assert_almost_equal( - pn.sky_level, 9., precision, - err_msg="PoissonNoise().withVariance results in wrong sky_level") - pn = pn.withScaledVariance(4.) - np.testing.assert_almost_equal( - pn.getVariance(), 36, precision, - err_msg="PoissonNoise().withScaledVariance results in wrong variance") - np.testing.assert_almost_equal( - pn.sky_level, 36., precision, - err_msg="PoissonNoise().withScaledVariance results in wrong sky_level") - - # Check picklability - check_pickle(pn, lambda x: (x.rng.serialize(), x.sky_level)) - check_pickle(pn, drawNoise) - check_pickle(pn) - - # Check copy, eq and ne - pn = pn.withVariance(pMean) - pn2 = galsim.PoissonNoise(pn.rng.duplicate(), pMean) - pn3 = pn.copy() - pn4 = pn.copy(rng=galsim.BaseDeviate(11)) - pn5 = galsim.PoissonNoise(pn.rng, 2 * pMean) - assert pn == pn2 - assert pn == pn3 - assert pn != pn4 - assert pn != pn5 - assert pn.rng.raw() == pn2.rng.raw() - assert pn == pn2 - assert pn == pn3 - pn.rng.raw() - assert pn != pn2 - assert pn == pn3 - - -@timer -def test_ccdnoise(): - """Test CCD Noise generator - """ - # Start with some regression tests where we have known values that we expect to generate: - - types = (jnp.int16, jnp.int32, jnp.float32, jnp.float64) - typestrings = ("S", "I", "F", "D") - - testseed = 1000 - gain = 3. - read_noise = 5. - sky = 50 - - # Tabulated results for the above settings and testseed value. - cResultS = np.array([[42, 52], [49, 45]], dtype=np.int16) # noqa: F841 - cResultI = np.array([[42, 52], [49, 45]], dtype=np.int32) # noqa: F841 - cResultF = np.array([ # noqa: F841 - [42.4286994934082, 52.42875671386719], - [49.016048431396484, 45.61003875732422] - ], dtype=np.float32) - cResultD = np.array([ # noqa: F841 - [42.42870031326479, 52.42875718917211], - [49.016050296441094, 45.61003745208172] - ], dtype=np.float64) - - for i in range(4): - prec = eval("precision" + typestrings[i]) - cResult = eval("cResult" + typestrings[i]) - - rng = galsim.BaseDeviate(testseed) - ccdnoise = galsim.CCDNoise(rng, gain=gain, read_noise=read_noise) - testImage = galsim.Image((np.zeros((2, 2)) + sky).astype(types[i])) - ccdnoise.applyTo(testImage) - np.testing.assert_array_almost_equal( - testImage.array, cResult, prec, - err_msg="Wrong CCD noise random sequence generated for Image" + typestrings[i] + ".") - - # Check that reseeding the rng reseeds the internal deviate in CCDNoise - rng.seed(testseed) - testImage.fill(sky) - ccdnoise.applyTo(testImage) - np.testing.assert_array_almost_equal( - testImage.array, cResult, prec, - err_msg=( - "Wrong CCD noise random sequence generated for Image" + typestrings[i] - + " after seed" - ), - ) - - # Check using addNoise - rng.seed(testseed) - testImage.fill(sky) - testImage.addNoise(ccdnoise) - np.testing.assert_array_almost_equal( - testImage.array, cResult, prec, - err_msg=( - "Wrong CCD noise random sequence generated for Image" + typestrings[i] - + " using addNoise" - ), - ) - - # Test filling an image with Fortran ordering - rng.seed(testseed) - testImageF = galsim.Image(np.zeros((2, 2)).T, dtype=types[i]) - testImageF.fill(sky) - testImageF.addNoise(ccdnoise) - np.testing.assert_array_almost_equal( - testImageF.array, cResult, prec, - err_msg="Wrong CCD noise generated for Fortran-ordered Image" + typestrings[i]) - - # Now include sky_level in ccdnoise - rng.seed(testseed) - ccdnoise = galsim.CCDNoise(rng, sky_level=sky, gain=gain, read_noise=read_noise) - testImage.fill(0) - ccdnoise.applyTo(testImage) - np.testing.assert_array_almost_equal( - testImage.array, cResult - sky, prec, - err_msg=( - "Wrong CCD noise random sequence generated for Image" + typestrings[i] - + " with sky_level included in noise" - ), - ) - - rng.seed(testseed) - testImage.fill(0) - testImage.addNoise(ccdnoise) - np.testing.assert_array_almost_equal( - testImage.array, cResult - sky, prec, - err_msg=( - "Wrong CCD noise random sequence generated for Image" + typestrings[i] - + " using addNoise with sky_level included in noise" - ), - ) - - # Check CCDNoise variance: - var1 = sky / gain + (read_noise / gain)**2 - np.testing.assert_almost_equal( - ccdnoise.getVariance(), var1, precision, - err_msg="CCDNoise getVariance returns wrong variance") - np.testing.assert_almost_equal( - ccdnoise.sky_level, sky, precision, - err_msg="CCDNoise sky_level returns wrong value") - np.testing.assert_almost_equal( - ccdnoise.gain, gain, precision, - err_msg="CCDNoise gain returns wrong value") - np.testing.assert_almost_equal( - ccdnoise.read_noise, read_noise, precision, - err_msg="CCDNoise read_noise returns wrong value") - - # Check that the noise model really does produce this variance. - # NB. If default float32 is used here, older versions of numpy will compute the variance - # in single precision, and with 2048^2 values, the final answer comes out significantly - # wrong (19.33 instead of 19.42, which gets compared to the nominal value of 19.44). - big_im = galsim.Image(2048, 2048, dtype=float) - big_im.addNoise(ccdnoise) - var = np.var(big_im.array) - print('variance = ', var) - print('getVar = ', ccdnoise.getVariance()) - np.testing.assert_almost_equal( - var, ccdnoise.getVariance(), 1, - err_msg='Realized variance for CCDNoise did not match getVariance()') - - # Check that CCDNoise adds to the image, not overwrites the image. - gal = galsim.Exponential(half_light_radius=2.3, flux=0.3) - # Note: again, flux/size^2 needs to be << sky_level or it will mess up the statistics. - gal.drawImage(image=big_im) - big_im.addNoise(ccdnoise) - gal.withFlux(-0.3).drawImage(image=big_im, add_to_image=True) - var = np.var(big_im.array) - np.testing.assert_almost_equal( - var, ccdnoise.getVariance(), 1, - err_msg='CCDNoise wrong when already an object drawn on the image') - - # Check using a non-integer sky level which does some slightly different calculations. - rng.seed(testseed) - big_im_int = galsim.Image(2048, 2048, dtype=int) - ccdnoise = galsim.CCDNoise(rng, sky_level=34.42, gain=1.6, read_noise=11.2) - big_im_int.fill(0) - big_im_int.addNoise(ccdnoise) - var = np.var(big_im_int.array) - np.testing.assert_almost_equal(var / ccdnoise.getVariance(), 1., decimal=2, - err_msg='CCDNoise wrong when sky_level is not an integer') - - # Using gain=0 means the read_noise is in ADU, not e- - rng.seed(testseed) - ccdnoise = galsim.CCDNoise(rng, gain=0., read_noise=read_noise) - var2 = read_noise**2 - np.testing.assert_almost_equal( - ccdnoise.getVariance(), var2, precision, - err_msg="CCDNoise getVariance returns wrong variance with gain=0") - np.testing.assert_almost_equal( - ccdnoise.sky_level, 0., precision, - err_msg="CCDNoise sky_level returns wrong value with gain=0") - np.testing.assert_almost_equal( - ccdnoise.gain, 0., precision, - err_msg="CCDNoise gain returns wrong value with gain=0") - np.testing.assert_almost_equal( - ccdnoise.read_noise, read_noise, precision, - err_msg="CCDNoise read_noise returns wrong value with gain=0") - big_im.fill(0) - big_im.addNoise(ccdnoise) - var = np.var(big_im.array) - np.testing.assert_almost_equal(var, ccdnoise.getVariance(), 1, - err_msg='CCDNoise wrong when gain=0') - - # Check withVariance - ccdnoise = galsim.CCDNoise(rng, sky_level=sky, gain=gain, read_noise=read_noise) - ccdnoise = ccdnoise.withVariance(9.) - np.testing.assert_almost_equal( - ccdnoise.getVariance(), 9., precision, - err_msg="CCDNoise withVariance results in wrong variance") - np.testing.assert_almost_equal( - ccdnoise.sky_level, (9. / var1) * sky, precision, - err_msg="CCDNoise withVariance results in wrong sky_level") - np.testing.assert_almost_equal( - ccdnoise.gain, gain, precision, - err_msg="CCDNoise withVariance results in wrong gain") - np.testing.assert_almost_equal( - ccdnoise.read_noise, np.sqrt(9. / var1) * read_noise, precision, - err_msg="CCDNoise withVariance results in wrong ReadNoise") - - # Check withScaledVariance - ccdnoise = ccdnoise.withScaledVariance(4.) - np.testing.assert_almost_equal( - ccdnoise.getVariance(), 36., precision, - err_msg="CCDNoise withVariance results in wrong variance") - np.testing.assert_almost_equal( - ccdnoise.sky_level, (36. / var1) * sky, precision, - err_msg="CCDNoise withVariance results in wrong sky_level") - np.testing.assert_almost_equal( - ccdnoise.gain, gain, precision, - err_msg="CCDNoise withVariance results in wrong gain") - np.testing.assert_almost_equal( - ccdnoise.read_noise, np.sqrt(36. / var1) * read_noise, precision, - err_msg="CCDNoise withVariance results in wrong ReadNoise") - - # Check arithmetic - ccdnoise = ccdnoise.withVariance(0.5) - ccdnoise2 = ccdnoise * 3 - np.testing.assert_almost_equal( - ccdnoise2.getVariance(), 1.5, precision, - err_msg="CCDNoise ccdnoise*3 results in wrong variance") - np.testing.assert_almost_equal( - ccdnoise.getVariance(), 0.5, precision, - err_msg="CCDNoise ccdnoise*3 results in wrong variance for original ccdnoise") - ccdnoise2 = 5 * ccdnoise - np.testing.assert_almost_equal( - ccdnoise2.getVariance(), 2.5, precision, - err_msg="CCDNoise 5*ccdnoise results in wrong variance") - np.testing.assert_almost_equal( - ccdnoise.getVariance(), 0.5, precision, - err_msg="CCDNoise 5*ccdnoise results in wrong variance for original ccdnoise") - ccdnoise2 = ccdnoise / 2 - np.testing.assert_almost_equal( - ccdnoise2.getVariance(), 0.25, precision, - err_msg="CCDNoise ccdnoise/2 results in wrong variance") - np.testing.assert_almost_equal( - ccdnoise.getVariance(), 0.5, precision, - err_msg="CCDNoise 5*ccdnoise results in wrong variance for original ccdnoise") - ccdnoise *= 3 - np.testing.assert_almost_equal( - ccdnoise.getVariance(), 1.5, precision, - err_msg="CCDNoise ccdnoise*=3 results in wrong variance") - ccdnoise /= 2 - np.testing.assert_almost_equal( - ccdnoise.getVariance(), 0.75, precision, - err_msg="CCDNoise ccdnoise/=2 results in wrong variance") - - # Check starting with CCDNoise() - ccdnoise = galsim.CCDNoise() - ccdnoise = ccdnoise.withVariance(9.) - np.testing.assert_almost_equal( - ccdnoise.getVariance(), 9., precision, - err_msg="CCDNoise().withVariance results in wrong variance") - np.testing.assert_almost_equal( - ccdnoise.sky_level, 9., precision, - err_msg="CCDNoise().withVariance results in wrong sky_level") - np.testing.assert_almost_equal( - ccdnoise.gain, 1., precision, - err_msg="CCDNoise().withVariance results in wrong gain") - np.testing.assert_almost_equal( - ccdnoise.read_noise, 0., precision, - err_msg="CCDNoise().withVariance results in wrong ReadNoise") - ccdnoise = ccdnoise.withScaledVariance(4.) - np.testing.assert_almost_equal( - ccdnoise.getVariance(), 36., precision, - err_msg="CCDNoise().withScaledVariance results in wrong variance") - np.testing.assert_almost_equal( - ccdnoise.sky_level, 36., precision, - err_msg="CCDNoise().withScaledVariance results in wrong sky_level") - np.testing.assert_almost_equal( - ccdnoise.gain, 1., precision, - err_msg="CCDNoise().withScaledVariance results in wrong gain") - np.testing.assert_almost_equal( - ccdnoise.read_noise, 0., precision, - err_msg="CCDNoise().withScaledVariance results in wrong ReadNoise") - - # Check picklability - check_pickle(ccdnoise, lambda x: (x.rng.serialize(), x.sky_level, x.gain, x.read_noise)) - check_pickle(ccdnoise, drawNoise) - check_pickle(ccdnoise) - - # Check copy, eq and ne - ccdnoise = galsim.CCDNoise(rng, sky, gain, read_noise) - ccdnoise2 = galsim.CCDNoise(ccdnoise.rng.duplicate(), gain=gain, read_noise=read_noise, - sky_level=sky) - ccdnoise3 = ccdnoise.copy() - ccdnoise4 = ccdnoise.copy(rng=galsim.BaseDeviate(11)) - ccdnoise5 = galsim.CCDNoise(ccdnoise.rng, gain=2 * gain, read_noise=read_noise, sky_level=sky) - ccdnoise6 = galsim.CCDNoise(ccdnoise.rng, gain=gain, read_noise=2 * read_noise, sky_level=sky) - ccdnoise7 = galsim.CCDNoise(ccdnoise.rng, gain=gain, read_noise=read_noise, sky_level=2 * sky) - assert ccdnoise == ccdnoise2 - assert ccdnoise == ccdnoise3 - assert ccdnoise != ccdnoise4 - assert ccdnoise != ccdnoise5 - assert ccdnoise != ccdnoise6 - assert ccdnoise != ccdnoise7 - assert ccdnoise.rng.raw() == ccdnoise2.rng.raw() - assert ccdnoise == ccdnoise2 - assert ccdnoise == ccdnoise3 - ccdnoise.rng.raw() - assert ccdnoise != ccdnoise2 - assert ccdnoise == ccdnoise3 - - -@timer -def test_addnoisesnr(): - """Test that addNoiseSNR is behaving sensibly. - """ - # Rather than reproducing the S/N calculation in addNoiseSNR(), we'll just check for - # self-consistency of the behavior with / without flux preservation. - # Begin by making some object that we draw into an Image. - gal_sigma = 3.7 - pix_scale = 0.6 - test_snr = 73. - gauss = galsim.Gaussian(sigma=gal_sigma) - im = gauss.drawImage(scale=pix_scale, dtype=np.float64) - - # Now make the noise object to use. - # Use a default-constructed rng (i.e. rng=None) since we had initially had trouble - # with that. And use the duplicate feature to get a second copy of this rng. - gn = galsim.GaussianNoise() - rng2 = gn.rng.duplicate() - - # Try addNoiseSNR with preserve_flux=True, so the RNG needs a different variance. - # Check what variance was added for this SNR, and that the RNG still has its original variance - # after this call. - var_out = im.addNoiseSNR(gn, test_snr, preserve_flux=True) - assert gn.getVariance() == 1.0 - max_val = im.array.max() - - # Now apply addNoiseSNR to another (clean) image with preserve_flux=False, so we use the noise - # variance in the original RNG, i.e., 1. Check that the returned variance is 1, and that the - # value of the maximum pixel (presumably the peak of the galaxy light profile) is scaled as we - # expect for this SNR. - im2 = gauss.drawImage(scale=pix_scale, dtype=np.float64) - gn2 = galsim.GaussianNoise(rng=rng2) - var_out2 = im2.addNoiseSNR(gn2, test_snr, preserve_flux=False) - assert var_out2 == 1.0 - expect_max_val2 = max_val * np.sqrt(var_out2 / var_out) - np.testing.assert_almost_equal( - im2.array.max(), expect_max_val2, decimal=8, - err_msg='addNoiseSNR with preserve_flux = True and False give inconsistent results') diff --git a/tests/jax/galsim/test_photon_array_jax.py b/tests/jax/galsim/test_photon_array_jax.py deleted file mode 100644 index 63fe3868..00000000 --- a/tests/jax/galsim/test_photon_array_jax.py +++ /dev/null @@ -1,1873 +0,0 @@ -# Copyright (c) 2012-2023 by the GalSim developers team on GitHub -# https://github.com/GalSim-developers -# -# This file is part of GalSim: The modular galaxy image simulation toolkit. -# https://github.com/GalSim-developers/GalSim -# -# GalSim is free software: redistribution and use in source and binary forms, -# with or without modification, are permitted provided that the following -# conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions, and the disclaimer given in the accompanying LICENSE -# file. -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions, and the disclaimer given in the documentation -# and/or other materials provided with the distribution. -# - -import unittest -import numpy as np -import os -import sys -import warnings - -# We don't require astroplan. So check if it's installed. -try: - with warnings.catch_warnings(): - warnings.filterwarnings("ignore") - import astroplan - no_astroplan = False -except ImportError: - no_astroplan = True - -import galsim -from galsim_test_helpers import * - -bppath = os.path.join(galsim.meta_data.share_dir, "bandpasses") -sedpath = os.path.join(galsim.meta_data.share_dir, "SEDs") - - -@timer -def test_photon_array(): - """Test the basic methods of PhotonArray class""" - nphotons = 1000 - - # First create from scratch - photon_array = galsim.PhotonArray(nphotons) - assert len(photon_array.x) == nphotons - assert len(photon_array.y) == nphotons - assert len(photon_array.flux) == nphotons - assert not photon_array.hasAllocatedWavelengths() - assert not photon_array.hasAllocatedAngles() - - # Initial values should all be 0 - np.testing.assert_array_equal(photon_array.x, 0.0) - np.testing.assert_array_equal(photon_array.y, 0.0) - np.testing.assert_array_equal(photon_array.flux, 0.0) - - # Check picklability - check_pickle(photon_array) - - # JAX does not support this way of assignement - # # Check assignment via numpy [:] - # photon_array.x[:] = 5 - # photon_array.y[:] = 17 - # photon_array.flux[:] = 23 - # np.testing.assert_array_equal(photon_array.x, 5.) - # np.testing.assert_array_equal(photon_array.y, 17.) - # np.testing.assert_array_equal(photon_array.flux, 23.) - - # Check assignment directly to the attributes - photon_array.x = 25 - photon_array.y = 37 - photon_array.flux = 53 - np.testing.assert_array_equal(photon_array.x, 25.0) - np.testing.assert_array_equal(photon_array.y, 37.0) - np.testing.assert_array_equal(photon_array.flux, 53.0) - - # Now create from shooting a profile - obj = galsim.Exponential(flux=1.7, scale_radius=2.3) - rng = galsim.UniformDeviate(1234) - photon_array = obj.shoot(nphotons, rng) - orig_x = photon_array.x.copy() - orig_y = photon_array.y.copy() - orig_flux = photon_array.flux.copy() - assert len(photon_array.x) == nphotons - assert len(photon_array.y) == nphotons - assert len(photon_array.flux) == nphotons - assert not photon_array.hasAllocatedWavelengths() - assert not photon_array.hasAllocatedAngles() - assert not photon_array.hasAllocatedPupil() - assert not photon_array.hasAllocatedTimes() - - # Check arithmetic ops - photon_array.x *= 5 - photon_array.y += 17 - photon_array.flux /= 23 - np.testing.assert_array_almost_equal(photon_array.x, orig_x * 5.0) - np.testing.assert_array_almost_equal(photon_array.y, orig_y + 17.0) - np.testing.assert_array_almost_equal(photon_array.flux, orig_flux / 23.0) - - # Check picklability again with non-zero values - check_pickle(photon_array) - - # Now assign to the optional arrays - photon_array.dxdz = 0.17 - assert photon_array.hasAllocatedAngles() - assert not photon_array.hasAllocatedWavelengths() - np.testing.assert_array_equal(photon_array.dxdz, 0.17) - np.testing.assert_array_equal(photon_array.dydz, 0.0) - - photon_array.dydz = 0.59 - np.testing.assert_array_equal(photon_array.dxdz, 0.17) - np.testing.assert_array_equal(photon_array.dydz, 0.59) - - # Check shooting negative flux - obj = galsim.Exponential(flux=-1.7, scale_radius=2.3) - rng = galsim.UniformDeviate(1234) - neg_photon_array = obj.shoot(nphotons, rng) - np.testing.assert_array_equal(neg_photon_array.x, orig_x) - np.testing.assert_array_equal(neg_photon_array.y, orig_y) - np.testing.assert_array_equal(neg_photon_array.flux, -orig_flux) - - # Start over to check that assigning to wavelength leaves dxdz, dydz alone. - photon_array = obj.shoot(nphotons, rng) - photon_array.wavelength = 500.0 - assert photon_array.hasAllocatedWavelengths() - assert not photon_array.hasAllocatedAngles() - assert not photon_array.hasAllocatedPupil() - assert not photon_array.hasAllocatedTimes() - np.testing.assert_array_equal(photon_array.wavelength, 500) - - photon_array.dxdz = 0.23 - photon_array.dydz = 0.88 - photon_array.wavelength = 912.0 - assert photon_array.hasAllocatedWavelengths() - assert photon_array.hasAllocatedAngles() - assert not photon_array.hasAllocatedPupil() - assert not photon_array.hasAllocatedTimes() - np.testing.assert_array_equal(photon_array.dxdz, 0.23) - np.testing.assert_array_equal(photon_array.dydz, 0.88) - np.testing.assert_array_equal(photon_array.wavelength, 912) - - # Add pupil coords - photon_array.pupil_u = 6.0 - assert photon_array.hasAllocatedWavelengths() - assert photon_array.hasAllocatedAngles() - assert photon_array.hasAllocatedPupil() - assert not photon_array.hasAllocatedTimes() - np.testing.assert_array_equal(photon_array.dxdz, 0.23) - np.testing.assert_array_equal(photon_array.dydz, 0.88) - np.testing.assert_array_equal(photon_array.wavelength, 912) - np.testing.assert_array_equal(photon_array.pupil_u, 6.0) - np.testing.assert_array_equal(photon_array.pupil_v, 0.0) - - # Add time stamps - photon_array.time = 0.0 - assert photon_array.hasAllocatedWavelengths() - assert photon_array.hasAllocatedAngles() - assert photon_array.hasAllocatedPupil() - assert photon_array.hasAllocatedTimes() - np.testing.assert_array_equal(photon_array.dxdz, 0.23) - np.testing.assert_array_equal(photon_array.dydz, 0.88) - np.testing.assert_array_equal(photon_array.wavelength, 912) - np.testing.assert_array_equal(photon_array.pupil_u, 6.0) - np.testing.assert_array_equal(photon_array.pupil_v, 0.0) - np.testing.assert_array_equal(photon_array.time, 0.0) - - # Check toggling is_corr - assert not photon_array.isCorrelated() - photon_array.setCorrelated() - assert photon_array.isCorrelated() - photon_array.setCorrelated(False) - assert not photon_array.isCorrelated() - photon_array.setCorrelated(True) - assert photon_array.isCorrelated() - - # Check rescaling the total flux - flux = photon_array.flux.sum() - np.testing.assert_almost_equal(photon_array.getTotalFlux(), flux) - photon_array.scaleFlux(17) - np.testing.assert_almost_equal(photon_array.getTotalFlux(), 17 * flux) - photon_array.setTotalFlux(199) - np.testing.assert_almost_equal(photon_array.getTotalFlux(), 199) - photon_array.scaleFlux(-1.7) - np.testing.assert_almost_equal(photon_array.getTotalFlux(), -1.7 * 199) - photon_array.setTotalFlux(-199) - np.testing.assert_almost_equal(photon_array.getTotalFlux(), -199) - - # Check rescaling the positions - x = photon_array.x.copy() - y = photon_array.y.copy() - photon_array.scaleXY(1.9) - np.testing.assert_array_almost_equal(photon_array.x, 1.9 * x) - np.testing.assert_array_almost_equal(photon_array.y, 1.9 * y) - - # Check ways to assign to photons - pa1 = galsim.PhotonArray(50) - pa1.x = photon_array.x[:50] - pa1.y = photon_array.y[:50] - pa1.flux = photon_array.flux[:50] - # for i in range(50): - # pa1.y[i] = photon_array.y[i] - # pa1.flux[0:50] = photon_array.flux[:50] - pa1.dxdz = photon_array.dxdz[:50] - pa1.dydz = photon_array.dydz[:50] - pa1.wavelength = photon_array.wavelength[:50] - pa1.pupil_u = photon_array.pupil_u[:50] - pa1.pupil_v = photon_array.pupil_v[:50] - pa1.time = photon_array.time[:50] - np.testing.assert_array_almost_equal(pa1.x, photon_array.x[:50]) - np.testing.assert_array_almost_equal(pa1.y, photon_array.y[:50]) - np.testing.assert_array_almost_equal(pa1.flux, photon_array.flux[:50]) - np.testing.assert_array_almost_equal(pa1.dxdz, photon_array.dxdz[:50]) - np.testing.assert_array_almost_equal(pa1.dydz, photon_array.dydz[:50]) - np.testing.assert_array_almost_equal(pa1.wavelength, photon_array.wavelength[:50]) - np.testing.assert_array_almost_equal(pa1.pupil_u, photon_array.pupil_u[:50]) - np.testing.assert_array_almost_equal(pa1.pupil_v, photon_array.pupil_v[:50]) - np.testing.assert_array_almost_equal(pa1.time, photon_array.time[:50]) - - # Check assignAt - pa2 = galsim.PhotonArray(100) - pa2.assignAt(0, pa1) - pa2.assignAt(50, pa1) - np.testing.assert_array_almost_equal(pa2.x[:50], pa1.x) - np.testing.assert_array_almost_equal(pa2.y[:50], pa1.y) - np.testing.assert_array_almost_equal(pa2.flux[:50], pa1.flux) - np.testing.assert_array_almost_equal(pa2.dxdz[:50], pa1.dxdz) - np.testing.assert_array_almost_equal(pa2.dydz[:50], pa1.dydz) - np.testing.assert_array_almost_equal(pa2.wavelength[:50], pa1.wavelength) - np.testing.assert_array_almost_equal(pa2.pupil_u[:50], pa1.pupil_u) - np.testing.assert_array_almost_equal(pa2.pupil_v[:50], pa1.pupil_v) - np.testing.assert_array_almost_equal(pa2.time[:50], pa1.time) - np.testing.assert_array_almost_equal(pa2.x[50:], pa1.x) - np.testing.assert_array_almost_equal(pa2.y[50:], pa1.y) - np.testing.assert_array_almost_equal(pa2.flux[50:], pa1.flux) - np.testing.assert_array_almost_equal(pa2.dxdz[50:], pa1.dxdz) - np.testing.assert_array_almost_equal(pa2.dydz[50:], pa1.dydz) - np.testing.assert_array_almost_equal(pa2.wavelength[50:], pa1.wavelength) - np.testing.assert_array_almost_equal(pa2.pupil_u[50:], pa1.pupil_u) - np.testing.assert_array_almost_equal(pa2.pupil_v[50:], pa1.pupil_v) - np.testing.assert_array_almost_equal(pa2.time[50:], pa1.time) - - # Error if it doesn't fit. - assert_raises(ValueError, pa2.assignAt, 90, pa1) - - # Test some trivial usage of makeFromImage - zero = galsim.Image(4, 4, init_value=0) - photons = galsim.PhotonArray.makeFromImage(zero) - print("photons = ", photons) - assert len(photons) == 16 - np.testing.assert_array_equal(photons.flux, 0.0) - - ones = galsim.Image(4, 4, init_value=1) - photons = galsim.PhotonArray.makeFromImage(ones) - print("photons = ", photons) - assert len(photons) == 16 - np.testing.assert_array_almost_equal(photons.flux, 1.0) - - tens = galsim.Image(4, 4, init_value=8) - photons = galsim.PhotonArray.makeFromImage(tens, max_flux=5.0) - print("photons = ", photons) - assert len(photons) == 32 - np.testing.assert_array_almost_equal(photons.flux, 4.0) - - assert_raises(ValueError, galsim.PhotonArray.makeFromImage, zero, max_flux=0.0) - assert_raises(ValueError, galsim.PhotonArray.makeFromImage, zero, max_flux=-2) - - # Check some other errors - undef = galsim.Image() - assert_raises(galsim.GalSimUndefinedBoundsError, pa2.addTo, undef) - - # Check picklability again with non-zero values for everything - check_pickle(photon_array) - - -@timer -def test_convolve(): - nphotons = 1000000 - - obj = galsim.Gaussian(flux=1.7, sigma=2.3) - rng = galsim.UniformDeviate(1234) - pa1 = obj.shoot(nphotons, rng) - rng2 = rng.duplicate() # Save this state. - pa2 = obj.shoot(nphotons, rng) - - # If not correlated then convolve is deterministic - conv_x = pa1.x + pa2.x - conv_y = pa1.y + pa2.y - conv_flux = pa1.flux * pa2.flux * nphotons - - np.testing.assert_allclose(np.sum(pa1.flux), 1.7) - np.testing.assert_allclose(np.sum(pa2.flux), 1.7) - np.testing.assert_allclose(np.sum(conv_flux), 1.7 * 1.7) - - np.testing.assert_allclose(np.sum(pa1.x**2) / nphotons, 2.3**2, rtol=0.01) - np.testing.assert_allclose(np.sum(pa2.x**2) / nphotons, 2.3**2, rtol=0.01) - np.testing.assert_allclose( - np.sum(conv_x**2) / nphotons, 2.0 * 2.3**2, rtol=0.01 - ) - - np.testing.assert_allclose(np.sum(pa1.y**2) / nphotons, 2.3**2, rtol=0.01) - np.testing.assert_allclose(np.sum(pa2.y**2) / nphotons, 2.3**2, rtol=0.01) - np.testing.assert_allclose( - np.sum(conv_y**2) / nphotons, 2.0 * 2.3**2, rtol=0.01 - ) - - pa3 = galsim.PhotonArray(nphotons) - pa3.assignAt(0, pa1) # copy from pa1 - pa3.convolve(pa2) - np.testing.assert_allclose(pa3.x, conv_x) - np.testing.assert_allclose(pa3.y, conv_y) - np.testing.assert_allclose(pa3.flux, conv_flux) - - # If one of them is correlated, it is still deterministic. - pa3.assignAt(0, pa1) - pa3.setCorrelated() - pa3.convolve(pa2) - np.testing.assert_allclose(pa3.x, conv_x) - np.testing.assert_allclose(pa3.y, conv_y) - np.testing.assert_allclose(pa3.flux, conv_flux) - - pa3.assignAt(0, pa1) - pa3.setCorrelated(False) - pa2.setCorrelated() - pa3.convolve(pa2) - np.testing.assert_allclose(pa3.x, conv_x) - np.testing.assert_allclose(pa3.y, conv_y) - np.testing.assert_allclose(pa3.flux, conv_flux) - - # But if both are correlated, then it's not this simple. - pa3.assignAt(0, pa1) - pa3.setCorrelated() - assert pa3.isCorrelated() - assert pa2.isCorrelated() - pa3.convolve(pa2) - with assert_raises(AssertionError): - np.testing.assert_allclose(pa3.x, conv_x) - with assert_raises(AssertionError): - np.testing.assert_allclose(pa3.y, conv_y) - np.testing.assert_allclose(np.sum(pa3.flux), 1.7 * 1.7) - np.testing.assert_allclose(np.sum(pa3.x**2) / nphotons, 2 * 2.3**2, rtol=0.01) - np.testing.assert_allclose(np.sum(pa3.y**2) / nphotons, 2 * 2.3**2, rtol=0.01) - - # Can also effect the convolution by treating the psf as a PhotonOp - pa3.assignAt(0, pa1) - pa3.setCorrelated() - obj.applyTo(pa3, rng=rng2) - np.testing.assert_allclose(pa3.x, conv_x) - np.testing.assert_allclose(pa3.y, conv_y) - np.testing.assert_allclose(pa3.flux, conv_flux) - - # Error to have different lengths - pa4 = galsim.PhotonArray(50, pa1.x[:50], pa1.y[:50], pa1.flux[:50]) - assert_raises(galsim.GalSimError, pa1.convolve, pa4) - - # Check propagation of dxdz, dydz, wavelength, pupil_u, pupil_v - for attr, checkFn in zip( - ["dxdz", "dydz", "wavelength", "pupil_u", "pupil_v", "time"], - [ - "hasAllocatedAngles", - "hasAllocatedAngles", - "hasAllocatedWavelengths", - "hasAllocatedPupil", - "hasAllocatedPupil", - "hasAllocatedTimes", - ], - ): - pa1 = obj.shoot(nphotons, rng) - pa2 = obj.shoot(nphotons, rng) - assert not getattr(pa1, checkFn)() - assert not getattr(pa1, checkFn)() - data = np.linspace(-0.1, 0.1, nphotons) - setattr(pa1, attr, data) - assert getattr(pa1, checkFn)() - assert not getattr(pa2, checkFn)() - pa1.convolve(pa2) - assert getattr(pa1, checkFn)() - assert not getattr(pa2, checkFn)() - np.testing.assert_array_equal(getattr(pa1, attr), data) - pa2.convolve(pa1) - assert getattr(pa1, checkFn)() - assert getattr(pa2, checkFn)() - np.testing.assert_array_equal(getattr(pa2, attr), data) - - # both have data now... - pa1.convolve(pa2) - np.testing.assert_array_equal(getattr(pa1, attr), data) - np.testing.assert_array_equal(getattr(pa2, attr), data) - - # If the second one has different data, the first takes precedence. - setattr(pa2, attr, data * 2) - pa1.convolve(pa2) - np.testing.assert_array_equal(getattr(pa1, attr), data) - np.testing.assert_array_equal(getattr(pa2, attr), 2 * data) - - -@timer -def test_wavelength_sampler(): - nphotons = 1000 - obj = galsim.Exponential(flux=1.7, scale_radius=2.3) - rng = galsim.UniformDeviate(1234) - - photon_array = obj.shoot(nphotons, rng) - - sed = galsim.SED(os.path.join(sedpath, "CWW_E_ext.sed"), "A", "flambda").thin() - bandpass = galsim.Bandpass(os.path.join(bppath, "LSST_r.dat"), "nm").thin() - - sampler = galsim.WavelengthSampler(sed, bandpass) - sampler.applyTo(photon_array, rng=rng) - - # Note: the underlying functionality of the sampleWavelengths function is tested - # in test_sed.py. So here we are really just testing that the wrapper class is - # properly writing to the photon_array.wavelengths array. - - assert photon_array.hasAllocatedWavelengths() - assert not photon_array.hasAllocatedAngles() - - check_pickle(sampler) - - print("mean wavelength = ", np.mean(photon_array.wavelength)) - print("min wavelength = ", np.min(photon_array.wavelength)) - print("max wavelength = ", np.max(photon_array.wavelength)) - - assert np.min(photon_array.wavelength) > bandpass.blue_limit - assert np.max(photon_array.wavelength) < bandpass.red_limit - - # This is a regression test based on the value at commit 134a119 - np.testing.assert_allclose( - np.mean(photon_array.wavelength), 622.755128, rtol=1.0e-4 - ) - - # If we use a flat SED (in photons/nm), then the mean sampled wavelength should very closely - # match the bandpass effective wavelength. - photon_array2 = galsim.PhotonArray(100000) - sed2 = galsim.SED("1", "nm", "fphotons") - sampler2 = galsim.WavelengthSampler(sed2, bandpass) - sampler2.applyTo(photon_array2, rng=rng) - np.testing.assert_allclose( - np.mean(photon_array2.wavelength), - bandpass.effective_wavelength, - rtol=0, - atol=0.2, # 2 Angstrom accuracy is pretty good - err_msg="Mean sampled wavelength not close to effective_wavelength", - ) - - # If the photon array already has wavelengths set, then it proceeds, but gives a warning. - with assert_warns(galsim.GalSimWarning): - sampler2.applyTo(photon_array2, rng=rng) - np.testing.assert_allclose( - np.mean(photon_array2.wavelength), - bandpass.effective_wavelength, - rtol=0, - atol=0.2, - ) - - # Test that using this as a surface op works properly. - - # First do the shooting and clipping manually. - im1 = galsim.Image(64, 64, scale=1) - im1.setCenter(0, 0) - photon_array.flux[photon_array.wavelength < 600] = 0.0 - photon_array.addTo(im1) - - # Make a dummy surface op that clips any photons with lambda < 600 - class Clip600: - def applyTo(self, photon_array, local_wcs=None, rng=None): - photon_array.flux[photon_array.wavelength < 600] = 0.0 - - # Use (a new) sampler and clip600 as photon_ops in drawImage - im2 = galsim.Image(64, 64, scale=1) - im2.setCenter(0, 0) - clip600 = Clip600() - rng2 = galsim.BaseDeviate(1234) - sampler2 = galsim.WavelengthSampler(sed, bandpass) - obj.drawImage( - im2, - method="phot", - n_photons=nphotons, - use_true_center=False, - photon_ops=[sampler2, clip600], - rng=rng2, - save_photons=True, - ) - print("sum = ", im1.array.sum(), im2.array.sum()) - np.testing.assert_array_equal(im1.array, im2.array) - - # Equivalent version just getting photons back - rng2.seed(1234) - photons = obj.makePhot(n_photons=nphotons, photon_ops=[sampler2, clip600], rng=rng2) - print("phot.x = ", photons.x) - print("im2.photons.x = ", im2.photons.x) - assert photons == im2.photons - - # Base class is invalid to try to use. - op = galsim.PhotonOp() - with assert_raises(NotImplementedError): - op.applyTo(photon_array) - - -@timer -def test_photon_angles(): - """Test the photon_array function""" - # Make a photon array - seed = 12345 - ud = galsim.UniformDeviate(seed) - gal = galsim.Sersic(n=4, half_light_radius=1) - photon_array = gal.shoot(100000, ud) - - # Add the directions (N.B. using the same seed as for generating the photon array - # above. The fact that it is the same does not matter here; the testing routine - # only needs to have a definite seed value so the consistency of the results with - # expectations can be evaluated precisely - fratio = 1.2 - obscuration = 0.2 - - # rng can be None, an existing BaseDeviate, or an integer - for rng in [None, ud, 12345]: - assigner = galsim.FRatioAngles(fratio, obscuration) - assigner.applyTo(photon_array, rng=rng) - - check_pickle(assigner) - - dxdz = photon_array.dxdz - dydz = photon_array.dydz - - phi = np.arctan2(dydz, dxdz) - tantheta = np.sqrt(np.square(dxdz) + np.square(dydz)) - sintheta = np.sin(np.arctan(tantheta)) - - # Check that the values are within the ranges expected - # (The test on phi really can't fail, because it is only testing the range of the - # arctan2 function.) - np.testing.assert_array_less( - -phi, np.pi, "Azimuth angles outside possible range" - ) - np.testing.assert_array_less( - phi, np.pi, "Azimuth angles outside possible range" - ) - - fov_angle = np.arctan(0.5 / fratio) - obscuration_angle = obscuration * fov_angle - np.testing.assert_array_less( - -sintheta, - -np.sin(obscuration_angle), - "Inclination angles outside possible range", - ) - np.testing.assert_array_less( - sintheta, np.sin(fov_angle), "Inclination angles outside possible range" - ) - - # Compare these slopes with the expected distributions (uniform in azimuth - # over all azimiths and uniform in sin(inclination) over the range of - # allowed inclinations - # Only test this for the last one, so we make sure we have a deterministic result. - # (The above tests should be reliable even for the default rng.) - phi_histo, phi_bins = np.histogram(phi, bins=100) - sintheta_histo, sintheta_bins = np.histogram(sintheta, bins=100) - phi_ref = float(np.sum(phi_histo)) / phi_histo.size - sintheta_ref = float(np.sum(sintheta_histo)) / sintheta_histo.size - - chisqr_phi = np.sum(np.square(phi_histo - phi_ref) / phi_ref) / phi_histo.size - chisqr_sintheta = ( - np.sum(np.square(sintheta_histo - sintheta_ref) / sintheta_ref) - / sintheta_histo.size - ) - - print("chisqr_phi = ", chisqr_phi) - print("chisqr_sintheta = ", chisqr_sintheta) - assert 0.9 < chisqr_phi < 1.1, "Distribution in azimuth is not nearly uniform" - assert ( - 0.9 < chisqr_sintheta < 1.1 - ), "Distribution in sin(inclination) is not nearly uniform" - - # Try some invalid inputs - assert_raises(ValueError, galsim.FRatioAngles, fratio=-0.3) - assert_raises(ValueError, galsim.FRatioAngles, fratio=1.2, obscuration=-0.3) - assert_raises(ValueError, galsim.FRatioAngles, fratio=1.2, obscuration=1.0) - assert_raises(ValueError, galsim.FRatioAngles, fratio=1.2, obscuration=1.9) - - -@timer -def test_photon_io(): - """Test the ability to read and write photons to a file""" - nphotons = 1000 - - obj = galsim.Exponential(flux=1.7, scale_radius=2.3) - rng = galsim.UniformDeviate(1234) - image = obj.drawImage(method="phot", n_photons=nphotons, save_photons=True, rng=rng) - photons = image.photons - assert photons.size() == len(photons) == nphotons - - with assert_raises(galsim.GalSimIncompatibleValuesError): - obj.drawImage(method="phot", n_photons=nphotons, save_photons=True, maxN=1.0e5) - - file_name = "output/photons1.dat" - photons.write(file_name) - - photons1 = galsim.PhotonArray.read(file_name) - - assert photons1.size() == nphotons - assert not photons1.hasAllocatedWavelengths() - assert not photons1.hasAllocatedAngles() - assert not photons1.hasAllocatedPupil() - assert not photons1.hasAllocatedTimes() - - np.testing.assert_array_equal(photons1.x, photons.x) - np.testing.assert_array_equal(photons1.y, photons.y) - np.testing.assert_array_equal(photons1.flux, photons.flux) - - sed = galsim.SED(os.path.join(sedpath, "CWW_E_ext.sed"), "nm", "flambda").thin() - bandpass = galsim.Bandpass(os.path.join(bppath, "LSST_r.dat"), "nm").thin() - - wave_sampler = galsim.WavelengthSampler(sed, bandpass) - angle_sampler = galsim.FRatioAngles(1.3, 0.3) - - ops = [wave_sampler, angle_sampler] - for op in ops: - op.applyTo(photons, rng=rng) - - # Directly inject some pupil coordinates and time stamps - photons.pupil_u = np.linspace(0, 1, nphotons) - photons.pupil_v = np.linspace(1, 2, nphotons) - photons.time = np.linspace(0, 30, nphotons) - - file_name = "output/photons2.dat" - photons.write(file_name) - - photons2 = galsim.PhotonArray.read(file_name) - - assert photons2.size() == nphotons - assert photons2.hasAllocatedWavelengths() - assert photons2.hasAllocatedAngles() - assert photons2.hasAllocatedPupil() - assert photons2.hasAllocatedTimes() - - np.testing.assert_array_equal(photons2.x, photons.x) - np.testing.assert_array_equal(photons2.y, photons.y) - np.testing.assert_array_equal(photons2.flux, photons.flux) - np.testing.assert_array_equal(photons2.dxdz, photons.dxdz) - np.testing.assert_array_equal(photons2.dydz, photons.dydz) - np.testing.assert_array_equal(photons2.wavelength, photons.wavelength) - np.testing.assert_array_equal(photons.pupil_u, photons.pupil_u) - np.testing.assert_array_equal(photons.pupil_v, photons.pupil_v) - np.testing.assert_array_equal(photons.time, photons.time) - - -@timer -def test_dcr(): - """Test the dcr surface op""" - # This tests that implementing DCR with the surface op is equivalent to using - # ChromaticAtmosphere. - # We use fairly extreme choices for the parameters to make the comparison easier, so - # we can still get good discrimination of any errors with only 10^6 photons. - zenith_angle = 45 * galsim.degrees # Larger angle has larger DCR. - parallactic_angle = 129 * galsim.degrees # Something random, not near 0 or 180 - pixel_scale = ( - 0.03 # Small pixel scale means shifts are many pixels, rather than a fraction. - ) - alpha = -1.2 # The normal alpha is -0.2, so this is exaggerates the effect. - - bandpass = galsim.Bandpass("LSST_r.dat", "nm") - base_wavelength = bandpass.effective_wavelength - base_wavelength += 500 # This exaggerates the effects fairly substantially. - - sed = galsim.SED("CWW_E_ext.sed", wave_type="ang", flux_type="flambda") - - flux = 1.0e6 - fwhm = 0.3 - base_PSF = galsim.Kolmogorov(fwhm=fwhm) - - # Use ChromaticAtmosphere - # Note, somewhat gratuitous check that ImageI works with dtype=int in config below. - im1 = galsim.ImageI(50, 50, scale=pixel_scale) - star = galsim.DeltaFunction() * sed - star = star.withFlux(flux, bandpass=bandpass) - chrom_PSF = galsim.ChromaticAtmosphere( - base_PSF, - base_wavelength=base_wavelength, - zenith_angle=zenith_angle, - parallactic_angle=parallactic_angle, - alpha=alpha, - ) - chrom = galsim.Convolve(star, chrom_PSF) - chrom.drawImage(bandpass, image=im1) - - # Repeat with config - config = { - "psf": { - "type": "ChromaticAtmosphere", - "base_profile": {"type": "Kolmogorov", "fwhm": fwhm}, - "base_wavelength": base_wavelength, - "zenith_angle": zenith_angle, - "parallactic_angle": parallactic_angle, - "alpha": alpha, - }, - "gal": {"type": "DeltaFunction", "flux": flux, "sed": sed}, - "image": { - "xsize": 50, - "ysize": 50, - "pixel_scale": pixel_scale, - "bandpass": bandpass, - "random_seed": 31415, - "dtype": int, - }, - } - im1c = galsim.config.BuildImage(config) - assert im1c == im1 - - # Use PhotonDCR - im2 = galsim.ImageI(50, 50, scale=pixel_scale) - dcr = galsim.PhotonDCR( - base_wavelength=base_wavelength, - zenith_angle=zenith_angle, - parallactic_angle=parallactic_angle, - alpha=alpha, - ) - achrom = base_PSF.withFlux(flux) - # Because we'll be comparing to config version, get the rng the way it will do it. - rng = galsim.BaseDeviate(galsim.BaseDeviate(31415).raw() + 1) - wave_sampler = galsim.WavelengthSampler(sed, bandpass) - photon_ops = [wave_sampler, dcr] - achrom.drawImage(image=im2, method="phot", rng=rng, photon_ops=photon_ops) - - check_pickle(dcr) - - # Repeat with config - config = { - "psf": {"type": "Kolmogorov", "fwhm": fwhm}, - "gal": {"type": "DeltaFunction", "flux": flux}, - "image": { - "xsize": 50, - "ysize": 50, - "pixel_scale": pixel_scale, - "bandpass": bandpass, - "random_seed": 31415, - "dtype": "np.int32", - }, - "stamp": { - "draw_method": "phot", - "photon_ops": [ - {"type": "WavelengthSampler", "sed": sed}, - { - "type": "PhotonDCR", - "base_wavelength": base_wavelength, - "zenith_angle": zenith_angle, - "parallactic_angle": parallactic_angle, - "alpha": alpha, - }, - ], - }, - } - im2c = galsim.config.BuildImage(config) - assert im2c == im2 - - # Should work with fft, but not quite match (because of inexact photon locations). - im3 = galsim.ImageF(50, 50, scale=pixel_scale) - achrom.drawImage(image=im3, method="fft", rng=rng, photon_ops=photon_ops) - printval(im3, im2, show=False) - np.testing.assert_allclose( - im3.array, - im2.array, - atol=0.1 * np.max(im2.array), - err_msg="PhotonDCR on fft image didn't match phot image", - ) - # Moments come out less than 1% different. - res2 = im2.FindAdaptiveMom() - res3 = im3.FindAdaptiveMom() - np.testing.assert_allclose(res3.moments_amp, res2.moments_amp, rtol=1.0e-2) - np.testing.assert_allclose(res3.moments_sigma, res2.moments_sigma, rtol=1.0e-2) - np.testing.assert_allclose( - res3.observed_shape.e1, res2.observed_shape.e1, atol=1.0e-2 - ) - np.testing.assert_allclose( - res3.observed_shape.e2, res2.observed_shape.e2, atol=1.0e-2 - ) - np.testing.assert_allclose( - res3.moments_centroid.x, res2.moments_centroid.x, rtol=1.0e-2 - ) - np.testing.assert_allclose( - res3.moments_centroid.y, res2.moments_centroid.y, rtol=1.0e-2 - ) - - # Repeat with maxN < flux - # Note: Because of the different way this generates the random positions, it's not identical - # to the above run without maxN. Both runs are equally valid realizations of photon - # positions corresponding to the FFT image. But not the same realization. - achrom.drawImage( - image=im3, method="auto", rng=rng, photon_ops=photon_ops, maxN=10**4 - ) - printval(im3, im2, show=False) - np.testing.assert_allclose( - im3.array, - im2.array, - atol=0.2 * np.max(im2.array), - err_msg="PhotonDCR on fft image with maxN didn't match phot image", - ) - res3 = im3.FindAdaptiveMom() - np.testing.assert_allclose(res3.moments_amp, res2.moments_amp, rtol=1.0e-2) - np.testing.assert_allclose(res3.moments_sigma, res2.moments_sigma, rtol=1.0e-2) - np.testing.assert_allclose( - res3.observed_shape.e1, res2.observed_shape.e1, atol=1.0e-2 - ) - np.testing.assert_allclose( - res3.observed_shape.e2, res2.observed_shape.e2, atol=1.0e-2 - ) - np.testing.assert_allclose( - res3.moments_centroid.x, res2.moments_centroid.x, rtol=1.0e-2 - ) - np.testing.assert_allclose( - res3.moments_centroid.y, res2.moments_centroid.y, rtol=1.0e-2 - ) - - # Compare ChromaticAtmosphere image with PhotonDCR image. - printval(im2, im1, show=False) - # tolerace for photon shooting is ~sqrt(flux) = 1.e3 - np.testing.assert_allclose( - im2.array, - im1.array, - atol=1.0e3, - err_msg="PhotonDCR didn't match ChromaticAtmosphere", - ) - - # Use ChromaticAtmosphere in photon_ops - im3 = galsim.ImageI(50, 50, scale=pixel_scale) - photon_ops = [chrom_PSF] - star.drawImage(bandpass, image=im3, method="phot", rng=rng, photon_ops=photon_ops) - printval(im3, im1, show=False) - np.testing.assert_allclose( - im3.array, - im1.array, - atol=1.0e3, - err_msg="ChromaticAtmosphere in photon_ops didn't match", - ) - - # Repeat with thinned bandpass and SED to check that thin still works well. - im3 = galsim.ImageI(50, 50, scale=pixel_scale) - thin = 0.1 # Even higher also works. But this is probably enough. - thin_bandpass = bandpass.thin(thin) - thin_sed = sed.thin(thin) - print("len bp = %d => %d" % (len(bandpass.wave_list), len(thin_bandpass.wave_list))) - print("len sed = %d => %d" % (len(sed.wave_list), len(thin_sed.wave_list))) - wave_sampler = galsim.WavelengthSampler(thin_sed, thin_bandpass) - photon_ops = [wave_sampler, dcr] - achrom.drawImage(image=im3, method="phot", rng=rng, photon_ops=photon_ops) - - printval(im3, im1, show=False) - np.testing.assert_allclose( - im3.array, - im1.array, - atol=1.0e3, - err_msg="thinning factor %f led to 1.e-4 level inaccuracy" % thin, - ) - - # Check scale_unit - im4 = galsim.ImageI(50, 50, scale=pixel_scale / 60) - wave_sampler = galsim.WavelengthSampler(sed, bandpass) - dcr = galsim.PhotonDCR( - base_wavelength=base_wavelength, - zenith_angle=zenith_angle, - parallactic_angle=parallactic_angle, - scale_unit="arcmin", - alpha=alpha, - ) - photon_ops = [wave_sampler, dcr] - rng = galsim.BaseDeviate(galsim.BaseDeviate(31415).raw() + 1) - achrom.dilate(1.0 / 60).drawImage( - image=im4, method="phot", rng=rng, photon_ops=photon_ops - ) - printval(im4, im1, show=False) - np.testing.assert_allclose( - im4.array, - im1.array, - atol=1.0e3, - err_msg="PhotonDCR with scale_unit=arcmin, didn't match", - ) - - galsim.config.RemoveCurrent(config) - del config["stamp"]["photon_ops"][1]["_get"] - config["stamp"]["photon_ops"][1]["scale_unit"] = "arcmin" - config["image"]["pixel_scale"] = pixel_scale / 60 - config["psf"]["fwhm"] = fwhm / 60 - im4c = galsim.config.BuildImage(config) - assert im4c == im4 - - # Check some other valid options - # alpha = 0 means don't do any size scaling. - # obj_coord, HA and latitude are another option for setting the angles - # pressure, temp, and water pressure are settable. - # Also use a non-trivial WCS. - wcs = galsim.FitsWCS("des_data/DECam_00154912_12_header.fits") - image = galsim.Image(50, 50, wcs=wcs) - bandpass = galsim.Bandpass("LSST_r.dat", wave_type="nm").thin(0.1) - base_wavelength = bandpass.effective_wavelength - lsst_lat = galsim.Angle.from_dms("-30:14:23.76") - lsst_long = galsim.Angle.from_dms("-70:44:34.67") - local_sidereal_time = ( - 3.14 * galsim.hours - ) # Not pi. This is the time for this observation. - - im5 = galsim.ImageI(50, 50, wcs=wcs) - obj_coord = wcs.toWorld(im5.true_center) - base_PSF = galsim.Kolmogorov(fwhm=0.9) - achrom = base_PSF.withFlux(flux) - dcr = galsim.PhotonDCR( - base_wavelength=base_wavelength, - obj_coord=obj_coord, - HA=local_sidereal_time - obj_coord.ra, - latitude=lsst_lat, - pressure=72, # default is 69.328 - temperature=290, # default is 293.15 - H2O_pressure=0.9, - ) # default is 1.067 - # alpha=0) # default is 0, so don't need to set it. - wave_sampler = galsim.WavelengthSampler(sed, bandpass) - photon_ops = [wave_sampler, dcr] - rng = galsim.BaseDeviate(galsim.BaseDeviate(31415).raw() + 1) - achrom.drawImage(image=im5, method="phot", rng=rng, photon_ops=photon_ops) - - check_pickle(dcr) - - galsim.config.RemoveCurrent(config) - config["psf"]["fwhm"] = 0.9 - config["image"] = { - "xsize": 50, - "ysize": 50, - "wcs": {"type": "Fits", "file_name": "des_data/DECam_00154912_12_header.fits"}, - "bandpass": bandpass, - "random_seed": 31415, - "dtype": "np.int32", - "world_pos": obj_coord, - } - config["stamp"]["photon_ops"][1] = { - "type": "PhotonDCR", - "base_wavelength": base_wavelength, - "HA": local_sidereal_time - obj_coord.ra, - "latitude": "-30:14:23.76 deg", - "pressure": 72, - "temperature": 290, - "H2O_pressure": 0.9, - } - im5c = galsim.config.BuildImage(config) - assert im5c == im5 - - # Also one using zenith_coord = (lst, lat) - config["stamp"]["photon_ops"][1] = { - "type": "PhotonDCR", - "base_wavelength": base_wavelength, - "zenith_coord": { - "type": "RADec", - "ra": local_sidereal_time, - "dec": lsst_lat, - }, - "pressure": 72, - "temperature": 290, - "H2O_pressure": 0.9, - } - im5d = galsim.config.BuildImage(config) - assert im5d == im5 - - im6 = galsim.ImageI(50, 50, wcs=wcs) - star = galsim.DeltaFunction() * sed - star = star.withFlux(flux, bandpass=bandpass) - chrom_PSF = galsim.ChromaticAtmosphere( - base_PSF, - base_wavelength=bandpass.effective_wavelength, - obj_coord=obj_coord, - HA=local_sidereal_time - obj_coord.ra, - latitude=lsst_lat, - pressure=72, - temperature=290, - H2O_pressure=0.9, - alpha=0, - ) - chrom = galsim.Convolve(star, chrom_PSF) - chrom.drawImage(bandpass, image=im6) - - printval(im5, im6, show=False) - np.testing.assert_allclose( - im5.array, im6.array, atol=1.0e3, err_msg="PhotonDCR with alpha=0 didn't match" - ) - - # Use ChromaticAtmosphere in photon_ops - im7 = galsim.ImageI(50, 50, wcs=wcs) - photon_ops = [chrom_PSF] - star.drawImage(bandpass, image=im7, method="phot", rng=rng, photon_ops=photon_ops) - printval(im7, im6, show=False) - np.testing.assert_allclose( - im7.array, - im6.array, - atol=1.0e3, - err_msg="ChromaticAtmosphere in photon_ops didn't match", - ) - - # ChromaticAtmosphere in photon_ops is almost trivially equal to base_psf and dcr in photon_ops. - im8 = galsim.ImageI(50, 50, wcs=wcs) - photon_ops = [base_PSF, dcr] - star.drawImage(bandpass, image=im8, method="phot", rng=rng, photon_ops=photon_ops) - printval(im8, im6, show=False) - np.testing.assert_allclose( - im8.array, - im6.array, - atol=1.0e3, - err_msg="base_psf + dcr in photon_ops didn't match", - ) - - # Including the wavelength sampler with chromatic drawing is not necessary, but is allowed. - # (Mostly in case someone wants to do something a little different w.r.t. wavelength sampling. - photon_ops = [wave_sampler, base_PSF, dcr] - star.drawImage(bandpass, image=im8, method="phot", rng=rng, photon_ops=photon_ops) - printval(im8, im6, show=False) - np.testing.assert_allclose( - im8.array, - im6.array, - atol=1.0e3, - err_msg="wave_sampler,base_psf,dcr in photon_ops didn't match", - ) - - # Also check invalid parameters - zenith_coord = galsim.CelestialCoord(13.54 * galsim.hours, lsst_lat) - assert_raises( - TypeError, - galsim.PhotonDCR, - zenith_angle=zenith_angle, - parallactic_angle=parallactic_angle, - ) # base_wavelength is required - assert_raises( - TypeError, - galsim.PhotonDCR, - base_wavelength=500, - parallactic_angle=parallactic_angle, - ) # zenith_angle (somehow) is required - assert_raises( - TypeError, - galsim.PhotonDCR, - 500, - zenith_angle=34.4, - parallactic_angle=parallactic_angle, - ) # zenith_angle must be Angle - assert_raises( - TypeError, - galsim.PhotonDCR, - 500, - zenith_angle=zenith_angle, - parallactic_angle=34.5, - ) # parallactic_angle must be Angle - assert_raises( - TypeError, galsim.PhotonDCR, 500, obj_coord=obj_coord, latitude=lsst_lat - ) # Missing HA - assert_raises( - TypeError, - galsim.PhotonDCR, - 500, - obj_coord=obj_coord, - HA=local_sidereal_time - obj_coord.ra, - ) # Missing latitude - assert_raises( - TypeError, galsim.PhotonDCR, 500, obj_coord=obj_coord - ) # Need either zenith_coord, or (HA,lat) - assert_raises( - TypeError, - galsim.PhotonDCR, - 500, - obj_coord=obj_coord, - zenith_coord=zenith_coord, - HA=local_sidereal_time - obj_coord.ra, - ) # Can't have both HA and zenith_coord - assert_raises( - TypeError, - galsim.PhotonDCR, - 500, - obj_coord=obj_coord, - zenith_coord=zenith_coord, - latitude=lsst_lat, - ) # Can't have both lat and zenith_coord - assert_raises( - TypeError, - galsim.PhotonDCR, - 500, - zenith_angle=zenith_angle, - parallactic_angle=parallactic_angle, - H20_pressure=1.0, - ) # invalid (misspelled) - assert_raises( - ValueError, - galsim.PhotonDCR, - 500, - zenith_angle=zenith_angle, - parallactic_angle=parallactic_angle, - scale_unit="inches", - ) # invalid scale_unit - photons = galsim.PhotonArray(2, flux=1) - assert_raises( - galsim.GalSimError, dcr.applyTo, photons - ) # Requires wavelengths to be set - assert_raises( - galsim.GalSimError, chrom_PSF.applyTo, photons - ) # Requires wavelengths to be set - photons = galsim.PhotonArray(2, flux=1, wavelength=500) - assert_raises(TypeError, dcr.applyTo, photons) # Requires local_wcs - - # Invalid to use dcr without some way of setting wavelengths. - assert_raises( - galsim.GalSimError, achrom.drawImage, im2, method="phot", photon_ops=[dcr] - ) - - -@unittest.skipIf(no_astroplan, "Unable to import astroplan") -@timer -def test_dcr_angles(): - """Check the DCR angle calculations by comparing to astroplan's calculations of the same.""" - # Note: test_chromatic.py and test_sed.py both also test aspects of the dcr module, so - # this particular test could belong in either of them too. But I (MJ) put it here, since - # I wrote it in conjunction with the tests of PhotonDCR to try to make sure that code - # is working properly. - import astropy.time - - # Set up an observation date, time, location, coordinate - # These are arbitrary, so ripped from astroplan's docs - # https://media.readthedocs.org/pdf/astroplan/latest/astroplan.pdf - subaru = astroplan.Observer.at_site("subaru") - time = astropy.time.Time("2015-06-16 12:00:00") - - # Stars that are visible from the north in summer time. - names = [ - "Vega", - "Polaris", - "Altair", - "Regulus", - "Spica", - "Algol", - "Fomalhaut", - "Markab", - "Deneb", - "Mizar", - "Dubhe", - "Sirius", - "Rigel", - "Alderamin", - ] - - for name in names: - try: - star = astroplan.FixedTarget.from_name(name) - except Exception as e: - print("Caught exception trying to make star from name ", name) - print(e) - print("Aborting. (Probably some kind of network problem.)") - return - print(star) - - with warnings.catch_warnings(): - warnings.filterwarnings("ignore") - ap_z = subaru.altaz(time, star).zen - ap_q = subaru.parallactic_angle(time, star) - local_sidereal_time = subaru.local_sidereal_time(time) - print("According to astroplan:") - print(" z,q = ", ap_z.deg, ap_q.deg) - - # Repeat with GalSim - coord = galsim.CelestialCoord( - star.ra.deg * galsim.degrees, star.dec.deg * galsim.degrees - ) - lat = subaru.location.lat.deg * galsim.degrees - ha = local_sidereal_time.deg * galsim.degrees - coord.ra - zenith = galsim.CelestialCoord(local_sidereal_time.deg * galsim.degrees, lat) - - # Two ways to calculate it - # 1. From coord, ha, lat - z, q, _ = galsim.dcr.parse_dcr_angles(obj_coord=coord, HA=ha, latitude=lat) - print("According to GalSim:") - print(" z,q = ", z / galsim.degrees, q / galsim.degrees) - - np.testing.assert_almost_equal( - z.rad, - ap_z.rad, - 2, - "zenith angle doesn't agree with astroplan's calculation.", - ) - - # Unfortunately, at least as of version 0.4, astroplan's parallactic angle calculation - # has a bug. It computes it as the arctan of some value, but doesn't use arctan2. - # So whenever |q| > 90 degrees, it gets it wrong by 180 degrees. Therefore, we only - # test that tan(q) is right. We'll check the quadrant below in test_dcr_moments(). - np.testing.assert_almost_equal( - np.tan(q), - np.tan(ap_q), - 2, - "parallactic angle doesn't agree with astroplan's calculation.", - ) - - # 2. From coord, zenith_coord - z, q, _ = galsim.dcr.parse_dcr_angles(obj_coord=coord, zenith_coord=zenith) - print(" z,q = ", z / galsim.degrees, q / galsim.degrees) - - np.testing.assert_almost_equal( - z.rad, - ap_z.rad, - 2, - "zenith angle doesn't agree with astroplan's calculation.", - ) - np.testing.assert_almost_equal( - np.tan(q), - np.tan(ap_q), - 2, - "parallactic angle doesn't agree with astroplan's calculation.", - ) - - -def test_dcr_moments(): - """Check that DCR gets the direction of the moment changes correct for some simple geometries. - i.e. Basically check the sign conventions used in the DCR code. - """ - # First, the basics. - # 1. DCR shifts blue photons closer to zenith, because the index of refraction larger. - # cf. http://lsstdesc.github.io/chroma/ - # 2. Galsim models profiles as seen from Earth with North up (and therefore East left). - # 3. Hour angle is negative when the object is in the east (soon after rising, say), - # zero when crossing the zenith meridian, and then positive to the west. - - # Use g-band, where the effect is more dramatic across the band than in redder bands. - # Also use a reference wavelength significantly to the red, so there should be a net - # overall shift towards zenith as well as a shear along the line to zenith. - bandpass = galsim.Bandpass("LSST_g.dat", "nm").thin(0.1) - base_wavelength = 600 # > red end of g band - - # Uniform across the band is fine for this. - sed = galsim.SED("1", wave_type="nm", flux_type="fphotons") - rng = galsim.BaseDeviate(31415) - wave_sampler = galsim.WavelengthSampler(sed, bandpass) - - star = galsim.Kolmogorov(fwhm=0.3, flux=1.0e6) # 10^6 photons should be enough. - im = galsim.ImageD( - 50, 50, scale=0.05 - ) # Small pixel scale, so shift is many pixels. - ra = 0 * galsim.degrees # Completely irrelevant here. - lat = -20 * galsim.degrees # Also doesn't really matter much. - - # 1. HA < 0, Dec < lat Spot should be shifted up and right. e2 > 0. - dcr = galsim.PhotonDCR( - base_wavelength=base_wavelength, - HA=-2 * galsim.hours, - latitude=lat, - obj_coord=galsim.CelestialCoord(ra, lat - 20 * galsim.degrees), - ) - photon_ops = [wave_sampler, dcr] - star.drawImage(image=im, method="phot", rng=rng, photon_ops=photon_ops) - moments = galsim.utilities.unweighted_moments(im, origin=im.true_center) - print("1. HA < 0, Dec < lat: ", moments) - assert moments["My"] > 0 # up - assert moments["Mx"] > 0 # right - assert moments["Mxy"] > 0 # e2 > 0 - - # 2. HA = 0, Dec < lat Spot should be shifted up. e1 < 0, e2 ~= 0. - dcr = galsim.PhotonDCR( - base_wavelength=base_wavelength, - HA=0 * galsim.hours, - latitude=lat, - obj_coord=galsim.CelestialCoord(ra, lat - 20 * galsim.degrees), - ) - photon_ops = [wave_sampler, dcr] - star.drawImage(image=im, method="phot", rng=rng, photon_ops=photon_ops) - moments = galsim.utilities.unweighted_moments(im, origin=im.true_center) - print("2. HA = 0, Dec < lat: ", moments) - assert moments["My"] > 0 # up - assert abs(moments["Mx"]) < 0.05 # not left or right - assert moments["Mxx"] < moments["Myy"] # e1 < 0 - assert abs(moments["Mxy"]) < 0.1 # e2 ~= 0 - - # 3. HA > 0, Dec < lat Spot should be shifted up and left. e2 < 0. - dcr = galsim.PhotonDCR( - base_wavelength=base_wavelength, - HA=2 * galsim.hours, - latitude=lat, - obj_coord=galsim.CelestialCoord(ra, lat - 20 * galsim.degrees), - ) - photon_ops = [wave_sampler, dcr] - star.drawImage(image=im, method="phot", rng=rng, photon_ops=photon_ops) - moments = galsim.utilities.unweighted_moments(im, origin=im.true_center) - print("3. HA > 0, Dec < lat: ", moments) - assert moments["My"] > 0 # up - assert moments["Mx"] < 0 # left - assert moments["Mxy"] < 0 # e2 < 0 - - # 4. HA < 0, Dec = lat Spot should be shifted right. e1 > 0, e2 ~= 0. - dcr = galsim.PhotonDCR( - base_wavelength=base_wavelength, - HA=-2 * galsim.hours, - latitude=lat, - obj_coord=galsim.CelestialCoord(ra, lat), - ) - photon_ops = [wave_sampler, dcr] - star.drawImage(image=im, method="phot", rng=rng, photon_ops=photon_ops) - moments = galsim.utilities.unweighted_moments(im, origin=im.true_center) - print("4. HA < 0, Dec = lat: ", moments) - assert ( - abs(moments["My"]) < 1.0 - ) # not up or down (Actually slightly down in the south.) - assert moments["Mx"] > 0 # right - assert moments["Mxx"] > moments["Myy"] # e1 > 0 - assert ( - abs(moments["Mxy"]) < 2.0 - ) # e2 ~= 0 (Actually slightly negative because of curvature.) - - # 5. HA = 0, Dec = lat Spot should not be shifted. e1 ~= 0, e2 ~= 0. - dcr = galsim.PhotonDCR( - base_wavelength=base_wavelength, - HA=0 * galsim.hours, - latitude=lat, - obj_coord=galsim.CelestialCoord(ra, lat), - ) - photon_ops = [wave_sampler, dcr] - star.drawImage(image=im, method="phot", rng=rng, photon_ops=photon_ops) - moments = galsim.utilities.unweighted_moments(im, origin=im.true_center) - print("5. HA = 0, Dec = lat: ", moments) - assert abs(moments["My"]) < 0.05 # not up or down - assert abs(moments["Mx"]) < 0.05 # not left or right - assert abs(moments["Mxx"] - moments["Myy"]) < 0.1 # e1 ~= 0 - assert abs(moments["Mxy"]) < 0.1 # e2 ~= 0 - - # 6. HA > 0, Dec = lat Spot should be shifted left. e1 > 0, e2 ~= 0. - dcr = galsim.PhotonDCR( - base_wavelength=base_wavelength, - HA=2 * galsim.hours, - latitude=lat, - obj_coord=galsim.CelestialCoord(ra, lat), - ) - photon_ops = [wave_sampler, dcr] - star.drawImage(image=im, method="phot", rng=rng, photon_ops=photon_ops) - moments = galsim.utilities.unweighted_moments(im, origin=im.true_center) - print("6. HA > 0, Dec = lat: ", moments) - assert ( - abs(moments["My"]) < 1.0 - ) # not up or down (Actually slightly down in the south.) - assert moments["Mx"] < 0 # left - assert moments["Mxx"] > moments["Myy"] # e1 > 0 - assert ( - abs(moments["Mxy"]) < 2.0 - ) # e2 ~= 0 (Actually slgihtly positive because of curvature.) - - # 7. HA < 0, Dec > lat Spot should be shifted down and right. e2 < 0. - dcr = galsim.PhotonDCR( - base_wavelength=base_wavelength, - HA=-2 * galsim.hours, - latitude=lat, - obj_coord=galsim.CelestialCoord(ra, lat + 20 * galsim.degrees), - ) - photon_ops = [wave_sampler, dcr] - star.drawImage(image=im, method="phot", rng=rng, photon_ops=photon_ops) - moments = galsim.utilities.unweighted_moments(im, origin=im.true_center) - print("7. HA < 0, Dec > lat: ", moments) - assert moments["My"] < 0 # down - assert moments["Mx"] > 0 # right - assert moments["Mxy"] < 0 # e2 < 0 - - # 8. HA = 0, Dec > lat Spot should be shifted down. e1 < 0, e2 ~= 0. - dcr = galsim.PhotonDCR( - base_wavelength=base_wavelength, - HA=0 * galsim.hours, - latitude=lat, - obj_coord=galsim.CelestialCoord(ra, lat + 20 * galsim.degrees), - ) - photon_ops = [wave_sampler, dcr] - star.drawImage(image=im, method="phot", rng=rng, photon_ops=photon_ops) - moments = galsim.utilities.unweighted_moments(im, origin=im.true_center) - print("8. HA = 0, Dec > lat: ", moments) - assert moments["My"] < 0 # down - assert abs(moments["Mx"]) < 0.05 # not left or right - assert moments["Mxx"] < moments["Myy"] # e1 < 0 - assert abs(moments["Mxy"]) < 0.1 # e2 ~= 0 - - # 9. HA > 0, Dec > lat Spot should be shifted down and left. e2 > 0. - dcr = galsim.PhotonDCR( - base_wavelength=base_wavelength, - HA=2 * galsim.hours, - latitude=lat, - obj_coord=galsim.CelestialCoord(ra, lat + 20 * galsim.degrees), - ) - photon_ops = [wave_sampler, dcr] - star.drawImage(image=im, method="phot", rng=rng, photon_ops=photon_ops) - moments = galsim.utilities.unweighted_moments(im, origin=im.true_center) - print("9. HA > 0, Dec > lat: ", moments) - assert moments["My"] < 0 # down - assert moments["Mx"] < 0 # left - assert moments["Mxy"] > 0 # e2 > 0 - - -@timer -def test_refract(): - ud = galsim.UniformDeviate(57721) - for _ in range(1000): - photon_array = galsim.PhotonArray(1000, flux=1) - photon_array.allocateAngles() - ud.generate(photon_array.dxdz) - ud.generate(photon_array.dydz) - photon_array.dxdz *= 1.2 # -0.6 to 0.6 - photon_array.dydz *= 1.2 - photon_array.dxdz -= 0.6 - photon_array.dydz -= 0.6 - # copy for testing later - dxdz0 = np.array(photon_array.dxdz) - dydz0 = np.array(photon_array.dydz) - index_ratio = ud() * 4 + 0.25 # 0.25 to 4.25 - refract = galsim.Refraction(index_ratio) - refract.applyTo(photon_array) - - check_pickle(refract) - - # Triangle is length 1 in the z direction and length sqrt(dxdz**2+dydz**2) - # in the 'r' direction. - rsqr0 = dxdz0**2 + dydz0**2 - sintheta0 = np.sqrt(rsqr0) / np.sqrt(1 + rsqr0) - # See if total internal reflection applies - w = sintheta0 < index_ratio - np.testing.assert_array_equal(photon_array.dxdz[~w], np.nan) - np.testing.assert_array_equal(photon_array.dydz[~w], np.nan) - np.testing.assert_array_equal(photon_array.flux, np.where(w, 1.0, 0.0)) - - sintheta0 = sintheta0[w] - dxdz0 = dxdz0[w] - dydz0 = dydz0[w] - dxdz1 = photon_array.dxdz[w] - dydz1 = photon_array.dydz[w] - rsqr1 = dxdz1**2 + dydz1**2 - sintheta1 = np.sqrt(rsqr1) / np.sqrt(1 + rsqr1) - # Check Snell's law - np.testing.assert_allclose(sintheta0, index_ratio * sintheta1) - - # Check azimuthal angle stays constant - phi0 = np.arctan2(dydz0, dxdz0) - phi1 = np.arctan2(dydz1, dxdz1) - np.testing.assert_allclose(phi0, phi1) - - # Check plane of refraction is perpendicular to (0,0,1) - np.testing.assert_allclose( - np.dot( - np.cross( - np.stack([dxdz0, dydz0, -np.ones(len(dxdz0))], axis=1), - np.stack([dxdz1, dydz1, -np.ones(len(dxdz1))], axis=1), - ), - [0, 0, 1], - ), - 0.0, - rtol=0, - atol=1e-13, - ) - - # Try a wavelength dependent index_ratio - index_ratio = lambda w: np.where(w < 1, 1.1, 2.2) - photon_array = galsim.PhotonArray(100) - photon_array.allocateWavelengths() - photon_array.allocateAngles() - ud.generate(photon_array.wavelength) - ud.generate(photon_array.dxdz) - ud.generate(photon_array.dydz) - photon_array.dxdz *= 1.2 # -0.6 to 0.6 - photon_array.dydz *= 1.2 - photon_array.dxdz -= 0.6 - photon_array.dydz -= 0.6 - photon_array.wavelength *= 2 # 0 to 2 - dxdz0 = photon_array.dxdz.copy() - dydz0 = photon_array.dydz.copy() - - refract_func = galsim.Refraction(index_ratio=index_ratio) - refract_func.applyTo(photon_array) - dxdz_func = photon_array.dxdz.copy() - dydz_func = photon_array.dydz.copy() - - photon_array.dxdz = dxdz0.copy() - photon_array.dydz = dydz0.copy() - refract11 = galsim.Refraction(index_ratio=1.1) - refract11.applyTo(photon_array) - dxdz11 = photon_array.dxdz.copy() - dydz11 = photon_array.dydz.copy() - - photon_array.dxdz = dxdz0.copy() - photon_array.dydz = dydz0.copy() - refract22 = galsim.Refraction(index_ratio=2.2) - refract22.applyTo(photon_array) - dxdz22 = photon_array.dxdz.copy() - dydz22 = photon_array.dydz.copy() - - w = photon_array.wavelength < 1 - np.testing.assert_allclose(dxdz_func, np.where(w, dxdz11, dxdz22)) - np.testing.assert_allclose(dydz_func, np.where(w, dydz11, dydz22)) - - -@timer -def test_focus_depth(): - bd = galsim.BaseDeviate(1234) - for _ in range(100): - # Test that FocusDepth is additive - photon_array = galsim.PhotonArray(1000) - photon_array2 = galsim.PhotonArray(1000) - photon_array.x = 0.0 - photon_array.y = 0.0 - photon_array2.x = 0.0 - photon_array2.y = 0.0 - galsim.FRatioAngles(1.234, obscuration=0.606).applyTo(photon_array, rng=bd) - photon_array2.dxdz = photon_array.dxdz - photon_array2.dydz = photon_array.dydz - fd1 = galsim.FocusDepth(1.1) - fd2 = galsim.FocusDepth(2.2) - fd3 = galsim.FocusDepth(3.3) - fd1.applyTo(photon_array) - fd2.applyTo(photon_array) - fd3.applyTo(photon_array2) - - check_pickle(fd1) - - np.testing.assert_allclose(photon_array.x, photon_array2.x, rtol=0, atol=1e-15) - np.testing.assert_allclose(photon_array.y, photon_array2.y, rtol=0, atol=1e-15) - # Assuming focus is at x=y=0, then - # intrafocal (depth < 0) => (x > 0 => dxdz < 0) - # extrafocal (depth > 0) => (x > 0 => dxdz > 0) - # We applied an extrafocal operation above, so check for corresponding - # relation between x, dxdz - np.testing.assert_array_less(0, photon_array.x * photon_array.dxdz) - - # transforming by depth and -depth is null - fd4 = galsim.FocusDepth(-3.3) - fd4.applyTo(photon_array) - np.testing.assert_allclose(photon_array.x, 0.0, rtol=0, atol=1e-15) - np.testing.assert_allclose(photon_array.y, 0.0, rtol=0, atol=1e-15) - - # Check that invalid photon array is trapped - pa = galsim.PhotonArray(10) - fd = galsim.FocusDepth(1.0) - with np.testing.assert_raises(galsim.GalSimError): - fd.applyTo(pa) - - # Check that we can infer depth from photon positions before and after... - for _ in range(100): - photon_array = galsim.PhotonArray(1000) - photon_array2 = galsim.PhotonArray(1000) - ud = galsim.UniformDeviate(bd) - ud.generate(photon_array.x) - ud.generate(photon_array.y) - photon_array.x -= 0.5 - photon_array.y -= 0.5 - galsim.FRatioAngles(1.234, obscuration=0.606).applyTo(photon_array, rng=bd) - photon_array2.x = photon_array.x - photon_array2.y = photon_array.y - photon_array2.dxdz = photon_array.dxdz - photon_array2.dydz = photon_array.dydz - depth = ud() - 0.5 - galsim.FocusDepth(depth).applyTo(photon_array2) - np.testing.assert_allclose( - (photon_array2.x - photon_array.x) / photon_array.dxdz, depth - ) - np.testing.assert_allclose( - (photon_array2.y - photon_array.y) / photon_array.dydz, depth - ) - np.testing.assert_allclose(photon_array.dxdz, photon_array2.dxdz) - np.testing.assert_allclose(photon_array.dydz, photon_array2.dydz) - - -@timer -def test_lsst_y_focus(): - # Check that applying reasonable focus depth (from O'Connor++06) indeed leads to smaller spot - # size for LSST y-band. - rng = galsim.BaseDeviate(9876543210) - bandpass = galsim.Bandpass("LSST_y.dat", wave_type="nm") - sed = galsim.SED("1", wave_type="nm", flux_type="flambda") - obj = galsim.Gaussian(fwhm=1e-5) - oversampling = 32 - photon_ops0 = [ - galsim.WavelengthSampler(sed, bandpass), - galsim.FRatioAngles(1.234, 0.606), - galsim.FocusDepth(0.0), - galsim.Refraction(3.9), - ] - img0 = obj.drawImage( - sensor=galsim.SiliconSensor(), - method="phot", - n_photons=100000, - photon_ops=photon_ops0, - scale=0.2 / oversampling, - nx=32 * oversampling, - ny=32 * oversampling, - rng=rng, - ) - T0 = img0.calculateMomentRadius() - T0 *= 10 * oversampling / 0.2 # arcsec => microns - - # O'Connor finds minimum spot size when the focus depth is ~ -12 microns. Our sensor isn't - # necessarily the same as the one there though; our minimum seems to be around -6 microns. - # That could be due to differences in the design of the sensor though. We just use -6 microns - # here, which is still useful to test the sign of the `depth` parameter and the interaction of - # the 4 different surface operators required to produce this effect, and is roughly consistent - # with O'Connor. - - depth1 = -6.0 # microns, negative means surface is intrafocal - depth1 /= 10 # microns => pixels - photon_ops1 = [ - galsim.WavelengthSampler(sed, bandpass), - galsim.FRatioAngles(1.234, 0.606), - galsim.FocusDepth(depth1), - galsim.Refraction(3.9), - ] - img1 = obj.drawImage( - sensor=galsim.SiliconSensor(), - method="phot", - n_photons=100000, - photon_ops=photon_ops1, - scale=0.2 / oversampling, - nx=32 * oversampling, - ny=32 * oversampling, - rng=rng, - ) - T1 = img1.calculateMomentRadius() - T1 *= 10 * oversampling / 0.2 # arcsec => microns - np.testing.assert_array_less(T1, T0) - - -@timer -def test_fromArrays(): - """Check that fromArrays constructor catches errors and ALWAYS copies.""" - - rng = galsim.BaseDeviate(123) - - x = np.empty(1000) - y = np.empty(1000) - flux = np.empty(1000) - - Nsplit = 444 - - pa_batch = galsim.PhotonArray.fromArrays(x, y, flux) - pa_1 = galsim.PhotonArray.fromArrays(x[:Nsplit], y[:Nsplit], flux[:Nsplit]) - pa_2 = galsim.PhotonArray.fromArrays(x[Nsplit:], y[Nsplit:], flux[Nsplit:]) - - assert pa_batch.x is not x - assert pa_batch.y is not y - assert pa_batch.flux is not flux - np.testing.assert_array_equal(pa_batch.x, x) - np.testing.assert_array_equal(pa_batch.y, y) - np.testing.assert_array_equal(pa_batch.flux, flux) - np.testing.assert_array_equal(pa_1.x, pa_batch.x[:Nsplit]) - np.testing.assert_array_equal(pa_1.y, pa_batch.y[:Nsplit]) - np.testing.assert_array_equal(pa_1.flux, pa_batch.flux[:Nsplit]) - np.testing.assert_array_equal(pa_2.x, pa_batch.x[Nsplit:]) - np.testing.assert_array_equal(pa_2.y, pa_batch.y[Nsplit:]) - np.testing.assert_array_equal(pa_2.flux, pa_batch.flux[Nsplit:]) - - # Do some manipulation and check views are still equivalent - obj1 = galsim.Gaussian(fwhm=0.1) * 64 - obj2 = galsim.Kolmogorov(fwhm=0.2) * 23 - - obj1._shoot(pa_1, rng) - obj2._shoot(pa_2, rng) - - assert pa_batch.x is x - assert pa_batch.y is y - assert pa_batch.flux is flux - np.testing.assert_array_equal(pa_batch.x, x) - np.testing.assert_array_equal(pa_batch.y, y) - np.testing.assert_array_equal(pa_batch.flux, flux) - np.testing.assert_array_equal(pa_1.x, pa_batch.x[:Nsplit]) - np.testing.assert_array_equal(pa_1.y, pa_batch.y[:Nsplit]) - np.testing.assert_array_equal(pa_1.flux, pa_batch.flux[:Nsplit]) - np.testing.assert_array_equal(pa_2.x, pa_batch.x[Nsplit:]) - np.testing.assert_array_equal(pa_2.y, pa_batch.y[Nsplit:]) - np.testing.assert_array_equal(pa_2.flux, pa_batch.flux[Nsplit:]) - - # Add some optional args and apply PhotonOps to the batch this time. - dxdz = np.empty(1000) - dydz = np.empty(1000) - wavelength = np.empty(1000) - pupil_u = np.empty(1000) - pupil_v = np.empty(1000) - time = np.empty(1000) - pa_batch = galsim.PhotonArray.fromArrays( - x, y, flux, dxdz, dydz, wavelength, pupil_u, pupil_v, time - ) - pa_1 = galsim.PhotonArray.fromArrays( - x[:Nsplit], - y[:Nsplit], - flux[:Nsplit], - dxdz[:Nsplit], - dydz[:Nsplit], - wavelength[:Nsplit], - pupil_u[:Nsplit], - pupil_v[:Nsplit], - time[:Nsplit], - ) - pa_2 = galsim.PhotonArray.fromArrays( - x[Nsplit:], - y[Nsplit:], - flux[Nsplit:], - dxdz[Nsplit:], - dydz[Nsplit:], - wavelength[Nsplit:], - pupil_u[Nsplit:], - pupil_v[Nsplit:], - time[Nsplit:], - ) - - sed = galsim.SED("vega.txt", wave_type="nm", flux_type="flambda") - bp = galsim.Bandpass("LSST_r.dat", wave_type="nm") - with assert_warns(galsim.GalSimWarning): - galsim.WavelengthSampler(sed, bp).applyTo(pa_batch, rng=rng) - galsim.FRatioAngles(1.2, 0.61).applyTo(pa_batch, rng=rng) - galsim.TimeSampler(0.0, 30.0).applyTo(pa_batch, rng=rng) - - assert pa_batch.x is x - assert pa_batch.y is y - assert pa_batch.flux is flux - assert pa_batch.dxdz is dxdz - assert pa_batch.dydz is dydz - assert pa_batch.wavelength is wavelength - assert pa_batch.pupil_u is pupil_u - assert pa_batch.pupil_v is pupil_v - assert pa_batch.time is time - np.testing.assert_array_equal(pa_batch.x, x) - np.testing.assert_array_equal(pa_batch.y, y) - np.testing.assert_array_equal(pa_batch.flux, flux) - np.testing.assert_array_equal(pa_batch.dxdz, dxdz) - np.testing.assert_array_equal(pa_batch.dydz, dydz) - np.testing.assert_array_equal(pa_batch.wavelength, wavelength) - np.testing.assert_array_equal(pa_batch.pupil_u, pupil_u) - np.testing.assert_array_equal(pa_batch.pupil_v, pupil_v) - np.testing.assert_array_equal(pa_batch.time, time) - np.testing.assert_array_equal(pa_1.x, pa_batch.x[:Nsplit]) - np.testing.assert_array_equal(pa_1.y, pa_batch.y[:Nsplit]) - np.testing.assert_array_equal(pa_1.flux, pa_batch.flux[:Nsplit]) - np.testing.assert_array_equal(pa_1.dxdz, pa_batch.dxdz[:Nsplit]) - np.testing.assert_array_equal(pa_1.dydz, pa_batch.dydz[:Nsplit]) - np.testing.assert_array_equal(pa_1.wavelength, pa_batch.wavelength[:Nsplit]) - np.testing.assert_array_equal(pa_1.pupil_u, pa_batch.pupil_u[:Nsplit]) - np.testing.assert_array_equal(pa_1.pupil_v, pa_batch.pupil_v[:Nsplit]) - np.testing.assert_array_equal(pa_1.time, pa_batch.time[:Nsplit]) - np.testing.assert_array_equal(pa_2.x, pa_batch.x[Nsplit:]) - np.testing.assert_array_equal(pa_2.y, pa_batch.y[Nsplit:]) - np.testing.assert_array_equal(pa_2.flux, pa_batch.flux[Nsplit:]) - np.testing.assert_array_equal(pa_2.dxdz, pa_batch.dxdz[Nsplit:]) - np.testing.assert_array_equal(pa_2.dydz, pa_batch.dydz[Nsplit:]) - np.testing.assert_array_equal(pa_2.wavelength, pa_batch.wavelength[Nsplit:]) - np.testing.assert_array_equal(pa_2.pupil_u, pa_batch.pupil_u[Nsplit:]) - np.testing.assert_array_equal(pa_2.pupil_v, pa_batch.pupil_v[Nsplit:]) - np.testing.assert_array_equal(pa_2.time, pa_batch.time[Nsplit:]) - - # Check the is_corr flag gets set - assert not pa_batch.isCorrelated() - pa_batch = galsim.PhotonArray.fromArrays( - x, y, flux, dxdz, dydz, wavelength, is_corr=True - ) - assert pa_batch.isCorrelated() - - # Check some invalid inputs are caught - with np.testing.assert_raises(TypeError): - galsim.PhotonArray.fromArrays(list(x), y, flux, dxdz, dydz, wavelength) - with np.testing.assert_raises(TypeError): - galsim.PhotonArray.fromArrays( - np.empty(1000, dtype=int), y, flux, dxdz, dydz, wavelength - ) - with np.testing.assert_raises(ValueError): - galsim.PhotonArray.fromArrays(x[:10], y, flux, dxdz, dydz, wavelength) - with np.testing.assert_raises(ValueError): - galsim.PhotonArray.fromArrays( - np.empty(2000)[::2], y, flux, dxdz, dydz, wavelength - ) - - -def test_pupil_annulus_sampler(): - """Check that we get a uniform distribution from PupilAnnulusSampler""" - seed = 54321 - sampler = galsim.PupilAnnulusSampler(1.0, 0.5) - pa = galsim.PhotonArray(1_000_000) - sampler.applyTo(pa, rng=seed) - r = np.hypot(pa.pupil_u, pa.pupil_v) - assert np.min(r) > 0.5 - assert np.max(r) < 1.0 - h, edges = np.histogram( - r, - bins=10, - range=(0.5, 1.0), - ) - areas = np.pi * (edges[1:] ** 2 - edges[:-1] ** 2) - # each bin should have ~100_000 photons, so +/- 0.3%. Test at 1%. - assert np.std(h / areas) / np.mean(h / areas) < 0.01 - - phi = np.arctan2(pa.pupil_v, pa.pupil_u) - phi[phi < 0] += 2 * np.pi - h, edges = np.histogram(phi, bins=10, range=(0.0, 2 * np.pi)) - assert np.std(h) / np.mean(h) < 0.01 - - check_pickle(sampler) - - -def test_time_sampler(): - """Check TimeSampler build arguments""" - seed = 97531 - sampler = galsim.TimeSampler() - assert sampler.t0 == 0 - assert sampler.exptime == 0 - pa = galsim.PhotonArray(1_000_000) - sampler.applyTo(pa, rng=seed) - np.testing.assert_array_equal(pa.time, 0.0) - check_pickle(sampler) - - sampler = galsim.TimeSampler(t0=1.0) - assert sampler.t0 == 1 - assert sampler.exptime == 0 - sampler.applyTo(pa, rng=seed) - np.testing.assert_array_equal(pa.time, 1.0) - check_pickle(sampler) - - sampler = galsim.TimeSampler(exptime=30.0) - assert sampler.t0 == 0 - assert sampler.exptime == 30 - sampler.applyTo(pa, rng=seed) - np.testing.assert_array_less(pa.time, 30) - np.testing.assert_array_less(-pa.time, 0) - check_pickle(sampler) - - sampler = galsim.TimeSampler(t0=10, exptime=30.0) - assert sampler.t0 == 10 - assert sampler.exptime == 30 - sampler.applyTo(pa, rng=seed) - np.testing.assert_array_less(pa.time, 40) - np.testing.assert_array_less(-pa.time, 10) - check_pickle(sampler) - - -def test_setFromImage_crash(): - """Geri Braunlich ran into a seg fault where the photon array was not allocated to be - sufficiently large for the photons it got from an image. - This test reproduces the error for version 2.4.8 for the purpose of fixing it. - - The bug turned out to be that some pixel values were (slightly) negative from the FFT, - and the total flux was estimated as np.sum(image.array). The negative pixels added - negatively to this sum, so the calculated total flux wasn't quite enough to hold all the - required photons. - - The fix was to use the absolute value of the image for this calculation. - """ - # These are (approximately) the specific values for one case where the code used to crash. - prof = galsim.Gaussian(sigma=0.13).withFlux(3972551) - wcs = galsim.JacobianWCS(-0.170, -0.106, 0.106, -0.170) - image = galsim.Image(1000, 1000, wcs=wcs, dtype=float) - - # Start with a simple draw with no photons - im1 = prof.drawImage(image=image.copy()) - - # Now with photon_ops. - # This had been sufficient to trigger the bug, but now photon_ops=[] is the same as None. - im2 = prof.drawImage(image=image.copy(), photon_ops=[], n_subsample=1) - assert im1 == im2 - - # Repeat with a non-empty, but still trivial, photon_ops. - im3 = prof.drawImage( - image=image.copy(), photon_ops=[galsim.FRatioAngles(1.2)], n_subsample=1 - ) - - # They aren't quite identical because of numerical rounding issues from going through - # a sum of fluxes on individual photons. - # In particular, we want to make sure negative pixels stay negative through this process. - assert im1 != im3 - np.testing.assert_allclose(im1.array, im3.array, rtol=1.0e-11) - w = np.where(im1.array != im3.array) - print("diff in ", len(w[0]), "pixels") - assert ( - len(w[0]) < 100 - ) # I find it to be different in only 39 photons on my machine. - - -if __name__ == "__main__": - testfns = [v for k, v in vars().items() if k[:5] == "test_" and callable(v)] - if no_astroplan: - print("Skipping test_dcr_angles, since astroplan not installed.") - testfns.remove(test_dcr_angles) - for testfn in testfns: - testfn() diff --git a/tests/jax/galsim/test_random_jax.py b/tests/jax/galsim/test_random_jax.py deleted file mode 100644 index 30941daa..00000000 --- a/tests/jax/galsim/test_random_jax.py +++ /dev/null @@ -1,2002 +0,0 @@ -import math -import numpy as np -import os -import galsim -from galsim.utilities import single_threaded -from galsim_test_helpers import timer, check_pickle # noqa: E402 - -precision = 10 -# decimal point at which agreement is required for all double precision tests - -precisionD = precision -precisionF = 5 # precision=10 does not make sense at single precision -precisionS = 1 # "precision" also a silly concept for ints, but allows all 4 tests to run in one go -precisionI = 1 - -# The number of values to generate when checking the mean and variance calculations. -# This is currently low enough to not dominate the time of the unit tests, but when changing -# something, it may be useful to add a couple zeros while testing. -nvals = 100000 - -testseed = 1000 # seed used for UniformDeviate for all tests -# Warning! If you change testseed, then all of the *Result variables below must change as well. - -# the right answer for the first three uniform deviates produced from testseed -uResult = (0.0160653916, 0.228817832, 0.1609966951) - -# mean, sigma to use for Gaussian tests -gMean = 4.7 -gSigma = 3.2 -# the right answer for the first three Gaussian deviates produced from testseed -gResult = (-2.1568953985, 2.3232138032, 1.5308165692) - -# N, p to use for binomial tests -bN = 10 -bp = 0.7 -# the right answer for the first three binomial deviates produced from testseed -bResult = (5, 8, 7) - -# mean to use for Poisson tests -pMean = 7 -# the right answer for the first three Poisson deviates produced from testseed -pResult = (6, 11, 4) - -# a & b to use for Weibull tests -wA = 4.0 -wB = 9.0 -# Tabulated results for Weibull -wResult = (3.2106530102, 6.4256210259, 5.8255498741) - -# k & theta to use for Gamma tests -gammaK = 1.5 -gammaTheta = 4.5 -# Tabulated results for Gamma -gammaResult = (10.9318881415, 7.6074550007, 2.0526795529) - -# n to use for Chi2 tests -chi2N = 30 -# Tabulated results for Chi2 -chi2Result = (36.7583415337, 32.7223187231, 23.1555198334) - -# function and min&max to use for DistDeviate function call tests -dmin = 0.0 -dmax = 2.0 - - -def dfunction(x): - return x * x - - -# Tabulated results for DistDeviate function call -dFunctionResult = (0.9826461346196363, 1.1973307331701328, 1.5105900949284945) - -# x and p arrays and interpolant to use for DistDeviate LookupTable tests -dx = [0.0, 1.0, 1.000000001, 2.999999999, 3.0, 4.0] -dp = [0.1, 0.1, 0.0, 0.0, 0.1, 0.1] -dLookupTable = galsim.LookupTable(x=dx, f=dp, interpolant="linear") -# Tabulated results for DistDeviate LookupTable call -dLookupTableResult = (0.23721845680847731, 0.42913599265739233, 0.86176396813243539) -# File with the same values -dLookupTableFile = os.path.join("random_data", "dLookupTable.dat") - - -@timer -def test_uniform(): - """Test uniform random number generator""" - u = galsim.UniformDeviate(testseed) - u2 = u.duplicate() - u3 = galsim.UniformDeviate(u.serialize()) - testResult = (u(), u(), u()) - np.testing.assert_array_almost_equal( - np.array(testResult), - np.array(uResult), - precision, - err_msg="Wrong uniform random number sequence generated", - ) - testResult = (u2(), u2(), u2()) - np.testing.assert_array_almost_equal( - np.array(testResult), - np.array(uResult), - precision, - err_msg="Wrong uniform random number sequence generated with duplicate", - ) - testResult = (u3(), u3(), u3()) - np.testing.assert_array_almost_equal( - np.array(testResult), - np.array(uResult), - precision, - err_msg="Wrong uniform random number sequence generated from serialize", - ) - - # Check that the mean and variance come out right - u = galsim.UniformDeviate(testseed) - vals = [u() for i in range(nvals)] - mean = np.mean(vals) - var = np.var(vals) - mu = 1.0 / 2.0 - v = 1.0 / 12.0 - print("mean = ", mean, " true mean = ", mu) - print("var = ", var, " true var = ", v) - np.testing.assert_almost_equal( - mean, mu, 1, err_msg="Wrong mean from UniformDeviate" - ) - np.testing.assert_almost_equal( - var, v, 1, err_msg="Wrong variance from UniformDeviate" - ) - - # Check discard - u2 = galsim.UniformDeviate(testseed) - u2.discard(nvals) - v1, v2 = u(), u2() - print("after %d vals, next one is %s, %s" % (nvals, v1, v2)) - assert v1 == v2 - assert u.has_reliable_discard - assert not u.generates_in_pairs - - # Check seed, reset - u.seed(testseed) - testResult2 = (u(), u(), u()) - np.testing.assert_array_equal( - np.array(testResult), - np.array(testResult2), - err_msg="Wrong uniform random number sequence generated after seed", - ) - - u.reset(testseed) - testResult2 = (u(), u(), u()) - np.testing.assert_array_equal( - np.array(testResult), - np.array(testResult2), - err_msg="Wrong uniform random number sequence generated after reset(seed)", - ) - - rng = galsim.BaseDeviate(testseed) - u.reset(rng) - testResult2 = (u(), u(), u()) - np.testing.assert_array_equal( - np.array(testResult), - np.array(testResult2), - err_msg="Wrong uniform random number sequence generated after reset(rng)", - ) - - # Check raw - u2.reset(testseed) - u2.discard(3) - np.testing.assert_equal( - u.raw(), u2.raw(), err_msg="Uniform deviates generate different raw values" - ) - - rng2 = galsim.BaseDeviate(testseed) - rng2.discard(4) - np.testing.assert_equal( - rng.raw(), - rng2.raw(), - err_msg="BaseDeviates generate different raw values after discard", - ) - - # Check that two connected uniform deviates work correctly together. - u2 = galsim.UniformDeviate(testseed) - u.reset(u2) - testResult2 = (u(), u2(), u()) - np.testing.assert_array_equal( - np.array(testResult), np.array(testResult2), - err_msg='Wrong uniform random number sequence generated using two uds') - u.seed(testseed) - testResult2 = (u2(), u(), u2()) - np.testing.assert_array_equal( - np.array(testResult), np.array(testResult2), - err_msg='Wrong uniform random number sequence generated using two uds after seed') - - # Check that seeding with the time works (although we cannot check the output). - # We're mostly just checking that this doesn't raise an exception. - # The output could be anything. - u.seed() - testResult2 = (u(), u(), u()) - assert testResult2 != testResult - u.reset() - testResult3 = (u(), u(), u()) - assert testResult3 != testResult - assert testResult3 != testResult2 - u.reset() - testResult4 = (u(), u(), u()) - assert testResult4 != testResult - assert testResult4 != testResult2 - assert testResult4 != testResult3 - u = galsim.UniformDeviate() - testResult5 = (u(), u(), u()) - assert testResult5 != testResult - assert testResult5 != testResult2 - assert testResult5 != testResult3 - assert testResult5 != testResult4 - - # NOTE: these tests differ since we cannot edit arrays in-place in JAX - # Test generate - u.seed(testseed) - test_array = np.empty(3) - test_array = u.generate(test_array) - np.testing.assert_array_almost_equal( - test_array, - np.array(uResult), - precision, - err_msg="Wrong uniform random number sequence from generate.", - ) - - # Test add_generate - u.seed(testseed) - test_array = u.add_generate(test_array) - np.testing.assert_array_almost_equal( - test_array, - 2.0 * np.array(uResult), - precision, - err_msg="Wrong uniform random number sequence from generate.", - ) - - # Test generate with a float32 array - u.seed(testseed) - test_array = np.empty(3, dtype=np.float32) - test_array = u.generate(test_array) - np.testing.assert_array_almost_equal( - test_array, - np.array(uResult), - precisionF, - err_msg="Wrong uniform random number sequence from generate.", - ) - - # Test add_generate - u.seed(testseed) - test_array = u.add_generate(test_array) - np.testing.assert_array_almost_equal( - test_array, - 2.0 * np.array(uResult), - precisionF, - err_msg="Wrong uniform random number sequence from generate.", - ) - - # Check that generated values are independent of number of threads. - u1 = galsim.UniformDeviate(testseed) - u2 = galsim.UniformDeviate(testseed) - v1 = np.empty(555) - v2 = np.empty(555) - with single_threaded(): - v1 = u1.generate(v1) - with single_threaded(num_threads=10): - v2 = u2.generate(v2) - np.testing.assert_array_equal(v1, v2) - with single_threaded(): - v1 = u1.add_generate(v1) - with single_threaded(num_threads=10): - v2 = u2.add_generate(v2) - np.testing.assert_array_equal(v1, v2) - - # Check picklability - check_pickle(u, lambda x: x.serialize()) - check_pickle(u, lambda x: (x(), x(), x(), x())) - check_pickle(u) - check_pickle(rng) - assert "UniformDeviate" in repr(u) - assert "UniformDeviate" in str(u) - assert isinstance(eval(repr(u)), galsim.UniformDeviate) - assert isinstance(eval(str(u)), galsim.UniformDeviate) - assert isinstance(eval(repr(rng)), galsim.BaseDeviate) - assert isinstance(eval(str(rng)), galsim.BaseDeviate) - - # Check that we can construct a UniformDeviate from None, and that it depends on dev/random. - u1 = galsim.UniformDeviate(None) - u2 = galsim.UniformDeviate(None) - assert u1 != u2, "Consecutive UniformDeviate(None) compared equal!" - - # NOTE: We do not test for these since we do no type checking in JAX - # # We shouldn't be able to construct a UniformDeviate from anything but a BaseDeviate, int, str, - # # or None. - # assert_raises(TypeError, galsim.UniformDeviate, dict()) - # assert_raises(TypeError, galsim.UniformDeviate, list()) - # assert_raises(TypeError, galsim.UniformDeviate, set()) - - # assert_raises(TypeError, u.seed, '123') - # assert_raises(TypeError, u.seed, 12.3) - - -@timer -def test_gaussian(): - """Test Gaussian random number generator - """ - g = galsim.GaussianDeviate(testseed, mean=gMean, sigma=gSigma) - g2 = g.duplicate() - g3 = galsim.GaussianDeviate(g.serialize(), mean=gMean, sigma=gSigma) - testResult = (g(), g(), g()) - np.testing.assert_array_almost_equal( - np.array(testResult), np.array(gResult), precision, - err_msg='Wrong Gaussian random number sequence generated') - testResult = (g2(), g2(), g2()) - np.testing.assert_array_almost_equal( - np.array(testResult), np.array(gResult), precision, - err_msg='Wrong Gaussian random number sequence generated with duplicate') - testResult = (g3(), g3(), g3()) - np.testing.assert_array_almost_equal( - np.array(testResult), np.array(gResult), precision, - err_msg='Wrong Gaussian random number sequence generated from serialize') - - # Check that the mean and variance come out right - g = galsim.GaussianDeviate(testseed, mean=gMean, sigma=gSigma) - vals = [g() for i in range(nvals)] - mean = np.mean(vals) - var = np.var(vals) - mu = gMean - v = gSigma**2 - print('mean = ', mean, ' true mean = ', mu) - print('var = ', var, ' true var = ', v) - np.testing.assert_almost_equal( - mean, mu, 1, - err_msg='Wrong mean from GaussianDeviate') - np.testing.assert_almost_equal( - var, v, 0, - err_msg='Wrong variance from GaussianDeviate') - - # Check discard - g2 = galsim.GaussianDeviate(testseed, mean=gMean, sigma=gSigma) - g2.discard(nvals) - v1, v2 = g(), g2() - print('after %d vals, next one is %s, %s' % (nvals, v1, v2)) - assert v1 == v2 - # NOTE: JAX doesn't appear to have this issue - # Note: For Gaussian, this only works if nvals is even. - # g2 = galsim.GaussianDeviate(testseed, mean=gMean, sigma=gSigma) - # g2.discard(nvals+1, suppress_warnings=True) - # v1,v2 = g(),g2() - # print('after %d vals, next one is %s, %s'%(nvals+1,v1,v2)) - # assert v1 != v2 - assert g.has_reliable_discard - # NOTE changed to NOT here for JAX - assert not g.generates_in_pairs - - # NOTE: JAX doesn't warn for this - # If don't explicitly suppress the warning, then a warning is emitted when n is odd. - # g2 = galsim.GaussianDeviate(testseed, mean=gMean, sigma=gSigma) - # with assert_warns(galsim.GalSimWarning): - # g2.discard(nvals+1) - - # Check seed, reset - g.seed(testseed) - testResult2 = (g(), g(), g()) - np.testing.assert_array_equal( - np.array(testResult), np.array(testResult2), - err_msg='Wrong Gaussian random number sequence generated after seed') - - g.reset(testseed) - testResult2 = (g(), g(), g()) - np.testing.assert_array_equal( - np.array(testResult), np.array(testResult2), - err_msg='Wrong Gaussian random number sequence generated after reset(seed)') - - rng = galsim.BaseDeviate(testseed) - g.reset(rng) - testResult2 = (g(), g(), g()) - np.testing.assert_array_equal( - np.array(testResult), np.array(testResult2), - err_msg='Wrong Gaussian random number sequence generated after reset(rng)') - - ud = galsim.UniformDeviate(testseed) - g.reset(ud) - testResult = (g(), g(), g()) - np.testing.assert_array_equal( - np.array(testResult), np.array(testResult2), - err_msg='Wrong Gaussian random number sequence generated after reset(ud)') - - # Check that two connected Gaussian deviates work correctly together. - g2 = galsim.GaussianDeviate(testseed, mean=gMean, sigma=gSigma) - g.reset(g2) - # Note: GaussianDeviate generates two values at a time, so we have to compare them in pairs. - testResult2 = (g(), g(), g2()) - np.testing.assert_array_equal( - np.array(testResult), np.array(testResult2), - err_msg='Wrong Gaussian random number sequence generated using two gds') - g.seed(testseed) - # For the same reason, after seeding one, we need to manually clear the other's cache: - g2.clearCache() - testResult2 = (g2(), g2(), g()) - np.testing.assert_array_equal( - np.array(testResult), np.array(testResult2), - err_msg='Wrong Gaussian random number sequence generated using two gds after seed') - - # Check that seeding with the time works (although we cannot check the output). - # We're mostly just checking that this doesn't raise an exception. - # The output could be anything. - g.seed() - testResult2 = (g(), g(), g()) - assert testResult2 != testResult - g.reset() - testResult3 = (g(), g(), g()) - assert testResult3 != testResult - assert testResult3 != testResult2 - g.reset() - testResult4 = (g(), g(), g()) - assert testResult4 != testResult - assert testResult4 != testResult2 - assert testResult4 != testResult3 - g = galsim.GaussianDeviate(mean=gMean, sigma=gSigma) - testResult5 = (g(), g(), g()) - assert testResult5 != testResult - assert testResult5 != testResult2 - assert testResult5 != testResult3 - assert testResult5 != testResult4 - - # Test generate - g.seed(testseed) - test_array = np.empty(3) - test_array.fill(np.nan) - test_array = g.generate(test_array) - np.testing.assert_array_almost_equal( - test_array, np.array(gResult), precision, - err_msg='Wrong Gaussian random number sequence from generate.') - - # Test generate_from_variance. - g2 = galsim.GaussianDeviate(testseed, mean=5, sigma=0.3) - g3 = galsim.GaussianDeviate(testseed, mean=5, sigma=0.3) - test_array = np.empty(3) - test_array.fill(gSigma**2) - test_array = g2.generate_from_variance(test_array) - np.testing.assert_array_almost_equal( - test_array, np.array(gResult) - gMean, precision, - err_msg='Wrong Gaussian random number sequence from generate_from_variance.') - # NOTE JAX can use the array shape here - # After running generate_from_variance, it should be back to using the specified mean, sigma. - # Note: need to round up to even number for discard, since gd generates 2 at a time. - g3.discard(len(test_array)) - print('g2,g3 = ', g2(), g3()) - assert g2() == g3() - - # Test generate with a float32 array. - g.seed(testseed) - test_array = np.empty(3, dtype=np.float32) - test_array.fill(np.nan) - test_array = g.generate(test_array) - np.testing.assert_array_almost_equal( - test_array, np.array(gResult), precisionF, - err_msg='Wrong Gaussian random number sequence from generate.') - - # Test generate_from_variance. - g2.seed(testseed) - test_array = np.empty(3, dtype=np.float32) - test_array.fill(gSigma**2) - test_array = g2.generate_from_variance(test_array) - np.testing.assert_array_almost_equal( - test_array, np.array(gResult) - gMean, precisionF, - err_msg='Wrong Gaussian random number sequence from generate_from_variance.') - - # Check that generated values are independent of number of threads. - g1 = galsim.GaussianDeviate(testseed, mean=53, sigma=1.3) - g2 = galsim.GaussianDeviate(testseed, mean=53, sigma=1.3) - v1 = np.empty(555) - v2 = np.empty(555) - with single_threaded(): - v1 = g1.generate(v1) - with single_threaded(num_threads=10): - v2 = g2.generate(v2) - np.testing.assert_array_equal(v1, v2) - with single_threaded(): - v1 = g1.add_generate(v1) - with single_threaded(num_threads=10): - v2 = g2.add_generate(v2) - np.testing.assert_array_equal(v1, v2) - ud = galsim.UniformDeviate(testseed + 3) - ud.generate(v1) - v1 += 6.7 - v2 = v1.copy() - with single_threaded(): - v1 = g1.generate_from_variance(v1) - with single_threaded(num_threads=10): - v2 = g2.generate_from_variance(v2) - np.testing.assert_array_equal(v1, v2) - - # Check picklability - check_pickle(g, lambda x: (x.serialize(), x.mean, x.sigma)) - check_pickle(g, lambda x: (x(), x(), x(), x())) - check_pickle(g) - assert 'GaussianDeviate' in repr(g) - assert 'GaussianDeviate' in str(g) - assert isinstance(eval(repr(g)), galsim.GaussianDeviate) - assert isinstance(eval(str(g)), galsim.GaussianDeviate) - - # Check that we can construct a GaussianDeviate from None, and that it depends on dev/random. - g1 = galsim.GaussianDeviate(None) - g2 = galsim.GaussianDeviate(None) - assert g1 != g2, "Consecutive GaussianDeviate(None) compared equal!" - - # NOTE: We do not test for these since we do no type checking in JAX - # We shouldn't be able to construct a GaussianDeviate from anything but a BaseDeviate, int, str, - # or None. - # assert_raises(TypeError, galsim.GaussianDeviate, dict()) - # assert_raises(TypeError, galsim.GaussianDeviate, list()) - # assert_raises(TypeError, galsim.GaussianDeviate, set()) - # assert_raises(ValueError, galsim.GaussianDeviate, testseed, mean=1, sigma=-1) - - -@timer -def test_binomial(): - """Test binomial random number generator - """ - - b = galsim.BinomialDeviate(testseed, N=bN, p=bp) - b2 = b.duplicate() - b3 = galsim.BinomialDeviate(b.serialize(), N=bN, p=bp) - testResult = (b(), b(), b()) - np.testing.assert_array_almost_equal( - np.array(testResult), np.array(bResult), precision, - err_msg='Wrong binomial random number sequence generated') - testResult = (b2(), b2(), b2()) - np.testing.assert_array_almost_equal( - np.array(testResult), np.array(bResult), precision, - err_msg='Wrong binomial random number sequence generated with duplicate') - testResult = (b3(), b3(), b3()) - np.testing.assert_array_almost_equal( - np.array(testResult), np.array(bResult), precision, - err_msg='Wrong binomial random number sequence generated from serialize') - - # Check that the mean and variance come out right - b = galsim.BinomialDeviate(testseed, N=bN, p=bp) - vals = [b() for i in range(nvals)] - mean = np.mean(vals) - var = np.var(vals) - mu = bN * bp - v = bN * bp * (1. - bp) - print('mean = ', mean, ' true mean = ', mu) - print('var = ', var, ' true var = ', v) - np.testing.assert_almost_equal( - mean, mu, 1, - err_msg='Wrong mean from BinomialDeviate') - np.testing.assert_almost_equal( - var, v, 1, - err_msg='Wrong variance from BinomialDeviate') - - # Check discard - b2 = galsim.BinomialDeviate(testseed, N=bN, p=bp) - b2.discard(nvals) - v1, v2 = b(), b2() - print('after %d vals, next one is %s, %s' % (nvals, v1, v2)) - assert v1 == v2 - assert b.has_reliable_discard - assert not b.generates_in_pairs - - # Check seed, reset - b.seed(testseed) - testResult2 = (b(), b(), b()) - np.testing.assert_array_equal( - np.array(testResult), np.array(testResult2), - err_msg='Wrong binomial random number sequence generated after seed') - - b.reset(testseed) - testResult2 = (b(), b(), b()) - np.testing.assert_array_equal( - np.array(testResult), np.array(testResult2), - err_msg='Wrong binomial random number sequence generated after reset(seed)') - - rng = galsim.BaseDeviate(testseed) - b.reset(rng) - testResult2 = (b(), b(), b()) - np.testing.assert_array_equal( - np.array(testResult), np.array(testResult2), - err_msg='Wrong binomial random number sequence generated after reset(rng)') - - ud = galsim.UniformDeviate(testseed) - b.reset(ud) - testResult = (b(), b(), b()) - np.testing.assert_array_equal( - np.array(testResult), np.array(testResult2), - err_msg='Wrong binomial random number sequence generated after reset(ud)') - - # Check that two connected binomial deviates work correctly together. - b2 = galsim.BinomialDeviate(testseed, N=bN, p=bp) - b.reset(b2) - testResult2 = (b(), b2(), b()) - np.testing.assert_array_equal( - np.array(testResult), np.array(testResult2), - err_msg='Wrong binomial random number sequence generated using two bds') - b.seed(testseed) - testResult2 = (b2(), b(), b2()) - np.testing.assert_array_equal( - np.array(testResult), np.array(testResult2), - err_msg='Wrong binomial random number sequence generated using two bds after seed') - - # Check that seeding with the time works (although we cannot check the output). - # We're mostly just checking that this doesn't raise an exception. - # The output could be anything. However, in this case, there are few enough options - # for the output that occasionally two of these match. So we don't do the normal - # testResult2 != testResult, etc. - b.seed() - testResult2 = (b(), b(), b()) - b.reset() - testResult3 = (b(), b(), b()) - b.reset() - testResult4 = (b(), b(), b()) - b = galsim.BinomialDeviate(testseed, N=bN, p=bp) - testResult5 = (b(), b(), b()) - assert ( - (testResult2 != testResult) - or (testResult3 != testResult) - or (testResult4 != testResult) - or (testResult5 != testResult) - ) - try: - assert testResult3 != testResult2 - assert testResult4 != testResult2 - assert testResult4 != testResult3 - assert testResult5 != testResult2 - assert testResult5 != testResult3 - assert testResult5 != testResult4 - except AssertionError: - print("one of the poisson results was equal but this can happen occasionally") - - # Test generate - b.seed(testseed) - test_array = np.empty(3) - test_array = b.generate(test_array) - np.testing.assert_array_almost_equal( - test_array, np.array(bResult), precision, - err_msg='Wrong binomial random number sequence from generate.') - - # Test generate with an int array - b.seed(testseed) - test_array = np.empty(3, dtype=int) - test_array = b.generate(test_array) - np.testing.assert_array_almost_equal( - test_array, np.array(bResult), precisionI, - err_msg='Wrong binomial random number sequence from generate.') - - # Check that generated values are independent of number of threads. - b1 = galsim.BinomialDeviate(testseed, N=17, p=0.7) - b2 = galsim.BinomialDeviate(testseed, N=17, p=0.7) - v1 = np.empty(555) - v2 = np.empty(555) - with single_threaded(): - v1 = b1.generate(v1) - with single_threaded(num_threads=10): - v2 = b2.generate(v2) - np.testing.assert_array_equal(v1, v2) - with single_threaded(): - v1 = b1.add_generate(v1) - with single_threaded(num_threads=10): - v2 = b2.add_generate(v2) - np.testing.assert_array_equal(v1, v2) - - # Check picklability - check_pickle(b, lambda x: (x.serialize(), x.n, x.p)) - check_pickle(b, lambda x: (x(), x(), x(), x())) - check_pickle(b) - assert 'BinomialDeviate' in repr(b) - assert 'BinomialDeviate' in str(b) - assert isinstance(eval(repr(b)), galsim.BinomialDeviate) - assert isinstance(eval(str(b)), galsim.BinomialDeviate) - - # Check that we can construct a BinomialDeviate from None, and that it depends on dev/random. - b1 = galsim.BinomialDeviate(None) - b2 = galsim.BinomialDeviate(None) - assert b1 != b2, "Consecutive BinomialDeviate(None) compared equal!" - # NOTE JAX does not do type checking - # # We shouldn't be able to construct a BinomialDeviate from anything but a BaseDeviate, int, str, - # # or None. - # assert_raises(TypeError, galsim.BinomialDeviate, dict()) - # assert_raises(TypeError, galsim.BinomialDeviate, list()) - # assert_raises(TypeError, galsim.BinomialDeviate, set()) - - -@timer -def test_poisson(): - """Test Poisson random number generator - """ - p = galsim.PoissonDeviate(testseed, mean=pMean) - p2 = p.duplicate() - p3 = galsim.PoissonDeviate(p.serialize(), mean=pMean) - testResult = (p(), p(), p()) - np.testing.assert_array_almost_equal( - np.array(testResult), np.array(pResult), precision, - err_msg='Wrong Poisson random number sequence generated') - testResult = (p2(), p2(), p2()) - np.testing.assert_array_almost_equal( - np.array(testResult), np.array(pResult), precision, - err_msg='Wrong Poisson random number sequence generated with duplicate') - testResult = (p3(), p3(), p3()) - np.testing.assert_array_almost_equal( - np.array(testResult), np.array(pResult), precision, - err_msg='Wrong Poisson random number sequence generated from serialize') - - # Check that the mean and variance come out right - p = galsim.PoissonDeviate(testseed, mean=pMean) - vals = [p() for i in range(nvals)] - mean = np.mean(vals) - var = np.var(vals) - mu = pMean - v = pMean - print('mean = ', mean, ' true mean = ', mu) - print('var = ', var, ' true var = ', v) - np.testing.assert_almost_equal( - mean, mu, 1, - err_msg='Wrong mean from PoissonDeviate') - np.testing.assert_almost_equal( - var, v, 1, - err_msg='Wrong variance from PoissonDeviate') - - # Check discard - p2 = galsim.PoissonDeviate(testseed, mean=pMean) - p2.discard(nvals, suppress_warnings=True) - v1, v2 = p(), p2() - print('With mean = %d, after %d vals, next one is %s, %s' % (pMean, nvals, v1, v2)) - assert v1 == v2 - - # NOTE: the JAX RNGs have reliabel discard - # With a very small mean value, Poisson reliably only uses 1 rng per value. - # But at only slightly larger means, it sometimes uses two rngs for a single value. - # Basically anything >= 10 causes this next test to have v1 != v2 - high_mean = 10 - p = galsim.PoissonDeviate(testseed, mean=high_mean) - p2 = galsim.PoissonDeviate(testseed, mean=high_mean) - vals = [p() for i in range(nvals)] - p2.discard(nvals, suppress_warnings=True) - v1, v2 = p(), p2() - print('With mean = %d, after %d vals, next one is %s, %s' % (high_mean, nvals, v1, v2)) - assert v1 == v2 - assert p.has_reliable_discard - assert not p.generates_in_pairs - - # NOTE jax-galsim doesn't do this - # Discard normally emits a warning for Poisson - # p2 = galsim.PoissonDeviate(testseed, mean=pMean) - # with assert_warns(galsim.GalSimWarning): - # p2.discard(nvals) - - # Check seed, reset - p = galsim.PoissonDeviate(testseed, mean=pMean) - p.seed(testseed) - testResult2 = (p(), p(), p()) - np.testing.assert_array_equal( - np.array(testResult), np.array(testResult2), - err_msg='Wrong poisson random number sequence generated after seed') - - p.reset(testseed) - testResult2 = (p(), p(), p()) - np.testing.assert_array_equal( - np.array(testResult), np.array(testResult2), - err_msg='Wrong poisson random number sequence generated after reset(seed)') - - rng = galsim.BaseDeviate(testseed) - p.reset(rng) - testResult2 = (p(), p(), p()) - np.testing.assert_array_equal( - np.array(testResult), np.array(testResult2), - err_msg='Wrong poisson random number sequence generated after reset(rng)') - - ud = galsim.UniformDeviate(testseed) - p.reset(ud) - testResult = (p(), p(), p()) - np.testing.assert_array_equal( - np.array(testResult), np.array(testResult2), - err_msg='Wrong poisson random number sequence generated after reset(ud)') - - # Check that two connected poisson deviates work correctly together. - p2 = galsim.PoissonDeviate(testseed, mean=pMean) - p.reset(p2) - testResult2 = (p(), p2(), p()) - np.testing.assert_array_equal( - np.array(testResult), np.array(testResult2), - err_msg='Wrong poisson random number sequence generated using two pds') - p.seed(testseed) - testResult2 = (p2(), p(), p2()) - np.testing.assert_array_equal( - np.array(testResult), np.array(testResult2), - err_msg='Wrong poisson random number sequence generated using two pds after seed') - - # Check that seeding with the time works (although we cannot check the output). - # We're mostly just checking that this doesn't raise an exception. - # The output could be anything. However, in this case, there are few enough options - # for the output that occasionally two of these match. So we don't do the normal - # testResult2 != testResult, etc. - p.seed() - testResult2 = (p(), p(), p()) - p.reset() - testResult3 = (p(), p(), p()) - p.reset() - testResult4 = (p(), p(), p()) - p = galsim.PoissonDeviate(mean=pMean) - testResult5 = (p(), p(), p()) - assert ( - (testResult2 != testResult) - or (testResult3 != testResult) - or (testResult4 != testResult) - or (testResult5 != testResult) - ) - try: - assert testResult3 != testResult2 - assert testResult4 != testResult2 - assert testResult4 != testResult3 - assert testResult5 != testResult2 - assert testResult5 != testResult3 - assert testResult5 != testResult4 - except AssertionError: - print("one of the poisson results was equal but this can happen occasionally") - - # Test generate - p.seed(testseed) - test_array = np.empty(3) - test_array.fill(np.nan) - test_array = p.generate(test_array) - np.testing.assert_array_almost_equal( - test_array, np.array(pResult), precision, - err_msg='Wrong poisson random number sequence from generate.') - - # Test generate with an int array - p.seed(testseed) - test_array = np.empty(3, dtype=int) - test_array = p.generate(test_array) - np.testing.assert_array_almost_equal( - test_array, np.array(pResult), precisionI, - err_msg='Wrong poisson random number sequence from generate.') - - # Test generate_from_expectation - p2 = galsim.PoissonDeviate(testseed, mean=77) - test_array = np.array([pMean] * 3, dtype=int) - test_array = p2.generate_from_expectation(test_array) - np.testing.assert_array_almost_equal( - test_array, np.array(pResult), precisionI, - err_msg='Wrong poisson random number sequence from generate_from_expectation.') - # After generating, it should be back to mean=77 - test_array2 = np.array([p2() for i in range(100)]) - print('test_array2 = ', test_array2) - print('mean = ', test_array2.mean()) - assert np.isclose(test_array2.mean(), 77, atol=2) - - # Check that generated values are independent of number of threads. - # This should be trivial, since Poisson disables multi-threading, but check anyway. - p1 = galsim.PoissonDeviate(testseed, mean=77) - p2 = galsim.PoissonDeviate(testseed, mean=77) - v1 = np.empty(555) - v2 = np.empty(555) - with single_threaded(): - v1 = p1.generate(v1) - with single_threaded(num_threads=10): - v2 = p2.generate(v2) - np.testing.assert_array_equal(v1, v2) - with single_threaded(): - v1 = p1.add_generate(v1) - with single_threaded(num_threads=10): - v2 = p2.add_generate(v2) - np.testing.assert_array_equal(v1, v2) - - # Check picklability - check_pickle(p, lambda x: (x.serialize(), x.mean)) - check_pickle(p, lambda x: (x(), x(), x(), x())) - check_pickle(p) - assert 'PoissonDeviate' in repr(p) - assert 'PoissonDeviate' in str(p) - assert isinstance(eval(repr(p)), galsim.PoissonDeviate) - assert isinstance(eval(str(p)), galsim.PoissonDeviate) - - # Check that we can construct a PoissonDeviate from None, and that it depends on dev/random. - p1 = galsim.PoissonDeviate(None) - p2 = galsim.PoissonDeviate(None) - assert p1 != p2, "Consecutive PoissonDeviate(None) compared equal!" - # NOTE we do not use type checking in JAX - # # We shouldn't be able to construct a PoissonDeviate from anything but a BaseDeviate, int, str, - # # or None. - # assert_raises(TypeError, galsim.PoissonDeviate, dict()) - # assert_raises(TypeError, galsim.PoissonDeviate, list()) - # assert_raises(TypeError, galsim.PoissonDeviate, set()) - - -@timer -def test_poisson_highmean(): - # NOTE JAX has the same issues as boost with high mean poisson RVs but - # it happens at lower mean values - mean_vals = [ - 2**17 - 50, # Uses Poisson - 2**17, # Uses Poisson (highest value of mean that does) - 2**17 + 50, # Uses Gaussian - 2**18, # This is where problems happen if not using Gaussian - 5.e20, # Definitely would have problems with normal implementation. - ] - - nvals = 100000 - rtol_var = 1.e-2 - - for mean in mean_vals: - print('Test PoissonDeviate with mean = ', np.log(mean) / np.log(2)) - p = galsim.PoissonDeviate(testseed, mean=mean) - p2 = p.duplicate() - p3 = galsim.PoissonDeviate(p.serialize(), mean=mean) - testResult = (p(), p(), p()) - testResult2 = (p2(), p2(), p2()) - testResult3 = (p3(), p3(), p3()) - np.testing.assert_allclose( - testResult2, testResult, rtol=1.e-8, - err_msg='PoissonDeviate.duplicate not equivalent for mean=%s' % mean) - np.testing.assert_allclose( - testResult3, testResult, rtol=1.e-8, - err_msg='PoissonDeviate from serialize not equivalent for mean=%s' % mean) - - # Check that the mean and variance come out right - p = galsim.PoissonDeviate(testseed, mean=mean) - vals = [p() for i in range(nvals)] - mu = np.mean(vals) - var = np.var(vals) - print("rtol = ", 3 * np.sqrt(mean / nvals) / mean) - print('mean = ', mu, ' true mean = ', mean) - print('var = ', var, ' true var = ', mean) - np.testing.assert_allclose( - mu, mean, rtol=3 * np.sqrt(mean / nvals) / mean, - err_msg='Wrong mean from PoissonDeviate with mean=%s' % mean) - np.testing.assert_allclose( - var, mean, rtol=rtol_var, - err_msg='Wrong variance from PoissonDeviate with mean=%s' % mean) - - # Check discard - p2 = galsim.PoissonDeviate(testseed, mean=mean) - p2.discard(nvals, suppress_warnings=True) - v1, v2 = p(), p2() - print('after %d vals, next one is %s, %s' % (nvals, v1, v2)) - assert v1 == v2 - - # Check that two connected poisson deviates work correctly together. - p2 = galsim.PoissonDeviate(testseed, mean=mean) - p.reset(p2) - testResult2 = (p(), p(), p2()) - np.testing.assert_array_equal( - testResult2, testResult, - err_msg='Wrong poisson random number sequence generated using two pds') - p.seed(testseed) - p2.clearCache() - testResult2 = (p2(), p2(), p()) - np.testing.assert_array_equal( - testResult2, testResult, - err_msg='Wrong poisson random number sequence generated using two pds after seed') - - # Test filling an image - p.seed(testseed) - testimage = galsim.ImageD(np.zeros((3, 1))) - testimage.addNoise(galsim.DeviateNoise(p)) - np.testing.assert_array_equal( - testimage.array.flatten(), testResult, - err_msg='Wrong poisson random number sequence generated when applied to image.') - - rng = galsim.BaseDeviate(testseed) - pn = galsim.PoissonNoise(rng, sky_level=mean) - testimage.fill(0) - testimage.addNoise(pn) - np.testing.assert_array_equal( - testimage.array.flatten(), np.array(testResult) - mean, - err_msg='Wrong poisson random number sequence generated using PoissonNoise') - - # Check PoissonNoise variance: - np.testing.assert_allclose( - pn.getVariance(), mean, rtol=1.e-8, - err_msg="PoissonNoise getVariance returns wrong variance") - np.testing.assert_allclose( - pn.sky_level, mean, rtol=1.e-8, - err_msg="PoissonNoise sky_level returns wrong value") - - # Check that the noise model really does produce this variance. - big_im = galsim.Image(2048, 2048, dtype=float) - big_im.addNoise(pn) - var = np.var(big_im.array) - print('variance = ', var) - print('getVar = ', pn.getVariance()) - np.testing.assert_allclose( - var, pn.getVariance(), rtol=rtol_var, - err_msg='Realized variance for PoissonNoise did not match getVariance()') - - -@timer -def test_poisson_zeromean(): - """Make sure Poisson Deviate behaves sensibly when mean=0. - """ - p = galsim.PoissonDeviate(testseed, mean=0) - p2 = p.duplicate() - p3 = galsim.PoissonDeviate(p.serialize(), mean=0) - check_pickle(p) - - # Test direct draws - testResult = (p(), p(), p()) - testResult2 = (p2(), p2(), p2()) - testResult3 = (p3(), p3(), p3()) - np.testing.assert_array_equal(testResult, 0) - np.testing.assert_array_equal(testResult2, 0) - np.testing.assert_array_equal(testResult3, 0) - - # Test generate - test_array = np.empty(3, dtype=int) - p.generate(test_array) - np.testing.assert_array_equal(test_array, 0) - p2.generate(test_array) - np.testing.assert_array_equal(test_array, 0) - p3.generate(test_array) - np.testing.assert_array_equal(test_array, 0) - - # Test generate_from_expectation - test_array = np.array([0, 0, 0]) - np.testing.assert_allclose(test_array, 0) - test_array = np.array([1, 0, 4]) - assert test_array[0] != 0 - assert test_array[1] == 0 - assert test_array[2] != 0 - - # NOTE JAX does not raise an error for this - # # Error raised if mean<0 - # with assert_raises(ValueError): - # p = galsim.PoissonDeviate(testseed, mean=-0.1) - # with assert_raises(ValueError): - # p = galsim.PoissonDeviate(testseed, mean=-10) - # test_array = np.array([-1,1,4]) - # with assert_raises(ValueError): - # p.generate_from_expectation(test_array) - # test_array = np.array([1,-1,-4]) - # with assert_raises(ValueError): - # p.generate_from_expectation(test_array) - - -@timer -def test_weibull(): - """Test Weibull random number generator - """ - w = galsim.WeibullDeviate(testseed, a=wA, b=wB) - w2 = w.duplicate() - w3 = galsim.WeibullDeviate(w.serialize(), a=wA, b=wB) - testResult = (w(), w(), w()) - np.testing.assert_array_almost_equal( - np.array(testResult), np.array(wResult), precision, - err_msg='Wrong Weibull random number sequence generated') - testResult = (w2(), w2(), w2()) - np.testing.assert_array_almost_equal( - np.array(testResult), np.array(wResult), precision, - err_msg='Wrong Weibull random number sequence generated with duplicate') - testResult = (w3(), w3(), w3()) - np.testing.assert_array_almost_equal( - np.array(testResult), np.array(wResult), precision, - err_msg='Wrong Weibull random number sequence generated from serialize') - - # Check that the mean and variance come out right - w = galsim.WeibullDeviate(testseed, a=wA, b=wB) - vals = [w() for i in range(nvals)] - mean = np.mean(vals) - var = np.var(vals) - gammaFactor1 = math.gamma(1. + 1. / wA) - gammaFactor2 = math.gamma(1. + 2. / wA) - mu = wB * gammaFactor1 - v = wB**2 * gammaFactor2 - mu**2 - print('mean = ', mean, ' true mean = ', mu) - print('var = ', var, ' true var = ', v) - np.testing.assert_almost_equal( - mean, mu, 1, - err_msg='Wrong mean from WeibullDeviate') - np.testing.assert_almost_equal( - var, v, 1, - err_msg='Wrong variance from WeibullDeviate') - - # Check discard - w2 = galsim.WeibullDeviate(testseed, a=wA, b=wB) - w2.discard(nvals) - v1, v2 = w(), w2() - print('after %d vals, next one is %s, %s' % (nvals, v1, v2)) - assert v1 == v2 - assert w.has_reliable_discard - assert not w.generates_in_pairs - - # Check seed, reset - w.seed(testseed) - testResult2 = (w(), w(), w()) - np.testing.assert_array_equal( - np.array(testResult), np.array(testResult2), - err_msg='Wrong weibull random number sequence generated after seed') - - w.reset(testseed) - testResult2 = (w(), w(), w()) - np.testing.assert_array_equal( - np.array(testResult), np.array(testResult2), - err_msg='Wrong weibull random number sequence generated after reset(seed)') - - rng = galsim.BaseDeviate(testseed) - w.reset(rng) - testResult2 = (w(), w(), w()) - np.testing.assert_array_equal( - np.array(testResult), np.array(testResult2), - err_msg='Wrong weibull random number sequence generated after reset(rng)') - - ud = galsim.UniformDeviate(testseed) - w.reset(ud) - testResult = (w(), w(), w()) - np.testing.assert_array_equal( - np.array(testResult), np.array(testResult2), - err_msg='Wrong weibull random number sequence generated after reset(ud)') - - # NOTE JAX does not allow connected deviates - # # Check that two connected weibull deviates work correctly together. - # w2 = galsim.WeibullDeviate(testseed, a=wA, b=wB) - # w.reset(w2) - # testResult2 = (w(), w2(), w()) - # np.testing.assert_array_equal( - # np.array(testResult), np.array(testResult2), - # err_msg='Wrong weibull random number sequence generated using two wds') - # w.seed(testseed) - # testResult2 = (w2(), w(), w2()) - # np.testing.assert_array_equal( - # np.array(testResult), np.array(testResult2), - # err_msg='Wrong weibull random number sequence generated using two wds after seed') - - # Check that seeding with the time works (although we cannot check the output). - # We're mostly just checking that this doesn't raise an exception. - # The output could be anything. - w.seed() - testResult2 = (w(), w(), w()) - assert testResult2 != testResult - w.reset() - testResult3 = (w(), w(), w()) - assert testResult3 != testResult - assert testResult3 != testResult2 - w.reset() - testResult4 = (w(), w(), w()) - assert testResult4 != testResult - assert testResult4 != testResult2 - assert testResult4 != testResult3 - w = galsim.WeibullDeviate(a=wA, b=wB) - testResult5 = (w(), w(), w()) - assert testResult5 != testResult - assert testResult5 != testResult2 - assert testResult5 != testResult3 - assert testResult5 != testResult4 - - # Test generate - w.seed(testseed) - test_array = np.empty(3) - test_array = w.generate(test_array) - np.testing.assert_array_almost_equal( - test_array, np.array(wResult), precision, - err_msg='Wrong weibull random number sequence from generate.') - - # Test generate with a float32 array - w.seed(testseed) - test_array = np.empty(3, dtype=np.float32) - test_array = w.generate(test_array) - np.testing.assert_array_almost_equal( - test_array, np.array(wResult), precisionF, - err_msg='Wrong weibull random number sequence from generate.') - - # Check that generated values are independent of number of threads. - w1 = galsim.WeibullDeviate(testseed, a=3.1, b=7.3) - w2 = galsim.WeibullDeviate(testseed, a=3.1, b=7.3) - v1 = np.empty(555) - v2 = np.empty(555) - with single_threaded(): - v1 = w1.generate(v1) - with single_threaded(num_threads=10): - v2 = w2.generate(v2) - np.testing.assert_array_equal(v1, v2) - with single_threaded(): - v1 = w1.add_generate(v1) - with single_threaded(num_threads=10): - v2 = w2.add_generate(v2) - np.testing.assert_array_equal(v1, v2) - - # Check picklability - check_pickle(w, lambda x: (x.serialize(), x.a, x.b)) - check_pickle(w, lambda x: (x(), x(), x(), x())) - check_pickle(w) - assert 'WeibullDeviate' in repr(w) - assert 'WeibullDeviate' in str(w) - assert isinstance(eval(repr(w)), galsim.WeibullDeviate) - assert isinstance(eval(str(w)), galsim.WeibullDeviate) - - # Check that we can construct a WeibullDeviate from None, and that it depends on dev/random. - w1 = galsim.WeibullDeviate(None) - w2 = galsim.WeibullDeviate(None) - assert w1 != w2, "Consecutive WeibullDeviate(None) compared equal!" - # NOTE JAX does not do type checking - # # We shouldn't be able to construct a WeibullDeviate from anything but a BaseDeviate, int, str, - # # or None. - # assert_raises(TypeError, galsim.WeibullDeviate, dict()) - # assert_raises(TypeError, galsim.WeibullDeviate, list()) - # assert_raises(TypeError, galsim.WeibullDeviate, set()) - - -@timer -def test_gamma(): - """Test Gamma random number generator - """ - g = galsim.GammaDeviate(testseed, k=gammaK, theta=gammaTheta) - g2 = g.duplicate() - g3 = galsim.GammaDeviate(g.serialize(), k=gammaK, theta=gammaTheta) - testResult = (g(), g(), g()) - np.testing.assert_array_almost_equal( - np.array(testResult), np.array(gammaResult), precision, - err_msg='Wrong Gamma random number sequence generated') - testResult = (g2(), g2(), g2()) - np.testing.assert_array_almost_equal( - np.array(testResult), np.array(gammaResult), precision, - err_msg='Wrong Gamma random number sequence generated with duplicate') - testResult = (g3(), g3(), g3()) - np.testing.assert_array_almost_equal( - np.array(testResult), np.array(gammaResult), precision, - err_msg='Wrong Gamma random number sequence generated from serialize') - - # Check that the mean and variance come out right - g = galsim.GammaDeviate(testseed, k=gammaK, theta=gammaTheta) - vals = [g() for i in range(nvals)] - mean = np.mean(vals) - var = np.var(vals) - mu = gammaK * gammaTheta - v = gammaK * gammaTheta**2 - print('mean = ', mean, ' true mean = ', mu) - print('var = ', var, ' true var = ', v) - np.testing.assert_almost_equal( - mean, mu, 1, - err_msg='Wrong mean from GammaDeviate') - np.testing.assert_almost_equal( - var, v, 0, - err_msg='Wrong variance from GammaDeviate') - - # NOTE jax has a reliable discard - # Check discard - g2 = galsim.GammaDeviate(testseed, k=gammaK, theta=gammaTheta) - g2.discard(nvals, suppress_warnings=True) - v1, v2 = g(), g2() - print('after %d vals, next one is %s, %s' % (nvals, v1, v2)) - # Gamma uses at least 2 rngs per value, but can use arbitrarily more than this. - assert v1 == v2 - assert g.has_reliable_discard - assert not g.generates_in_pairs - - # NOTE jax has a reliable discard - # Discard normally emits a warning for Gamma - # g2 = galsim.GammaDeviate(testseed, k=gammaK, theta=gammaTheta) - # with assert_warns(galsim.GalSimWarning): - # g2.discard(nvals) - - # Check seed, reset - g.seed(testseed) - testResult2 = (g(), g(), g()) - np.testing.assert_array_equal( - np.array(testResult), np.array(testResult2), - err_msg='Wrong gamma random number sequence generated after seed') - - g.reset(testseed) - testResult2 = (g(), g(), g()) - np.testing.assert_array_equal( - np.array(testResult), np.array(testResult2), - err_msg='Wrong gamma random number sequence generated after reset(seed)') - - rng = galsim.BaseDeviate(testseed) - g.reset(rng) - testResult2 = (g(), g(), g()) - np.testing.assert_array_equal( - np.array(testResult), np.array(testResult2), - err_msg='Wrong gamma random number sequence generated after reset(rng)') - - ud = galsim.UniformDeviate(testseed) - g.reset(ud) - testResult = (g(), g(), g()) - np.testing.assert_array_equal( - np.array(testResult), np.array(testResult2), - err_msg='Wrong gamma random number sequence generated after reset(ud)') - - # NOTE jax cannot connect RNGs - # # Check that two connected gamma deviates work correctly together. - # g2 = galsim.GammaDeviate(testseed, k=gammaK, theta=gammaTheta) - # g.reset(g2) - # testResult2 = (g(), g2(), g()) - # np.testing.assert_array_equal( - # np.array(testResult), np.array(testResult2), - # err_msg='Wrong gamma random number sequence generated using two gds') - # g.seed(testseed) - # testResult2 = (g2(), g(), g2()) - # np.testing.assert_array_equal( - # np.array(testResult), np.array(testResult2), - # err_msg='Wrong gamma random number sequence generated using two gds after seed') - - # Check that seeding with the time works (although we cannot check the output). - # We're mostly just checking that this doesn't raise an exception. - # The output could be anything. - g.seed() - testResult2 = (g(), g(), g()) - assert testResult2 != testResult - g.reset() - testResult3 = (g(), g(), g()) - assert testResult3 != testResult - assert testResult3 != testResult2 - g.reset() - testResult4 = (g(), g(), g()) - assert testResult4 != testResult - assert testResult4 != testResult2 - assert testResult4 != testResult3 - g = galsim.GammaDeviate(k=gammaK, theta=gammaTheta) - testResult5 = (g(), g(), g()) - assert testResult5 != testResult - assert testResult5 != testResult2 - assert testResult5 != testResult3 - assert testResult5 != testResult4 - - # Test generate - g.seed(testseed) - test_array = np.empty(3) - test_array = g.generate(test_array) - np.testing.assert_array_almost_equal( - test_array, np.array(gammaResult), precision, - err_msg='Wrong gamma random number sequence from generate.') - - # Test generate with a float32 array - g.seed(testseed) - test_array = np.empty(3, dtype=np.float32) - test_array = g.generate(test_array) - np.testing.assert_array_almost_equal( - test_array, np.array(gammaResult), precisionF, - err_msg='Wrong gamma random number sequence from generate.') - - # Check picklability - check_pickle(g, lambda x: (x.serialize(), x.k, x.theta)) - check_pickle(g, lambda x: (x(), x(), x(), x())) - check_pickle(g) - assert 'GammaDeviate' in repr(g) - assert 'GammaDeviate' in str(g) - assert isinstance(eval(repr(g)), galsim.GammaDeviate) - assert isinstance(eval(str(g)), galsim.GammaDeviate) - - # Check that we can construct a GammaDeviate from None, and that it depends on dev/random. - g1 = galsim.GammaDeviate(None) - g2 = galsim.GammaDeviate(None) - assert g1 != g2, "Consecutive GammaDeviate(None) compared equal!" - # NOTE jax does not raise for type errors - # # We shouldn't be able to construct a GammaDeviate from anything but a BaseDeviate, int, str, - # # or None. - # assert_raises(TypeError, galsim.GammaDeviate, dict()) - # assert_raises(TypeError, galsim.GammaDeviate, list()) - # assert_raises(TypeError, galsim.GammaDeviate, set()) - - -@timer -def test_chi2(): - """Test Chi^2 random number generator - """ - c = galsim.Chi2Deviate(testseed, n=chi2N) - c2 = c.duplicate() - c3 = galsim.Chi2Deviate(c.serialize(), n=chi2N) - testResult = (c(), c(), c()) - np.testing.assert_array_almost_equal( - np.array(testResult), np.array(chi2Result), precision, - err_msg='Wrong Chi^2 random number sequence generated') - testResult = (c2(), c2(), c2()) - np.testing.assert_array_almost_equal( - np.array(testResult), np.array(chi2Result), precision, - err_msg='Wrong Chi^2 random number sequence generated with duplicate') - testResult = (c3(), c3(), c3()) - np.testing.assert_array_almost_equal( - np.array(testResult), np.array(chi2Result), precision, - err_msg='Wrong Chi^2 random number sequence generated from serialize') - - # Check that the mean and variance come out right - c = galsim.Chi2Deviate(testseed, n=chi2N) - vals = [c() for i in range(nvals)] - mean = np.mean(vals) - var = np.var(vals) - mu = chi2N - v = 2. * chi2N - print('mean = ', mean, ' true mean = ', mu) - print('var = ', var, ' true var = ', v) - np.testing.assert_almost_equal( - mean, mu, 1, - err_msg='Wrong mean from Chi2Deviate') - np.testing.assert_almost_equal( - var, v, 0, - err_msg='Wrong variance from Chi2Deviate') - - # NOTE: - # Check discard - c2 = galsim.Chi2Deviate(testseed, n=chi2N) - c2.discard(nvals) - v1, v2 = c(), c2() - print('after %d vals, next one is %s, %s' % (nvals, v1, v2)) - assert v1 == v2 - assert c.has_reliable_discard - assert not c.generates_in_pairs - - # NOTE jax has reliable discard - # # Discard normally emits a warning for Chi2 - # c2 = galsim.Chi2Deviate(testseed, n=chi2N) - # with assert_warns(galsim.GalSimWarning): - # c2.discard(nvals) - - # Check seed, reset - c.seed(testseed) - testResult2 = (c(), c(), c()) - np.testing.assert_array_equal( - np.array(testResult), np.array(testResult2), - err_msg='Wrong Chi^2 random number sequence generated after seed') - - c.reset(testseed) - testResult2 = (c(), c(), c()) - np.testing.assert_array_equal( - np.array(testResult), np.array(testResult2), - err_msg='Wrong Chi^2 random number sequence generated after reset(seed)') - - rng = galsim.BaseDeviate(testseed) - c.reset(rng) - testResult2 = (c(), c(), c()) - np.testing.assert_array_equal( - np.array(testResult), np.array(testResult2), - err_msg='Wrong Chi^2 random number sequence generated after reset(rng)') - - ud = galsim.UniformDeviate(testseed) - c.reset(ud) - testResult = (c(), c(), c()) - np.testing.assert_array_equal( - np.array(testResult), np.array(testResult2), - err_msg='Wrong Chi^2 random number sequence generated after reset(ud)') - - # NOTE we cannot connect two deviates in JAX - # # Check that two connected Chi^2 deviates work correctly together. - # c2 = galsim.Chi2Deviate(testseed, n=chi2N) - # c.reset(c2) - # testResult2 = (c(), c2(), c()) - # np.testing.assert_array_equal( - # np.array(testResult), np.array(testResult2), - # err_msg='Wrong Chi^2 random number sequence generated using two cds') - # c.seed(testseed) - # testResult2 = (c2(), c(), c2()) - # np.testing.assert_array_equal( - # np.array(testResult), np.array(testResult2), - # err_msg='Wrong Chi^2 random number sequence generated using two cds after seed') - - # Check that seeding with the time works (although we cannot check the output). - # We're mostly just checking that this doesn't raise an exception. - # The output could be anything. - c.seed() - testResult2 = (c(), c(), c()) - assert testResult2 != testResult - c.reset() - testResult3 = (c(), c(), c()) - assert testResult3 != testResult - assert testResult3 != testResult2 - c.reset() - testResult4 = (c(), c(), c()) - assert testResult4 != testResult - assert testResult4 != testResult2 - assert testResult4 != testResult3 - c = galsim.Chi2Deviate(n=chi2N) - testResult5 = (c(), c(), c()) - assert testResult5 != testResult - assert testResult5 != testResult2 - assert testResult5 != testResult3 - assert testResult5 != testResult4 - - # Test generate - c.seed(testseed) - test_array = np.empty(3) - test_array = c.generate(test_array) - np.testing.assert_array_almost_equal( - test_array, np.array(chi2Result), precision, - err_msg='Wrong Chi^2 random number sequence from generate.') - - # Test generate with a float32 array - c.seed(testseed) - test_array = np.empty(3, dtype=np.float32) - test_array = c.generate(test_array) - np.testing.assert_array_almost_equal( - test_array, np.array(chi2Result), precisionF, - err_msg='Wrong Chi^2 random number sequence from generate.') - - # Check picklability - check_pickle(c, lambda x: (x.serialize(), x.n)) - check_pickle(c, lambda x: (x(), x(), x(), x())) - check_pickle(c) - assert 'Chi2Deviate' in repr(c) - assert 'Chi2Deviate' in str(c) - assert isinstance(eval(repr(c)), galsim.Chi2Deviate) - assert isinstance(eval(str(c)), galsim.Chi2Deviate) - - # Check that we can construct a Chi2Deviate from None, and that it depends on dev/random. - c1 = galsim.Chi2Deviate(None) - c2 = galsim.Chi2Deviate(None) - assert c1 != c2, "Consecutive Chi2Deviate(None) compared equal!" - # NOTE jax does not raise - # # We shouldn't be able to construct a Chi2Deviate from anything but a BaseDeviate, int, str, - # # or None. - # assert_raises(TypeError, galsim.Chi2Deviate, dict()) - # assert_raises(TypeError, galsim.Chi2Deviate, list()) - # assert_raises(TypeError, galsim.Chi2Deviate, set()) - - -# @timer -# def test_distfunction(): -# """Test distribution-defined random number generator with a function -# """ -# # Make sure it requires an input function in order to work. -# assert_raises(TypeError, galsim.DistDeviate) -# # Make sure it does appropriate input sanity checks. -# assert_raises(TypeError, galsim.DistDeviate, -# function='../examples/data/cosmo-fid.zmed1.00_smoothed.out', -# x_min=1.) -# assert_raises(TypeError, galsim.DistDeviate, function=1.0) -# assert_raises(ValueError, galsim.DistDeviate, function='foo.dat') -# assert_raises(TypeError, galsim.DistDeviate, function = lambda x : x*x, interpolant='linear') -# assert_raises(TypeError, galsim.DistDeviate, function = lambda x : x*x) -# assert_raises(TypeError, galsim.DistDeviate, function = lambda x : x*x, x_min=1.) -# test_vals = range(10) -# assert_raises(TypeError, galsim.DistDeviate, -# function=galsim.LookupTable(test_vals, test_vals), -# x_min = 1.) -# foo = galsim.DistDeviate(10, galsim.LookupTable(test_vals, test_vals)) -# assert_raises(ValueError, foo.val, -1.) -# assert_raises(ValueError, galsim.DistDeviate, function = lambda x : -1, x_min=dmin, x_max=dmax) -# assert_raises(ValueError, galsim.DistDeviate, function = lambda x : x**2-1, x_min=dmin, x_max=dmax) - -# d = galsim.DistDeviate(testseed, function=dfunction, x_min=dmin, x_max=dmax) -# d2 = d.duplicate() -# d3 = galsim.DistDeviate(d.serialize(), function=dfunction, x_min=dmin, x_max=dmax) -# testResult = (d(), d(), d()) -# np.testing.assert_array_almost_equal( -# np.array(testResult), np.array(dFunctionResult), precision, -# err_msg='Wrong DistDeviate random number sequence generated') -# testResult = (d2(), d2(), d2()) -# np.testing.assert_array_almost_equal( -# np.array(testResult), np.array(dFunctionResult), precision, -# err_msg='Wrong DistDeviate random number sequence generated with duplicate') -# testResult = (d3(), d3(), d3()) -# np.testing.assert_array_almost_equal( -# np.array(testResult), np.array(dFunctionResult), precision, -# err_msg='Wrong DistDeviate random number sequence generated from serialize') - -# # Check val() method -# # pdf(x) = x^2 -# # cdf(x) = (x/2)^3 -# # val(y) = 2 y^(1/3) -# np.testing.assert_almost_equal(d.val(0), 0, 4) -# np.testing.assert_almost_equal(d.val(1), 2, 4) -# np.testing.assert_almost_equal(d.val(0.125), 1, 4) -# np.testing.assert_almost_equal(d.val(0.027), 0.6, 4) -# np.testing.assert_almost_equal(d.val(0.512), 1.6, 4) -# u = galsim.UniformDeviate(testseed) -# testResult = (d.val(u()), d.val(u()), d.val(u())) -# np.testing.assert_array_almost_equal( -# np.array(testResult), np.array(dFunctionResult), precision, -# err_msg='Wrong DistDeviate sequence using d.val(u())') - -# # Check that the mean and variance come out right -# d = galsim.DistDeviate(testseed, function=dfunction, x_min=dmin, x_max=dmax) -# vals = [d() for i in range(nvals)] -# mean = np.mean(vals) -# var = np.var(vals) -# mu = 3./2. -# v = 3./20. -# print('mean = ',mean,' true mean = ',mu) -# print('var = ',var,' true var = ',v) -# np.testing.assert_almost_equal(mean, mu, 1, -# err_msg='Wrong mean from DistDeviate random numbers using function') -# np.testing.assert_almost_equal(var, v, 1, -# err_msg='Wrong variance from DistDeviate random numbers using function') - -# # Check discard -# d2 = galsim.DistDeviate(testseed, function=dfunction, x_min=dmin, x_max=dmax) -# d2.discard(nvals) -# v1,v2 = d(),d2() -# print('after %d vals, next one is %s, %s'%(nvals,v1,v2)) -# assert v1 == v2 -# assert d.has_reliable_discard -# assert not d.generates_in_pairs - -# # Check seed, reset -# d.seed(testseed) -# testResult2 = (d(), d(), d()) -# np.testing.assert_array_equal( -# np.array(testResult), np.array(testResult2), -# err_msg='Wrong DistDeviate random number sequence generated after seed') - -# d.reset(testseed) -# testResult2 = (d(), d(), d()) -# np.testing.assert_array_equal( -# np.array(testResult), np.array(testResult2), -# err_msg='Wrong DistDeviate random number sequence generated after reset(seed)') - -# rng = galsim.BaseDeviate(testseed) -# d.reset(rng) -# testResult2 = (d(), d(), d()) -# np.testing.assert_array_equal( -# np.array(testResult), np.array(testResult2), -# err_msg='Wrong DistDeviate random number sequence generated after reset(rng)') - -# ud = galsim.UniformDeviate(testseed) -# d.reset(ud) -# testResult = (d(), d(), d()) -# np.testing.assert_array_equal( -# np.array(testResult), np.array(testResult2), -# err_msg='Wrong DistDeviate random number sequence generated after reset(ud)') - -# # Check that two connected DistDeviate deviates work correctly together. -# d2 = galsim.DistDeviate(testseed, function=dfunction, x_min=dmin, x_max=dmax) -# d.reset(d2) -# testResult2 = (d(), d2(), d()) -# np.testing.assert_array_equal( -# np.array(testResult), np.array(testResult2), -# err_msg='Wrong DistDeviate random number sequence generated using two dds') -# d.seed(testseed) -# testResult2 = (d2(), d(), d2()) -# np.testing.assert_array_equal( -# np.array(testResult), np.array(testResult2), -# err_msg='Wrong DistDeviate random number sequence generated using two dds after seed') - -# # Check that seeding with the time works (although we cannot check the output). -# # We're mostly just checking that this doesn't raise an exception. -# # The output could be anything. -# d.seed() -# testResult2 = (d(), d(), d()) -# assert testResult2 != testResult -# d.reset() -# testResult3 = (d(), d(), d()) -# assert testResult3 != testResult -# assert testResult3 != testResult2 -# d.reset() -# testResult4 = (d(), d(), d()) -# assert testResult4 != testResult -# assert testResult4 != testResult2 -# assert testResult4 != testResult3 -# d = galsim.DistDeviate(function=dfunction, x_min=dmin, x_max=dmax) -# testResult5 = (d(), d(), d()) -# assert testResult5 != testResult -# assert testResult5 != testResult2 -# assert testResult5 != testResult3 -# assert testResult5 != testResult4 - -# # Check with lambda function -# d = galsim.DistDeviate(testseed, function=lambda x: x*x, x_min=dmin, x_max=dmax) -# testResult = (d(), d(), d()) -# np.testing.assert_array_almost_equal( -# np.array(testResult), np.array(dFunctionResult), precision, -# err_msg='Wrong DistDeviate random number sequence generated with lambda function') - -# # Check auto-generated lambda function -# d = galsim.DistDeviate(testseed, function='x*x', x_min=dmin, x_max=dmax) -# testResult = (d(), d(), d()) -# np.testing.assert_array_almost_equal( -# np.array(testResult), np.array(dFunctionResult), precision, -# err_msg='Wrong DistDeviate random number sequence generated with auto-lambda function') - -# # Test generate -# d.seed(testseed) -# test_array = np.empty(3) -# d.generate(test_array) -# np.testing.assert_array_almost_equal( -# test_array, np.array(dFunctionResult), precision, -# err_msg='Wrong DistDeviate random number sequence from generate.') - -# # Test add_generate -# d.seed(testseed) -# d.add_generate(test_array) -# np.testing.assert_array_almost_equal( -# test_array, 2*np.array(dFunctionResult), precision, -# err_msg='Wrong DistDeviate random number sequence from add_generate.') - -# # Test generate with a float32 array -# d.seed(testseed) -# test_array = np.empty(3, dtype=np.float32) -# d.generate(test_array) -# np.testing.assert_array_almost_equal( -# test_array, np.array(dFunctionResult), precisionF, -# err_msg='Wrong DistDeviate random number sequence from generate.') - -# # Test add_generate with a float32 array -# d.seed(testseed) -# d.add_generate(test_array) -# np.testing.assert_array_almost_equal( -# test_array, 2*np.array(dFunctionResult), precisionF, -# err_msg='Wrong DistDeviate random number sequence from add_generate.') - -# # Check that generated values are independent of number of threads. -# d1 = galsim.DistDeviate(testseed, function=lambda x: np.exp(-x**3), x_min=0, x_max=2) -# d2 = galsim.DistDeviate(testseed, function=lambda x: np.exp(-x**3), x_min=0, x_max=2) -# v1 = np.empty(555) -# v2 = np.empty(555) -# with single_threaded(): -# d1.generate(v1) -# with single_threaded(num_threads=10): -# d2.generate(v2) -# np.testing.assert_array_equal(v1, v2) -# with single_threaded(): -# d1.add_generate(v1) -# with single_threaded(num_threads=10): -# d2.add_generate(v2) -# np.testing.assert_array_equal(v1, v2) - -# # Check picklability -# check_pickle(d, lambda x: (x(), x(), x(), x())) -# check_pickle(d) -# assert 'DistDeviate' in repr(d) -# assert 'DistDeviate' in str(d) -# assert isinstance(eval(repr(d)), galsim.DistDeviate) -# assert isinstance(eval(str(d)), galsim.DistDeviate) - -# # Check that we can construct a DistDeviate from None, and that it depends on dev/random. -# c1 = galsim.DistDeviate(None, lambda x:1, 0, 1) -# c2 = galsim.DistDeviate(None, lambda x:1, 0, 1) -# assert c1 != c2, "Consecutive DistDeviate(None) compared equal!" -# # We shouldn't be able to construct a DistDeviate from anything but a BaseDeviate, int, str, -# # or None. -# assert_raises(TypeError, galsim.DistDeviate, dict(), lambda x:1, 0, 1) -# assert_raises(TypeError, galsim.DistDeviate, list(), lambda x:1, 0, 1) -# assert_raises(TypeError, galsim.DistDeviate, set(), lambda x:1, 0, 1) - - -# @timer -# def test_distLookupTable(): -# """Test distribution-defined random number generator with a LookupTable -# """ -# precision = 9 -# # Note: 256 used to be the default, so this is a regression test -# # We check below that it works with the default npoints=None -# d = galsim.DistDeviate(testseed, function=dLookupTable, npoints=256) -# d2 = d.duplicate() -# d3 = galsim.DistDeviate(d.serialize(), function=dLookupTable, npoints=256) -# np.testing.assert_equal( -# d.x_min, dLookupTable.x_min, -# err_msg='DistDeviate and the LookupTable passed to it have different lower bounds') -# np.testing.assert_equal( -# d.x_max, dLookupTable.x_max, -# err_msg='DistDeviate and the LookupTable passed to it have different upper bounds') - -# testResult = (d(), d(), d()) -# print('testResult = ',testResult) -# np.testing.assert_array_almost_equal( -# np.array(testResult), np.array(dLookupTableResult), precision, -# err_msg='Wrong DistDeviate random number sequence using LookupTable') -# testResult = (d2(), d2(), d2()) -# np.testing.assert_array_almost_equal( -# np.array(testResult), np.array(dLookupTableResult), precision, -# err_msg='Wrong DistDeviate random number sequence using LookupTable with duplicate') -# testResult = (d3(), d3(), d3()) -# np.testing.assert_array_almost_equal( -# np.array(testResult), np.array(dLookupTableResult), precision, -# err_msg='Wrong DistDeviate random number sequence using LookupTable from serialize') - -# # Check that the mean and variance come out right -# d = galsim.DistDeviate(testseed, function=dLookupTable, npoints=256) -# vals = [d() for i in range(nvals)] -# mean = np.mean(vals) -# var = np.var(vals) -# mu = 2. -# v = 7./3. -# print('mean = ',mean,' true mean = ',mu) -# print('var = ',var,' true var = ',v) -# np.testing.assert_almost_equal(mean, mu, 1, -# err_msg='Wrong mean from DistDeviate random numbers using LookupTable') -# np.testing.assert_almost_equal(var, v, 1, -# err_msg='Wrong variance from DistDeviate random numbers using LookupTable') - -# # Check discard -# d2 = galsim.DistDeviate(testseed, function=dLookupTable, npoints=256) -# d2.discard(nvals) -# v1,v2 = d(),d2() -# print('after %d vals, next one is %s, %s'%(nvals,v1,v2)) -# assert v1 == v2 -# assert d.has_reliable_discard -# assert not d.generates_in_pairs - -# # This should give the same values with only 5 points because of the particular nature -# # of these arrays. -# d = galsim.DistDeviate(testseed, function=dLookupTable, npoints=5) -# testResult = (d(), d(), d()) -# np.testing.assert_array_almost_equal( -# np.array(testResult), np.array(dLookupTableResult), precision, -# err_msg='Wrong DistDeviate random number sequence for LookupTable with 5 points') - -# # And it should also work if npoints is None -# d = galsim.DistDeviate(testseed, function=dLookupTable) -# testResult = (d(), d(), d()) -# assert len(dLookupTable.x) == len(d._inverse_cdf.x) -# np.testing.assert_array_almost_equal( -# np.array(testResult), np.array(dLookupTableResult), precision, -# err_msg='Wrong DistDeviate random number sequence for LookupTable with npoints=None') - -# # Also read these values from a file -# d = galsim.DistDeviate(testseed, function=dLookupTableFile, interpolant='linear') -# testResult = (d(), d(), d()) -# np.testing.assert_array_almost_equal( -# np.array(testResult), np.array(dLookupTableResult), precision, -# err_msg='Wrong DistDeviate random number sequence for LookupTable from file') - -# d = galsim.DistDeviate(testseed, function=dLookupTableFile) -# testResult = (d(), d(), d()) -# np.testing.assert_array_almost_equal( -# np.array(testResult), np.array(dLookupTableResult), precision, -# err_msg='Wrong DistDeviate random number sequence for LookupTable with default ' -# 'interpolant') - -# # Test generate -# d.seed(testseed) -# test_array = np.empty(3) -# d.generate(test_array) -# np.testing.assert_array_almost_equal( -# test_array, np.array(dLookupTableResult), precision, -# err_msg='Wrong DistDeviate random number sequence from generate.') - -# # Test filling an image -# d.seed(testseed) -# testimage = galsim.ImageD(np.zeros((3, 1))) -# testimage.addNoise(galsim.DeviateNoise(d)) -# np.testing.assert_array_almost_equal( -# testimage.array.flatten(), np.array(dLookupTableResult), precision, -# err_msg='Wrong DistDeviate random number sequence generated when applied to image.') - -# # Test generate -# d.seed(testseed) -# test_array = np.empty(3) -# d.generate(test_array) -# np.testing.assert_array_almost_equal( -# test_array, np.array(dLookupTableResult), precision, -# err_msg='Wrong DistDeviate random number sequence from generate.') - -# # Test generate with a float32 array -# d.seed(testseed) -# test_array = np.empty(3, dtype=np.float32) -# d.generate(test_array) -# np.testing.assert_array_almost_equal( -# test_array, np.array(dLookupTableResult), precisionF, -# err_msg='Wrong DistDeviate random number sequence from generate.') - -# # Test a case with nearly flat probabilities -# # x and p arrays with and without a small (epsilon) step -# dx_eps = np.arange(6) -# dp1_eps = np.zeros(dx_eps.shape) -# dp2_eps = np.zeros(dx_eps.shape) -# eps = np.finfo(dp1_eps[0].dtype).eps -# dp1_eps[0] = 0.5 -# dp2_eps[0] = 0.5 -# dp1_eps[-1] = 0.5 -# dp2_eps[-2] = eps -# dp2_eps[-1] = 0.5-eps -# dLookupTableEps1 = galsim.LookupTable(x=dx_eps, f=dp1_eps, interpolant='linear') -# dLookupTableEps2 = galsim.LookupTable(x=dx_eps, f=dp2_eps, interpolant='linear') -# d1 = galsim.DistDeviate(testseed, function=dLookupTableEps1, npoints=len(dx_eps)) -# d2 = galsim.DistDeviate(testseed, function=dLookupTableEps2, npoints=len(dx_eps)) -# # If these were successfully created everything is probably fine, but check they create the same -# # internal LookupTable -# np.testing.assert_array_almost_equal( -# d1._inverse_cdf.getArgs(), d2._inverse_cdf.getArgs(), precision, -# err_msg='DistDeviate with near-flat probabilities incorrectly created ' -# 'a monotonic version of the CDF') -# np.testing.assert_array_almost_equal( -# d1._inverse_cdf.getVals(), d2._inverse_cdf.getVals(), precision, -# err_msg='DistDeviate with near-flat probabilities incorrectly created ' -# 'a monotonic version of the CDF') - -# # And that they generate the same values -# ar1 = np.empty(100); d1.generate(ar1) -# ar2 = np.empty(100); d2.generate(ar2) -# np.testing.assert_array_almost_equal(ar1, ar2, precision, -# err_msg='Two DistDeviates with near-flat probabilities generated different values.') - -# # Check picklability -# check_pickle(d, lambda x: (x(), x(), x(), x())) -# check_pickle(d) -# assert 'DistDeviate' in repr(d) -# assert 'DistDeviate' in str(d) -# assert isinstance(eval(repr(d)), galsim.DistDeviate) -# assert isinstance(eval(str(d)), galsim.DistDeviate) - - -@timer -def test_multiprocess(): - """Test that the same random numbers are generated in single-process and multi-process modes. - """ - from multiprocessing import current_process - from multiprocessing import get_context - ctx = get_context('fork') - Process = ctx.Process - Queue = ctx.Queue - - def generate_list(seed): - """Given a particular seed value, generate a list of random numbers. - Should be deterministic given the input seed value. - """ - rng = galsim.UniformDeviate(seed) - out = [] - for i in range(20): - out.append(rng()) - return out - - def worker(input, output): - """input is a queue with seed values - output is a queue storing the results of the tasks along with the process name, - and which args the result is for. - """ - for args in iter(input.get, 'STOP'): - result = generate_list(*args) - output.put((result, current_process().name, args)) - - # Use sequential numbers. - # On inspection, can see that even the first value in each list is random with - # respect to the other lists. i.e. "nearby" inputs do not produce nearby outputs. - # I don't know of an actual assert to do for this, but it is clearly true. - seeds = [1532424 + i for i in range(16)] - - nproc = 4 # Each process will do 4 lists (typically) - - # First make lists in the single process: - ref_lists = dict() - for seed in seeds: - list = generate_list(seed) - ref_lists[seed] = list - - # Now do this with multiprocessing - # Put the seeds in a queue - task_queue = Queue() - for seed in seeds: - task_queue.put([seed]) - - # Run the tasks: - done_queue = Queue() - for k in range(nproc): - Process(target=worker, args=(task_queue, done_queue)).start() - - # Check the results in the order they finished - for i in range(len(seeds)): - _list, proc, args = done_queue.get() - seed = args[0] - np.testing.assert_array_equal( - _list, ref_lists[seed], - err_msg="Random numbers are different when using multiprocessing") - - # Stop the processes: - for k in range(nproc): - task_queue.put('STOP') - - -@timer -def test_permute(): - """Simple tests of the permute() function.""" - # Make a fake list, and another list consisting of indices. - my_list = [3.7, 4.1, 1.9, 11.1, 378.3, 100.0] - import copy - my_list_copy = copy.deepcopy(my_list) - n_list = len(my_list) - ind_list = list(range(n_list)) - - # Permute both at the same time. - galsim.random.permute(312, np.array(my_list), np.array(ind_list)) - - # Make sure that everything is sensible - for ind in range(n_list): - assert my_list_copy[ind_list[ind]] == my_list[ind] - - # Repeat with same seed, should do same permutation. - my_list = copy.deepcopy(my_list_copy) - galsim.random.permute(312, np.array(my_list)) - for ind in range(n_list): - assert my_list_copy[ind_list[ind]] == my_list[ind] - - # NOTE no errors raised in JAX - # # permute with no lists should raise TypeError - # with assert_raises(TypeError): - # galsim.random.permute(312) - - -@timer -def test_int64(): - # cf. #1009 - # Check that various possible integer types work as seeds. - - rng1 = galsim.BaseDeviate(int(123)) - # cf. https://www.numpy.org/devdocs/user/basics.types.html - ivalues = [ - np.int8(123), # Note this one requires i < 128 - np.int16(123), - np.int32(123), - np.int64(123), - np.uint8(123), - np.uint16(123), - np.uint32(123), - np.uint64(123), - np.short(123), - np.ushort(123), - np.intc(123), - np.uintc(123), - np.intp(123), - np.uintp(123), - np.int_(123), - np.longlong(123), - np.ulonglong(123), - np.array(123).astype(np.int64)] - - for i in ivalues: - rng2 = galsim.BaseDeviate(i) - assert rng2 == rng1 diff --git a/tests/jax/galsim/test_shear_jax.py b/tests/jax/galsim/test_shear_jax.py deleted file mode 100644 index 890baea7..00000000 --- a/tests/jax/galsim/test_shear_jax.py +++ /dev/null @@ -1,377 +0,0 @@ -from __future__ import print_function - -import os -import sys - -import galsim -import numpy as np -from galsim_test_helpers import * - -# Below are a set of tests to make sure that we have achieved consistency in defining shears and -# ellipses using different conventions. The underlying idea is that in other files we already -# have plenty of tests to ensure that a given Shear can be properly applied and gives the -# expected result. So here, we just work at the level of Shears that we've defined, -# and make sure that they have the properties we expect given the values that were used to -# initialize them. For that, we have some sets of fiducial shears/dilations/shifts for which -# calculations were done independently (e.g., to get what is eta given the value of g). We go over -# the various way to initialize the shears, and make sure that their different values are properly -# set. We also test the methods of the python Shear classes to make sure that they give -# the expected results. - -##### set up necessary info for tests -# a few shear values over which we will loop so we can check them all -# note: Rachel started with these q and beta, and then calculated all the other numbers in IDL using -# the standard formulae -q = [1.0, 0.5, 0.3, 0.1, 0.7, 0.9, 0.99, 1.0 - 8.75e-5] -n_shear = len(q) -beta = [ - 0.0, - 0.5 * np.pi, - 0.25 * np.pi, - 0.0 * np.pi, - np.pi / 3.0, - np.pi, - -0.25 * np.pi, - -0.5 * np.pi, -] -g = [ - 0.0, - 0.333333, - 0.538462, - 0.818182, - 0.176471, - 0.05263157897, - 0.005025125626, - 4.375191415e-5, -] -g1 = [ - 0.0, - -0.33333334, - 0.0, - 0.81818175, - -0.088235296, - 0.05263157897, - 0.0, - -4.375191415e-5, -] -g2 = [0.0, 0.0, 0.53846157, 0.0, 0.15282802, 0.0, -0.005025125626, 0.0] -e = [ - 0.0, - 0.600000, - 0.834862, - 0.980198, - 0.342282, - 0.1049723757, - 0.01004999747, - 8.750382812e-5, -] -e1 = [0.0, -0.6000000, 0.0, 0.98019803, -0.17114094, 0.1049723757, 0.0, -8.750382812e-5] -e2 = [0.0, 0.0, 0.83486235, 0.0, 0.29642480, 0.0, -0.01004999747, 0.0] -eta = [ - 0.0, - 0.693147, - 1.20397, - 2.30259, - 0.356675, - 0.1053605157, - 0.01005033585, - 8.750382835e-5, -] -eta1 = [ - 0.0, - -0.69314718, - 0.0, - 2.3025851, - -0.17833748, - 0.1053605157, - 0.0, - -8.750382835e-5, -] -eta2 = [0.0, 0.0, 1.2039728, 0.0, 0.30888958, 0.0, -0.01005033585, 0.0] -decimal = 5 - - -#### some helper functions -def all_shear_vals(test_shear, index, mult_val=1.0): - print("test_shear = ", repr(test_shear)) - # this function tests that all values of some Shear object are consistent with the tabulated - # values, given the appropriate index against which to test, and properly accounting for the - # fact that sometimes the angle is in the range [pi, 2*pi) - ### note: can only use mult_val = 1, 0, -1 - if mult_val != -1.0 and mult_val != 0.0 and mult_val != 1.0: - raise ValueError("Cannot put multiplier that is not -1, 0, or 1!") - beta_rad = test_shear.beta.rad - while beta_rad < 0.0: - beta_rad += np.pi - - test_beta = beta[index] - if mult_val < 0.0: - test_beta -= 0.5 * np.pi - while test_beta < 0.0: - test_beta += np.pi - # Special, if g == 0 exactly, beta is undefined, so just set it to zero. - if test_shear.g == 0.0: - test_beta = beta_rad = 0.0 - - vec = [ - test_shear.g, - test_shear.g1, - test_shear.g2, - test_shear.e, - test_shear.e1, - test_shear.e2, - test_shear.eta, - test_shear.eta1, - test_shear.eta2, - test_shear.esq, - test_shear.q, - beta_rad % np.pi, - ] - test_vec = [ - np.abs(mult_val) * g[index], - mult_val * g1[index], - mult_val * g2[index], - np.abs(mult_val) * e[index], - mult_val * e1[index], - mult_val * e2[index], - np.abs(mult_val) * eta[index], - mult_val * eta1[index], - mult_val * eta2[index], - mult_val * mult_val * e[index] * e[index], - q[index], - test_beta % np.pi, - ] - np.testing.assert_array_almost_equal( - vec, test_vec, decimal=decimal, err_msg="Incorrectly initialized Shear" - ) - if index == n_shear - 1: - # On the last one with values < 1.e-4, multiply everything by 1.e4 and check again. - vec = [1.0e4 * v for v in vec[:-2]] # don't include q or beta now. - test_vec = [1.0e4 * v for v in test_vec[:-2]] - np.testing.assert_array_almost_equal( - vec, test_vec, decimal=decimal, err_msg="Incorrectly initialized Shear" - ) - # Test that the utiltiy function g1g2_to_e1e2 is equivalent to the Shear calculation. - test_e1, test_e2 = galsim.utilities.g1g2_to_e1e2(test_shear.g1, test_shear.g2) - np.testing.assert_almost_equal( - test_e1, test_shear.e1, err_msg="Incorrect e1 calculation" - ) - np.testing.assert_almost_equal( - test_e2, test_shear.e2, err_msg="Incorrect e2 calculation" - ) - - -def add_distortions(d1, d2, d1app, d2app): - # add the distortions - denom = 1.0 + d1 * d1app + d2 * d2app - dapp_sq = d1app**2 + d2app**2 - if dapp_sq == 0: - return d1, d2 - else: - factor = (1.0 - np.sqrt(1.0 - dapp_sq)) * (d2 * d1app - d1 * d2app) / dapp_sq - d1tot = (d1 + d1app + d2app * factor) / denom - d2tot = (d2 + d2app - d1app * factor) / denom - return d1tot, d2tot - - -@timer -def test_shear_initialization(): - """Test that Shears can be initialized in a variety of ways and get the expected results.""" - # first make an empty Shear and make sure that it has zeros in the right places - s = galsim.Shear() - vec = [s.g, s.g1, s.g2, s.e, s.e1, s.e2, s.eta, s.eta1, s.eta2, s.esq] - vec_ideal = np.zeros(len(vec)) - np.testing.assert_array_almost_equal( - vec, vec_ideal, decimal=decimal, err_msg="Incorrectly initialized empty shear" - ) - # JAX specific modification - # ------------------------- - # The line below was: np.testing.assert_equal(s.q, 1.0) - # But because s.q is a jax object and not a float, it doesn't pass the test - np.testing.assert_array_almost_equal(s.q, 1.0) - # now loop over shear values and ways of initializing - for ind in range(n_shear): - # initialize with reduced shear components - s = galsim.Shear(g1=g1[ind], g2=g2[ind]) - all_shear_vals(s, ind) - if g1[ind] == 0.0: - s = galsim.Shear(g2=g2[ind]) - all_shear_vals(s, ind) - if g2[ind] == 0.0: - s = galsim.Shear(g1=g1[ind]) - all_shear_vals(s, ind) - # initialize with distortion components - s = galsim.Shear(e1=e1[ind], e2=e2[ind]) - all_shear_vals(s, ind) - if e1[ind] == 0.0: - s = galsim.Shear(e2=e2[ind]) - all_shear_vals(s, ind) - if e2[ind] == 0.0: - s = galsim.Shear(e1=e1[ind]) - all_shear_vals(s, ind) - # initialize with conformal shear components - s = galsim.Shear(eta1=eta1[ind], eta2=eta2[ind]) - all_shear_vals(s, ind) - if eta1[ind] == 0.0: - s = galsim.Shear(eta2=eta2[ind]) - all_shear_vals(s, ind) - if eta2[ind] == 0.0: - s = galsim.Shear(eta1=eta1[ind]) - all_shear_vals(s, ind) - # initialize with axis ratio and position angle - s = galsim.Shear(q=q[ind], beta=beta[ind] * galsim.radians) - all_shear_vals(s, ind) - # initialize with reduced shear and position angle - s = galsim.Shear(g=g[ind], beta=beta[ind] * galsim.radians) - all_shear_vals(s, ind) - # initialize with distortion and position angle - s = galsim.Shear(e=e[ind], beta=beta[ind] * galsim.radians) - all_shear_vals(s, ind) - # initialize with conformal shear and position angle - s = galsim.Shear(eta=eta[ind], beta=beta[ind] * galsim.radians) - all_shear_vals(s, ind) - # initialize with a complex number g1 + 1j * g2 - s = galsim.Shear(g1[ind] + 1j * g2[ind]) - all_shear_vals(s, ind) - s = galsim._Shear(g1[ind] + 1j * g2[ind]) - all_shear_vals(s, ind) - # which should also be the value of s.shear - s2 = galsim.Shear(s.shear) - all_shear_vals(s2, ind) - - # JAX specific modification - # ------------------------- - # We don't allow jax objects to be pickled. - # check_pickle(s) - - # finally check some examples of invalid initializations for Shear - assert_raises(TypeError, galsim.Shear, 0.3) - assert_raises(TypeError, galsim.Shear, 0.3, 0.3) - assert_raises(TypeError, galsim.Shear, g1=0.3, e2=0.2) - assert_raises(TypeError, galsim.Shear, eta1=0.3, beta=0.0 * galsim.degrees) - assert_raises(TypeError, galsim.Shear, q=0.3) - # JAX specific modification - # ------------------------- - # We do not perform RangeError checks in JAX to preserve jittability. - # assert_raises(galsim.GalSimRangeError, galsim.Shear, q=1.3, beta=0.0 * galsim.degrees) - # assert_raises(galsim.GalSimRangeError, galsim.Shear, g1=0.9, g2=0.6) - # assert_raises(galsim.GalSimRangeError, galsim.Shear, e=-1.3, beta=0.0 * galsim.radians) - # assert_raises(galsim.GalSimRangeError, galsim.Shear, e=1.3, beta=0.0 * galsim.radians) - # assert_raises(galsim.GalSimRangeError, galsim.Shear, e1=0.7, e2=0.9) - assert_raises(TypeError, galsim.Shear, g=0.5) - assert_raises(TypeError, galsim.Shear, e=0.5) - assert_raises(TypeError, galsim.Shear, eta=0.5) - # JAX specific modification - # ------------------------- - # We do not perform RangeError checks in JAX to preserve jittability. - # assert_raises(galsim.GalSimRangeError, galsim.Shear, eta=-0.5, beta=0.0 * galsim.radians) - # assert_raises(galsim.GalSimRangeError, galsim.Shear, g=1.3, beta=0.0 * galsim.radians) - # assert_raises(galsim.GalSimRangeError, galsim.Shear, g=-0.3, beta=0.0 * galsim.radians) - assert_raises(TypeError, galsim.Shear, e=0.3, beta=0.0) - assert_raises(TypeError, galsim.Shear, eta=0.3, beta=0.0) - assert_raises(TypeError, galsim.Shear, randomkwarg=0.1) - assert_raises(TypeError, galsim.Shear, g1=0.1, randomkwarg=0.1) - assert_raises(TypeError, galsim.Shear, g1=0.1, e1=0.1) - assert_raises(TypeError, galsim.Shear, g1=0.1, e=0.1) - assert_raises(TypeError, galsim.Shear, g1=0.1, g=0.1) - assert_raises(TypeError, galsim.Shear, beta=45.0 * galsim.degrees) - assert_raises(TypeError, galsim.Shear, beta=45.0 * galsim.degrees, g=0.3, eta=0.1) - assert_raises(TypeError, galsim.Shear, beta=45.0, g=0.3) - assert_raises(TypeError, galsim.Shear, q=0.1, beta=0.0) - - -@timer -def test_shear_methods(): - """Test that the most commonly-used methods of the Shear class give the expected results.""" - for ind in range(n_shear): - s = galsim.Shear(e1=e1[ind], e2=e2[ind]) - # check negation - s2 = -s - all_shear_vals(s2, ind, mult_val=-1.0) - # check addition - s2 = s + s - exp_e1, exp_e2 = add_distortions(s.e1, s.e2, s.e1, s.e2) - np.testing.assert_array_almost_equal( - [s2.e1, s2.e2], - [exp_e1, exp_e2], - decimal=decimal, - err_msg="Failed to properly add distortions", - ) - # check subtraction - s3 = s - s2 - exp_e1, exp_e2 = add_distortions(s.e1, s.e2, -1.0 * s2.e1, -1.0 * s2.e2) - np.testing.assert_array_almost_equal( - [s3.e1, s3.e2], - [exp_e1, exp_e2], - decimal=decimal, - err_msg="Failed to properly subtract distortions", - ) - # check += - savee1 = s.e1 - savee2 = s.e2 - s += s2 - exp_e1, exp_e2 = add_distortions(savee1, savee2, s2.e1, s2.e2) - np.testing.assert_array_almost_equal( - [s.e1, s.e2], - [exp_e1, exp_e2], - decimal=decimal, - err_msg="Failed to properly += distortions", - ) - # check -= - savee1 = s.e1 - savee2 = s.e2 - s -= s - exp_e1, exp_e2 = add_distortions(savee1, savee2, -1.0 * savee1, -1.0 * savee2) - np.testing.assert_array_almost_equal( - [s.e1, s.e2], - [exp_e1, exp_e2], - decimal=decimal, - err_msg="Failed to properly -= distortions", - ) - - # check == - s = galsim.Shear(g1=g1[ind], g2=g2[ind]) - s2 = galsim.Shear(g1=g1[ind], g2=g2[ind]) - np.testing.assert_equal(s == s2, True, err_msg="Failed to check for equality") - # check != - np.testing.assert_equal(s != s2, False, err_msg="Failed to check for equality") - - -@timer -def test_shear_matrix(): - """Test that the shear matrix is calculated correctly.""" - for ind in range(n_shear): - s1 = galsim.Shear(g1=g1[ind], g2=g2[ind]) - - true_m1 = np.array( - [[1.0 + g1[ind], g2[ind]], [g2[ind], 1.0 - g1[ind]]] - ) / np.sqrt(1.0 - g1[ind] ** 2 - g2[ind] ** 2) - m1 = s1.getMatrix() - - np.testing.assert_array_almost_equal( - m1, true_m1, decimal=12, err_msg="getMatrix returned wrong matrix" - ) - - for ind2 in range(n_shear): - s2 = galsim.Shear(g1=g1[ind2], g2=g2[ind2]) - m2 = s2.getMatrix() - - s3 = s1 + s2 - m3 = s3.getMatrix() - - theta = s1.rotationWith(s2) - r = np.array( - [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]] - ) - np.testing.assert_array_almost_equal( - m3.dot(r), - m1.dot(m2), - decimal=12, - err_msg="rotationWith returned wrong angle", - ) - - -if __name__ == "__main__": - testfns = [v for k, v in vars().items() if k[:5] == "test_" and callable(v)] - for testfn in testfns: - testfn() diff --git a/tests/jax/galsim/test_shear_position_jax.py b/tests/jax/galsim/test_shear_position_jax.py deleted file mode 100644 index 1ba07dab..00000000 --- a/tests/jax/galsim/test_shear_position_jax.py +++ /dev/null @@ -1,220 +0,0 @@ -import galsim -import numpy as np -from galsim_test_helpers import assert_raises, timer - - -@timer -def test_shear_position_image_integration_pixelwcs_jax(): - wcs = galsim.PixelScale(0.3) - obj1 = galsim.Gaussian(sigma=3) - obj2 = galsim.Gaussian(sigma=2) - pos2 = galsim.PositionD(3, 5) - sum = obj1 + obj2.shift(pos2) - shear = galsim.Shear(g1=0.1, g2=0.18) - im1 = galsim.Image(50, 50, wcs=wcs) - sum.shear(shear).drawImage(im1, center=im1.center) - - # Equivalent to shear each object separately and drawing at the sheared position. - im2 = galsim.Image(50, 50, wcs=wcs) - obj1.shear(shear).drawImage(im2, center=im2.center) - obj2.shear(shear).drawImage( - im2, - add_to_image=True, - center=im2.center + wcs.toImage(pos2.shear(shear)), - ) - - print("err:", np.max(np.abs(im1.array - im2.array))) - np.testing.assert_allclose(im1.array, im2.array, rtol=0, atol=5e-8) - - -@timer -def test_wrap_jax_simple_real(): - """Test the image.wrap() function.""" - # Start with a fairly simple test where the image is 4 copies of the same data: - im_orig = galsim.Image( - [ - [11.0, 12.0, 13.0, 14.0, 11.0, 12.0, 13.0, 14.0], - [21.0, 22.0, 23.0, 24.0, 21.0, 22.0, 23.0, 24.0], - [31.0, 32.0, 33.0, 34.0, 31.0, 32.0, 33.0, 34.0], - [41.0, 42.0, 43.0, 44.0, 41.0, 42.0, 43.0, 44.0], - [11.0, 12.0, 13.0, 14.0, 11.0, 12.0, 13.0, 14.0], - [21.0, 22.0, 23.0, 24.0, 21.0, 22.0, 23.0, 24.0], - [31.0, 32.0, 33.0, 34.0, 31.0, 32.0, 33.0, 34.0], - [41.0, 42.0, 43.0, 44.0, 41.0, 42.0, 43.0, 44.0], - ] - ) - im = im_orig.copy() - b = galsim.BoundsI(1, 4, 1, 4) - im_quad = im_orig[b] - im_wrap = im.wrap(b) - np.testing.assert_allclose(im_wrap.array, 4.0 * im_quad.array) - - # The same thing should work no matter where the lower left corner is: - for xmin, ymin in ((1, 5), (5, 1), (5, 5), (2, 3), (4, 1)): - b = galsim.BoundsI(xmin, xmin + 3, ymin, ymin + 3) - im_quad = im_orig[b] - im = im_orig.copy() - im_wrap = im.wrap(b) - np.testing.assert_allclose( - im_wrap.array, - 4.0 * im_quad.array, - err_msg="image.wrap(%s) did not match expectation" % b, - ) - np.testing.assert_allclose( - im_wrap.array, - im[b].array, - err_msg="image.wrap(%s) did not return the right subimage" % b, - ) - # this test passes even though we do not get a view - im[b].fill(0) - np.testing.assert_allclose( - im_wrap.array, - im[b].array, - err_msg="image.wrap(%s) did not return a view of the original" % b, - ) - - -@timer -def test_wrap_jax_weird_real(): - # Now test where the subimage is not a simple fraction of the original, and all the - # sizes are different. - im = galsim.ImageD(17, 23, xmin=0, ymin=0) - b = galsim.BoundsI(7, 9, 11, 18) - im_test = galsim.ImageD(b, init_value=0) - for i in range(17): - for j in range(23): - val = np.exp(i / 7.3) + (j / 12.9) ** 3 # Something randomly complicated... - im[i, j] = val - # Find the location in the sub-image for this point. - ii = (i - b.xmin) % (b.xmax - b.xmin + 1) + b.xmin - jj = (j - b.ymin) % (b.ymax - b.ymin + 1) + b.ymin - im_test.addValue(ii, jj, val) - im_wrap = im.wrap(b) - np.testing.assert_allclose( - im_wrap.array, - im_test.array, - err_msg="image.wrap(%s) did not match expectation" % b, - ) - np.testing.assert_array_equal( - im_wrap.array, im[b].array, "image.wrap(%s) did not return the right subimage" - ) - np.testing.assert_equal( - im_wrap.bounds, b, "image.wrap(%s) does not have the correct bounds" - ) - - -@timer -def test_wrap_jax_complex(): - # For complex images (in particular k-space images), we often want the image to be implicitly - # Hermitian, so we only need to keep around half of it. - M = 38 - N = 25 - K = 8 - L = 5 - im = galsim.ImageCD(2 * M + 1, 2 * N + 1, xmin=-M, ymin=-N) # Explicitly Hermitian - im2 = galsim.ImageCD( - 2 * M + 1, N + 1, xmin=-M, ymin=0 - ) # Implicitly Hermitian across y axis - im3 = galsim.ImageCD( - M + 1, 2 * N + 1, xmin=0, ymin=-N - ) # Implicitly Hermitian across x axis - # print('im = ',im) - # print('im2 = ',im2) - # print('im3 = ',im3) - b = galsim.BoundsI(-K + 1, K, -L + 1, L) - b2 = galsim.BoundsI(-K + 1, K, 0, L) - b3 = galsim.BoundsI(0, K, -L + 1, L) - im_test = galsim.ImageCD(b, init_value=0) - for i in range(-M, M + 1): - for j in range(-N, N + 1): - # An arbitrary, complicated Hermitian function. - val = ( - np.exp((i / (2.3 * M)) ** 2 + 1j * (2.8 * i - 1.3 * j)) - + ((2 + 3j * j) / (1.9 * N)) ** 3 - ) - # val = 2*(i-j)**2 + 3j*(i+j) - - im[i, j] = val - if j >= 0: - im2[i, j] = val - if i >= 0: - im3[i, j] = val - - ii = (i - b.xmin) % (b.xmax - b.xmin + 1) + b.xmin - jj = (j - b.ymin) % (b.ymax - b.ymin + 1) + b.ymin - im_test.addValue(ii, jj, val) - # print("im = ",im.array) - - # Confirm that the image is Hermitian. - for i in range(-M, M + 1): - for j in range(-N, N + 1): - assert im(i, j) == im(-i, -j).conjugate() - - im_wrap = im.wrap(b) - # print("im_wrap = ",im_wrap.array) - np.testing.assert_allclose( - im_wrap.array, - im_test.array, - err_msg="image.wrap(%s) did not match expectation" % b, - ) - np.testing.assert_array_equal( - im_wrap.array, im[b].array, "image.wrap(%s) did not return the right subimage" - ) - np.testing.assert_equal( - im_wrap.bounds, b, "image.wrap(%s) does not have the correct bounds" - ) - - # FIXME: turn on when wrapping works for hermitian images - if False: - im2_wrap = im2.wrap(b2, hermitian="y") - # print('im_test = ',im_test[b2].array) - # print('im2_wrap = ',im2_wrap.array) - # print('diff = ',im2_wrap.array-im_test[b2].array) - np.testing.assert_allclose( - im2_wrap.array, - im_test[b2].array, - err_msg="image.wrap(%s) did not match expectation" % b, - ) - np.testing.assert_array_equal( - im2_wrap.array, - im2[b2].array, - "image.wrap(%s) did not return the right subimage", - ) - np.testing.assert_equal( - im2_wrap.bounds, b2, "image.wrap(%s) does not have the correct bounds" - ) - - im3_wrap = im3.wrap(b3, hermitian="x") - # print('im_test = ',im_test[b3].array) - # print('im3_wrap = ',im3_wrap.array) - # print('diff = ',im3_wrap.array-im_test[b3].array) - np.testing.assert_allclose( - im3_wrap.array, - im_test[b3].array, - err_msg="image.wrap(%s) did not match expectation" % b, - ) - np.testing.assert_array_equal( - im3_wrap.array, - im3[b3].array, - "image.wrap(%s) did not return the right subimage", - ) - np.testing.assert_equal( - im3_wrap.bounds, b3, "image.wrap(%s) does not have the correct bounds" - ) - - b = galsim.BoundsI(-K + 1, K, -L + 1, L) - b2 = galsim.BoundsI(-K + 1, K, 0, L) - b3 = galsim.BoundsI(0, K, -L + 1, L) - assert_raises(TypeError, im.wrap, bounds=None) - assert_raises(ValueError, im.wrap, b2, hermitian="y") - assert_raises(ValueError, im.wrap, b, hermitian="invalid") - assert_raises(ValueError, im.wrap, b3, hermitian="x") - - # FIXME: turn on when wrapping works for hermitian images - if False: - assert_raises(ValueError, im3.wrap, b, hermitian="x") - assert_raises(ValueError, im3.wrap, b2, hermitian="x") - assert_raises(ValueError, im2.wrap, b, hermitian="y") - assert_raises(ValueError, im2.wrap, b3, hermitian="y") - assert_raises(ValueError, im2.wrap, b2, hermitian="invalid") - assert_raises(ValueError, im3.wrap, b3, hermitian="invalid") diff --git a/tests/jax/galsim/test_wcs_jax.py b/tests/jax/galsim/test_wcs_jax.py deleted file mode 100644 index 0093ad61..00000000 --- a/tests/jax/galsim/test_wcs_jax.py +++ /dev/null @@ -1,4090 +0,0 @@ -from __future__ import print_function - -import os -import sys -import time -import warnings - -import numpy as np -from galsim_test_helpers import ( - Profile, - assert_raises, - assert_warns, - check_pickle, - gsobject_compare, - timer, -) - -import jax_galsim as galsim - -# These positions will be used a few times below, so define them here. -# One of the tests requires that the last pair are integers, so don't change that. -near_x_list = [0, 0.242, -1.342, -5] -near_y_list = [0, -0.173, 2.003, -7] - -far_x_list = [10, -31.7, -183.6, -700] -far_y_list = [10, 12.5, 103.3, 500] - -# Make a few different profiles to check. Make sure to include ones that -# aren't symmetrical so we don't get fooled by symmetries. -prof1 = galsim.Gaussian(sigma=1.7, flux=100) -prof2 = prof1.shear(g1=0.3, g2=-0.12) -prof3 = prof2 + galsim.Exponential(scale_radius=1.3, flux=20).shift(-0.1, -0.4) -profiles = [prof1, prof2, prof3] - -if __name__ != "__main__": - # Some of the classes we test here are not terribly fast. WcsToolsWCS in particular. - # So reduce the number of tests. Keep the hardest ones, since the easier ones are mostly - # useful as diagnostics when there are problems. So they will get run when doing - # python test_wcs.py. But not during a pytest run. - near_x_list = near_x_list[-2:] - near_y_list = near_y_list[-2:] - far_x_list = far_x_list[-2:] - far_y_list = far_y_list[-2:] - profiles = [prof3] - -all_x_list = near_x_list + far_x_list -all_y_list = near_y_list + far_y_list - -# How many digits of accuracy should we demand? -# We test everything in units of arcsec, so this corresponds to 1.e-3 arcsec. -# 1 mas should be plenty accurate for our purposes. (And note that most classes do much -# better than this. Just a few things that require iterative solutions for the world->image -# transformation or things that output to a fits file at slightly less than full precision -# do worse than 6 digits.) -digits = 3 - -# The HPX, TAN, TSC, STG, ZEA, and ZPN files were downloaded from the web site: -# -# http://www.atnf.csiro.au/people/mcalabre/WCS/example_data.html -# -# From that page: "The example maps and spectra are offered mainly as test material for -# software that deals with FITS WCS." -# -# I picked the ones that GSFitsWCS can do plus a couple others that struck me as interstingly -# different, but there are a bunch more on that web page as well. In particular, I included ZPN, -# since that uses PV values, which the others don't, and HPX, since it is not implemented in -# wcstools. -# -# The SIP, TPV, ZPX, REGION, and TNX are either new or "non-standard" that are not implemented -# by (at least older versions of) wcslib. They were downloaded from the web site: -# -# http://fits.gsfc.nasa.gov/fits_registry.html -# -# For each file, I use ds9 to pick out two reference points. I generally try to pick two -# points on opposite sides of the image so any non-linearities in the WCS are maximized. -# For most of them, I then use wcstools to get the ra and dec to 6 digits of accuracy. -# (Unfortunately, ds9's ra and dec information is only accurate to 2 or 3 digits.) -# The exception is HPX, for which I used the PyAst library to compute accurate values, -# since wcstools can't understand it. - -references = { - # Note: the four 1904-66 files use the brightest pixels in the same two stars. - # The ra, dec are thus essentially the same (modulo the large pixel size of 3 arcmin). - # However, the image positions are quite different. - "HPX": ( - "1904-66_HPX.fits", - [ - ("19:39:16.551671", "-63:42:47.346862", 114, 180, 13.59960), - ("18:19:35.761589", "-63:46:08.860203", 144, 30, 11.49591), - ], - ), - "TAN": ( - "1904-66_TAN.fits", - [ - ("19:39:30.753119", "-63:42:59.217527", 117, 178, 13.43628), - ("18:19:18.652839", "-63:49:03.833411", 153, 35, 11.44438), - ], - ), - "TSC": ( - "1904-66_TSC.fits", - [ - ("19:39:39.996553", "-63:41:14.585586", 113, 161, 12.48409), - ("18:19:05.985494", "-63:49:05.781036", 141, 48, 11.65945), - ], - ), - "STG": ( - "1904-66_STG.fits", - [ - ("19:39:14.752140", "-63:44:20.882465", 112, 172, 13.1618), - ("18:19:37.824461", "-63:46:24.483497", 147, 38, 11.6091), - ], - ), - "ZEA": ( - "1904-66_ZEA.fits", - [ - ("19:39:26.871566", "-63:43:26.059526", 110, 170, 13.253), - ("18:19:34.480902", "-63:46:40.038427", 144, 39, 11.62), - ], - ), - "ARC": ( - "1904-66_ARC.fits", - [ - ("19:39:28.622018", "-63:41:53.658982", 111, 171, 13.7654), - ("18:19:47.020701", "-63:46:22.381334", 145, 39, 11.2099), - ], - ), - "ZPN": ( - "1904-66_ZPN.fits", - [ - ("19:39:24.948254", "-63:46:43.636138", 95, 151, 12.84769), - ("18:19:24.149409", "-63:49:37.453404", 122, 48, 11.01434), - ], - ), - "SIP": ( - "sipsample.fits", - [ - ("13:30:01.474154", "47:12:51.794474", 242, 75, 12.24437), - ("13:29:43.747626", "47:09:13.879660", 12, 106, 5.30282), - ], - ), - "TPV": ( - "tpv.fits", - [ - ("03:30:09.340034", "-28:43:50.811107", 418, 78, 2859.53882), - ("03:30:15.728999", "-28:45:01.488629", 148, 393, 2957.98584), - ], - ), - # Strangely, zpx.fits is the same image as tpv.fits, but the WCS-computed RA, Dec - # values are not anywhere close to TELRA, TELDEC in the header. It's a bit - # unfortunate, since my understanding is that ZPX can encode the same function as - # TPV, so they could have produced the equivalent function. But instead they just - # inserted some totally off-the-wall different WCS transformation. - "ZPX": ( - "zpx.fits", - [ - ("21:24:12.094326", "37:10:34.575917", 418, 78, 2859.53882), - ("21:24:05.350816", "37:11:44.596579", 148, 393, 2957.98584), - ], - ), - # Older versions of the new TPV standard just used the TAN wcs name and expected - # the code to notice the PV values and use them correctly. This did not become a - # FITS standard (or even a registered non-standard), but some old FITS files use - # this, so we want to support it. I just edited the tpv.fits to change the - # CTYPE values from TPV to TAN. - "TAN-PV": ( - "tanpv.fits", - [ - ("03:30:09.340034", "-28:43:50.811107", 418, 78, 2859.53882), - ("03:30:15.728999", "-28:45:01.488629", 148, 393, 2957.98584), - ], - ), - # It is apparently valid FITS format to have Dec as the first axis and RA as the second. - # This is in fact the output of PyAst when writing the file tpv.fits in FITS encoding. - # It seems worth testing that all the WCS types get this input correct. - "TAN-FLIP": ( - "tanflip.fits", - [ - ("03:30:09.262392", "-28:43:48.697347", 418, 78, 2859.53882), - ("03:30:15.718834", "-28:44:59.073468", 148, 393, 2957.98584), - ], - ), - "REGION": ( - "region.fits", - [ - ("14:02:11.202432", "54:30:07.702200", 80, 80, 2241), - ("14:04:17.341523", "54:16:28.554326", 45, 54, 1227), - ], - ), - # Strangely, ds9 seems to get this one wrong. It differs by about 6 arcsec in dec. - # But PyAst and wcstools agree on these values, so I'm taking them to be accurate. - "TNX": ( - "tnx.fits", - [ - ("17:46:53.214511", "-30:08:47.895372", 32, 91, 7140), - ("17:46:58.100741", "-30:07:50.121787", 246, 326, 15022), - ], - ), - # Zwicky Transient Facility uses TPV with higher order polynomial than what is in tpv.fits. - "ZTF": ( - "ztf_20180525484722_000600_zg_c05_o_q1_sciimg.fits", - [ - ("0:18:42.475844", "+25:45:00.858971", 1040, 1029, 59344.3), - ("0:17:03.568150", "+25:32:40.484235", 2340, 1796, 2171.35), - ], - ), -} -all_tags = references.keys() - - -def do_wcs_pos(wcs, ufunc, vfunc, name, x0=0, y0=0, color=None): - # I would call this do_wcs_pos_tests, but pytest takes any function with test - # _anywhere_ in the name an tries to run it. So make sure the name doesn't - # have 'test' in it. There are a bunch of other do* functions that work similarly. - - # Check that (x,y) -> (u,v) and converse work correctly - if "local" in name or "jacobian" in name or "affine" in name: - # If the "local" is really a non-local WCS which has been localized, then we cannot - # count on the far positions to be sufficiently accurate. Just use near positions. - x_list = near_x_list - y_list = near_y_list - # And even then, it sometimes fails at our normal 3 digits because of the 2nd derivative - # coming into play. - digits2 = 1 - else: - x_list = all_x_list - y_list = all_y_list - digits2 = digits - u_list = [ufunc(x + x0, y + y0) for x, y in zip(x_list, y_list)] - v_list = [vfunc(x + x0, y + y0) for x, y in zip(x_list, y_list)] - - for x, y, u, v in zip(x_list, y_list, u_list, v_list): - image_pos = galsim.PositionD(x + x0, y + y0) - world_pos = galsim.PositionD(u, v) - world_pos2 = wcs.toWorld(image_pos, color=color) - world_pos3 = wcs.posToWorld(image_pos, color=color) - np.testing.assert_almost_equal( - world_pos.x, - world_pos2.x, - digits2, - "wcs.toWorld returned wrong world position for " + name, - ) - np.testing.assert_almost_equal( - world_pos.y, - world_pos2.y, - digits2, - "wcs.toWorld returned wrong world position for " + name, - ) - np.testing.assert_almost_equal( - world_pos.x, - world_pos3.x, - digits2, - "wcs.posToWorld returned wrong world position for " + name, - ) - np.testing.assert_almost_equal( - world_pos.y, - world_pos3.y, - digits2, - "wcs.posToWorld returned wrong world position for " + name, - ) - - u1, v1 = wcs.toWorld(x + x0, y + y0, color=color) - u2, v2 = wcs.xyTouv(x + x0, y + y0, color=color) - np.testing.assert_almost_equal( - u1, u, digits2, "wcs.toWorld(x,y) returned wrong u position for " + name - ) - np.testing.assert_almost_equal( - v1, v, digits2, "wcs.toWorld(x,y) returned wrong v position for " + name - ) - np.testing.assert_almost_equal( - u2, u, digits2, "wcs.xyTouv(x,y) returned wrong u position for " + name - ) - np.testing.assert_almost_equal( - v2, v, digits2, "wcs.xyTouv(x,y) returned wrong v position for " + name - ) - - scale = wcs.maxLinearScale(image_pos, color=color) - try: - # The reverse transformation is not guaranteed to be implemented, - # so guard against NotImplementedError being raised: - image_pos2 = wcs.toImage(world_pos, color=color) - image_pos3 = wcs.posToImage(world_pos, color=color) - test_reverse = True - except NotImplementedError: - assert_raises( - NotImplementedError, wcs._x, world_pos.x, world_pos.y, color=color - ) - assert_raises( - NotImplementedError, wcs._y, world_pos.x, world_pos.y, color=color - ) - test_reverse = False - else: - np.testing.assert_almost_equal( - image_pos.x * scale, - image_pos2.x * scale, - digits2, - "wcs.toImage returned wrong image position for " + name, - ) - np.testing.assert_almost_equal( - image_pos.y * scale, - image_pos2.y * scale, - digits2, - "wcs.toImage returned wrong image position for " + name, - ) - np.testing.assert_almost_equal( - image_pos.x * scale, - image_pos3.x * scale, - digits2, - "wcs.posToImage returned wrong image position for " + name, - ) - np.testing.assert_almost_equal( - image_pos.y * scale, - image_pos3.y * scale, - digits2, - "wcs.posToImage returned wrong image position for " + name, - ) - - x1, y1 = wcs.toImage(u, v, color=color) - x2, y2 = wcs.uvToxy(u, v, color=color) - np.testing.assert_almost_equal( - x1, - x + x0, - digits2, - "wcs.toImage(u,v) returned wrong x position for " + name, - ) - np.testing.assert_almost_equal( - y1, - y + y0, - digits2, - "wcs.toImage(u,v) returned wrong y position for " + name, - ) - np.testing.assert_almost_equal( - x2, - x + x0, - digits2, - "wcs.uvToxy(u,v) returned wrong x position for " + name, - ) - np.testing.assert_almost_equal( - y2, - y + y0, - digits2, - "wcs.uvToxy(u,v) returned wrong y position for " + name, - ) - - # Test xyTouv with arrays - u3, v3 = wcs.toWorld(np.array(x_list) + x0, np.array(y_list) + y0, color=color) - u4, v4 = wcs.xyTouv(np.array(x_list) + x0, np.array(y_list) + y0, color=color) - np.testing.assert_almost_equal( - u3, - u_list, - digits2, - "wcs.toWorld(x,y) with arrays returned wrong u positions for " + name, - ) - np.testing.assert_almost_equal( - v3, - v_list, - digits2, - "wcs.toWorld(x,y) with arrays returned wrong v positions for " + name, - ) - np.testing.assert_almost_equal( - u4, - u_list, - digits2, - "wcs.xyTouv(x,y) with arrays returned wrong u positions for " + name, - ) - np.testing.assert_almost_equal( - v4, - v_list, - digits2, - "wcs.xyTouv(x,y) with arrays returned wrong v positions for " + name, - ) - - if test_reverse: - # Test uvToxy with arrays - x3, y3 = wcs.toImage(np.array(u_list), np.array(v_list), color=color) - x4, y4 = wcs.uvToxy(np.array(u_list), np.array(v_list), color=color) - np.testing.assert_almost_equal( - x3 - x0, - x_list, - digits2, - "wcs.toImage(u,v) with arrays returned wrong x positions for " + name, - ) - np.testing.assert_almost_equal( - y3 - y0, - y_list, - digits2, - "wcs.toImage(u,v) with arrays returned wrong y positions for " + name, - ) - np.testing.assert_almost_equal( - x4 - x0, - x_list, - digits2, - "wcs.uvToxy(u,v) with arrays returned wrong x positions for " + name, - ) - np.testing.assert_almost_equal( - y4 - y0, - y_list, - digits2, - "wcs.uvToxy(u,v) with arrays returned wrong y positions for " + name, - ) - - if x0 == 0 and y0 == 0: - # The last item in list should also work as a PositionI - image_pos = galsim.PositionI(x, y) - np.testing.assert_almost_equal( - world_pos.x, - wcs.toWorld(image_pos, color=color).x, - digits2, - "wcs.toWorld gave different value with PositionI image_pos for " + name, - ) - np.testing.assert_almost_equal( - world_pos.y, - wcs.toWorld(image_pos, color=color).y, - digits2, - "wcs.toWorld gave different value with PositionI image_pos for " + name, - ) - - # Note that this function is only called for EuclideanWCS types. There is a different - # set of tests relevant for CelestialWCS done in do_celestial_wcs(). - assert not wcs.isCelestial() - assert_raises(TypeError, wcs.toWorld) - assert_raises(TypeError, wcs.toWorld, (3, 4)) - assert_raises(TypeError, wcs.toWorld, 3, 4, units=galsim.degrees) - assert_raises(TypeError, wcs.toWorld, 3, 4, 5) - assert_raises( - TypeError, - wcs.toWorld, - galsim.CelestialCoord(0 * galsim.degrees, 0 * galsim.degrees), - ) - - assert_raises(TypeError, wcs.posToWorld) - assert_raises(TypeError, wcs.posToWorld, (3, 4)) - assert_raises(TypeError, wcs.posToWorld, 3, 4) - assert_raises(TypeError, wcs.posToWorld, 3, 4, 5) - assert_raises( - TypeError, - wcs.posToWorld, - galsim.CelestialCoord(0 * galsim.degrees, 0 * galsim.degrees), - ) - - assert_raises(TypeError, wcs.xyTouv) - assert_raises(TypeError, wcs.xyTouv, 3) - assert_raises(TypeError, wcs.xyTouv, 3, 4, units=galsim.degrees) - assert_raises(TypeError, wcs.xyTouv, galsim.PositionD(3, 4)) - - assert_raises(TypeError, wcs.toImage) - assert_raises(TypeError, wcs.toImage, (3, 4)) - assert_raises(TypeError, wcs.toImage, 3, 4, units=galsim.degrees) - assert_raises( - TypeError, - wcs.toImage, - galsim.CelestialCoord(0 * galsim.degrees, 0 * galsim.degrees), - ) - assert_raises(TypeError, wcs.toImage, 3, 4, 5) - - assert_raises(TypeError, wcs.posToImage) - assert_raises(TypeError, wcs.posToImage, (3, 4)) - assert_raises( - TypeError, - wcs.posToImage, - galsim.CelestialCoord(0 * galsim.degrees, 0 * galsim.degrees), - ) - assert_raises(TypeError, wcs.posToImage, 3, 4) - assert_raises(TypeError, wcs.posToImage, 3, 4, 5) - - assert_raises(TypeError, wcs.uvToxy, 3, 4, units=galsim.degrees) - assert_raises(TypeError, wcs.uvToxy) - assert_raises(TypeError, wcs.uvToxy, 3) - assert_raises(TypeError, wcs.uvToxy, world_pos) - - -def check_world(pos1, pos2, digits, err_msg): - if isinstance(pos1, galsim.CelestialCoord): - np.testing.assert_almost_equal( - pos1.distanceTo(pos2) / galsim.arcsec, 0, digits, err_msg - ) - else: - np.testing.assert_almost_equal(pos1.x, pos2.x, digits, err_msg) - np.testing.assert_almost_equal(pos1.y, pos2.y, digits, err_msg) - - -def do_wcs_image(wcs, name, approx=False): - print("Start image tests for WCS " + name) - - # Use the "blank" image as our test image. It's not blank in the sense of having all - # zeros. Rather, there are basically random values that we can use to test that - # the shifted values are correct. And it is a conveniently small-ish, non-square image. - dir = os.path.join( - os.path.dirname(__file__), "..", "..", "GalSim", "tests", "fits_files" - ) - file_name = "blankimg.fits" - im = galsim.fits.read(file_name, dir=dir) - np.testing.assert_equal(im.origin.x, 1, "initial origin is not 1,1 as expected") - np.testing.assert_equal(im.origin.y, 1, "initial origin is not 1,1 as expected") - im.wcs = wcs - world1 = im.wcs.toWorld(im.origin) - value1 = im(im.origin) - world2 = im.wcs.toWorld(im.center) - value2 = im(im.center) - offset = galsim.PositionI(11, 13) - image_pos = im.origin + offset - world3 = im.wcs.toWorld(image_pos) - value3 = im(image_pos) - - # Test writing the image to a fits file and reading it back in. - # The new image doesn't have to have the same wcs type. But it does have to produce - # consistent values of the world coordinates. - test_name = "test_" + name + ".fits" - im.write(test_name, dir=dir) - im2 = galsim.fits.read(test_name, dir=dir) - if approx: - # Sometimes, the round trip doesn't preserve accuracy completely. - # In these cases, only test the positions after write/read to 1 digit. - digits2 = 1 - else: - digits2 = digits - np.testing.assert_equal( - im2.origin.x, im.origin.x, "origin changed after write/read" - ) - np.testing.assert_equal( - im2.origin.y, im.origin.y, "origin changed after write/read" - ) - check_world( - im2.wcs.toWorld(im.origin), - world1, - digits2, - "World position of origin is wrong after write/read.", - ) - np.testing.assert_almost_equal( - im2(im.origin), - value1, - digits, - "Image value at origin is wrong after write/read.", - ) - check_world( - im2.wcs.toWorld(im.center), - world2, - digits2, - "World position of center is wrong after write/read.", - ) - np.testing.assert_almost_equal( - im2(im.center), - value2, - digits, - "Image value at center is wrong after write/read.", - ) - check_world( - im2.wcs.toWorld(image_pos), - world3, - digits2, - "World position of image_pos is wrong after write/read.", - ) - np.testing.assert_almost_equal( - im2(image_pos), - value3, - digits, - "Image value at center is wrong after write/read.", - ) - - if wcs.isUniform(): - # Test that the regular CD, CRPIX, CRVAL items that are written to the header - # describe an equivalent WCS as this one. - affine = galsim.FitsWCS(test_name, dir=dir) - check_world( - affine.toWorld(im.origin), - world1, - digits2, - "World position of origin is wrong after write/read.", - ) - check_world( - affine.toWorld(im.center), - world2, - digits2, - "World position of center is wrong after write/read.", - ) - check_world( - affine.toWorld(image_pos), - world3, - digits2, - "World position of image_pos is wrong after write/read.", - ) - - # Test that im.shift does the right thing to the wcs - # Also test parsing a position as x,y args. - dx = 3 - dy = 9 - im.shift(dx, dy) - image_pos = im.origin + offset - np.testing.assert_equal(im.origin.x, 1 + dx, "shift set origin to wrong value") - np.testing.assert_equal(im.origin.y, 1 + dy, "shift set origin to wrong value") - check_world( - im.wcs.toWorld(im.origin), - world1, - digits, - "World position of origin after shift is wrong.", - ) - np.testing.assert_almost_equal( - im(im.origin), value1, digits, "Image value at origin after shift is wrong." - ) - check_world( - im.wcs.toWorld(im.center), - world2, - digits, - "World position of center after shift is wrong.", - ) - np.testing.assert_almost_equal( - im(im.center), value2, digits, "Image value at center after shift is wrong." - ) - check_world( - im.wcs.toWorld(image_pos), - world3, - digits, - "World position of image_pos after shift is wrong.", - ) - np.testing.assert_almost_equal( - im(image_pos), value3, digits, "image value at center after shift is wrong." - ) - - # Test that im.setOrigin does the right thing to the wcs - # Also test parsing a position as a tuple. - new_origin = (-3432, 1907) - im.setOrigin(new_origin) - image_pos = im.origin + offset - np.testing.assert_equal( - im.origin.x, new_origin[0], "setOrigin set origin to wrong value" - ) - np.testing.assert_equal( - im.origin.y, new_origin[1], "setOrigin set origin to wrong value" - ) - check_world( - im.wcs.toWorld(im.origin), - world1, - digits, - "World position of origin after setOrigin is wrong.", - ) - np.testing.assert_almost_equal( - im(im.origin), value1, digits, "Image value at origin after setOrigin is wrong." - ) - check_world( - im.wcs.toWorld(im.center), - world2, - digits, - "World position of center after setOrigin is wrong.", - ) - np.testing.assert_almost_equal( - im(im.center), value2, digits, "Image value at center after setOrigin is wrong." - ) - check_world( - im.wcs.toWorld(image_pos), - world3, - digits, - "World position of image_pos after setOrigin is wrong.", - ) - np.testing.assert_almost_equal( - im(image_pos), value3, digits, "Image value at center after setOrigin is wrong." - ) - - # Test that im.setCenter does the right thing to the wcs. - # Also test parsing a position as a PositionI object. - new_center = galsim.PositionI(0, 0) - im.setCenter(new_center) - image_pos = im.origin + offset - np.testing.assert_equal( - im.center.x, new_center.x, "setCenter set center to wrong value" - ) - np.testing.assert_equal( - im.center.y, new_center.y, "setCenter set center to wrong value" - ) - check_world( - im.wcs.toWorld(im.origin), - world1, - digits, - "World position of origin after setCenter is wrong.", - ) - np.testing.assert_almost_equal( - im(im.origin), value1, digits, "Image value at origin after setCenter is wrong." - ) - check_world( - im.wcs.toWorld(im.center), - world2, - digits, - "World position of center after setCenter is wrong.", - ) - np.testing.assert_almost_equal( - im(im.center), value2, digits, "Image value at center after setCenter is wrong." - ) - check_world( - im.wcs.toWorld(image_pos), - world3, - digits, - "World position of image_pos after setCenter is wrong.", - ) - np.testing.assert_almost_equal( - im(image_pos), value3, digits, "Image value at center after setCenter is wrong." - ) - - # Test makeSkyImage - if __name__ != "__main__": - # Use a smaller image to speed things up. - im = im[galsim.BoundsI(im.xmin, im.xmin + 5, im.ymin, im.ymin + 5)] - new_origin = (-134, 128) - im.setOrigin(new_origin) - sky_level = 177 - wcs.makeSkyImage(im, sky_level) - for x, y in [ - (im.bounds.xmin, im.bounds.ymin), - (im.bounds.xmax, im.bounds.ymin), - (im.bounds.xmin, im.bounds.ymax), - (im.bounds.xmax, im.bounds.ymax), - (im.center.x, im.center.y), - ]: - val = im(x, y) - area = wcs.pixelArea(galsim.PositionD(x, y)) - np.testing.assert_almost_equal( - val / (area * sky_level), 1.0, digits, "SkyImage at %d,%d is wrong" % (x, y) - ) - # Check that all values are near the same value. In particular if ra crosses 0, then this - # used to be a problem if the wcs returned ra values that jump from 360 to 0. - # Not very stringent test, since we're just checking that we don't have some pixels - # that are orders of magnitude different from the average. So rtol=2 is good. - # (1 is fine for the nosetests runs, but one of the edge cases fails in main runs. It's not - # a problem I think, just one of the weird wcs types has a more variable pixel area.) - np.testing.assert_allclose(im.array, area * sky_level, rtol=2) - - -def do_local_wcs(wcs, ufunc, vfunc, name): - print("Start testing local WCS " + name) - - # Check that local and withOrigin work correctly: - wcs2 = wcs.local() - assert wcs == wcs2, name + " local() is not == the original" - new_origin = galsim.PositionI(123, 321) - wcs3 = wcs.withOrigin(new_origin) - assert wcs != wcs3, name + " is not != wcs.withOrigin(pos)" - assert wcs3 != wcs, name + " is not != wcs.withOrigin(pos) (reverse)" - wcs2 = wcs3.local() - assert wcs == wcs2, name + " is not equal after wcs.withOrigin(pos).local()" - world_pos1 = wcs.toWorld(galsim.PositionD(0, 0)) - world_pos2 = wcs3.toWorld(new_origin) - np.testing.assert_almost_equal( - world_pos2.x, - world_pos1.x, - digits, - "withOrigin(new_origin) returned wrong world position", - ) - np.testing.assert_almost_equal( - world_pos2.y, - world_pos1.y, - digits, - "withOrigin(new_origin) returned wrong world position", - ) - new_world_origin = galsim.PositionD(5352.7, 9234.3) - wcs4 = wcs.withOrigin(new_origin, new_world_origin) - world_pos3 = wcs4.toWorld(new_origin) - np.testing.assert_almost_equal( - world_pos3.x, - new_world_origin.x, - digits, - "withOrigin(new_origin, new_world_origin) returned wrong position", - ) - np.testing.assert_almost_equal( - world_pos3.y, - new_world_origin.y, - digits, - "withOrigin(new_origin, new_world_origin) returned wrong position", - ) - wcs5 = wcs.shiftOrigin(new_origin, new_world_origin) - assert wcs4 == wcs5 # For LocalWCS, shiftOrigin is equivalent to withOrigin - - # Check inverse: - image_pos = wcs.inverse().toWorld(world_pos1) - np.testing.assert_almost_equal( - image_pos.x, - 0, - digits, - "wcs.inverse().toWorld(world_pos) returned wrong image position", - ) - np.testing.assert_almost_equal( - image_pos.y, - 0, - digits, - "wcs.inverse().toWorld(world_pos) returned wrong image position", - ) - image_pos = wcs4.toImage(new_world_origin) - np.testing.assert_almost_equal( - image_pos.x, - new_origin.x, - digits, - "wcs4.toImage(new_world_origin) returned wrong image position", - ) - np.testing.assert_almost_equal( - image_pos.y, - new_origin.y, - digits, - "wcs4.toImage(new_world_origin) returned wrong image position", - ) - image_pos = wcs4.inverse().toWorld(new_world_origin) - np.testing.assert_almost_equal( - image_pos.x, - new_origin.x, - digits, - "wcs4.inverse().toWorld(new_world_origin) returned wrong image position", - ) - np.testing.assert_almost_equal( - image_pos.y, - new_origin.y, - digits, - "wcs4.inverse().toWorld(new_world_origin) returned wrong image position", - ) - - # Check that (x,y) -> (u,v) and converse work correctly - do_wcs_pos(wcs, ufunc, vfunc, name) - - # Check picklability - check_pickle(wcs) - - # Test the transformation of a GSObject - # These only work for local WCS projections! - - near_u_list = [ufunc(x, y) for x, y in zip(near_x_list, near_y_list)] - near_v_list = [vfunc(x, y) for x, y in zip(near_x_list, near_y_list)] - - im1 = galsim.Image(64, 64, wcs=wcs) - im2 = galsim.Image(64, 64, scale=1.0) - - for world_profile in profiles: - # The profiles build above are in world coordinates (as usual) - - # Convert to image coordinates - image_profile = wcs.toImage(world_profile) - - assert wcs.profileToImage(world_profile) == image_profile - - # Also check round trip (starting with either one) - world_profile2 = wcs.toWorld(image_profile) - image_profile2 = wcs.toImage(world_profile2) - - assert wcs.profileToWorld(image_profile) == world_profile2 - - for x, y, u, v in zip(near_x_list, near_y_list, near_u_list, near_v_list): - image_pos = galsim.PositionD(x, y) - world_pos = galsim.PositionD(u, v) - pixel_area = wcs.pixelArea(image_pos=image_pos) - - np.testing.assert_almost_equal( - image_profile.xValue(image_pos) / pixel_area, - world_profile.xValue(world_pos), - digits, - "xValue for image_profile and world_profile differ for " + name, - ) - np.testing.assert_almost_equal( - image_profile.xValue(image_pos), - image_profile2.xValue(image_pos), - digits, - "image_profile not equivalent after round trip through world for " - + name, - ) - np.testing.assert_almost_equal( - world_profile.xValue(world_pos), - world_profile2.xValue(world_pos), - digits, - "world_profile not equivalent after round trip through image for " - + name, - ) - - # The last item in list should also work as a PositionI - image_pos = galsim.PositionI(x, y) - np.testing.assert_almost_equal( - pixel_area, - wcs.pixelArea(image_pos=image_pos), - digits, - "pixelArea gave different result for PositionI image_pos for " + name, - ) - np.testing.assert_almost_equal( - image_profile.xValue(image_pos) / pixel_area, - world_profile.xValue(world_pos), - digits, - "xValue for image_profile gave different result for PositionI for " + name, - ) - np.testing.assert_almost_equal( - image_profile.xValue(image_pos), - image_profile2.xValue(image_pos), - digits, - "round trip xValue gave different result for PositionI for " + name, - ) - - # Test drawing the profile on an image with the given wcs - world_profile.drawImage(im1, method="no_pixel") - image_profile.drawImage(im2, method="no_pixel") - np.testing.assert_array_almost_equal( - im1.array, - im2.array, - digits, - "world_profile and image_profile were different when drawn for " + name, - ) - - # Test drawing the profile at a different center - world_profile.drawImage(im1, method="no_pixel", center=(30.9, 34.1)) - image_profile.drawImage(im2, method="no_pixel", center=(30.9, 34.1)) - np.testing.assert_array_almost_equal(im1.array, im2.array, digits) - - assert_raises(TypeError, wcs.withOrigin) - assert_raises(TypeError, wcs.withOrigin, origin=(3, 4), color=0.3) - assert_raises( - TypeError, wcs.withOrigin, origin=image_pos, world_origin=(3, 4), color=0.3 - ) - - -def do_jac_decomp(wcs, name): - scale, shear, theta, flip = wcs.getDecomposition() - - # First see if we can recreate the right matrix from this: - S = np.array([[1.0 + shear.g1, shear.g2], [shear.g2, 1.0 - shear.g1]]) / np.sqrt( - 1.0 - shear.g1**2 - shear.g2**2 - ) - R = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) - if flip: - F = np.array([[0, 1], [1, 0]]) - else: - F = np.array([[1, 0], [0, 1]]) - - M = scale * S.dot(R).dot(F) - J = wcs.getMatrix() - np.testing.assert_array_almost_equal( - M, J, 8, "Decomposition was inconsistent with jacobian for " + name - ) - - # The minLinearScale is scale * (1-g) / sqrt(1-g^2) - import math - - g = shear.g - min_scale = scale * (1.0 - g) / math.sqrt(1.0 - g**2) - np.testing.assert_almost_equal(wcs.minLinearScale(), min_scale, 6, "minLinearScale") - # The maxLinearScale is scale * (1+g) / sqrt(1-g^2) - max_scale = scale * (1.0 + g) / math.sqrt(1.0 - g**2) - np.testing.assert_almost_equal(wcs.maxLinearScale(), max_scale, 6, "minLinearScale") - - # There are some relations between the decomposition and the inverse decomposition that should - # be true: - scale2, shear2, theta2, flip2 = wcs.inverse().getDecomposition() - np.testing.assert_equal(flip, flip2, "inverse flip") - np.testing.assert_almost_equal(scale, 1.0 / scale2, 6, "inverse scale") - if flip: - np.testing.assert_almost_equal(theta.rad, theta2.rad, 6, "inverse theta") - else: - np.testing.assert_almost_equal(theta.rad, -theta2.rad, 6, "inverse theta") - np.testing.assert_almost_equal(shear.g, shear2.g, 6, "inverse shear") - # There is no simple relation between the directions of the shear in the two cases. - # The shear direction gets mixed up by the rotation if that is non-zero. - - # Also check that the profile is transformed equivalently as advertised in the docstring - # for getDecomposition. - base_obj = galsim.Gaussian(sigma=2) - # Make sure it doesn't have any initial symmetry! - base_obj = base_obj.shear(g1=0.1, g2=0.23).shift(0.17, -0.37) - - obj1 = base_obj.transform(wcs.dudx, wcs.dudy, wcs.dvdx, wcs.dvdy) - - if flip: - obj2 = base_obj.transform(0, 1, 1, 0) - else: - obj2 = base_obj - obj2 = obj2.rotate(theta).shear(shear).expand(scale) - - gsobject_compare(obj1, obj2) - - -def do_nonlocal_wcs(wcs, ufunc, vfunc, name, test_pickle=True, color=None): - print("Start testing non-local WCS " + name) - - # Check that shiftOrigin and local work correctly: - new_origin = galsim.PositionI(123, 321) - wcs3 = wcs.shiftOrigin(new_origin) - assert wcs != wcs3, name + " is not != wcs.shiftOrigin(pos)" - wcs4 = wcs.local(wcs.origin, color=color) - assert wcs != wcs4, name + " is not != wcs.local()" - assert wcs4 != wcs, name + " is not != wcs.local() (reverse)" - if wcs.isUniform(): - if wcs.world_origin == galsim.PositionD(0, 0): - wcs2 = wcs.local(wcs.origin, color=color).withOrigin(wcs.origin) - assert wcs == wcs2, ( - name + " is not equal after wcs.local().withOrigin(origin)" - ) - wcs2 = wcs.local(wcs.origin, color=color).withOrigin( - wcs.origin, wcs.world_origin - ) - assert wcs == wcs2, ( - name + " not equal after wcs.local().withOrigin(origin,world_origin)" - ) - world_pos1 = wcs.toWorld(galsim.PositionD(0, 0), color=color) - wcs3 = wcs.shiftOrigin(new_origin) - world_pos2 = wcs3.toWorld(new_origin, color=color) - np.testing.assert_almost_equal( - world_pos2.x, - world_pos1.x, - digits, - "shiftOrigin(new_origin) returned wrong world position", - ) - np.testing.assert_almost_equal( - world_pos2.y, - world_pos1.y, - digits, - "shiftOrigin(new_origin) returned wrong world position", - ) - new_world_origin = galsim.PositionD(5352.7, 9234.3) - wcs5 = wcs.shiftOrigin(new_origin, new_world_origin, color=color) - world_pos3 = wcs5.toWorld(new_origin, color=color) - np.testing.assert_almost_equal( - world_pos3.x, - new_world_origin.x, - digits, - "shiftOrigin(new_origin, new_world_origin) returned wrong position", - ) - np.testing.assert_almost_equal( - world_pos3.y, - new_world_origin.y, - digits, - "shiftOrigin(new_origin, new_world_origin) returned wrong position", - ) - - # Check that (x,y) -> (u,v) and converse work correctly - # These tests work regardless of whether the WCS is local or not. - do_wcs_pos(wcs, ufunc, vfunc, name, color=color) - - # Check picklability - if test_pickle: - check_pickle(wcs) - - # The GSObject transformation tests are only valid for a local WCS. - # But it should work for wcs.local() - - far_u_list = [ufunc(x, y) for x, y in zip(far_x_list, far_y_list)] - far_v_list = [vfunc(x, y) for x, y in zip(far_x_list, far_y_list)] - - full_im1 = galsim.Image( - galsim.BoundsI(-1023, 1024, -1023, 1024), wcs=wcs.fixColor(color) - ) - full_im2 = galsim.Image(galsim.BoundsI(-1023, 1024, -1023, 1024), scale=1.0) - - for x0, y0, u0, v0 in zip(far_x_list, far_y_list, far_u_list, far_v_list): - local_ufunc = lambda x, y: ufunc(x + x0, y + y0) - u0 # noqa: E731 - local_vfunc = lambda x, y: vfunc(x + x0, y + y0) - v0 # noqa: E731 - image_pos = galsim.PositionD(x0, y0) - world_pos = galsim.PositionD(u0, v0) - do_wcs_pos( - wcs.local(image_pos, color=color), - local_ufunc, - local_vfunc, - name + ".local(image_pos)", - ) - do_wcs_pos( - wcs.jacobian(image_pos, color=color), - local_ufunc, - local_vfunc, - name + ".jacobian(image_pos)", - ) - do_wcs_pos( - wcs.affine(image_pos, color=color), - ufunc, - vfunc, - name + ".affine(image_pos)", - x0, - y0, - ) - - try: - # The local call is not guaranteed to be implemented for world_pos. - # So guard against NotImplementedError. - do_wcs_pos( - wcs.local(world_pos=world_pos, color=color), - local_ufunc, - local_vfunc, - name + ".local(world_pos)", - ) - do_wcs_pos( - wcs.jacobian(world_pos=world_pos, color=color), - local_ufunc, - local_vfunc, - name + ".jacobian(world_pos)", - ) - do_wcs_pos( - wcs.affine(world_pos=world_pos, color=color), - ufunc, - vfunc, - name + ".affine(world_pos)", - x0, - y0, - ) - except NotImplementedError: - pass - - # Test drawing the profile on an image with the given wcs - ix0 = int(x0) - iy0 = int(y0) - dx = x0 - ix0 - dy = y0 - iy0 - b = galsim.BoundsI(ix0 - 31, ix0 + 31, iy0 - 31, iy0 + 31) - im1 = full_im1[b] - im2 = full_im2[b] - - for world_profile in profiles: - image_profile = wcs.toImage(world_profile, image_pos=image_pos, color=color) - - world_profile.drawImage(im1, offset=(dx, dy), method="no_pixel") - image_profile.drawImage(im2, offset=(dx, dy), method="no_pixel") - np.testing.assert_array_almost_equal( - im1.array, - im2.array, - digits, - "world_profile and image_profile differed when drawn for " + name, - ) - - # Equivalent call using center rather than offset - im1b = world_profile.drawImage( - im1.copy(), center=(x0, y0), method="no_pixel" - ) - np.testing.assert_array_almost_equal(im1b.array, im1.array, digits) - im2b = image_profile.drawImage( - im2.copy(), center=(x0, y0), method="no_pixel" - ) - np.testing.assert_array_almost_equal(im2b.array, im2.array, digits) - - try: - # The toImage call is not guaranteed to be implemented for world_pos. - # So guard against NotImplementedError. - image_profile = wcs.toImage( - world_profile, world_pos=world_pos, color=color - ) - - world_profile.drawImage(im1, offset=(dx, dy), method="no_pixel") - image_profile.drawImage(im2, offset=(dx, dy), method="no_pixel") - np.testing.assert_array_almost_equal( - im1.array, - im2.array, - digits, - "world_profile and image_profile differed when drawn for " + name, - ) - except NotImplementedError: - pass - - # Since these postage stamps are odd, should get the same answer if we draw - # using the true center or not. - world_profile.drawImage(im1, method="no_pixel", use_true_center=False) - image_profile.drawImage(im2, method="no_pixel", use_true_center=True) - np.testing.assert_array_almost_equal( - im1.array, - im2.array, - digits, - "world_profile at center and image_profile differed when drawn for " - + name, - ) - - # Can also pass in wcs as a parameter to drawImage. - world_profile.drawImage(im1, method="no_pixel", wcs=wcs.fixColor(color)) - image_profile.drawImage(im2, method="no_pixel") - np.testing.assert_array_almost_equal( - im1.array, - im2.array, - digits, - "world_profile with given wcs and image_profile differed when drawn for " - + name, - ) - - # Check some properties that should be the same for the wcs and its local jacobian. - np.testing.assert_allclose( - wcs.minLinearScale(image_pos=image_pos, color=color), - wcs.jacobian(image_pos=image_pos, color=color).minLinearScale(color=color), - ) - np.testing.assert_allclose( - wcs.maxLinearScale(image_pos=image_pos, color=color), - wcs.jacobian(image_pos=image_pos, color=color).maxLinearScale(color=color), - ) - np.testing.assert_allclose( - wcs.pixelArea(image_pos=image_pos, color=color), - wcs.jacobian(image_pos=image_pos, color=color).pixelArea(color=color), - ) - - if not wcs.isUniform(): - assert_raises(TypeError, wcs.local) - assert_raises( - TypeError, wcs.local, image_pos=image_pos, world_pos=world_pos, color=color - ) - assert_raises(TypeError, wcs.local, image_pos=(3, 4), color=color) - assert_raises(TypeError, wcs.local, world_pos=(3, 4), color=color) - - assert_raises(TypeError, wcs.shiftOrigin) - assert_raises(TypeError, wcs.shiftOrigin, origin=(3, 4), color=color) - assert_raises( - TypeError, wcs.shiftOrigin, origin=image_pos, world_origin=(3, 4), color=color - ) - - -def do_celestial_wcs(wcs, name, test_pickle=True, approx=False): - # It's a bit harder to test WCS functions that return a CelestialCoord, since - # (usually) we don't have an exact formula to compare with. So the tests here - # are a bit sparer. - - print("Start testing celestial WCS " + name) - - # Check that shiftOrigin and local work correctly: - new_origin = galsim.PositionI(123, 321) - wcs3 = wcs.shiftOrigin(new_origin) - assert wcs != wcs3, name + " is not != wcs.shiftOrigin(pos)" - wcs4 = wcs.local(wcs.origin) - assert wcs != wcs4, name + " is not != wcs.local()" - assert wcs4 != wcs, name + " is not != wcs.local() (reverse)" - world_pos1 = wcs.toWorld(galsim.PositionD(0, 0)) - wcs3 = wcs.shiftOrigin(new_origin) - world_pos2 = wcs3.toWorld(new_origin) - np.testing.assert_almost_equal( - world_pos2.distanceTo(world_pos1) / galsim.arcsec, - 0, - digits, - "shiftOrigin(new_origin) returned wrong world position", - ) - - full_im1 = galsim.Image(galsim.BoundsI(-1023, 1024, -1023, 1024), wcs=wcs) - full_im2 = galsim.Image(galsim.BoundsI(-1023, 1024, -1023, 1024), scale=1.0) - - # Some of the FITS images have really huge pixel scales. Lower the accuracy requirement - # for them. 2 digits in arcsec corresponds to 4 digits in pixels. - max_scale = wcs.maxLinearScale(wcs.origin) - if max_scale > 100: # arcsec - digits2 = 2 - else: - digits2 = digits - if approx: - atol = 0.01 - else: - atol = 1.0e-6 - - # Check picklability - if test_pickle: - check_pickle(wcs) - - near_ra_list = [] - near_dec_list = [] - for x0, y0 in zip(near_x_list, near_y_list): - image_pos = galsim.PositionD(x0, y0) - world_pos = wcs.toWorld(image_pos) - near_ra_list.append(world_pos.ra.rad) - near_dec_list.append(world_pos.dec.rad) - - # posToWorld is equivalent - world_pos2 = wcs.posToWorld(image_pos) - assert world_pos2 == world_pos - - # xyToradec also works - ra1, dec1 = wcs.toWorld(x0, y0, units=galsim.radians) - ra2, dec2 = wcs.xyToradec(x0, y0, galsim.degrees) - ra3, dec3 = wcs.toWorld(x0, y0, units="arcmin") - assert np.isclose(ra1, world_pos.ra.rad) - assert np.isclose(dec1, world_pos.dec.rad) - assert np.isclose(ra2, world_pos.ra / galsim.degrees) - assert np.isclose(dec2, world_pos.dec / galsim.degrees) - assert np.isclose(ra3, world_pos.ra / galsim.arcmin) - assert np.isclose(dec3, world_pos.dec / galsim.arcmin) - - # toWorld on scalars should not return arrays - assert_raises(TypeError, len, ra1) - assert_raises(TypeError, len, dec1) - assert_raises(TypeError, len, ra2) - assert_raises(TypeError, len, dec2) - assert_raises(TypeError, len, ra3) - assert_raises(TypeError, len, dec3) - - try: - # The toImage call is not guaranteed to be implemented for world_pos. - # So guard against NotImplementedError. - image_pos1 = wcs.toImage(world_pos) - image_pos2 = wcs.posToImage(world_pos) - test_reverse = True - except NotImplementedError: - test_reverse = False - else: - assert np.isclose(image_pos1.x, x0, rtol=1.0e-3, atol=atol) - assert np.isclose(image_pos1.y, y0, rtol=1.0e-3, atol=atol) - assert np.isclose(image_pos2.x, x0, rtol=1.0e-3, atol=atol) - assert np.isclose(image_pos2.y, y0, rtol=1.0e-3, atol=atol) - - x1, y1 = wcs.toImage(ra1, dec1, units=galsim.radians) - x2, y2 = wcs.radecToxy(ra2, dec2, units="deg") - x3, y3 = wcs.radecToxy(ra3, dec3, units=galsim.arcmin) - assert_raises(TypeError, len, x1) - assert_raises(TypeError, len, y1) - assert_raises(TypeError, len, x2) - assert_raises(TypeError, len, y2) - assert_raises(TypeError, len, x3) - assert_raises(TypeError, len, y3) - assert np.isclose(x1, x0, rtol=1.0e-3, atol=atol) - assert np.isclose(y1, y0, rtol=1.0e-3, atol=atol) - assert np.isclose(x2, x0, rtol=1.0e-3, atol=atol) - assert np.isclose(y2, y0, rtol=1.0e-3, atol=atol) - assert np.isclose(x3, x0, rtol=1.0e-3, atol=atol) - assert np.isclose(y3, y0, rtol=1.0e-3, atol=atol) - - # Check the calculation of the jacobian - w1 = wcs.toWorld(galsim.PositionD(x0 + 0.5, y0)) - w2 = wcs.toWorld(galsim.PositionD(x0 - 0.5, y0)) - w3 = wcs.toWorld(galsim.PositionD(x0, y0 + 0.5)) - w4 = wcs.toWorld(galsim.PositionD(x0, y0 - 0.5)) - cosdec = np.cos(world_pos.dec) - jac = wcs.jacobian(image_pos) - np.testing.assert_array_almost_equal( - jac.dudx, - (w2.ra - w1.ra) / galsim.arcsec * cosdec, - digits2, - "jacobian dudx incorrect for " + name, - ) - np.testing.assert_array_almost_equal( - jac.dudy, - (w4.ra - w3.ra) / galsim.arcsec * cosdec, - digits2, - "jacobian dudy incorrect for " + name, - ) - np.testing.assert_array_almost_equal( - jac.dvdx, - (w1.dec - w2.dec) / galsim.arcsec, - digits2, - "jacobian dvdx incorrect for " + name, - ) - np.testing.assert_array_almost_equal( - jac.dvdy, - (w3.dec - w4.dec) / galsim.arcsec, - digits2, - "jacobian dvdy incorrect for " + name, - ) - - # toWorld with projection should be (roughly) equivalent to the local around the - # projection point. - origin = galsim.PositionD(0, 0) - uv_pos1 = wcs.toWorld(image_pos, project_center=wcs.toWorld(origin)) - uv_pos2 = wcs.local(origin).toWorld(image_pos) - uv_pos3 = wcs.posToWorld(image_pos, project_center=wcs.toWorld(origin)) - u3, v3 = wcs.toWorld(origin).project(world_pos, "gnomonic") - np.testing.assert_allclose(uv_pos1.x, uv_pos2.x, rtol=1.0e-1, atol=1.0e-8) - np.testing.assert_allclose(uv_pos1.y, uv_pos2.y, rtol=1.0e-1, atol=1.0e-8) - np.testing.assert_allclose( - uv_pos1.x, u3 / galsim.arcsec, rtol=1.0e-6, atol=1.0e-8 - ) - np.testing.assert_allclose( - uv_pos1.y, v3 / galsim.arcsec, rtol=1.0e-6, atol=1.0e-8 - ) - np.testing.assert_allclose( - uv_pos3.x, u3 / galsim.arcsec, rtol=1.0e-6, atol=1.0e-8 - ) - np.testing.assert_allclose( - uv_pos3.y, v3 / galsim.arcsec, rtol=1.0e-6, atol=1.0e-8 - ) - - origin = galsim.PositionD(x0 + 0.5, y0 - 0.5) - uv_pos1 = wcs.toWorld(image_pos, project_center=wcs.toWorld(origin)) - uv_pos2 = wcs.local(origin).toWorld(image_pos - origin) - uv_pos3 = wcs.posToWorld(image_pos, project_center=wcs.toWorld(origin)) - u3, v3 = wcs.toWorld(origin).project(world_pos, "gnomonic") - np.testing.assert_allclose(uv_pos1.x, uv_pos2.x, rtol=1.0e-2, atol=1.0e-8) - np.testing.assert_allclose(uv_pos1.y, uv_pos2.y, rtol=1.0e-2, atol=1.0e-8) - np.testing.assert_allclose( - uv_pos1.x, u3 / galsim.arcsec, rtol=1.0e-6, atol=1.0e-8 - ) - np.testing.assert_allclose( - uv_pos1.y, v3 / galsim.arcsec, rtol=1.0e-6, atol=1.0e-8 - ) - np.testing.assert_allclose( - uv_pos3.x, u3 / galsim.arcsec, rtol=1.0e-6, atol=1.0e-8 - ) - np.testing.assert_allclose( - uv_pos3.y, v3 / galsim.arcsec, rtol=1.0e-6, atol=1.0e-8 - ) - - # Test drawing the profile on an image with the given wcs - ix0 = int(x0) - iy0 = int(y0) - dx = x0 - ix0 - dy = y0 - iy0 - b = galsim.BoundsI(ix0 - 31, ix0 + 31, iy0 - 31, iy0 + 31) - im1 = full_im1[b] - im2 = full_im2[b] - - for world_profile in profiles: - image_profile = wcs.toImage(world_profile, image_pos=image_pos) - - world_profile.drawImage(im1, offset=(dx, dy), method="no_pixel") - image_profile.drawImage(im2, offset=(dx, dy), method="no_pixel") - np.testing.assert_array_almost_equal( - im1.array, - im2.array, - digits, - "world_profile and image_profile differed when drawn for " + name, - ) - - # Equivalent call using center rather than offset - im1b = world_profile.drawImage( - im1.copy(), center=(x0, y0), method="no_pixel" - ) - np.testing.assert_array_almost_equal(im1b.array, im1.array, digits) - im2b = image_profile.drawImage( - im2.copy(), center=(x0, y0), method="no_pixel" - ) - np.testing.assert_array_almost_equal(im2b.array, im2.array, digits) - - if test_reverse: - image_profile = wcs.toImage(world_profile, world_pos=world_pos) - - world_profile.drawImage(im1, offset=(dx, dy), method="no_pixel") - image_profile.drawImage(im2, offset=(dx, dy), method="no_pixel") - np.testing.assert_array_almost_equal( - im1.array, - im2.array, - digits, - "world_profile and image_profile differed when drawn for " + name, - ) - - # Test xyToradec with array - xar = np.array(near_x_list) - yar = np.array(near_y_list) - raar = np.array(near_ra_list) - decar = np.array(near_dec_list) - ra1, dec1 = wcs.toWorld(xar, yar, units=galsim.radians) - ra2, dec2 = wcs.xyToradec(xar, yar, galsim.degrees) - ra3, dec3 = wcs.toWorld(xar, yar, units="arcmin") - np.testing.assert_allclose(ra1, raar) - np.testing.assert_allclose(dec1, decar) - np.testing.assert_allclose(ra2, raar * (galsim.radians / galsim.degrees)) - np.testing.assert_allclose(dec2, decar * (galsim.radians / galsim.degrees)) - np.testing.assert_allclose(ra3, raar * (galsim.radians / galsim.arcmin)) - np.testing.assert_allclose(dec3, decar * (galsim.radians / galsim.arcmin)) - - if test_reverse: - x1, y1 = wcs.toImage(raar, decar, units=galsim.radians) - x2, y2 = wcs.radecToxy( - raar * (galsim.radians / galsim.degrees), - decar * (galsim.radians / galsim.degrees), - galsim.degrees, - ) - x3, y3 = wcs.toImage( - raar * (galsim.radians / galsim.arcmin), - decar * (galsim.radians / galsim.arcmin), - units="arcmin", - ) - np.testing.assert_allclose(x1, xar, rtol=1.0e-3, atol=atol) - np.testing.assert_allclose(y1, yar, rtol=1.0e-3, atol=atol) - np.testing.assert_allclose(x2, xar, rtol=1.0e-3, atol=atol) - np.testing.assert_allclose(y2, yar, rtol=1.0e-3, atol=atol) - np.testing.assert_allclose(x3, xar, rtol=1.0e-3, atol=atol) - np.testing.assert_allclose(y3, yar, rtol=1.0e-3, atol=atol) - - assert_raises(TypeError, wcs.toWorld) - assert_raises(TypeError, wcs.toWorld, (3, 4)) - assert_raises(TypeError, wcs.toWorld, 3, 4) # no units - assert_raises(TypeError, wcs.toWorld, 3, 4, 5) - assert_raises( - TypeError, - wcs.toWorld, - galsim.CelestialCoord(0 * galsim.degrees, 0 * galsim.degrees), - ) - - assert_raises(TypeError, wcs.posToWorld) - assert_raises(TypeError, wcs.posToWorld, (3, 4)) - assert_raises(TypeError, wcs.posToWorld, 3, 4) - assert_raises(TypeError, wcs.posToWorld, 3, 4, 5) - assert_raises( - TypeError, - wcs.posToWorld, - galsim.CelestialCoord(0 * galsim.degrees, 0 * galsim.degrees), - ) - - assert_raises(TypeError, wcs.xyToradec) - assert_raises(TypeError, wcs.xyToradec, 3) - assert_raises(TypeError, wcs.xyToradec, 3, 4) # no units - assert_raises(TypeError, wcs.xyToradec, galsim.PositionD(3, 4)) - assert_raises(ValueError, wcs.xyToradec, 3, 4, 5) # 5 interpreted as a unit - - assert_raises(TypeError, wcs.toImage) - assert_raises(TypeError, wcs.toImage, (3, 4)) - assert_raises(TypeError, wcs.toImage, 3, 4) # no units - assert_raises(TypeError, wcs.toImage, galsim.PositionD(3, 4)) - assert_raises(TypeError, wcs.toImage, 3, 4, 5) - - assert_raises(TypeError, wcs.posToImage) - assert_raises(TypeError, wcs.posToImage, (3, 4)) - assert_raises(TypeError, wcs.posToImage, galsim.PositionD(3, 4)) - assert_raises(TypeError, wcs.posToImage, 3, 4) - assert_raises(TypeError, wcs.posToImage, 3, 4, 5) - - assert_raises(TypeError, wcs.radecToxy) - assert_raises(TypeError, wcs.radecToxy, 3, 4) - assert_raises(TypeError, wcs.radecToxy, units=galsim.degrees) - assert_raises(TypeError, wcs.radecToxy, 3, units=galsim.degrees) - assert_raises(TypeError, wcs.radecToxy, 3, 4, 5, units=galsim.degrees) - assert_raises(TypeError, wcs.radecToxy, world_pos, units=galsim.degrees) - assert_raises(ValueError, wcs.radecToxy, 3, 4, 5) - - assert_raises(TypeError, wcs.local) - assert_raises(TypeError, wcs.local, image_pos=image_pos, world_pos=world_pos) - assert_raises(TypeError, wcs.local, image_pos=(3, 4)) - assert_raises(TypeError, wcs.local, world_pos=(3, 4)) - - assert_raises(TypeError, wcs.shiftOrigin) - assert_raises(TypeError, wcs.shiftOrigin, origin=(3, 4)) - assert_raises(TypeError, wcs.shiftOrigin, world_origin=(3, 4)) - assert_raises(TypeError, wcs.shiftOrigin, origin=image_pos, world_origin=world_pos) - - -@timer -def test_pixelscale(): - """Test the PixelScale class""" - scale = 0.23 - wcs = galsim.PixelScale(scale) - assert wcs.isPixelScale() - assert wcs.isLocal() - assert wcs.isUniform() - assert not wcs.isCelestial() - - # Check basic copy and == , !=: - wcs2 = wcs.copy() - assert wcs == wcs2, "PixelScale copy is not == the original" - wcs3 = galsim.PixelScale(scale + 0.1234) - assert wcs != wcs3, "PixelScale is not != a different one" - assert wcs.scale == scale - assert wcs.origin == galsim.PositionD(0, 0) - assert wcs.world_origin == galsim.PositionD(0, 0) - - # JAX specific modification - # ------------------------- - # PixelScale does not check for correct type to enable jitting/vmapping - # assert_raises(TypeError, galsim.PixelScale) - # assert_raises(TypeError, galsim.PixelScale, scale=galsim.PixelScale(scale)) - # assert_raises(TypeError, galsim.PixelScale, scale=scale, origin=galsim.PositionD(0, 0)) - # assert_raises(TypeError, galsim.PixelScale, scale=scale, world_origin=galsim.PositionD(0, 0)) - - ufunc = lambda x, y: x * scale # noqa: E731 - vfunc = lambda x, y: y * scale # noqa: E731 - - # Do generic tests that apply to all WCS types - do_local_wcs(wcs, ufunc, vfunc, "PixelScale") - - # Check that using a wcs in the context of an image works correctly - do_wcs_image(wcs, "PixelScale") - - # Check jacobian() - jac = wcs.jacobian() - np.testing.assert_almost_equal( - jac.dudx, scale, digits, "PixelScale dudx does not match expected value." - ) - np.testing.assert_almost_equal( - jac.dudy, 0.0, digits, "PixelScale dudy does not match expected value." - ) - np.testing.assert_almost_equal( - jac.dvdx, 0.0, digits, "PixelScale dvdx does not match expected value." - ) - np.testing.assert_almost_equal( - jac.dvdy, scale, digits, "PixelScale dvdy does not match expected value." - ) - - # Check the decomposition: - do_jac_decomp(jac, "PixelScale") - - # Add an image origin offset - x0 = 1 - y0 = 1 - origin = galsim.PositionI(x0, y0) - wcs = galsim.OffsetWCS(scale, origin) - wcs2 = galsim.PixelScale(scale).withOrigin(origin) - assert wcs == wcs2, "OffsetWCS is not == PixelScale.withOrigin(origin)" - assert wcs.origin == origin - assert wcs.scale == scale - assert wcs.isPixelScale() - assert not wcs.isLocal() - assert wcs.isUniform() - assert not wcs.isCelestial() - - # Default origin is (0,0) - wcs3 = galsim.OffsetWCS(scale) - assert wcs3.origin == galsim.PositionD(0, 0) - assert wcs3.world_origin == galsim.PositionD(0, 0) - - assert_raises(TypeError, galsim.OffsetWCS) - assert_raises(TypeError, galsim.OffsetWCS, scale=galsim.PixelScale(scale)) - assert_raises(TypeError, galsim.OffsetWCS, scale=scale, origin=5) - assert_raises( - TypeError, - galsim.OffsetWCS, - scale=scale, - origin=galsim.CelestialCoord(0 * galsim.degrees, 0 * galsim.degrees), - ) - assert_raises( - TypeError, - galsim.OffsetWCS, - scale=scale, - world_origin=galsim.CelestialCoord(0 * galsim.degrees, 0 * galsim.degrees), - ) - - # Check basic copy and == , != for OffsetWCS: - wcs2 = wcs.copy() - assert wcs == wcs2, "OffsetWCS copy is not == the original" - wcs3a = galsim.OffsetWCS(scale + 0.123, origin) - wcs3b = galsim.OffsetWCS(scale, origin * 2) - wcs3c = galsim.OffsetWCS(scale, origin, origin) - assert wcs != wcs3a, "OffsetWCS is not != a different one (scale)" - assert wcs != wcs3b, "OffsetWCS is not != a different one (origin)" - assert wcs != wcs3c, "OffsetWCS is not != a different one (world_origin)" - - ufunc = lambda x, y: scale * (x - x0) # noqa: E731 - vfunc = lambda x, y: scale * (y - y0) # noqa: E731 - do_nonlocal_wcs(wcs, ufunc, vfunc, "OffsetWCS 1") - - # Add a world origin offset - u0 = 124.3 - v0 = -141.9 - world_origin = galsim.PositionD(u0, v0) - wcs = galsim.OffsetWCS(scale, world_origin=world_origin) - ufunc = lambda x, y: scale * x + u0 # noqa: E731 - vfunc = lambda x, y: scale * y + v0 # noqa: E731 - do_nonlocal_wcs(wcs, ufunc, vfunc, "OffsetWCS 2") - - # Add both kinds of offsets - x0 = -3 - y0 = 104 - u0 = 1423.9 - v0 = 8242.7 - origin = galsim.PositionD(x0, y0) - world_origin = galsim.PositionD(u0, v0) - wcs = galsim.OffsetWCS(scale, origin=origin, world_origin=world_origin) - ufunc = lambda x, y: scale * (x - x0) + u0 # noqa: E731 - vfunc = lambda x, y: scale * (y - y0) + v0 # noqa: E731 - do_nonlocal_wcs(wcs, ufunc, vfunc, "OffsetWCS 3") - - # Check that using a wcs in the context of an image works correctly - do_wcs_image(wcs, "OffsetWCS") - - -@timer -def test_shearwcs(): - """Test the ShearWCS class""" - scale = 0.23 - g1 = 0.14 - g2 = -0.37 - shear = galsim.Shear(g1=g1, g2=g2) - wcs = galsim.ShearWCS(scale, shear) - assert wcs.shear == shear - assert wcs.origin == galsim.PositionD(0, 0) - assert wcs.world_origin == galsim.PositionD(0, 0) - assert not wcs.isPixelScale() - assert wcs.isLocal() - assert wcs.isUniform() - assert not wcs.isCelestial() - - assert_raises(TypeError, galsim.ShearWCS) - assert_raises(TypeError, galsim.ShearWCS, shear=0.3) - assert_raises( - TypeError, galsim.ShearWCS, shear=shear, origin=galsim.PositionD(0, 0) - ) - assert_raises( - TypeError, galsim.ShearWCS, shear=shear, world_origin=galsim.PositionD(0, 0) - ) - assert_raises(TypeError, galsim.ShearWCS, g1=g1, g2=g2) - - # Check basic copy and == , !=: - wcs2 = wcs.copy() - assert wcs == wcs2, "ShearWCS copy is not == the original" - wcs3a = galsim.ShearWCS(scale + 0.1234, shear) - wcs3b = galsim.ShearWCS(scale, -shear) - assert wcs != wcs3a, "ShearWCS is not != a different one (scale)" - assert wcs != wcs3b, "ShearWCS is not != a different one (shear)" - - factor = 1.0 / np.sqrt(1.0 - g1 * g1 - g2 * g2) - ufunc = lambda x, y: (x - g1 * x - g2 * y) * scale * factor # noqa: E731 - vfunc = lambda x, y: (y + g1 * y - g2 * x) * scale * factor # noqa: E731 - - # Do generic tests that apply to all WCS types - do_local_wcs(wcs, ufunc, vfunc, "ShearWCS") - - # Check that using a wcs in the context of an image works correctly - do_wcs_image(wcs, "ShearWCS") - - # Check jacobian() - jac = wcs.jacobian() - np.testing.assert_almost_equal( - jac.dudx, - (1.0 - g1) * scale * factor, - digits, - "ShearWCS dudx does not match expected value.", - ) - np.testing.assert_almost_equal( - jac.dudy, - -g2 * scale * factor, - digits, - "ShearWCS dudy does not match expected value.", - ) - np.testing.assert_almost_equal( - jac.dvdx, - -g2 * scale * factor, - digits, - "ShearWCS dvdx does not match expected value.", - ) - np.testing.assert_almost_equal( - jac.dvdy, - (1.0 + g1) * scale * factor, - digits, - "ShearWCS dvdy does not match expected value.", - ) - - # Check the decomposition: - do_jac_decomp(jac, "ShearWCS") - - # Add an image origin offset - x0 = 1 - y0 = 1 - origin = galsim.PositionD(x0, y0) - wcs = galsim.OffsetShearWCS(scale, shear, origin) - wcs2 = galsim.ShearWCS(scale, shear).withOrigin(origin) - assert wcs == wcs2, "OffsetShearWCS is not == ShearWCS.withOrigin(origin)" - assert wcs.shear == shear - assert wcs.origin == origin - assert wcs.world_origin == galsim.PositionD(0, 0) - assert not wcs.isPixelScale() - assert not wcs.isLocal() - assert wcs.isUniform() - assert not wcs.isCelestial() - - wcs3 = galsim.OffsetShearWCS(scale, shear) - assert wcs3.origin == galsim.PositionD(0, 0) - assert wcs3.world_origin == galsim.PositionD(0, 0) - - assert_raises(TypeError, galsim.OffsetShearWCS) - assert_raises(TypeError, galsim.OffsetShearWCS, shear=0.3) - assert_raises(TypeError, galsim.OffsetShearWCS, shear=shear, origin=5) - assert_raises( - TypeError, - galsim.OffsetShearWCS, - shear=shear, - origin=galsim.CelestialCoord(0 * galsim.degrees, 0 * galsim.degrees), - ) - assert_raises( - TypeError, - galsim.OffsetShearWCS, - shear=shear, - world_origin=galsim.CelestialCoord(0 * galsim.degrees, 0 * galsim.degrees), - ) - - # Check basic copy and == , != for OffsetShearWCS: - wcs2 = wcs.copy() - assert wcs == wcs2, "OffsetShearWCS copy is not == the original" - wcs3a = galsim.OffsetShearWCS(scale + 0.123, shear, origin) - wcs3b = galsim.OffsetShearWCS(scale, -shear, origin) - wcs3c = galsim.OffsetShearWCS(scale, shear, origin * 2) - wcs3d = galsim.OffsetShearWCS(scale, shear, origin, origin) - assert wcs != wcs3a, "OffsetShearWCS is not != a different one (scale)" - assert wcs != wcs3b, "OffsetShearWCS is not != a different one (shear)" - assert wcs != wcs3c, "OffsetShearWCS is not != a different one (origin)" - assert wcs != wcs3d, "OffsetShearWCS is not != a different one (world_origin)" - - ufunc = ( # noqa: E731 - lambda x, y: ((1 - g1) * (x - x0) - g2 * (y - y0)) * scale * factor - ) - vfunc = ( # noqa: E731 - lambda x, y: ((1 + g1) * (y - y0) - g2 * (x - x0)) * scale * factor - ) - do_nonlocal_wcs(wcs, ufunc, vfunc, "OffsetShearWCS 1") - - # Add a world origin offset - u0 = 124.3 - v0 = -141.9 - world_origin = galsim.PositionD(u0, v0) - wcs = galsim.OffsetShearWCS(scale, shear, world_origin=world_origin) - ufunc = lambda x, y: ((1 - g1) * x - g2 * y) * scale * factor + u0 # noqa: E731 - vfunc = lambda x, y: ((1 + g1) * y - g2 * x) * scale * factor + v0 # noqa: E731 - do_nonlocal_wcs(wcs, ufunc, vfunc, "OffsetShearWCS 2") - - # Add both kinds of offsets - x0 = -3 - y0 = 104 - u0 = 1423.9 - v0 = 8242.7 - origin = galsim.PositionD(x0, y0) - world_origin = galsim.PositionD(u0, v0) - wcs = galsim.OffsetShearWCS(scale, shear, origin=origin, world_origin=world_origin) - ufunc = ( # noqa: E731 - lambda x, y: ((1 - g1) * (x - x0) - g2 * (y - y0)) * scale * factor + u0 - ) - vfunc = ( # noqa: E731 - lambda x, y: ((1 + g1) * (y - y0) - g2 * (x - x0)) * scale * factor + v0 - ) - do_nonlocal_wcs(wcs, ufunc, vfunc, "OffsetShearWCS 3") - - # Check that using a wcs in the context of an image works correctly - do_wcs_image(wcs, "OffsetShearWCS") - - -@timer -def test_affinetransform(): - """Test the AffineTransform class""" - # First a slight tweak on a simple scale factor - dudx = 0.2342 - dudy = 0.0023 - dvdx = 0.0019 - dvdy = 0.2391 - - wcs = galsim.JacobianWCS(dudx, dudy, dvdx, dvdy) - assert not wcs.isPixelScale() - assert wcs.isLocal() - assert wcs.isUniform() - assert not wcs.isCelestial() - - assert wcs.dudx == dudx - assert wcs.dudy == dudy - assert wcs.dvdx == dvdx - assert wcs.dvdy == dvdy - - assert_raises(TypeError, galsim.JacobianWCS) - assert_raises(TypeError, galsim.JacobianWCS, dudx, dudy, dvdx) - assert_raises( - TypeError, - galsim.JacobianWCS, - dudx, - dudy, - dvdx, - dvdy, - origin=galsim.PositionD(0, 0), - ) - assert_raises( - TypeError, - galsim.JacobianWCS, - dudx, - dudy, - dvdx, - dvdy, - world_origin=galsim.PositionD(0, 0), - ) - - # Check basic copy and == , !=: - wcs2 = wcs.copy() - assert wcs == wcs2, "JacobianWCS copy is not == the original" - wcs3a = galsim.JacobianWCS(dudx + 0.123, dudy, dvdx, dvdy) - wcs3b = galsim.JacobianWCS(dudx, dudy + 0.123, dvdx, dvdy) - wcs3c = galsim.JacobianWCS(dudx, dudy, dvdx + 0.123, dvdy) - wcs3d = galsim.JacobianWCS(dudx, dudy, dvdx, dvdy + 0.123) - assert wcs != wcs3a, "JacobianWCS is not != a different one (dudx)" - assert wcs != wcs3b, "JacobianWCS is not != a different one (dudy)" - assert wcs != wcs3c, "JacobianWCS is not != a different one (dvdx)" - assert wcs != wcs3d, "JacobianWCS is not != a different one (dvdy)" - - ufunc = lambda x, y: dudx * x + dudy * y # noqa: E731 - vfunc = lambda x, y: dvdx * x + dvdy * y # noqa: E731 - do_local_wcs(wcs, ufunc, vfunc, "JacobianWCS 1") - - # Check the decomposition: - do_jac_decomp(wcs, "JacobianWCS 1") - - # Add an image origin offset - x0 = 1 - y0 = 1 - origin = galsim.PositionD(x0, y0) - wcs = galsim.AffineTransform(dudx, dudy, dvdx, dvdy, origin) - wcs2 = galsim.JacobianWCS(dudx, dudy, dvdx, dvdy).withOrigin(origin) - assert wcs == wcs2, "AffineTransform is not == JacobianWCS.withOrigin(origin)" - assert not wcs.isPixelScale() - assert not wcs.isLocal() - assert wcs.isUniform() - assert not wcs.isCelestial() - - assert_raises(TypeError, galsim.AffineTransform) - assert_raises(TypeError, galsim.AffineTransform, dudx, dudy, dvdx) - assert_raises(TypeError, galsim.AffineTransform, dudx, dudy, dvdx, dvdy, origin=3) - assert_raises( - TypeError, - galsim.AffineTransform, - dudx, - dudy, - dvdx, - dvdy, - origin=galsim.CelestialCoord(0 * galsim.degrees, 0 * galsim.degrees), - ) - assert_raises( - TypeError, - galsim.AffineTransform, - dudx, - dudy, - dvdx, - dvdy, - world_origin=galsim.CelestialCoord(0 * galsim.degrees, 0 * galsim.degrees), - ) - - # Check basic copy and == , != for AffineTransform: - wcs2 = wcs.copy() - assert wcs == wcs2, "AffineTransform copy is not == the original" - wcs3a = galsim.AffineTransform(dudx + 0.123, dudy, dvdx, dvdy, origin) - wcs3b = galsim.AffineTransform(dudx, dudy + 0.123, dvdx, dvdy, origin) - wcs3c = galsim.AffineTransform(dudx, dudy, dvdx + 0.123, dvdy, origin) - wcs3d = galsim.AffineTransform(dudx, dudy, dvdx, dvdy + 0.123, origin) - wcs3e = galsim.AffineTransform(dudx, dudy, dvdx, dvdy, origin * 2) - wcs3f = galsim.AffineTransform(dudx, dudy, dvdx, dvdy, origin, origin) - assert wcs != wcs3a, "AffineTransform is not != a different one (dudx)" - assert wcs != wcs3b, "AffineTransform is not != a different one (dudy)" - assert wcs != wcs3c, "AffineTransform is not != a different one (dvdx)" - assert wcs != wcs3d, "AffineTransform is not != a different one (dvdy)" - assert wcs != wcs3e, "AffineTransform is not != a different one (origin)" - assert wcs != wcs3f, "AffineTransform is not != a different one (world_origin)" - - ufunc = lambda x, y: dudx * (x - x0) + dudy * (y - y0) # noqa: E731 - vfunc = lambda x, y: dvdx * (x - x0) + dvdy * (y - y0) # noqa: E731 - do_nonlocal_wcs(wcs, ufunc, vfunc, "AffineTransform 1") - - # Next one with a flip and significant rotation and a large (u,v) offset - dudx = 0.1432 - dudy = 0.2342 - dvdx = 0.2391 - dvdy = 0.1409 - - wcs = galsim.JacobianWCS(dudx, dudy, dvdx, dvdy) - ufunc = lambda x, y: dudx * x + dudy * y # noqa: E731 - vfunc = lambda x, y: dvdx * x + dvdy * y # noqa: E731 - do_local_wcs(wcs, ufunc, vfunc, "JacobianWCS 2") - - # Check the decomposition: - do_jac_decomp(wcs, "JacobianWCS 2") - - # Add a world origin offset - u0 = 124.3 - v0 = -141.9 - wcs = galsim.AffineTransform( - dudx, dudy, dvdx, dvdy, world_origin=galsim.PositionD(u0, v0) - ) - ufunc = lambda x, y: dudx * x + dudy * y + u0 # noqa: E731 - vfunc = lambda x, y: dvdx * x + dvdy * y + v0 # noqa: E731 - do_nonlocal_wcs(wcs, ufunc, vfunc, "AffineTransform 2") - - # Finally a really crazy one that isn't remotely regular - dudx = 0.2342 - dudy = -0.1432 - dvdx = 0.0924 - dvdy = -0.3013 - - wcs = galsim.JacobianWCS(dudx, dudy, dvdx, dvdy) - ufunc = lambda x, y: dudx * x + dudy * y # noqa: E731 - vfunc = lambda x, y: dvdx * x + dvdy * y # noqa: E731 - do_local_wcs(wcs, ufunc, vfunc, "Jacobian 3") - - # Check the decomposition: - do_jac_decomp(wcs, "JacobianWCS 3") - - # Check that using a wcs in the context of an image works correctly - do_wcs_image(wcs, "JacobianWCS") - - # Add both kinds of offsets - x0 = -3 - y0 = 104 - u0 = 1423.9 - v0 = 8242.7 - origin = galsim.PositionD(x0, y0) - world_origin = galsim.PositionD(u0, v0) - wcs = galsim.AffineTransform( - dudx, dudy, dvdx, dvdy, origin=origin, world_origin=world_origin - ) - ufunc = lambda x, y: dudx * (x - x0) + dudy * (y - y0) + u0 # noqa: E731 - vfunc = lambda x, y: dvdx * (x - x0) + dvdy * (y - y0) + v0 # noqa: E731 - do_nonlocal_wcs(wcs, ufunc, vfunc, "AffineTransform 3") - - # Check that using a wcs in the context of an image works correctly - do_wcs_image(wcs, "AffineTransform") - - # Degenerate transformation should raise some errors - degen_wcs = galsim.JacobianWCS(0.2, 0.1, 0.2, 0.1) - assert_raises(galsim.GalSimError, degen_wcs.getDecomposition) - - image_pos = galsim.PositionD(0, 0) - world_pos = degen_wcs.toWorld(image_pos) # This direction is ok. - assert_raises(galsim.GalSimError, degen_wcs.toImage, world_pos) # This is not. - assert_raises(galsim.GalSimError, degen_wcs._x, 0, 0) - assert_raises(galsim.GalSimError, degen_wcs._y, 0, 0) - assert_raises(galsim.GalSimError, degen_wcs.inverse) - assert_raises(galsim.GalSimError, degen_wcs.toImage, galsim.Gaussian(sigma=2)) - - -def radial_u(x, y): - """A cubic radial function used for a u(x,y) function""" - # Note: This is designed to be smooth enough that the local approximates are accurate - # to 5 decimal places when we do the local tests. - # - # We will use a functional form of rho/r0 = r/r0 + a (r/r0)^3 - # To be accurate to < 1.e-6 arcsec at an offset of 7 pixels at r = 700, we need: - # | rho(r+dr) - rho(r) - drho/dr(r) * dr | < 1.e-6 - # 1/2 |d2(rho)/dr^2| * dr^2 < 1.e-6 - # rho = r + a/r0^2 r^3 - # rho' = 1 + 3a/r0^2 r^2 - # rho'' = 6a/r0^2 r - # 1/2 6|a| / 2000^2 * 700 * 7^2 < 1.e-6 - # |a| < 3.8e-5 - - r0 = 2000.0 # scale factor for function - a = 2.3e-5 - rho_over_r = 1 + a * (x * x + y * y) / (r0 * r0) - return x * rho_over_r - - -def radial_v(x, y): - """A radial function used for a u(x,y) function""" - r0 = 2000.0 - a = 2.3e-5 - rho_over_r = 1 + a * (x * x + y * y) / (r0 * r0) - return y * rho_over_r - - -class Cubic(object): - """A class that can act as a function, implementing a cubic radial function.""" - - def __init__(self, a, r0, whichuv): - self._a = a - self._r0 = r0 - self._uv = whichuv - - def __call__(self, x, y): - rho_over_r = 1 + self._a * (x * x + y * y) / (self._r0 * self._r0) - if self._uv == "u": - return x * rho_over_r - else: - return y * rho_over_r - - -@timer -def test_uvfunction(): - """Test the UVFunction class""" - # First make some that are identical to simpler WCS classes: - # 1. Like PixelScale - scale = 0.17 - ufunc = lambda x, y: x * scale # noqa: E731 - vfunc = lambda x, y: y * scale # noqa: E731 - wcs = galsim.UVFunction(ufunc, vfunc) - do_nonlocal_wcs(wcs, ufunc, vfunc, "UVFunction like PixelScale", test_pickle=False) - assert wcs.ufunc(2.9, 3.7) == ufunc(2.9, 3.7) - assert wcs.vfunc(2.9, 3.7) == vfunc(2.9, 3.7) - assert wcs.xfunc is None - assert wcs.yfunc is None - assert not wcs.isPixelScale() - assert not wcs.isLocal() - assert not wcs.isUniform() - assert not wcs.isCelestial() - - # Also check with inverse functions. - xfunc = lambda u, v: u / scale # noqa: E731 - yfunc = lambda u, v: v / scale # noqa: E731 - wcs = galsim.UVFunction(ufunc, vfunc, xfunc, yfunc) - do_nonlocal_wcs( - wcs, ufunc, vfunc, "UVFunction like PixelScale with inverse", test_pickle=False - ) - assert wcs.ufunc(2.9, 3.7) == ufunc(2.9, 3.7) - assert wcs.vfunc(2.9, 3.7) == vfunc(2.9, 3.7) - assert wcs.xfunc(2.9, 3.7) == xfunc(2.9, 3.7) - assert wcs.yfunc(2.9, 3.7) == yfunc(2.9, 3.7) - - assert_raises(TypeError, galsim.UVFunction) - assert_raises(TypeError, galsim.UVFunction, ufunc=ufunc) - assert_raises(TypeError, galsim.UVFunction, vfunc=vfunc) - assert_raises(TypeError, galsim.UVFunction, ufunc, vfunc, origin=5) - assert_raises( - TypeError, - galsim.UVFunction, - ufunc, - vfunc, - origin=galsim.CelestialCoord(0 * galsim.degrees, 0 * galsim.degrees), - ) - assert_raises( - TypeError, - galsim.UVFunction, - ufunc, - vfunc, - world_origin=galsim.CelestialCoord(0 * galsim.degrees, 0 * galsim.degrees), - ) - - # 2. Like ShearWCS - scale = 0.23 - g1 = 0.14 - g2 = -0.37 - factor = 1.0 / np.sqrt(1.0 - g1 * g1 - g2 * g2) - ufunc = lambda x, y: (x - g1 * x - g2 * y) * scale * factor # noqa: E731 - vfunc = lambda x, y: (y + g1 * y - g2 * x) * scale * factor # noqa: E731 - wcs = galsim.UVFunction(ufunc, vfunc) - do_nonlocal_wcs(wcs, ufunc, vfunc, "UVFunction like ShearWCS", test_pickle=False) - - # Also check with inverse functions. - xfunc = lambda u, v: (u + g1 * u + g2 * v) / scale * factor # noqa: E731 - yfunc = lambda u, v: (v - g1 * v + g2 * u) / scale * factor # noqa: E731 - wcs = galsim.UVFunction(ufunc, vfunc, xfunc, yfunc) - do_nonlocal_wcs( - wcs, ufunc, vfunc, "UVFunction like ShearWCS with inverse", test_pickle=False - ) - - # 3. Like an AffineTransform - dudx = 0.2342 - dudy = 0.1432 - dvdx = 0.1409 - dvdy = 0.2391 - - ufunc = lambda x, y: dudx * x + dudy * y # noqa: E731 - vfunc = lambda x, y: dvdx * x + dvdy * y # noqa: E731 - wcs = galsim.UVFunction(ufunc, vfunc) - do_nonlocal_wcs( - wcs, ufunc, vfunc, "UVFunction like AffineTransform", test_pickle=False - ) - - # Check that passing functions as strings works correctly. - wcs = galsim.UVFunction( - ufunc="%r*x + %r*y" % (dudx, dudy), vfunc="%r*x + %r*y" % (dvdx, dvdy) - ) - do_nonlocal_wcs(wcs, ufunc, vfunc, "UVFunction with string funcs", test_pickle=True) - - # Check that using a wcs in the context of an image works correctly - do_wcs_image(wcs, "UVFunction_string") - - # Also check with inverse functions. - det = dudx * dvdy - dudy * dvdx - wcs = galsim.UVFunction( - ufunc="%r*x + %r*y" % (dudx, dudy), - vfunc="%r*x + %r*y" % (dvdx, dvdy), - xfunc="(%r*u + %r*v)/(%r)" % (dvdy, -dudy, det), - yfunc="(%r*u + %r*v)/(%r)" % (-dvdx, dudx, det), - ) - do_nonlocal_wcs( - wcs, ufunc, vfunc, "UVFunction with string inverse funcs", test_pickle=True - ) - - # The same thing in fact, but nominally takes color as an argument. - wcsc = galsim.UVFunction( - ufunc="%r*x + %r*y" % (dudx, dudy), - vfunc="%r*x + %r*y" % (dvdx, dvdy), - xfunc="(%r*u + %r*v)/(%r)" % (dvdy, -dudy, det), - yfunc="(%r*u + %r*v)/(%r)" % (-dvdx, dudx, det), - uses_color=True, - ) - do_nonlocal_wcs( - wcsc, ufunc, vfunc, "UVFunction with unused color term", test_pickle=True - ) - - # 4. Next some UVFunctions with non-trivial offsets - x0 = 1.3 - y0 = -0.9 - u0 = 124.3 - v0 = -141.9 - origin = galsim.PositionD(x0, y0) - world_origin = galsim.PositionD(u0, v0) - ufunc2 = lambda x, y: dudx * (x - x0) + dudy * (y - y0) + u0 # noqa: E731 - vfunc2 = lambda x, y: dvdx * (x - x0) + dvdy * (y - y0) + v0 # noqa: E731 - wcs = galsim.UVFunction(ufunc2, vfunc2) - do_nonlocal_wcs( - wcs, ufunc2, vfunc2, "UVFunction with origins in funcs", test_pickle=False - ) - wcs = galsim.UVFunction(ufunc, vfunc, origin=origin, world_origin=world_origin) - do_nonlocal_wcs( - wcs, ufunc2, vfunc2, "UVFunction with origin arguments", test_pickle=False - ) - - # Check that using a wcs in the context of an image works correctly - do_wcs_image(wcs, "UVFunction_lambda") - - # Check basic copy and == , != for UVFunction - wcs2 = wcs.copy() - assert wcs == wcs2, "UVFunction copy is not == the original" - wcs3a = galsim.UVFunction(vfunc, vfunc, origin=origin, world_origin=world_origin) - wcs3b = galsim.UVFunction(ufunc, ufunc, origin=origin, world_origin=world_origin) - wcs3c = galsim.UVFunction( - ufunc, vfunc, origin=origin * 2, world_origin=world_origin - ) - wcs3d = galsim.UVFunction( - ufunc, vfunc, origin=origin, world_origin=world_origin * 2 - ) - assert wcs != wcs3a, "UVFunction is not != a different one (ufunc)" - assert wcs != wcs3b, "UVFunction is not != a different one (vfunc)" - assert wcs != wcs3c, "UVFunction is not != a different one (origin)" - assert wcs != wcs3d, "UVFunction is not != a different one (world_origin)" - - # 5. Now some non-trivial 3rd order radial function. - origin = galsim.PositionD(x0, y0) - wcs = galsim.UVFunction(radial_u, radial_v, origin=origin) - - # Check jacobian() - for x, y in zip(far_x_list, far_y_list): - image_pos = galsim.PositionD(x, y) - jac = wcs.jacobian(image_pos) - # u = x * rho_over_r - # v = y * rho_over_r - # For simplicity of notation, let rho_over_r = w(r) = 1 + a r^2/r0^2 - # dudx = w + x dwdr drdx = w + x (2ar/r0^2) (x/r) = w + 2a x^2/r0^2 - # dudy = x dwdr drdy = x (2ar/r0^2) (y/r) = 2a xy/r0^2 - # dvdx = y dwdr drdx = y (2ar/r0^2) (x/r) = 2a xy/r0^2 - # dvdy = w + y dwdr drdy = w + y (2ar/r0^2) (y/r) = w + 2a y^2/r0^2 - r0 = 2000.0 - a = 2.3e-5 - factor = a / (r0 * r0) - w = 1.0 + factor * (x * x + y * y) - np.testing.assert_almost_equal( - jac.dudx, - w + 2 * factor * x * x, - digits, - "UVFunction dudx does not match expected value.", - ) - np.testing.assert_almost_equal( - jac.dudy, - 2 * factor * x * y, - digits, - "UVFunction dudy does not match expected value.", - ) - np.testing.assert_almost_equal( - jac.dvdx, - 2 * factor * x * y, - digits, - "UVFunction dvdx does not match expected value.", - ) - np.testing.assert_almost_equal( - jac.dvdy, - w + 2 * factor * y * y, - digits, - "UVFunction dvdy does not match expected value.", - ) - - ufunc = lambda x, y: radial_u(x - x0, y - y0) # noqa: E731 - vfunc = lambda x, y: radial_v(x - x0, y - y0) # noqa: E731 - do_nonlocal_wcs(wcs, ufunc, vfunc, "Cubic radial UVFunction", test_pickle=False) - - # Check that using a wcs in the context of an image works correctly - do_wcs_image(wcs, "UVFunction_func") - - # 6. Repeat with a function object rather than a regular function. - # Use a different `a` parameter for u and v to make things more interesting. - cubic_u = Cubic(2.9e-5, 2000.0, "u") - cubic_v = Cubic(-3.7e-5, 2000.0, "v") - wcs = galsim.UVFunction(cubic_u, cubic_v, origin=galsim.PositionD(x0, y0)) - ufunc = lambda x, y: cubic_u(x - x0, y - y0) # noqa: E731 - vfunc = lambda x, y: cubic_v(x - x0, y - y0) # noqa: E731 - do_nonlocal_wcs(wcs, ufunc, vfunc, "Cubic object UVFunction", test_pickle=False) - - # Check that using a wcs in the context of an image works correctly - do_wcs_image(wcs, "UVFunction_object") - - # 7. Test the UVFunction that is used in demo9 to confirm that I got the - # inverse function correct! - ufunc = lambda x, y: 0.05 * x * (1.0 + 2.0e-6 * (x**2 + y**2)) # noqa: E731 - vfunc = lambda x, y: 0.05 * y * (1.0 + 2.0e-6 * (x**2 + y**2)) # noqa: E731 - # w = 0.05 (r + 2.e-6 r^3) - # 0 = r^3 + 5e5 r - 1e7 w - # - # Cardano's formula gives - # (http://en.wikipedia.org/wiki/Cubic_function#Cardano.27s_method) - # r = ( sqrt( (5e6 w)^2 + (5e5)^3/27 ) + (5e6 w) )^1/3 - - # ( sqrt( (5e6 w)^2 + (5e5)^3/27 ) - (5e6 w) )^1/3 - # = 100 ( ( 5 sqrt( w^2 + 5.e3/27 ) + 5 w )^1/3 - - # ( 5 sqrt( w^2 + 5.e3/27 ) - 5 w )^1/3 ) - import math - - xfunc = lambda u, v: ( # noqa: E731 - lambda w: ( - 0.0 - if w == 0.0 - else 100.0 - * u - / w - * ( - (5 * math.sqrt(w**2 + 5.0e3 / 27.0) + 5 * w) ** (1.0 / 3.0) - - (5 * math.sqrt(w**2 + 5.0e3 / 27.0) - 5 * w) ** (1.0 / 3.0) - ) - ) - )(math.sqrt(u**2 + v**2)) - yfunc = lambda u, v: ( # noqa: E731 - lambda w: ( - 0.0 - if w == 0.0 - else 100.0 - * v - / w - * ( - (5 * math.sqrt(w**2 + 5.0e3 / 27.0) + 5 * w) ** (1.0 / 3.0) - - (5 * math.sqrt(w**2 + 5.0e3 / 27.0) - 5 * w) ** (1.0 / 3.0) - ) - ) - )(math.sqrt(u**2 + v**2)) - wcs = galsim.UVFunction(ufunc, vfunc, xfunc, yfunc) - do_nonlocal_wcs(wcs, ufunc, vfunc, "UVFunction from demo9", test_pickle=False) - - # Check that passing really long strings works correctly. - ufuncs = "0.05 * x * (1. + 2.e-6 * (x**2 + y**2))" - vfuncs = "0.05 * y * (1. + 2.e-6 * (x**2 + y**2))" - xfuncs = ( - "(lambda w: ( 0. if w==0. else " - " 100.*u/w*(( 5*math.sqrt(w**2+5.e3/27.)+5*w )**(1./3.) - " - " ( 5*math.sqrt(w**2+5.e3/27.)-5*w )**(1./3.))) )(math.sqrt(u**2+v**2))" - ) - yfuncs = ( - "(lambda w: ( 0. if w==0. else " - " 100.*v/w*(( 5*math.sqrt(w**2+5.e3/27.)+5*w )**(1./3.) - " - " ( 5*math.sqrt(w**2+5.e3/27.)-5*w )**(1./3.))) )(math.sqrt(u**2+v**2))" - ) - wcs = galsim.UVFunction(ufuncs, vfuncs, xfuncs, yfuncs) - do_nonlocal_wcs( - wcs, ufunc, vfunc, "UVFunction from demo9, string", test_pickle=True - ) - do_wcs_image(wcs, "UVFunction from demo9, string") - - # This version doesn't work with numpy arrays because of the math functions. - # This provides a test of that branch of the makeSkyImage function. - ufunc = ( # noqa: E731 - lambda x, y: 0.17 * x * (1.0 + 1.0e-5 * math.sqrt(x**2 + y**2)) - ) - vfunc = ( # noqa: E731 - lambda x, y: 0.17 * y * (1.0 + 1.0e-5 * math.sqrt(x**2 + y**2)) - ) - wcs = galsim.UVFunction(ufunc, vfunc) - do_nonlocal_wcs(wcs, ufunc, vfunc, "UVFunction with math funcs", test_pickle=False) - do_wcs_image(wcs, "UVFunction_math") - - # 8. A non-trivial color example - ufunc = lambda x, y, c: (dudx + 0.1 * c) * x + dudy * y # noqa: E731 - vfunc = lambda x, y, c: dvdx * x + (dvdy - 0.2 * c) * y # noqa: E731 - xfunc = lambda u, v, c: ((dvdy - 0.2 * c) * u - dudy * v) / ( # noqa: E731 - (dudx + 0.1 * c) * (dvdy - 0.2 * c) - dudy * dvdx - ) - yfunc = lambda u, v, c: (-dvdx * u + (dudx + 0.1 * c) * v) / ( # noqa: E731 - (dudx + 0.1 * c) * (dvdy - 0.2 * c) - dudy * dvdx - ) - wcs = galsim.UVFunction(ufunc, vfunc, xfunc, yfunc, uses_color=True) - do_nonlocal_wcs( - wcs, - lambda x, y: ufunc(x, y, -0.3), - lambda x, y: vfunc(x, y, -0.3), - "UVFunction with color-dependence", - test_pickle=False, - color=-0.3, - ) - - # Also, check this one as a string - wcs = galsim.UVFunction( - ufunc="(%r+0.1*c)*x + %r*y" % (dudx, dudy), - vfunc="%r*x + (%r-0.2*c)*y" % (dvdx, dvdy), - xfunc="((%r-0.2*c)*u - %r*v)/((%r+0.1*c)*(%r-0.2*c)-%r)" - % (dvdy, dudy, dudx, dvdy, dudy * dvdx), - yfunc="(-%r*u + (%r+0.1*c)*v)/((%r+0.1*c)*(%r-0.2*c)-%r)" - % (dvdx, dudx, dudx, dvdy, dudy * dvdx), - uses_color=True, - ) - do_nonlocal_wcs( - wcs, - lambda x, y: ufunc(x, y, 1.7), - lambda x, y: vfunc(x, y, 1.7), - "UVFunction with color-dependence, string", - test_pickle=True, - color=1.7, - ) - - # 9. A non-trivial color example that fails for arrays - ufunc = lambda x, y, c: math.exp(c * x) # noqa: E731 - vfunc = lambda x, y, c: math.exp(c * y / 2) # noqa: E731 - xfunc = lambda u, v, c: math.log(u) / c # noqa: E731 - yfunc = lambda u, v, c: math.log(v) * 2 / c # noqa: E731 - wcs = galsim.UVFunction(ufunc, vfunc, xfunc, yfunc, uses_color=True) - do_nonlocal_wcs( - wcs, - lambda x, y: ufunc(x, y, 0.01), - lambda x, y: vfunc(x, y, 0.01), - "UVFunction with math and color-dependence", - test_pickle=False, - color=0.01, - ) - - # 10. One with invalid functions, which raise errors. (Just for coverage really.) - ufunc = lambda x, y: 0.05 * x * math.sqrt(x**2 - y**2) # noqa: E731 - vfunc = lambda x, y: 0.05 * y * math.sqrt(3 * y**2 - x**2) # noqa: E731 - xfunc = lambda u, v: 0.05 * u * math.sqrt(2 * u**2 - 7 * v**2) # noqa: E731 - yfunc = lambda u, v: 0.05 * v * math.sqrt(8 * v**2 - u**2) # noqa: E731 - wcs = galsim.UVFunction(ufunc, vfunc, xfunc, yfunc) - assert_raises(ValueError, wcs.toWorld, galsim.PositionD(5, 6)) - assert_raises(ValueError, wcs.toWorld, galsim.PositionD(8, 2)) - assert_raises(ValueError, wcs.toImage, galsim.PositionD(3, 3)) - assert_raises(ValueError, wcs.toImage, galsim.PositionD(6, 0)) - # Repeat with color - ufunc = lambda x, y, c: 0.05 * c * math.sqrt(x**2 - y**2) # noqa: E731 - vfunc = lambda x, y, c: 0.05 * c * math.sqrt(3 * y**2 - x**2) # noqa: E731 - xfunc = lambda u, v, c: 0.05 * c * math.sqrt(2 * u**2 - 7 * v**2) # noqa: E731 - yfunc = lambda u, v, c: 0.05 * c * math.sqrt(8 * v**2 - u**2) # noqa: E731 - wcs = galsim.UVFunction(ufunc, vfunc, xfunc, yfunc, uses_color=True) - assert_raises(ValueError, wcs.toWorld, galsim.PositionD(5, 6), color=0.2) - assert_raises(ValueError, wcs.toWorld, galsim.PositionD(8, 2), color=0.2) - assert_raises(ValueError, wcs.toImage, galsim.PositionD(3, 3), color=0.2) - assert_raises(ValueError, wcs.toImage, galsim.PositionD(6, 0), color=0.2) - - -@timer -def test_radecfunction(): - """Test the RaDecFunction class""" - # Do a sterographic projection of the above UV functions around a given reference point. - funcs = [] - - scale = 0.17 - ufunc = lambda x, y: x * scale # noqa: E731 - vfunc = lambda x, y: y * scale # noqa: E731 - funcs.append((ufunc, vfunc, "like PixelScale")) - - scale = 0.23 - g1 = 0.14 - g2 = -0.37 - factor = 1.0 / np.sqrt(1.0 - g1 * g1 - g2 * g2) - ufunc = lambda x, y: (x - g1 * x - g2 * y) * scale * factor # noqa: E731 - vfunc = lambda x, y: (y + g1 * y - g2 * x) * scale * factor # noqa: E731 - funcs.append((ufunc, vfunc, "like ShearWCS")) - - dudx = 0.2342 - dudy = 0.1432 - dvdx = 0.1409 - dvdy = 0.2391 - ufunc = lambda x, y: dudx * x + dudy * y # noqa: E731 - vfunc = lambda x, y: dvdx * x + dvdy * y # noqa: E731 - funcs.append((ufunc, vfunc, "like JacobianWCS")) - - x0 = 1.3 - y0 = -0.9 - u0 = 124.3 - v0 = -141.9 - ufunc = lambda x, y: dudx * (x - x0) + dudy * (y - y0) + u0 # noqa: E731 - vfunc = lambda x, y: dvdx * (x - x0) + dvdy * (y - y0) + v0 # noqa: E731 - funcs.append((ufunc, vfunc, "like AffineTransform")) - - funcs.append((radial_u, radial_v, "Cubic radial")) - - ufunc = lambda x, y: radial_u(x - x0, y - y0) # noqa: E731 - vfunc = lambda x, y: radial_v(x - x0, y - y0) # noqa: E731 - funcs.append((ufunc, vfunc, "offset Cubic radial")) - - cubic_u = Cubic(2.9e-5, 2000.0, "u") - cubic_v = Cubic(-3.7e-5, 2000.0, "v") - - ufunc = lambda x, y: cubic_u(x - x0, y - y0) # noqa: E731 - vfunc = lambda x, y: cubic_v(x - x0, y - y0) # noqa: E731 - funcs.append((ufunc, vfunc, "offset Cubic object")) - - # The last one needs to not have a lambda, since we use it for the image test, which - # requires the function to be picklable. - funcs.append((cubic_u, cubic_v, "Cubic object")) - - # ra, dec of projection centers (in degrees) - # Note: This won't work directly at a pole, but it should work rather close to the pole. - # Using dec = 89.5 (in the last one) failed some of our tests, but 89.1 succeeds. - centers = [(0.0, 0.0), (34.0, 12.0), (-190.4, -79.8), (234.56, 89.1)] - - # We need this below. - north_pole = galsim.CelestialCoord(0 * galsim.degrees, 90 * galsim.degrees) - - for ufunc, vfunc, name in funcs: - u0 = ufunc(0.0, 0.0) - v0 = vfunc(0.0, 0.0) - wcs1 = galsim.UVFunction(ufunc, vfunc) - for cenra, cendec in centers: - center = galsim.CelestialCoord( - cenra * galsim.degrees, cendec * galsim.degrees - ) - - scale = galsim.arcsec / galsim.radians - radec_func = lambda x, y: center.deproject_rad( # noqa: E731 - ufunc(x, y) * scale, vfunc(x, y) * scale, projection="lambert" - ) - wcs2 = galsim.RaDecFunction(radec_func) - assert not wcs2.isPixelScale() - assert not wcs2.isLocal() - assert not wcs2.isUniform() - assert wcs2.isCelestial() - - # Also test with one that doesn't work with numpy arrays to test that the - # code does the right thing in that case too, since local and makeSkyImage - # try the numpy option first and do something else if it fails. - # This also tests the alternate initialization using separate ra_func, dec_fun. - ra_func = lambda x, y: center.deproject( # noqa: E731 - ufunc(x, y) * galsim.arcsec, - vfunc(x, y) * galsim.arcsec, - projection="lambert", - ).ra.rad - dec_func = lambda x, y: center.deproject( # noqa: E731 - ufunc(x, y) * galsim.arcsec, - vfunc(x, y) * galsim.arcsec, - projection="lambert", - ).dec.rad - wcs3 = galsim.RaDecFunction(ra_func, dec_func) - - # The pickle tests need to have a string for ra_func, dec_func, which is - # a bit tough with the ufunc,vfunc stuff. So do something simpler for that. - radec_str = '%r.deproject_rad(x*%f,y*%f,projection="lambert")' % ( - center, - scale, - scale, - ) - wcs4 = galsim.RaDecFunction(radec_str, origin=galsim.PositionD(17.0, 34.0)) - ra_str = ( - '%r.deproject(x*galsim.arcsec,y*galsim.arcsec,projection="lambert").ra.rad' - % center - ) - dec_str = ( - '%r.deproject(x*galsim.arcsec,y*galsim.arcsec,projection="lambert").dec.rad' - % center - ) - wcs5 = galsim.RaDecFunction( - ra_str, dec_str, origin=galsim.PositionD(-9.0, -8.0) - ) - - wcs6 = wcs2.copy() - assert wcs2 == wcs6, "RaDecFunction copy is not == the original" - assert wcs6.radec_func(3, 4) == radec_func(3, 4) - - # Check that distance, jacobian for some x,y positions match the UV values. - for x, y in zip(far_x_list, far_y_list): - # First do some basic checks of project, deproject for the given (u,v) - u = ufunc(x, y) - v = vfunc(x, y) - coord = center.deproject( - u * galsim.arcsec, v * galsim.arcsec, projection="lambert" - ) - ra, dec = radec_func(x, y) - np.testing.assert_almost_equal( - ra, coord.ra.rad, 8, "rafunc produced wrong value" - ) - np.testing.assert_almost_equal( - dec, coord.dec.rad, 8, "decfunc produced wrong value" - ) - pos = center.project(coord, projection="lambert") - np.testing.assert_almost_equal( - pos[0] / galsim.arcsec, u, digits, "project x was inconsistent" - ) - np.testing.assert_almost_equal( - pos[1] / galsim.arcsec, v, digits, "project y was inconsistent" - ) - d1 = np.sqrt(u * u + v * v) - d2 = center.distanceTo(coord) - # The distances aren't expected to match. Instead, for a Lambert projection, - # d1 should match the straight line distance through the sphere. - import math - - d2 = 2.0 * np.sin(d2 / 2.0) * galsim.radians / galsim.arcsec - np.testing.assert_almost_equal( - d2, d1, digits, "deprojected dist does not match expected value." - ) - - # Now test the two initializations of RaDecFunction. - for test_wcs in [wcs2, wcs3]: - image_pos = galsim.PositionD(x, y) - world_pos1 = wcs1.toWorld(image_pos) - world_pos2 = test_wcs.toWorld(image_pos) - d3 = np.sqrt(world_pos1.x**2 + world_pos1.y**2) - d4 = center.distanceTo(world_pos2) - d4 = 2.0 * np.sin(d4 / 2) * galsim.radians / galsim.arcsec - np.testing.assert_almost_equal( - d3, - d1, - digits, - "UV " + name + " dist does not match expected value.", - ) - np.testing.assert_almost_equal( - d4, - d1, - digits, - "RaDec " + name + " dist does not match expected value.", - ) - - # Calculate the Jacobians for each wcs - jac1 = wcs1.jacobian(image_pos) - jac2 = test_wcs.jacobian(image_pos) - - # The pixel area should match pretty much exactly. The Lambert projection - # is an area preserving projection. - np.testing.assert_almost_equal( - jac2.pixelArea(), - jac1.pixelArea(), - digits, - "RaDecFunction " - + name - + " pixelArea() does not match expected value.", - ) - np.testing.assert_almost_equal( - test_wcs.pixelArea(image_pos), - jac1.pixelArea(), - digits, - "RaDecFunction " - + name - + " pixelArea(pos) does not match expected value.", - ) - - # The distortion should be pretty small, so the min/max linear scale should - # match pretty well. - np.testing.assert_almost_equal( - jac2.minLinearScale(), - jac1.minLinearScale(), - digits, - "RaDecFunction " - + name - + " minLinearScale() does not match expected value.", - ) - np.testing.assert_almost_equal( - test_wcs.minLinearScale(image_pos), - jac1.minLinearScale(), - digits, - "RaDecFunction " - + name - + " minLinearScale(pos) does not match expected value.", - ) - np.testing.assert_almost_equal( - jac2.maxLinearScale(), - jac1.maxLinearScale(), - digits, - "RaDecFunction " - + name - + " maxLinearScale() does not match expected value.", - ) - np.testing.assert_almost_equal( - test_wcs.maxLinearScale(image_pos), - jac1.maxLinearScale(), - digits, - "RaDecFunction " - + name - + " maxLinearScale(pos) does not match expected value.", - ) - - # The main discrepancy between the jacobians is a rotation term. - # The pixels in the projected coordinates do not necessarily point north, - # since the direction to north changes over the field. However, we can - # calculate this expected discrepancy and correct for it to get a comparison - # of the full jacobian that should be accurate to 5 digits. - # If A = coord, B = center, and C = the north pole, then the rotation angle is - # 180 deg - A - B. - A = coord.angleBetween(center, north_pole) - B = center.angleBetween(north_pole, coord) - C = north_pole.angleBetween(coord, center) - # The angle C should equal coord.ra - center.ra, so use this as a unit test of - # the angleBetween function: - np.testing.assert_almost_equal( - C / galsim.degrees, - (coord.ra - center.ra) / galsim.degrees, - digits, - "CelestialCoord calculated the wrong angle between center and coord", - ) - angle = 180 * galsim.degrees - A - B - - # Now we can use this angle to correct the jacobian from test_wcs. - s, c = angle.sincos() - rot_dudx = c * jac2.dudx + s * jac2.dvdx - rot_dudy = c * jac2.dudy + s * jac2.dvdy - rot_dvdx = -s * jac2.dudx + c * jac2.dvdx - rot_dvdy = -s * jac2.dudy + c * jac2.dvdy - - np.testing.assert_almost_equal( - rot_dudx, - jac1.dudx, - digits, - "RaDecFunction " - + name - + " dudx (rotated) does not match expected value.", - ) - np.testing.assert_almost_equal( - rot_dudy, - jac1.dudy, - digits, - "RaDecFunction " - + name - + " dudy (rotated) does not match expected value.", - ) - np.testing.assert_almost_equal( - rot_dvdx, - jac1.dvdx, - digits, - "RaDecFunction " - + name - + " dvdx (rotated) does not match expected value.", - ) - np.testing.assert_almost_equal( - rot_dvdy, - jac1.dvdy, - digits, - "RaDecFunction " - + name - + " dvdy (rotated) does not match expected value.", - ) - - if abs(center.dec / galsim.degrees) < 45: - # The projections far to the north or the south don't pass all the tests in - # do_celestial because of the high non-linearities in the projection, so just - # skip them. - do_celestial_wcs( - wcs2, - "RaDecFunc 1 centered at " - + str(center.ra / galsim.degrees) - + ", " - + str(center.dec / galsim.degrees), - test_pickle=False, - ) - do_celestial_wcs( - wcs3, - "RaDecFunc 2 centered at " - + str(center.ra / galsim.degrees) - + ", " - + str(center.dec / galsim.degrees), - test_pickle=False, - ) - - do_celestial_wcs( - wcs4, - "RaDecFunc 3 centered at " - + str(center.ra / galsim.degrees) - + ", " - + str(center.dec / galsim.degrees), - test_pickle=True, - ) - do_celestial_wcs( - wcs5, - "RaDecFunc 4 centered at " - + str(center.ra / galsim.degrees) - + ", " - + str(center.dec / galsim.degrees), - test_pickle=True, - ) - - assert_raises(TypeError, galsim.RaDecFunction) - assert_raises(TypeError, galsim.RaDecFunction, radec_func, origin=5) - assert_raises( - TypeError, - galsim.RaDecFunction, - radec_func, - origin=galsim.CelestialCoord(0 * galsim.degrees, 0 * galsim.degrees), - ) - assert_raises( - TypeError, galsim.RaDecFunction, radec_func, world_origin=galsim.PositionD(0, 0) - ) - - # Check that using a wcs in the context of an image works correctly - # (Uses the last wcs2, wcs3 set in the above loops.) - do_wcs_image(wcs2, "RaDecFunction") - do_wcs_image(wcs3, "RaDecFunction") - - # One with invalid functions, which raise errors. (Just for coverage really.) - radec_func = lambda x, y: center.deproject_rad( # noqa: E731 - math.sqrt(x), math.sqrt(y), projection="lambert" - ) - wcs = galsim.RaDecFunction(radec_func) - assert_raises(ValueError, wcs.toWorld, galsim.PositionD(-5, 6)) - assert_raises(ValueError, wcs.toWorld, galsim.PositionD(8, -2)) - - -def do_ref(wcs, ref_list, name, approx=False, image=None): - # Test that the given wcs object correctly converts the reference positions - - # Normally, we check the agreement to 1.e-3 arcsec. - # However, we allow the caller to indicate the that inverse transform is only approximate. - # In this case, we only check to 1 digit. Originally, this was just for the reverse - # transformation from world back to image coordinates, since some of the transformations - # are not analytic, so some routines don't iterate to a very high accuracy. But older - # versions of wcstools are slightly (~0.01 arcsec) inaccurate even for the forward - # transformation for TNX and ZPX. So now we use digits2 for both toWorld and toImage checks. - if approx: - digits2 = 1 - else: - digits2 = digits - - print("Start reference testing for " + name) - for ref in ref_list: - ra = galsim.Angle.from_hms(ref[0]) - dec = galsim.Angle.from_dms(ref[1]) - x = ref[2] - y = ref[3] - val = ref[4] - - # Check image -> world - ref_coord = galsim.CelestialCoord(ra, dec) - coord = wcs.toWorld(galsim.PositionD(x, y)) - dist = ref_coord.distanceTo(coord) / galsim.arcsec - print("x, y = ", x, y) - print("ref_coord = ", ref_coord.ra.hms(), ref_coord.dec.dms()) - print("coord = ", coord.ra.hms(), coord.dec.dms()) - np.testing.assert_almost_equal( - dist, 0, digits2, "wcs.toWorld differed from expected value" - ) - - # Check world -> image - pixel_scale = wcs.minLinearScale(galsim.PositionD(x, y)) - pos = wcs.toImage(galsim.CelestialCoord(ra, dec)) - np.testing.assert_almost_equal( - (x - pos.x) * pixel_scale, - 0, - digits2, - "wcs.toImage differed from expected value", - ) - np.testing.assert_almost_equal( - (y - pos.y) * pixel_scale, - 0, - digits2, - "wcs.toImage differed from expected value", - ) - if image: - np.testing.assert_almost_equal( - image(x, y), val, digits, "image(x,y) differed from reference value" - ) - - -@timer -def test_astropywcs(): - """Test the AstropyWCS class""" - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=RuntimeWarning) - import astropy.wcs # noqa: F401 - import scipy # noqa: F401 - AstropyWCS constructor will do this, so check now. - - # These all work, but it is quite slow, so only test a few of them for the regular unit tests. - # (1.8 seconds for 4 tags.) - # Test all of them when running python test_wcs.py. - if __name__ == "__main__": - test_tags = [ - "HPX", - "TAN", - "TSC", - "STG", - "ZEA", - "ARC", - "ZPN", - "SIP", - "TPV", - "TAN-PV", - "TAN-FLIP", - "REGION", - "ZTF", - ] - # The ones that still don't work are TNX and ZPX. - # In both cases, astropy thinks it reads them successfully, but fails the tests and/or - # bombs out with a malloc error: - # incorrect checksum for freed object - object was probably modified after being freed - else: - test_tags = ["TAN", "SIP", "ZTF", "TAN-PV"] - - dir = "fits_files" - for tag in test_tags: - file_name, ref_list = references[tag] - print(tag, " file_name = ", file_name) - if tag == "TAN": - wcs = galsim.AstropyWCS(file_name, dir=dir, compression="none", hdu=0) - else: - wcs = galsim.AstropyWCS(file_name, dir=dir) - assert not wcs.isPixelScale() - assert not wcs.isLocal() - assert not wcs.isUniform() - assert wcs.isCelestial() - - do_ref(wcs, ref_list, "AstropyWCS " + tag) - - if tag == "TAN": - # Also check origin. (Now that reference checks are done.) - wcs = galsim.AstropyWCS( - file_name, - dir=dir, - compression="none", - hdu=0, - origin=galsim.PositionD(3, 4), - ) - - do_celestial_wcs(wcs, "Astropy file " + file_name) - - do_wcs_image(wcs, "AstropyWCS_" + tag) - - # Can also use an existing astropy.wcs.WCS instance to construct. - # This is probably a rare use case, but could aid efficiency if you already build the - # astropy WCS for other purposes. - astropy_wcs = wcs.wcs # Just steal such an object from the last wcs above. - assert isinstance(astropy_wcs, astropy.wcs.WCS) - wcs1 = galsim.AstropyWCS(wcs=astropy_wcs) - do_celestial_wcs(wcs1, "AstropyWCS from wcs", test_pickle=False) - repr(wcs1) - - # Can also use a header. Again steal it from the wcs above. - wcs2 = galsim.AstropyWCS(header=wcs.header) - do_celestial_wcs(wcs2, "AstropyWCS from header", test_pickle=True) - - # Doesn't support LINEAR WCS types. - with assert_raises(galsim.GalSimError): - galsim.AstropyWCS("SBProfile_comparison_images/kolmogorov.fits") - - # This file does not have any WCS information in it. - with assert_raises(galsim.GalSimError): - galsim.AstropyWCS("fits_files/blankimg.fits") - - assert_raises(TypeError, galsim.AstropyWCS) - assert_raises(TypeError, galsim.AstropyWCS, file_name, header="dummy") - assert_raises(TypeError, galsim.AstropyWCS, file_name, wcs=wcs) - assert_raises(TypeError, galsim.AstropyWCS, wcs=wcs, header="dummy") - - # Astropy thinks it can handle ZPX files, but as of version 2.0.4, they don't work right. - # It reads it in ok, and even works with it fine. But it doesn't round trip through - # its own write and read. Even worse, it natively gives a fairly obscure error, which - # we convert into an OSError by hand. - # This test will let us know when they finally fix it. If it fails, we can remove this - # test and add 'ZPX' to the list of working astropy.wcs types above. - with assert_raises(OSError): - wcs = galsim.AstropyWCS(references["ZPX"][0], dir=dir) - do_wcs_image(wcs, "AstropyWCS_ZPX") - - -@timer -def test_pyastwcs(): - """Test the PyAstWCS class""" - try: - import starlink.Ast - except ImportError: - print("Unable to import starlink.Ast. Skipping PyAstWCS tests.") - return - - # These all work, but it is quite slow, so only test a few of them for the regular unit tests. - # (2.4 seconds for 6 tags.) - # Test all of them when running python test_wcs.py. - if __name__ == "__main__": - test_tags = [ - "HPX", - "TAN", - "TSC", - "STG", - "ZEA", - "ARC", - "ZPN", - "SIP", - "TPV", - "ZPX", - "TAN-PV", - "TAN-FLIP", - "REGION", - "TNX", - "ZTF", - ] - else: - test_tags = ["TAN", "ZPX", "SIP", "TAN-PV", "TNX", "ZTF"] - - dir = "fits_files" - for tag in test_tags: - file_name, ref_list = references[tag] - print(tag, " file_name = ", file_name) - if tag == "TAN": - wcs = galsim.PyAstWCS(file_name, dir=dir, compression="none", hdu=0) - else: - wcs = galsim.PyAstWCS(file_name, dir=dir) - assert not wcs.isPixelScale() - assert not wcs.isLocal() - assert not wcs.isUniform() - assert wcs.isCelestial() - - # The PyAst implementation of the SIP type only gets the inverse transformation - # approximately correct. So we need to be a bit looser in that check. - approx = tag in ["SIP"] - do_ref(wcs, ref_list, "PyAstWCS " + tag, approx) - - if tag == "TAN": - # Also check origin. (Now that reference checks are done.) - wcs = galsim.PyAstWCS( - file_name, - dir=dir, - compression="none", - hdu=0, - origin=galsim.PositionD(3, 4), - ) - - do_celestial_wcs(wcs, "PyAst file " + file_name, approx=approx) - - # TAN-FLIP has an error of 4mas after write and read here, which I don't really understand. - # but it's small enough an error that I don't think it's worth worrying about further. - approx = tag in ["ZPX", "TAN-FLIP", "ZTF"] - do_wcs_image(wcs, "PyAstWCS_" + tag, approx) - - # Can also use an existing startlink.Ast.FrameSet instance to construct. - # This is probably a rare use case, but could aid efficiency if you already open the - # fits file with starlink for other purposes. - wcs = galsim.PyAstWCS(references["TAN"][0], dir=dir) - wcsinfo = wcs.wcsinfo - assert isinstance(wcsinfo, starlink.Ast.FrameSet) - wcs1 = galsim.PyAstWCS(wcsinfo=wcsinfo) - do_celestial_wcs(wcs1, "PyAstWCS from wcsinfo", test_pickle=False) - repr(wcs1) - - # Can also use a header. Again steal it from the wcs above. - wcs2 = galsim.PyAstWCS(header=wcs.header) - do_celestial_wcs(wcs2, "PyAstWCS from header", test_pickle=True) - - # Doesn't support LINEAR WCS types. - with assert_raises(galsim.GalSimError): - galsim.PyAstWCS("SBProfile_comparison_images/kolmogorov.fits") - - # This file does not have any WCS information in it. - with assert_raises(OSError): - galsim.PyAstWCS("fits_files/blankimg.fits") - - assert_raises(TypeError, galsim.PyAstWCS) - assert_raises(TypeError, galsim.PyAstWCS, file_name, header="dummy") - assert_raises(TypeError, galsim.PyAstWCS, file_name, wcsinfo=wcsinfo) - assert_raises(TypeError, galsim.PyAstWCS, wcsinfo=wcsinfo, header="dummy") - - -@timer -def test_wcstools(): - """Test the WcsToolsWCS class""" - # These all work, but it is very slow, so only test one of them for the regular unit tests. - # (1.5 seconds for just the one tag.) - # Test all of them when running python test_wcs.py. - if __name__ == "__main__": - # Note: TPV seems to work, but on one machine, repeated calls to xy2sky with the same - # x,y values vary between two distinct ra,dec outputs. I have no idea what's going on, - # since I thought the calculation ought to be deterministic, but it clearly something - # isn't working right. So just skip that test. - test_tags = [ - "TAN", - "TSC", - "STG", - "ZEA", - "ARC", - "ZPN", - "SIP", - "ZPX", - "TAN-FLIP", - "REGION", - "TNX", - ] - else: - test_tags = ["TNX"] - - dir = "fits_files" - try: - galsim.WcsToolsWCS(references["TAN"][0], dir=dir) - except OSError: - print("Unable to execute xy2sky. Skipping WcsToolsWCS tests.") - return - - for tag in test_tags: - file_name, ref_list = references[tag] - print(tag, " file_name = ", file_name) - wcs = galsim.WcsToolsWCS(file_name, dir=dir) - assert not wcs.isPixelScale() - assert not wcs.isLocal() - assert not wcs.isUniform() - assert wcs.isCelestial() - - # The wcstools implementation of the SIP and TPV types only gets the inverse - # transformations approximately correct. So we need to be a bit looser in those checks. - approx = tag in ["SIP", "TPV", "ZPX", "TNX"] - do_ref(wcs, ref_list, "WcsToolsWCS " + tag, approx) - - # Recenter (x,y) = (0,0) at the image center to avoid wcstools warnings about going - # off the image. - im = galsim.fits.read(file_name, dir=dir) - wcs = wcs.shiftOrigin(origin=-im.center) - - do_celestial_wcs(wcs, "WcsToolsWCS " + file_name, approx=approx) - - do_wcs_image(wcs, "WcsToolsWCS_" + tag) - - # HPX is one of the ones that WcsToolsWCS doesn't support. - with assert_raises(galsim.GalSimError): - galsim.WcsToolsWCS(references["HPX"][0], dir=dir) - - # This file does not have any WCS information in it. - with assert_raises(OSError): - galsim.WcsToolsWCS("fits_files/blankimg.fits") - - # Doesn't support LINEAR WCS types. - with assert_raises(galsim.GalSimError): - galsim.WcsToolsWCS("SBProfile_comparison_images/kolmogorov.fits") - - assert_raises(TypeError, galsim.WcsToolsWCS) - assert_raises(TypeError, galsim.WcsToolsWCS, file_name, header="dummy") - - -@timer -def test_gsfitswcs(): - """Test the GSFitsWCS class""" - # These are all relatively fast (total time for all 10 is about 1.1 seconds), - # so unlike some of the other WCS types, it's not that slow to do all of them. - # And it's required to get (relatively) complete test coverage. - test_tags = [ - "TAN", - "STG", - "ZEA", - "ARC", - "TPV", - "TAN-PV", - "TAN-FLIP", - "TNX", - "SIP", - "ZTF", - ] - - dir = os.path.join( - os.path.dirname(__file__), "..", "..", "GalSim", "tests", "fits_files" - ) - - for tag in test_tags: - file_name, ref_list = references[tag] - print(tag, " file_name = ", file_name) - if tag == "TAN": - # For this one, check compression and hdu options. - wcs = galsim.GSFitsWCS(file_name, dir=dir, compression="none", hdu=0) - else: - wcs = galsim.GSFitsWCS(file_name, dir=dir) - assert not wcs.isPixelScale() - assert not wcs.isLocal() - assert not wcs.isUniform() - assert wcs.isCelestial() - - do_ref(wcs, ref_list, "GSFitsWCS " + tag) - - if tag == "TAN": - # Also check origin. (Now that reference checks are done.) - wcs = galsim.GSFitsWCS( - file_name, - dir=dir, - compression="none", - hdu=0, - origin=galsim.PositionD(3, 4), - ) - - do_celestial_wcs(wcs, "GSFitsWCS " + file_name) - - do_wcs_image(wcs, "GSFitsWCS_" + tag) - - # tpv_odd.fits is a modified version to have (unsupported) odd powers of r. - dr = os.path.join( - os.path.dirname(__file__), - "..", - "..", - "GalSim", - "tests", - ) - with assert_raises(galsim.GalSimNotImplementedError): - galsim.GSFitsWCS(dr + "/fits_files/tpv_odd.fits", dir=dir) - - # TSC is one of the ones that GSFitsWCS doesn't support. - with assert_raises(galsim.GalSimValueError): - galsim.GSFitsWCS(references["TSC"][0], dir=dir) - - # Doesn't support LINEAR WCS types. - with assert_raises(galsim.GalSimError): - galsim.GSFitsWCS(dr + "/SBProfile_comparison_images/kolmogorov.fits") - - # This file does not have any WCS information in it. - with assert_raises(galsim.GalSimError): - galsim.GSFitsWCS(dr + "/fits_files/blankimg.fits") - - assert_raises(TypeError, galsim.GSFitsWCS) - assert_raises(TypeError, galsim.GSFitsWCS, file_name, header="dummy") - assert_raises(ValueError, galsim.GSFitsWCS, _data=("invalid", 0, 0, 0, 0, 0, 0)) - - -@timer -def test_inverseab_convergence(): - """Test a SIP file that used to fail to converge. - cf. #1105, fixed in #1106. - """ - wcs = galsim.GSFitsWCS( - _data=[ - "TAN-SIP", - np.array([6.643813544313604e-05, 2.167251511173191e-06]), - np.array( - [ - [-0.00019733136932200036, 0.00037674127418905063], - [0.0003767412741890354, 0.00019733136932198898], - ] - ), - galsim.CelestialCoord( - galsim.Angle(2.171481673601117, galsim.radians), - galsim.Angle(-0.47508762601580773, galsim.radians), - ), - None, - np.array( - [ - [ - [0.0, 0.0, 1.7492391970878894e-11, 5.689175560623694e-12], - [1.0, -2.960663840648593e-14, -6.305126324859291e-09, 0.0], - [1.6656211233894224e-11, -1.7067526686225035e-11, 0.0, 0.0], - [-6.296581950910473e-09, 0.0, 0.0, 0.0], - ], - [ - [0.0, 1.0, 5.444700521713077e-13, -6.296581950917086e-09], - [0.0, -8.377345670352661e-13, 1.7067526700203848e-11, 0.0], - [5.69481360769684e-13, -6.305126324862401e-09, 0.0, 0.0], - [-5.6891755511980266e-12, 0.0, 0.0, 0.0], - ], - ] - ), - np.array( - [ - [ - [ - 0.0, - 2.0757753968394382e-12, - -1.901883278189908e-11, - 1.8853670343455113e-11, - ], - [ - 0.9988286261206877, - 3.8197440624200676e-17, - 7.132351432025443e-09, - 0.0, - ], - [-1.9018950513878596e-11, -5.6561012619600154e-11, 0.0, 0.0], - [7.160667102377515e-09, 0.0, 0.0, 0.0], - ], - [ - [ - 0.0, - 0.998828626120979, - -6.203305031319767e-13, - 7.1606671023315485e-09, - ], - [ - -1.442341346591337e-12, - -7.782798728336095e-17, - 5.6561012640257335e-11, - 0.0, - ], - [-6.20409043536687e-13, 7.132351431975042e-09, 0.0, 0.0], - [-1.885367062073805e-11, 0.0, 0.0, 0.0], - ], - ] - ), - ] - ) - ra = 2.1859428247518253 - dec = -0.5001963313433293 - x, y = wcs.radecToxy(ra, dec, units="radians") - ra1, dec1 = wcs.xyToradec(x, y, units="radians") - assert np.isclose(ra1, ra) - assert np.isclose(dec1, dec) - - ra = 2.186179858413897 - dec = -0.49995036220334654 - x, y = wcs.radecToxy(ra, dec, units="radians") - ra1, dec1 = wcs.xyToradec(x, y, units="radians") - assert np.isclose(ra1, ra) - assert np.isclose(dec1, dec) - - # Now one that should fail, since it's well outside the applicable area for the SIP polynomials. - ra = 2.1 - dec = -0.45 - x, y = wcs.radecToxy(ra, dec, units="radians") - assert np.all(np.isnan(x)) - assert np.all(np.isnan(y)) - - # Check as part of a longer list (longer than 256 is important) - rng = np.random.RandomState(1234) - ra = rng.uniform(2.185, 2.186, 1000) - dec = rng.uniform(-0.501, -0.499, 1000) - ra = np.append(ra, [2.1, 2.9]) - dec = np.append(dec, [-0.45, 0.2]) - print("ra = ", ra) - print("dec = ", dec) - x, y = wcs.radecToxy(ra, dec, units="radians") - assert np.sum(np.isnan(x)) >= 2 - assert np.sum(np.isnan(y)) >= 2 - - -@timer -def test_tanwcs(): - """Test the TanWCS function, which returns a GSFitsWCS instance.""" - - # Use TanWCS function to create TAN GSFitsWCS objects from scratch. - # First a slight tweak on a simple scale factor - dudx = 0.2342 - dudy = 0.0023 - dvdx = 0.0019 - dvdy = 0.2391 - x0 = 1 - y0 = 1 - origin = galsim.PositionD(x0, y0) - affine = galsim.AffineTransform(dudx, dudy, dvdx, dvdy, origin) - center = galsim.CelestialCoord(0.0 * galsim.radians, 0.0 * galsim.radians) - wcs = galsim.TanWCS(affine, center) - do_celestial_wcs(wcs, "TanWCS 1") - do_wcs_image(wcs, "TanWCS 1") - - # Next one with a flip and significant rotation and a large (u,v) offset - dudx = 0.1432 - dudy = 0.2342 - dvdx = 0.2391 - dvdy = 0.1409 - u0 = 124.3 - v0 = -141.9 - wcs = galsim.AffineTransform( - dudx, dudy, dvdx, dvdy, world_origin=galsim.PositionD(u0, v0) - ) - center = galsim.CelestialCoord(3.4 * galsim.hours, -17.9 * galsim.degrees) - wcs = galsim.TanWCS(affine, center) - do_celestial_wcs(wcs, "TanWCS 2") - do_wcs_image(wcs, "TanWCS 2") - - # Check crossing ra=0. - # Note: this worked properly even before fixing issue #1030, since GSFitsWCS keeps all - # the ra values close to the value at the center of the image, so the ra values here - # span from below 360 to above 360 without wrapping to 0. - # cf. test_razero below for a test with a wcs that does wrap back to 0. - dudx = 1.4 - dudy = 0.2 - dvdx = -0.3 - dvdy = 1.5 - u0 = 0.0 - v0 = 0.0 - wcs = galsim.AffineTransform( - dudx, dudy, dvdx, dvdy, world_origin=galsim.PositionD(u0, v0) - ) - center = galsim.CelestialCoord(359.99 * galsim.degrees, -37.9 * galsim.degrees) - wcs = galsim.TanWCS(affine, center) - do_celestial_wcs(wcs, "TanWCS 3") - do_wcs_image(wcs, "TanWCS 3") - - # Finally a really crazy one that isn't remotely regular - dudx = 0.2342 - dudy = -0.1432 - dvdx = 0.0924 - dvdy = -0.3013 - x0 = -3 - y0 = 104 - u0 = 1423.9 - v0 = 8242.7 - origin = galsim.PositionD(x0, y0) - world_origin = galsim.PositionD(u0, v0) - wcs = galsim.AffineTransform( - dudx, dudy, dvdx, dvdy, origin=origin, world_origin=world_origin - ) - center = galsim.CelestialCoord(-241.4 * galsim.hours, 87.9 * galsim.degrees) - wcs = galsim.TanWCS(affine, center) - do_celestial_wcs(wcs, "TanWCS 4") - do_wcs_image(wcs, "TanWCS 4") - - -@timer -def test_fitswcs(): - """Test the FitsWCS factory function""" - if __name__ == "__main__": - # For more thorough unit tests (when running python test_wcs.py explicitly), this - # will test everything. If you don't have everything installed (especially - # PyAst, then this may fail. - test_tags = all_tags - else: - # These should always work, since GSFitsWCS will work on them. So this - # mostly just tests the basic interface of the FitsWCS function. - test_tags = ["TAN", "TPV"] - try: - import starlink.Ast # noqa: F401 - - # Useful also to test one that GSFitsWCS doesn't work on. This works on Travis at - # least, and helps to cover some of the FitsWCS functionality where the first try - # isn't successful. - test_tags.append("HPX") - except Exception: - pass - - dir = os.path.join( - os.path.dirname(__file__), "..", "..", "GalSim", "tests", "fits_files" - ) - - for tag in test_tags: - file_name, ref_list = references[tag] - print(tag, " file_name = ", file_name) - if tag == "TAN": - wcs = galsim.FitsWCS(file_name, dir=dir, compression="none", hdu=0) - else: - wcs = galsim.FitsWCS(file_name, dir=dir, suppress_warning=True) - print("FitsWCS is really ", type(wcs)) - - if isinstance(wcs, galsim.AffineTransform): - import warnings - - warnings.warn( - "None of the existing WCS classes were able to read " + file_name - ) - else: - approx = tag in ["ZPX", "ZTF"] and isinstance(wcs, galsim.PyAstWCS) - do_ref(wcs, ref_list, "FitsWCS " + tag) - do_celestial_wcs(wcs, "FitsWCS " + file_name) - do_wcs_image(wcs, "FitsWCS_" + tag, approx) - - # Should also be able to build the file just from a fits.read() call, which - # uses FitsWCS behind the scenes. - im = galsim.fits.read(file_name, dir=dir) - do_ref(im.wcs, ref_list, "WCS from fits.read " + tag, im) - - # Finally, also check that AffineTransform can read the file. - # We don't really have any accuracy checks here. This really just checks that the - # read function doesn't raise an exception. - hdu, hdu_list, fin = galsim.fits.readFile(file_name, dir) - galsim.AffineTransform._readHeader(hdu.header) - galsim.fits.closeHDUList(hdu_list, fin) - - # This does support LINEAR WCS types. - dr = os.path.join( - os.path.dirname(__file__), - "..", - "..", - "GalSim", - "tests", - ) - linear = galsim.FitsWCS(dr + "/" + "SBProfile_comparison_images/kolmogorov.fits") - assert isinstance(linear, galsim.OffsetWCS) - - # This file does not have any WCS information in it. - with assert_warns(galsim.GalSimWarning): - pixel = galsim.FitsWCS(dr + "/" + "fits_files/blankimg.fits") - assert pixel == galsim.PixelScale(1.0) - - # Can suppress the warning if desired - pixel = galsim.FitsWCS(dr + "/" + "fits_files/blankimg.fits", suppress_warning=True) - assert pixel == galsim.PixelScale(1.0) - - assert_raises(TypeError, galsim.FitsWCS) - assert_raises(TypeError, galsim.FitsWCS, file_name, header="dummy") - - # If some format can't be handled by one of the installed modules, - # then FitsWCS can end up at the AffineTransform. - # The easiest way to mock this up is to adjust the fits_wcs_types list. - if sys.version_info < (3,): - return # mock only available on python 3 - from unittest import mock - - with mock.patch( - "galsim.fitswcs.fits_wcs_types", - [ - galsim.GSFitsWCS, - ], - ): - file_name, ref_list = references["ZPX"] - with assert_warns(galsim.GalSimWarning): - wcs = galsim.FitsWCS(file_name, dir=dir) - print("wcs = ", wcs) - assert isinstance(wcs, galsim.AffineTransform) - - # Can suppress the warning if desired - wcs = galsim.FitsWCS(file_name, dir=dir, suppress_warning=True) - assert isinstance(wcs, galsim.AffineTransform) - - -def check_sphere(ra1, dec1, ra2, dec2, atol=1): - # Vectorizing CelestialCoord.distanceTo() - # ra/dec in rad - # atol in arcsec - x1 = np.cos(dec1) * np.cos(ra1) - y1 = np.cos(dec1) * np.sin(ra1) - z1 = np.sin(dec1) - x2 = np.cos(dec2) * np.cos(ra2) - y2 = np.cos(dec2) * np.sin(ra2) - z2 = np.sin(dec2) - dsq = (x1 - x2) ** 2 + (y1 - y2) ** 2 + (z1 - z2) ** 2 - dist = 2 * np.arcsin(0.5 * np.sqrt(dsq)) - w = dsq >= 3.99 - if np.any(w): - cross = np.cross(np.array([x1, y1, z1])[w], np.array([x2, y2, z2])[w]) - crossq = cross[0] ** 2 + cross[1] ** 2 + cross[2] ** 2 - dist[w] = np.pi - np.arcsin(np.sqrt(crossq)) - dist = np.rad2deg(dist) * 3600 - np.testing.assert_allclose(dist, 0.0, rtol=0.0, atol=atol) - - -@timer -def test_fittedsipwcs(): - """Test our ability to construct a WCS from x, y, ra, dec tuples""" - import astropy.io.fits as fits - - tol = { # (arcsec, pixels) - "HPX": (5.0, 0.02), - "TAN": (5.0, 0.02), - "TSC": (6.0, 0.02), - "STG": (7.0, 0.03), - "ZEA": (10.0, 0.1), - "ARC": (5.0, 0.02), - "ZPN": (3000.0, 6), - "SIP": (1e-6, 0.01), - "TPV": (1e-6, 1e-3), - "ZPX": (1e-5, 3e-5), - "TAN-PV": (1e-6, 2e-4), - "TAN-FLIP": (1e-6, 1e-6), - "REGION": (1e-6, 1e-6), - "TNX": (1e-6, 2e-5), - "ZTF": (0.1, 0.1), - } - - dir = os.path.join(os.path.dirname(__file__), "..", "..", "GalSim/tests/fits_files") - - if __name__ == "__main__": - test_tags = all_tags - else: - # For pytest runs, don't bother with the crazy ones. We really only care about - # testing that the code works correctly. That can be done with just a couple of these. - test_tags = ["TAN", "TPV", "SIP", "TNX"] - - nref = 300 # Use more than 256, since that's the block length in C++. - nfit = 50 - - rng = galsim.UniformDeviate(57721) - - for tag in test_tags: - file_name, ref_list = references[tag] - header = fits.getheader(os.path.join(dir, file_name)) - wcs = galsim.FitsWCS(header=header, suppress_warning=True) - if not wcs.isCelestial(): - continue - print(tag) - x = np.empty(nref, dtype=float) - y = np.empty(nref, dtype=float) - rng.generate(x) - rng.generate(y) - x *= header["NAXIS1"] # rough range of reference WCS image coords - y *= header["NAXIS2"] - ra, dec = wcs.xyToradec(x, y, units="rad") - print("pixel scale = ", wcs.pixelArea(galsim.PositionD(1, 1)) ** 0.5) - - # First check that we can get a fit without distorions, using appropriate - # Projections. Should get very high accuracy here. - if tag in ("STG", "ZEA", "ARC", "TAN"): - fittedWCS = galsim.FittedSIPWCS( - x[:nfit], - y[:nfit], - ra[:nfit], - dec[:nfit], - order=1, - wcs_type=tag, - center=wcs.center, - ) - ra_test, dec_test = fittedWCS.xyToradec(x, y, units="rad") - check_sphere(ra, dec, ra_test, dec_test, atol=1e-9) - x_test, y_test = fittedWCS.radecToxy(ra, dec, units="rad") - np.testing.assert_allclose( - np.hstack([x, y]), np.hstack([x_test, y_test]), rtol=0, atol=1e-9 - ) - header = {} - fittedWCS.writeToFitsHeader(header, galsim.BoundsI(0, 192, 0, 192)) - assert header["CTYPE1"] == "RA---" + tag - - # Now try adding in SIP distortions. We'll just use TAN projection from - # here forward. - fittedWCS = galsim.FittedSIPWCS(x[:nfit], y[:nfit], ra[:nfit], dec[:nfit]) - ra_test, dec_test = fittedWCS.xyToradec(x, y, units="rad") - - # import matplotlib.pyplot as plt - # fig, axes = plt.subplots(ncols=2, figsize=(10, 3)) - # dra = np.rad2deg(((ra-ra_test)+np.pi)%(2*np.pi)-np.pi)*3600 - # ddec = np.rad2deg(dec-dec_test)*3600 - # s0 = axes[0].scatter( - # np.rad2deg(ra), np.rad2deg(dec), - # c=dra, vmin=np.min(dra[:nfit]), vmax=np.max(dra[:nfit]) - # ) - # plt.colorbar(s0, ax=axes[0]) - # s1 = axes[1].scatter( - # np.rad2deg(ra), np.rad2deg(dec), - # c=ddec, vmin=np.min(ddec[:nfit]), vmax=np.max(ddec[:nfit]) - # ) - # plt.colorbar(s1, ax=axes[1]) - # axes[0].scatter( - # np.rad2deg(ra[:nfit]), np.rad2deg(dec[:nfit]), - # facecolors='none', edgecolors='k', s=100 - # ) - # axes[1].scatter( - # np.rad2deg(ra[:nfit]), np.rad2deg(dec[:nfit]), - # facecolors='none', edgecolors='k', s=100 - # ) - # plt.show() - - check_sphere(ra, dec, ra_test, dec_test, atol=tol[tag][0]) - - # Check reverse - x_test, y_test = fittedWCS.radecToxy(ra, dec, units="rad") - - # import matplotlib.pyplot as plt - # fig, axes = plt.subplots(ncols=2, figsize=(10, 3)) - # dx = x-x_test - # dy = y-y_test - # s0 = axes[0].scatter( - # x, y, c=dx, vmin=np.min(dx[:nfit]), vmax=np.max(dx[:nfit]) - # ) - # plt.colorbar(s0, ax=axes[0]) - # s1 = axes[1].scatter( - # x, y, c=dy, vmin=np.min(dy[:nfit]), vmax=np.max(dy[:nfit]) - # ) - # plt.colorbar(s1, ax=axes[1]) - # axes[0].scatter( - # x[:nfit], y[:nfit], facecolors='none', edgecolors='k', s=100 - # ) - # axes[1].scatter( - # x[:nfit], y[:nfit], facecolors='none', edgecolors='k', s=100 - # ) - # plt.show() - - np.testing.assert_allclose( - np.hstack([x, y]), np.hstack([x_test, y_test]), rtol=0, atol=tol[tag][1] - ) - - # Check x,y being higher dim than 1 - x_test_2d, y_test_2d = fittedWCS.radecToxy( - ra.reshape(30, 10), dec.reshape(30, 10), units="rad" - ) - assert x_test_2d.shape == (30, 10) - assert y_test_2d.shape == (30, 10) - np.testing.assert_array_equal(x_test_2d.ravel(), x_test) - np.testing.assert_array_equal(y_test_2d.ravel(), y_test) - - ra_test_3d, dec_test_3d = fittedWCS.xyToradec( - x.reshape(5, 5, 12), y.reshape(5, 5, 12), units="rad" - ) - assert ra_test_3d.shape == (5, 5, 12) - assert dec_test_3d.shape == (5, 5, 12) - np.testing.assert_array_equal(ra_test_3d.ravel(), ra_test) - np.testing.assert_array_equal(dec_test_3d.ravel(), dec_test) - - # Try again, but force a different center - # Increase tolerance since WCS no longer nicely centered on region of - # interest. - center = galsim.CelestialCoord( - galsim.Angle.from_hms(ref_list[0][0]), galsim.Angle.from_dms(ref_list[0][1]) - ) - fittedWCS = galsim.FittedSIPWCS( - x[:nfit], y[:nfit], ra[:nfit], dec[:nfit], order=3, center=center - ) - ra_test, dec_test = fittedWCS.xyToradec(x, y, units="rad") - check_sphere(ra, dec, ra_test, dec_test, atol=tol[tag][0]) - # Check reverse - x_test, y_test = fittedWCS.radecToxy(ra, dec, units="rad") - np.testing.assert_allclose( - np.hstack([x, y]), np.hstack([x_test, y_test]), rtol=0, atol=tol[tag][1] - ) - assert fittedWCS.center == center - - # Check illegal values - with np.testing.assert_raises(galsim.GalSimValueError): - galsim.FittedSIPWCS(x, y, ra, dec, order=0) - - # Check that we get a standard TAN WCS if order=1 - header = {} - wcs = galsim.FittedSIPWCS(x, y, ra, dec, order=1) - wcs.writeToFitsHeader(header, galsim.BoundsI(0, 192, 0, 192)) - assert "A_ORDER" not in header - assert header["CTYPE1"] == "RA---TAN" - # and a TAN-SIP WCS if order > 1 - header = {} - wcs = galsim.FittedSIPWCS(x, y, ra, dec, order=2) - wcs.writeToFitsHeader(header, galsim.BoundsI(0, 192, 0, 192)) - assert "A_ORDER" in header - assert header["CTYPE1"] == "RA---TAN-SIP" - - # Check that error is raised if not enough stars are supplied. - # 3 stars is enough for order=1 - wcs = galsim.FittedSIPWCS(x[:3], y[:3], ra[:3], dec[:3], order=1) - # but 2 is not. - with np.testing.assert_raises(galsim.GalSimError): - wcs = galsim.FittedSIPWCS(x[:2], y[:2], ra[:2], dec[:2], order=1) - # For order=3, there are 2*(3+4) ab coefficiens, 2 crpix, and 4 cd. - # 2 constraints per star means we need at least 10 stars - wcs = galsim.FittedSIPWCS(x[:10], y[:10], ra[:10], dec[:10], order=3) - with np.testing.assert_raises(galsim.GalSimError): - wcs = galsim.FittedSIPWCS(x[:9], y[:9], ra[:9], dec[:9], order=3) - - if __name__ != "__main__": - return - - # Finally, the ZPN fit isn't very good with a TAN projection. - # The native projection for ZPN is ARC (postel), and there are radial - # polynomial corrections up to 7th order in this particular WCS. Our 2D SIP - # corrections can't reproduce an odd-powered radial correction, (would need - # a sqrt to do so) so we never really get that close here. Best ARC-SIP I - # could find is order=4, which gives an error of ~0.5 degree over an image a - # few 10s of degrees across. Not great, but I think this is just an - # exceptional WCS, and not a problem with the code. - file_name, ref_list = references["ZPN"] - header = fits.getheader(os.path.join(dir, file_name)) - wcs = galsim.FitsWCS(header=header, suppress_warning=True) - x = np.empty(nref, dtype=float) - y = np.empty(nref, dtype=float) - rng.generate(x) - rng.generate(y) - x *= header["NAXIS1"] - y *= header["NAXIS2"] - ra, dec = wcs.xyToradec(x, y, units="rad") - # Use the same center of projection as the source wcs. - center = galsim.CelestialCoord( - ra=header["LONPOLE"] * galsim.degrees, dec=header["LATPOLE"] * galsim.degrees - ) - fittedWCS = galsim.FittedSIPWCS( - x[:nfit], - y[:nfit], - ra[:nfit], - dec[:nfit], - wcs_type="ARC", - order=4, - center=center, - ) - ra_test, dec_test = fittedWCS.xyToradec(x, y, units="rad") - check_sphere(ra, dec, ra_test, dec_test, atol=2400) - x_test, y_test = fittedWCS.radecToxy(ra, dec, units="rad") - np.testing.assert_allclose( - np.hstack([x, y]), np.hstack([x_test, y_test]), rtol=0, atol=10.0 - ) - # We can at least confirm we made an ARC-SIP WCS - header = {} - fittedWCS.writeToFitsHeader(header, galsim.BoundsI(0, 192, 0, 192)) - assert header["CTYPE1"] == "RA---ARC-SIP" - - -@timer -def test_scamp(): - """Test that we can read in a SCamp .head file correctly""" - dir = os.path.join( - os.path.dirname(__file__), "..", "..", "GalSim", "tests", "fits_files" - ) - file_name = "scamp.head" - - wcs = galsim.FitsWCS(file_name, dir=dir, text_file=True) - print("SCamp FitsWCS is really ", type(wcs)) - - # These are just random points that I checked on one machine with this file. - # For this test, we don't care much about an independent accuracy test, since that should - # be covered by the other tests. We are mostly testing that the above syntax works - # correctly, and that different machines (with different pyfits versions perhaps) end - # up reading in the same GSFitsWCS object. - ref_list = [ - ("01:04:44.197307", "-03:39:07.588000", 123, 567, 0), - ("01:04:36.022067", "-03:39:33.900586", 789, 432, 0), - ] - # This also checks that the dms parser works with : separators, which I'm not sure if - # I test anywhere else... - - do_ref(wcs, ref_list, "Scamp FitsWCS") - - -@timer -def test_compateq(): - """Test that WCS equality vs. compatibility work as physically expected.""" - # First check that compatible works properly for two WCS that are actually equal - assert galsim.wcs.compatible(galsim.PixelScale(0.23), galsim.PixelScale(0.23)) - # Now for a simple offset: check they are compatible but not equal - assert galsim.wcs.compatible( - galsim.PixelScale(0.23), galsim.OffsetWCS(0.23, galsim.PositionD(12, 34)) - ) - assert galsim.PixelScale(0.23) != galsim.OffsetWCS(0.23, galsim.PositionD(12, 34)) - # Further examples of compatible but != below. - assert galsim.wcs.compatible( - galsim.JacobianWCS(0.2, 0.01, -0.02, 0.23), - galsim.AffineTransform( - 0.2, 0.01, -0.02, 0.23, galsim.PositionD(12, 34), galsim.PositionD(45, 54) - ), - ) - assert galsim.JacobianWCS(0.2, 0.01, -0.02, 0.23) != galsim.AffineTransform( - 0.2, 0.01, -0.02, 0.23, galsim.PositionD(12, 34), galsim.PositionD(45, 54) - ) - assert galsim.wcs.compatible( - galsim.PixelScale(0.23), - galsim.AffineTransform( - 0.23, 0.0, 0.0, 0.23, galsim.PositionD(12, 34), galsim.PositionD(45, 54) - ), - ) - assert galsim.PixelScale(0.23) != galsim.AffineTransform( - 0.23, 0.0, 0.0, 0.23, galsim.PositionD(12, 34), galsim.PositionD(45, 54) - ) - - # Finally, some that are truly incompatible. - assert not galsim.wcs.compatible(galsim.PixelScale(0.23), galsim.PixelScale(0.27)) - assert not galsim.wcs.compatible( - galsim.PixelScale(0.23), galsim.JacobianWCS(0.23, 0.01, -0.02, 0.27) - ) - assert not galsim.wcs.compatible( - galsim.JacobianWCS(0.2, -0.01, 0.02, 0.23), - galsim.AffineTransform( - 0.2, 0.01, -0.02, 0.23, galsim.PositionD(12, 34), galsim.PositionD(45, 54) - ), - ) - - # Non-uniform WCSs are considered compatible if their jacobians are everywhere the same. - # It (obviously) doesn't actually check this -- it relies on the functional part being - # the same, and maybe just resetting the origin(s). - uv1 = galsim.UVFunction( - "0.2*x + 0.01*x*y - 0.03*y**2", - "0.2*y - 0.01*x*y + 0.04*x**2", - origin=galsim.PositionD(12, 34), - world_origin=galsim.PositionD(45, 54), - ) - uv2 = galsim.UVFunction( - "0.2*x + 0.01*x*y - 0.03*y**2", - "0.2*y - 0.01*x*y + 0.04*x**2", - origin=galsim.PositionD(23, 56), - world_origin=galsim.PositionD(11, 22), - ) - uv3 = galsim.UVFunction( - "0.2*x - 0.01*x*y + 0.03*y**2", - "0.2*y + 0.01*x*y - 0.04*x**2", - origin=galsim.PositionD(23, 56), - world_origin=galsim.PositionD(11, 22), - ) - affine = galsim.AffineTransform( - 0.2, 0.01, -0.02, 0.23, galsim.PositionD(12, 34), galsim.PositionD(45, 54) - ) - assert galsim.wcs.compatible(uv1, uv2) - assert galsim.wcs.compatible(uv2, uv1) - assert not galsim.wcs.compatible(uv1, uv3) - assert not galsim.wcs.compatible(uv2, uv3) - assert not galsim.wcs.compatible(uv3, uv1) - assert not galsim.wcs.compatible(uv3, uv2) - assert not galsim.wcs.compatible(uv1, affine) - assert not galsim.wcs.compatible(uv2, affine) - assert not galsim.wcs.compatible(uv3, affine) - assert not galsim.wcs.compatible(affine, uv1) - assert not galsim.wcs.compatible(affine, uv2) - assert not galsim.wcs.compatible(affine, uv3) - - -@timer -def test_coadd(): - """ - This mostly serves as an example of how to treat the WCSs properly when using - galsim.InterpolatedImages to make a coadd. Not exactly what this class was designed - for, but since people have used it that way, it's useful to have a working example. - """ - # Make three "observations" of an object on images with different WCSs. - - # Three different local jacobaians. (Even different relative flips to make the differences - # more obvious than just the relative rotations and distortions.) - jac = [ - (0.26, 0.05, -0.08, 0.24), # Normal orientation - (0.25, -0.02, 0.01, -0.24), # Flipped on y axis (e2 -> -e2) - (0.03, 0.27, 0.29, 0.07), # Flipped on x=y axis (e1 -> -e1) - ] - - # Three different centroid positions - pos = [(123.23, 743.12), (772.11, 444.61), (921.37, 382.82)] - - # All the same sky position - sky_pos = galsim.CelestialCoord(5 * galsim.hours, -25 * galsim.degrees) - - # Calculate the appropriate bounds to use - N = 32 - bounds = [ - galsim.BoundsI( - int(p[0]) - N / 2 + 1, - int(p[0]) + N / 2, - int(p[1]) - N / 2 + 1, - int(p[1]) + N / 2, - ) - for p in pos - ] - - # Calculate the offset from the center - offset = [galsim.PositionD(*p) - b.true_center for (p, b) in zip(pos, bounds)] - - # Construct the WCSs - wcs = [ - galsim.TanWCS( - affine=galsim.AffineTransform(*j, origin=galsim.PositionD(*p)), - world_origin=sky_pos, - ) - for (j, p) in zip(jac, pos) - ] - - # All the same galaxy profile. (NB: I'm ignoring the PSF here.) - gal = galsim.Exponential(half_light_radius=1.3, flux=456).shear(g1=0.4, g2=0.3) - - # Draw the images - # NB: no_pixel here just so it's easier to check the shear values at the end without having - # to account for the dilution by the pixel convolution. - images = [ - gal.drawImage(image=galsim.Image(b, wcs=w), offset=o, method="no_pixel") - for (b, w, o) in zip(bounds, wcs, offset) - ] - - # Measured moments should have very different shears, and accurate centers - mom0 = images[0].FindAdaptiveMom() - print( - "im0: observed_shape = ", - mom0.observed_shape, - " center = ", - mom0.moments_centroid, - ) - assert mom0.observed_shape.e1 > 0 - assert mom0.observed_shape.e2 > 0 - np.testing.assert_almost_equal(mom0.moments_centroid.x, pos[0][0], decimal=1) - np.testing.assert_almost_equal(mom0.moments_centroid.y, pos[0][1], decimal=1) - - mom1 = images[1].FindAdaptiveMom() - print( - "im1: observed_shape = ", - mom1.observed_shape, - " center = ", - mom1.moments_centroid, - ) - assert mom1.observed_shape.e1 > 0 - assert mom1.observed_shape.e2 < 0 - np.testing.assert_almost_equal(mom1.moments_centroid.x, pos[1][0], decimal=1) - np.testing.assert_almost_equal(mom1.moments_centroid.y, pos[1][1], decimal=1) - - mom2 = images[2].FindAdaptiveMom() - print( - "im2: observed_shape = ", - mom2.observed_shape, - " center = ", - mom2.moments_centroid, - ) - assert mom2.observed_shape.e1 < 0 - assert mom2.observed_shape.e2 > 0 - np.testing.assert_almost_equal(mom2.moments_centroid.x, pos[2][0], decimal=1) - np.testing.assert_almost_equal(mom2.moments_centroid.y, pos[2][1], decimal=1) - - # Make an empty image for the coadd - coadd_image = galsim.Image(48, 48, scale=0.2) - - for p, im in zip(pos, images): - # Make sure we tell the profile where we think the center of the object is on the image. - offset = galsim.PositionD(*p) - im.true_center - interp = galsim.InterpolatedImage(im, offset=offset) - # Here the no_pixel is required. The InterpolatedImage already has pixels so we - # don't want to convovle by a pixel response again. - interp.drawImage(coadd_image, add_to_image=True, method="no_pixel") - - mom = coadd_image.FindAdaptiveMom() - print( - "coadd: observed_shape = ", - mom.observed_shape, - " center = ", - mom.moments_centroid, - ) - np.testing.assert_almost_equal(mom.observed_shape.g1, 0.4, decimal=2) - np.testing.assert_almost_equal(mom.observed_shape.g2, 0.3, decimal=2) - np.testing.assert_almost_equal(mom.moments_centroid.x, 24.5, decimal=2) - np.testing.assert_almost_equal(mom.moments_centroid.y, 24.5, decimal=2) - - -@timer -def test_lowercase(): - # The WCS parsing should be insensitive to the case of the header key values. - # Matt Becker ran into a problem when his wcs dict had lowercase keys. - wcs_dict = { - "simple": True, - "bitpix": -32, - "naxis": 2, - "naxis1": 10000, - "naxis2": 10000, - "extend": True, - "gs_xmin": 1, - "gs_ymin": 1, - "gs_wcs": "GSFitsWCS", - "ctype1": "RA---TAN", - "ctype2": "DEC--TAN", - "crpix1": 5000.5, - "crpix2": 5000.5, - "cd1_1": -7.305555555556e-05, - "cd1_2": 0.0, - "cd2_1": 0.0, - "cd2_2": 7.305555555556e-05, - "cunit1": "deg ", - "cunit2": "deg ", - "crval1": 86.176841, - "crval2": -22.827778, - } - wcs = galsim.FitsWCS(header=wcs_dict) - print("wcs = ", wcs) - assert isinstance(wcs, galsim.GSFitsWCS) - print(wcs.local(galsim.PositionD(0, 0))) - np.testing.assert_allclose( - wcs.local(galsim.PositionD(0, 0)).getMatrix().ravel(), - [0.26298, 0.00071, -0.00072, 0.26298], - atol=1.0e-4, - ) - - -@timer -def test_int_args(): - """Test that integer arguments for various things work correctly.""" - # Some of these used to trigger - # TypeError: Cannot cast ufunc subtract output from dtype('float64') to dtype('int64') - # with casting rule 'same_kind' - # This started with numpy v1.10. - - test_tags = all_tags - - dir = os.path.join( - os.path.dirname(__file__), "..", "..", "GalSim", "tests", "fits_files" - ) - - for tag in test_tags: - file_name, ref_list = references[tag] - wcs = galsim.FitsWCS(file_name, dir=dir, suppress_warning=True) - - posi = galsim.PositionI(5, 6) - posd = galsim.PositionD(5, 6) - - local_wcs1 = wcs.local(posd) - local_wcs2 = wcs.local(posi) - assert local_wcs1 == local_wcs2 - - wposi = wcs.toWorld(posi) - posi_roundtrip = wcs.toImage(wposi) - print("posi_roundtrip = ", posi_roundtrip) - assert np.isclose(posi_roundtrip.x, posi.x) - assert np.isclose(posi_roundtrip.y, posi.y) - - # Also check a DES file. - # Along the way, check issue #1024 where Erin noticed that reading the WCS from the - # header of a compressed file was spending lots of time decompressing the data, which - # is unnecessary. - dir = os.path.join( - os.path.dirname(__file__), "..", "..", "GalSim", "tests", "des_data" - ) - file_name = "DECam_00158414_01.fits.fz" - with Profile(): - t0 = time.time() - wcs = galsim.FitsWCS(file_name, dir=dir) - t1 = time.time() - # Before fixing #1024, this took about 0.5 sec. - # Now it usually takes about 0.04 sec. Testing at 0.25 seems like a reasonable midpoint. - print("Time = ", t1 - t0) - if __name__ == "__main__": - # Don't include this in regular unit tests, since it's not really something we need - # to guarantee. This timing estimate is appropriate for my laptop, but maybe not - # all systems. It also fails for pypy on GHA for some reason. - assert t1 - t0 < 0.25 - - posi = galsim.PositionI(5, 6) - posd = galsim.PositionD(5, 6) - - local_wcs1 = wcs.local(posd) - local_wcs2 = wcs.local(posi) - assert local_wcs1 == local_wcs2 - - wposi = wcs.toWorld(posi) - posi_roundtrip = wcs.toImage(wposi) - print("posi_roundtrip = ", posi_roundtrip) - assert np.isclose(posi_roundtrip.x, posi.x) - assert np.isclose(posi_roundtrip.y, posi.y) - - -@timer -def test_razero(): - """Test the makeSkyImage function near ra=0.""" - # This test reproduces the problem Chris Walter found when using the LSST WCS backend - # (with imsim) near ra=0. The WCS radec function would return number slightly less than - # 360 and then cross over to just over 0, rather than vary smoothly across the image. - - # Note: GSFitsWCS was never a problem. It always returns ra values near the value of the - # center of the image. So if the center is just below 360, it will return values > 360 - # when crossing the ra=0 line. Or if the center is just above 0, then it will return - # negative values when crossing the other way. - - # However, astropy has this "feature" of always wrapping the angles to 0..2pi, so we can - # use that to test that our makeSkyImage function works properly for wcs functions that - # do this. - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=RuntimeWarning) - import astropy.wcs # noqa: F401 - import scipy # noqa: F401 - AstropyWCS constructor will do this, so check now. - - dir = "fits_files" - # This file is based in sipsample.fits, but with the CRVAL1 changed to 0.002322805429 - file_name = "razero.fits" - wcs = galsim.AstropyWCS(file_name, dir=dir) - - print("0,0 -> ", wcs.toWorld(galsim.PositionD(0, 0))) - # The makeSkyImage test uses an image with bounds (-135,9,127,288) - # So these tests just make sure that we'll see the jump in ra values when we test this. - print("-135,127 -> ", wcs.toWorld(galsim.PositionD(-135, 127))) - print("9,127 -> ", wcs.toWorld(galsim.PositionD(9, 127))) - print("-135,288 -> ", wcs.toWorld(galsim.PositionD(-135, 288))) - print("9,288 -> ", wcs.toWorld(galsim.PositionD(9, 288))) - assert 359 < wcs.toWorld(galsim.PositionD(-135, 127)).ra / galsim.degrees < 360 - assert 0 < wcs.toWorld(galsim.PositionD(9, 288)).ra / galsim.degrees < 1 - - do_celestial_wcs(wcs, "Astropy file " + file_name) - do_wcs_image(wcs, "Astropy near ra=0") - - # This file is similar, but adjusted to be located near the south pole. - file_name = "pole.fits" - wcs = galsim.AstropyWCS(file_name, dir=dir) - - print("0,0 -> ", wcs.toWorld(galsim.PositionD(0, 0))) - # This time we get a whole range of ra values going around the perimeter, so a simple - # wrapping wouldn't work. - print("-135,127 -> ", wcs.toWorld(galsim.PositionD(-135, 127))) - print("-63,127 -> ", wcs.toWorld(galsim.PositionD(-63, 127))) - print("9,127 -> ", wcs.toWorld(galsim.PositionD(9, 127))) - print("9,208 -> ", wcs.toWorld(galsim.PositionD(9, 208))) - print("9,288 -> ", wcs.toWorld(galsim.PositionD(9, 288))) - print("-63, 288 -> ", wcs.toWorld(galsim.PositionD(-63, 288))) - print("-135,288 -> ", wcs.toWorld(galsim.PositionD(-135, 288))) - print("-135, 208 -> ", wcs.toWorld(galsim.PositionD(-135, 208))) - # The center is pretty close to the south pole. - # If it gets any closer, then the precise test at the center of the image fails at 3 d.p. - # I think we just need to accept that at the pole itself, our finite difference calculation - # won't be super accurate. With this, the pole is only a few pixels from the center of the - # image, and the center pixel passes our test. - print("-63, 208 -> ", wcs.toWorld(galsim.PositionD(-63, 208))) - - do_celestial_wcs(wcs, "Astropy file " + file_name) - do_wcs_image(wcs, "Astropy near pole") - - -if __name__ == "__main__": - testfns = [v for k, v in vars().items() if k[:5] == "test_" and callable(v)] - for testfn in testfns: - testfn() diff --git a/tests/jax/test_api.py b/tests/jax/test_api.py index ee525920..77c7df07 100644 --- a/tests/jax/test_api.py +++ b/tests/jax/test_api.py @@ -913,3 +913,115 @@ def test_api_noise(): "VariableGaussianNoise", "CCDNoise", } <= tested + + +@pytest.mark.parametrize( + "obj1", + [ + jax_galsim.Gaussian(fwhm=1.0), + jax_galsim.Pixel(scale=1.0), + jax_galsim.Exponential(scale_radius=1.0), + jax_galsim.Exponential(half_light_radius=1.0), + jax_galsim.Moffat(fwhm=1.0, beta=3), + jax_galsim.Moffat(scale_radius=1.0, beta=3), + jax_galsim.Shear(g1=0.1, g2=0.2), + jax_galsim.PositionD(x=0.1, y=0.2), + jax_galsim.BoundsI(xmin=0, xmax=1, ymin=0, ymax=1), + jax_galsim.BoundsD(xmin=0, xmax=1, ymin=0, ymax=1), + jax_galsim.ShearWCS(0.2, jax_galsim.Shear(g1=0.1, g2=0.2)), + jax_galsim.Delta(), + jax_galsim.Nearest(), + jax_galsim.Lanczos(3), + jax_galsim.Lanczos(3, conserve_dc=False), + jax_galsim.Quintic(), + jax_galsim.Linear(), + jax_galsim.Cubic(), + jax_galsim.SincInterpolant(), + ], +) +def test_api_pickling_eval_repr_basic(obj1): + """This test is here until we run all of the galsim tests which cover this one.""" + # test copied from galsim + import copy + import pickle + from collections.abc import Hashable + from numbers import Complex, Integral, Real # noqa: F401 + + # In case the repr uses these: + from numpy import ( # noqa: F401 + array, + complex64, + complex128, + float32, + float64, + int16, + int32, + ndarray, + uint16, + uint32, + ) + + def func(x): + return x + + print("Try pickling ", str(obj1)) + + # print('pickled obj1 = ',pickle.dumps(obj1)) + obj2 = pickle.loads(pickle.dumps(obj1)) + assert obj2 is not obj1 + # print('obj1 = ',repr(obj1)) + # print('obj2 = ',repr(obj2)) + f1 = func(obj1) + f2 = func(obj2) + # print('func(obj1) = ',repr(f1)) + # print('func(obj2) = ',repr(f2)) + assert f1 == f2 + + # Check that == works properly if the other thing isn't the same type. + assert f1 != object() + assert object() != f1 + + # Test the hash values are equal for two equivalent objects. + if isinstance(obj1, Hashable): + # print('hash = ',hash(obj1),hash(obj2)) + assert hash(obj1) == hash(obj2) + + obj3 = copy.copy(obj1) + assert obj3 is not obj1 + random = hasattr(obj1, "rng") or "rng" in repr(obj1) + if not random: # Things with an rng attribute won't be identical on copy. + f3 = func(obj3) + assert f3 == f1 + + obj4 = copy.deepcopy(obj1) + assert obj4 is not obj1 + f4 = func(obj4) + if random: + f1 = func(obj1) + # print('func(obj1) = ',repr(f1)) + # print('func(obj4) = ',repr(f4)) + assert f4 == f1 # But everything should be identical with deepcopy. + + # Also test that the repr is an accurate representation of the object. + # The gold standard is that eval(repr(obj)) == obj. So check that here as well. + # A few objects we don't expect to work this way in GalSim; when testing these, we set the + # `irreprable` kwarg to true. Also, we skip anything with random deviates since these don't + # respect the eval/repr roundtrip. + + if not random: + # A further complication is that the default numpy print options do not lead to sufficient + # precision for the eval string to exactly reproduce the original object, and start + # truncating the output for relatively small size arrays. So we temporarily bump up the + # precision and truncation threshold for testing. + # print(repr(obj1)) + with _galsim.utilities.printoptions(precision=20, threshold=np.inf): + obj5 = eval(repr(obj1)) + # print('obj1 = ',repr(obj1)) + # print('obj5 = ',repr(obj5)) + f5 = func(obj5) + # print('f1 = ',f1) + # print('f5 = ',f5) + assert f5 == f1, "func(obj1) = %r\nfunc(obj5) = %r" % (f1, f5) + else: + # Even if we're not actually doing the test, still make the repr to check for syntax errors. + repr(obj1) diff --git a/tests/jax/test_image_wrapping.py b/tests/jax/test_image_wrapping.py index 14bfe100..10fc4067 100644 --- a/tests/jax/test_image_wrapping.py +++ b/tests/jax/test_image_wrapping.py @@ -1,6 +1,6 @@ import jax import numpy as np -from galsim_test_helpers import timer +from galsim_test_helpers import timer, assert_raises import jax_galsim as galsim from jax_galsim.core.wrap_image import ( @@ -142,3 +142,192 @@ def _wrapit(im): p, grad = jax.vjp(_wrapit, im3) grad(p) jax.jvp(_wrapit, (im3,), (im3 * 2,)) + + +@timer +def test_wrap_jax_simple_real(): + """Test the image.wrap() function.""" + # Start with a fairly simple test where the image is 4 copies of the same data: + im_orig = galsim.Image( + [ + [11.0, 12.0, 13.0, 14.0, 11.0, 12.0, 13.0, 14.0], + [21.0, 22.0, 23.0, 24.0, 21.0, 22.0, 23.0, 24.0], + [31.0, 32.0, 33.0, 34.0, 31.0, 32.0, 33.0, 34.0], + [41.0, 42.0, 43.0, 44.0, 41.0, 42.0, 43.0, 44.0], + [11.0, 12.0, 13.0, 14.0, 11.0, 12.0, 13.0, 14.0], + [21.0, 22.0, 23.0, 24.0, 21.0, 22.0, 23.0, 24.0], + [31.0, 32.0, 33.0, 34.0, 31.0, 32.0, 33.0, 34.0], + [41.0, 42.0, 43.0, 44.0, 41.0, 42.0, 43.0, 44.0], + ] + ) + im = im_orig.copy() + b = galsim.BoundsI(1, 4, 1, 4) + im_quad = im_orig[b] + im_wrap = im.wrap(b) + np.testing.assert_allclose(im_wrap.array, 4.0 * im_quad.array) + + # The same thing should work no matter where the lower left corner is: + for xmin, ymin in ((1, 5), (5, 1), (5, 5), (2, 3), (4, 1)): + b = galsim.BoundsI(xmin, xmin + 3, ymin, ymin + 3) + im_quad = im_orig[b] + im = im_orig.copy() + im_wrap = im.wrap(b) + np.testing.assert_allclose( + im_wrap.array, + 4.0 * im_quad.array, + err_msg="image.wrap(%s) did not match expectation" % b, + ) + np.testing.assert_allclose( + im_wrap.array, + im[b].array, + err_msg="image.wrap(%s) did not return the right subimage" % b, + ) + # this test passes even though we do not get a view + im[b].fill(0) + np.testing.assert_allclose( + im_wrap.array, + im[b].array, + err_msg="image.wrap(%s) did not return a view of the original" % b, + ) + + +@timer +def test_wrap_jax_weird_real(): + # Now test where the subimage is not a simple fraction of the original, and all the + # sizes are different. + im = galsim.ImageD(17, 23, xmin=0, ymin=0) + b = galsim.BoundsI(7, 9, 11, 18) + im_test = galsim.ImageD(b, init_value=0) + for i in range(17): + for j in range(23): + val = np.exp(i / 7.3) + (j / 12.9) ** 3 # Something randomly complicated... + im[i, j] = val + # Find the location in the sub-image for this point. + ii = (i - b.xmin) % (b.xmax - b.xmin + 1) + b.xmin + jj = (j - b.ymin) % (b.ymax - b.ymin + 1) + b.ymin + im_test.addValue(ii, jj, val) + im_wrap = im.wrap(b) + np.testing.assert_allclose( + im_wrap.array, + im_test.array, + err_msg="image.wrap(%s) did not match expectation" % b, + ) + np.testing.assert_array_equal( + im_wrap.array, im[b].array, "image.wrap(%s) did not return the right subimage" + ) + np.testing.assert_equal( + im_wrap.bounds, b, "image.wrap(%s) does not have the correct bounds" + ) + + +@timer +def test_wrap_jax_complex(): + # For complex images (in particular k-space images), we often want the image to be implicitly + # Hermitian, so we only need to keep around half of it. + M = 38 + N = 25 + K = 8 + L = 5 + im = galsim.ImageCD(2 * M + 1, 2 * N + 1, xmin=-M, ymin=-N) # Explicitly Hermitian + im2 = galsim.ImageCD( + 2 * M + 1, N + 1, xmin=-M, ymin=0 + ) # Implicitly Hermitian across y axis + im3 = galsim.ImageCD( + M + 1, 2 * N + 1, xmin=0, ymin=-N + ) # Implicitly Hermitian across x axis + # print('im = ',im) + # print('im2 = ',im2) + # print('im3 = ',im3) + b = galsim.BoundsI(-K + 1, K, -L + 1, L) + b2 = galsim.BoundsI(-K + 1, K, 0, L) + b3 = galsim.BoundsI(0, K, -L + 1, L) + im_test = galsim.ImageCD(b, init_value=0) + for i in range(-M, M + 1): + for j in range(-N, N + 1): + # An arbitrary, complicated Hermitian function. + val = ( + np.exp((i / (2.3 * M)) ** 2 + 1j * (2.8 * i - 1.3 * j)) + + ((2 + 3j * j) / (1.9 * N)) ** 3 + ) + # val = 2*(i-j)**2 + 3j*(i+j) + + im[i, j] = val + if j >= 0: + im2[i, j] = val + if i >= 0: + im3[i, j] = val + + ii = (i - b.xmin) % (b.xmax - b.xmin + 1) + b.xmin + jj = (j - b.ymin) % (b.ymax - b.ymin + 1) + b.ymin + im_test.addValue(ii, jj, val) + # print("im = ",im.array) + + # Confirm that the image is Hermitian. + for i in range(-M, M + 1): + for j in range(-N, N + 1): + assert im(i, j) == im(-i, -j).conjugate() + + im_wrap = im.wrap(b) + # print("im_wrap = ",im_wrap.array) + np.testing.assert_allclose( + im_wrap.array, + im_test.array, + err_msg="image.wrap(%s) did not match expectation" % b, + ) + np.testing.assert_array_equal( + im_wrap.array, im[b].array, "image.wrap(%s) did not return the right subimage" + ) + np.testing.assert_equal( + im_wrap.bounds, b, "image.wrap(%s) does not have the correct bounds" + ) + + im2_wrap = im2.wrap(b2, hermitian="y") + # print('im_test = ',im_test[b2].array) + # print('im2_wrap = ',im2_wrap.array) + # print('diff = ',im2_wrap.array-im_test[b2].array) + np.testing.assert_allclose( + im2_wrap.array, + im_test[b2].array, + err_msg="image.wrap(%s) did not match expectation" % b, + ) + np.testing.assert_array_equal( + im2_wrap.array, + im2[b2].array, + "image.wrap(%s) did not return the right subimage", + ) + np.testing.assert_equal( + im2_wrap.bounds, b2, "image.wrap(%s) does not have the correct bounds" + ) + + im3_wrap = im3.wrap(b3, hermitian="x") + # print('im_test = ',im_test[b3].array) + # print('im3_wrap = ',im3_wrap.array) + # print('diff = ',im3_wrap.array-im_test[b3].array) + np.testing.assert_allclose( + im3_wrap.array, + im_test[b3].array, + err_msg="image.wrap(%s) did not match expectation" % b, + ) + np.testing.assert_array_equal( + im3_wrap.array, + im3[b3].array, + "image.wrap(%s) did not return the right subimage", + ) + np.testing.assert_equal( + im3_wrap.bounds, b3, "image.wrap(%s) does not have the correct bounds" + ) + + b = galsim.BoundsI(-K + 1, K, -L + 1, L) + b2 = galsim.BoundsI(-K + 1, K, 0, L) + b3 = galsim.BoundsI(0, K, -L + 1, L) + assert_raises(TypeError, im.wrap, bounds=None) + assert_raises(ValueError, im.wrap, b2, hermitian="y") + assert_raises(ValueError, im.wrap, b, hermitian="invalid") + assert_raises(ValueError, im.wrap, b3, hermitian="x") + + assert_raises(ValueError, im3.wrap, b, hermitian="x") + assert_raises(ValueError, im3.wrap, b2, hermitian="x") + assert_raises(ValueError, im2.wrap, b, hermitian="y") + assert_raises(ValueError, im2.wrap, b3, hermitian="y") + assert_raises(ValueError, im2.wrap, b2, hermitian="invalid") + assert_raises(ValueError, im3.wrap, b3, hermitian="invalid") diff --git a/tests/jax/galsim/test_interpolant_jax.py b/tests/jax/test_interpolant_jax.py similarity index 100% rename from tests/jax/galsim/test_interpolant_jax.py rename to tests/jax/test_interpolant_jax.py diff --git a/tests/jax/test_metacal.py b/tests/jax/test_metacal_jax.py similarity index 100% rename from tests/jax/test_metacal.py rename to tests/jax/test_metacal_jax.py diff --git a/tests/jax/test_ref_impl.py b/tests/jax/test_ref_impl.py new file mode 100644 index 00000000..7af83b8a --- /dev/null +++ b/tests/jax/test_ref_impl.py @@ -0,0 +1,72 @@ +import galsim as ref_galsim +import numpy as np + +import jax_galsim + + +def test_ref_impl_convolve(): + """Validates convolutions against reference GalSim""" + dx = 0.2 + fwhm_backwards_compatible = 1.0927449310213702 + + # First way to do a convolution + def conv1(galsim): + psf = galsim.Gaussian(fwhm=fwhm_backwards_compatible, flux=1) + pixel = galsim.Pixel(scale=dx, flux=1.0) + conv = galsim.Convolve([psf, pixel], real_space=False) + return conv.drawImage( + nx=32, ny=32, scale=dx, method="sb", use_true_center=False + ).array + + np.testing.assert_array_almost_equal( + conv1(ref_galsim), + conv1(jax_galsim), + 5, + err_msg="Gaussian convolved with Pixel disagrees with expected result", + ) + + # Second way of doing a convolution + def conv2(galsim): + psf = galsim.Gaussian(fwhm=fwhm_backwards_compatible, flux=1) + pixel = galsim.Pixel(scale=dx, flux=1.0) + conv = galsim.Convolve(psf, pixel, real_space=False) + return conv.drawImage( + nx=32, ny=32, scale=dx, method="sb", use_true_center=False + ).array + + np.testing.assert_array_almost_equal( + conv2(ref_galsim), + conv2(jax_galsim), + 5, + err_msg=" GSObject Convolve(psf,pixel) disagrees with expected result", + ) + + +def test_ref_impl_shearconvolve(): + """Verifies that a convolution of a Sheared Gaussian and a Box Profile + return the expected results. + """ + e1 = 0.05 + e2 = 0.0 + dx = 0.2 + + def func(galsim): + psf = ( + galsim.Gaussian(flux=1, sigma=1) + .shear(e1=e1, e2=e2) + .rotate(33 * galsim.degrees) + .shift(0.1, 0.4) + * 1.1 + ) + pixel = galsim.Pixel(scale=dx, flux=1.0) + conv = galsim.Convolve([psf, pixel]) + return conv.drawImage( + nx=32, ny=32, scale=dx, method="sb", use_true_center=False + ).array + + np.testing.assert_array_almost_equal( + func(ref_galsim), + func(jax_galsim), + 5, + err_msg="Using GSObject Convolve([psf,pixel]) disagrees with expected result", + ) diff --git a/tests/jax/test_temporary_tests.py b/tests/jax/test_temporary_tests.py deleted file mode 100644 index 252b4803..00000000 --- a/tests/jax/test_temporary_tests.py +++ /dev/null @@ -1,188 +0,0 @@ -import galsim as ref_galsim -import numpy as np -import pytest - -import jax_galsim - - -def test_convolve_temp(): - """Validates convolutions against reference GalSim - This test is to be removed once we can execute test_convolve.py in - its entirety. - """ - dx = 0.2 - fwhm_backwards_compatible = 1.0927449310213702 - - # First way to do a convolution - def conv1(galsim): - psf = galsim.Gaussian(fwhm=fwhm_backwards_compatible, flux=1) - pixel = galsim.Pixel(scale=dx, flux=1.0) - conv = galsim.Convolve([psf, pixel], real_space=False) - return conv.drawImage( - nx=32, ny=32, scale=dx, method="sb", use_true_center=False - ).array - - np.testing.assert_array_almost_equal( - conv1(ref_galsim), - conv1(jax_galsim), - 5, - err_msg="Gaussian convolved with Pixel disagrees with expected result", - ) - - # Second way of doing a convolution - def conv2(galsim): - psf = galsim.Gaussian(fwhm=fwhm_backwards_compatible, flux=1) - pixel = galsim.Pixel(scale=dx, flux=1.0) - conv = galsim.Convolve(psf, pixel, real_space=False) - return conv.drawImage( - nx=32, ny=32, scale=dx, method="sb", use_true_center=False - ).array - - np.testing.assert_array_almost_equal( - conv2(ref_galsim), - conv2(jax_galsim), - 5, - err_msg=" GSObject Convolve(psf,pixel) disagrees with expected result", - ) - - -def test_shearconvolve_temp(): - """Verifies that a convolution of a Sheared Gaussian and a Box Profile - return the expected results. - """ - e1 = 0.05 - e2 = 0.0 - dx = 0.2 - - def func(galsim): - psf = ( - galsim.Gaussian(flux=1, sigma=1) - .shear(e1=e1, e2=e2) - .rotate(33 * galsim.degrees) - .shift(0.1, 0.4) - * 1.1 - ) - pixel = galsim.Pixel(scale=dx, flux=1.0) - conv = galsim.Convolve([psf, pixel]) - return conv.drawImage( - nx=32, ny=32, scale=dx, method="sb", use_true_center=False - ).array - - np.testing.assert_array_almost_equal( - func(ref_galsim), - func(jax_galsim), - 5, - err_msg="Using GSObject Convolve([psf,pixel]) disagrees with expected result", - ) - - -@pytest.mark.parametrize( - "obj1", - [ - jax_galsim.Gaussian(fwhm=1.0), - jax_galsim.Pixel(scale=1.0), - jax_galsim.Exponential(scale_radius=1.0), - jax_galsim.Exponential(half_light_radius=1.0), - jax_galsim.Moffat(fwhm=1.0, beta=3), - jax_galsim.Moffat(scale_radius=1.0, beta=3), - jax_galsim.Shear(g1=0.1, g2=0.2), - jax_galsim.PositionD(x=0.1, y=0.2), - jax_galsim.BoundsI(xmin=0, xmax=1, ymin=0, ymax=1), - jax_galsim.BoundsD(xmin=0, xmax=1, ymin=0, ymax=1), - jax_galsim.ShearWCS(0.2, jax_galsim.Shear(g1=0.1, g2=0.2)), - jax_galsim.Delta(), - jax_galsim.Nearest(), - jax_galsim.Lanczos(3), - jax_galsim.Lanczos(3, conserve_dc=False), - jax_galsim.Quintic(), - jax_galsim.Linear(), - jax_galsim.Cubic(), - jax_galsim.SincInterpolant(), - ], -) -def test_pickling_eval_repr(obj1): - """This test is here until we run all of the galsim tests which cover this one.""" - # test copied from galsim - import copy - import pickle - from collections.abc import Hashable - from numbers import Complex, Integral, Real # noqa: F401 - - # In case the repr uses these: - from numpy import ( # noqa: F401 - array, - complex64, - complex128, - float32, - float64, - int16, - int32, - ndarray, - uint16, - uint32, - ) - - def func(x): - return x - - print("Try pickling ", str(obj1)) - - # print('pickled obj1 = ',pickle.dumps(obj1)) - obj2 = pickle.loads(pickle.dumps(obj1)) - assert obj2 is not obj1 - # print('obj1 = ',repr(obj1)) - # print('obj2 = ',repr(obj2)) - f1 = func(obj1) - f2 = func(obj2) - # print('func(obj1) = ',repr(f1)) - # print('func(obj2) = ',repr(f2)) - assert f1 == f2 - - # Check that == works properly if the other thing isn't the same type. - assert f1 != object() - assert object() != f1 - - # Test the hash values are equal for two equivalent objects. - if isinstance(obj1, Hashable): - # print('hash = ',hash(obj1),hash(obj2)) - assert hash(obj1) == hash(obj2) - - obj3 = copy.copy(obj1) - assert obj3 is not obj1 - random = hasattr(obj1, "rng") or "rng" in repr(obj1) - if not random: # Things with an rng attribute won't be identical on copy. - f3 = func(obj3) - assert f3 == f1 - - obj4 = copy.deepcopy(obj1) - assert obj4 is not obj1 - f4 = func(obj4) - if random: - f1 = func(obj1) - # print('func(obj1) = ',repr(f1)) - # print('func(obj4) = ',repr(f4)) - assert f4 == f1 # But everything should be identical with deepcopy. - - # Also test that the repr is an accurate representation of the object. - # The gold standard is that eval(repr(obj)) == obj. So check that here as well. - # A few objects we don't expect to work this way in GalSim; when testing these, we set the - # `irreprable` kwarg to true. Also, we skip anything with random deviates since these don't - # respect the eval/repr roundtrip. - - if not random: - # A further complication is that the default numpy print options do not lead to sufficient - # precision for the eval string to exactly reproduce the original object, and start - # truncating the output for relatively small size arrays. So we temporarily bump up the - # precision and truncation threshold for testing. - # print(repr(obj1)) - with ref_galsim.utilities.printoptions(precision=20, threshold=np.inf): - obj5 = eval(repr(obj1)) - # print('obj1 = ',repr(obj1)) - # print('obj5 = ',repr(obj5)) - f5 = func(obj5) - # print('f1 = ',f1) - # print('f5 = ',f5) - assert f5 == f1, "func(obj1) = %r\nfunc(obj5) = %r" % (f1, f5) - else: - # Even if we're not actually doing the test, still make the repr to check for syntax errors. - repr(obj1) From 2ed69afb4fceb1968271dedffe46eb46dcc80cff Mon Sep 17 00:00:00 2001 From: beckermr Date: Sun, 26 Nov 2023 07:34:15 -0500 Subject: [PATCH 36/85] update to latest test suite --- tests/GalSim | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/GalSim b/tests/GalSim index 8a3440d7..7d6923e2 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 8a3440d72d739763514a2620e61c6e50668648b9 +Subproject commit 7d6923e2ade53506ce972dd336e25f777ca4ba3c From a8e669ad4e89c126d7b855d039f0f4db5ee08880 Mon Sep 17 00:00:00 2001 From: beckermr Date: Sun, 26 Nov 2023 07:45:28 -0500 Subject: [PATCH 37/85] STY blacken --- jax_galsim/angle.py | 4 +++- jax_galsim/bounds.py | 18 +++++++++++++----- jax_galsim/image.py | 2 +- jax_galsim/sensor.py | 2 +- jax_galsim/utilities.py | 2 +- jax_galsim/wcs.py | 2 +- tests/conftest.py | 17 ++++++----------- tests/jax/test_image_wrapping.py | 2 +- 8 files changed, 27 insertions(+), 22 deletions(-) diff --git a/jax_galsim/angle.py b/jax_galsim/angle.py index 3126a6da..8bfdc8af 100644 --- a/jax_galsim/angle.py +++ b/jax_galsim/angle.py @@ -36,7 +36,9 @@ def __init__(self, value): """ if isinstance(value, AngleUnit): raise TypeError("Cannot construct AngleUnit from another AngleUnit") - self._value = 1.0 * cast_to_float(value) # this will cause an exception if things are not numeric + self._value = 1.0 * cast_to_float( + value + ) # this will cause an exception if things are not numeric @property def value(self): diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index 2c0658b4..f89a495f 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -11,8 +11,8 @@ ensure_hashable, has_tracers, ) -from jax_galsim.position import Position, PositionD, PositionI from jax_galsim.errors import GalSimUndefinedBoundsError +from jax_galsim.position import Position, PositionD, PositionI # The reason for avoid these tests is that they are not easy to do for jitted code. @@ -97,8 +97,10 @@ def _parse_args(self, *args, **kwargs): @_wraps(_galsim.Bounds.isDefined) def isDefined(self, _static=False): if _static: - return self._isdefined and np.all(self.xmin <= self.xmax) and np.all( - self.ymin <= self.ymax + return ( + self._isdefined + and np.all(self.xmin <= self.xmax) + and np.all(self.ymin <= self.ymax) ) else: return ( @@ -371,7 +373,10 @@ def __eq__(self, other): and np.array_equal(self.ymin, other.ymin, equal_nan=True) and np.array_equal(self.ymax, other.ymax, equal_nan=True) ) - or ((not self.isDefined(_static=True)) and (not other.isDefined(_static=True))) + or ( + (not self.isDefined(_static=True)) + and (not other.isDefined(_static=True)) + ) ) ) @@ -518,7 +523,10 @@ def numpyShape(self, _static=False): return jax.lax.cond( jnp.any(self.isDefined()), lambda xmin, xmax, ymin, ymax: (ymax - ymin + 1, xmax - xmin + 1), - lambda xmin, xmax, ymin, ymax: (jnp.zeros_like(xmin), jnp.zeros_like(xmin)), + lambda xmin, xmax, ymin, ymax: ( + jnp.zeros_like(xmin), + jnp.zeros_like(xmin), + ), self.xmin, self.xmax, self.ymin, diff --git a/jax_galsim/image.py b/jax_galsim/image.py index a7a265d2..06a32a18 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -6,10 +6,10 @@ from jax_galsim.bounds import Bounds, BoundsD, BoundsI from jax_galsim.core.utils import ensure_hashable +from jax_galsim.errors import GalSimImmutableError from jax_galsim.position import PositionI from jax_galsim.utilities import parse_pos_args from jax_galsim.wcs import BaseWCS, PixelScale -from jax_galsim.errors import GalSimImmutableError @_wraps( diff --git a/jax_galsim/sensor.py b/jax_galsim/sensor.py index 47d5a940..27f3552e 100644 --- a/jax_galsim/sensor.py +++ b/jax_galsim/sensor.py @@ -2,8 +2,8 @@ from jax._src.numpy.util import _wraps from jax.tree_util import register_pytree_node_class -from jax_galsim.position import PositionI from jax_galsim.errors import GalSimUndefinedBoundsError +from jax_galsim.position import PositionI @_wraps(_galsim.Sensor) diff --git a/jax_galsim/utilities.py b/jax_galsim/utilities.py index 2a187366..1c9fedb6 100644 --- a/jax_galsim/utilities.py +++ b/jax_galsim/utilities.py @@ -5,9 +5,9 @@ import jax.numpy as jnp from jax._src.numpy.util import _wraps +from jax_galsim.core.utils import has_tracers from jax_galsim.errors import GalSimIncompatibleValuesError, GalSimValueError from jax_galsim.position import PositionD, PositionI -from jax_galsim.core.utils import has_tracers printoptions = _galsim.utilities.printoptions diff --git a/jax_galsim/wcs.py b/jax_galsim/wcs.py index fc01bfea..b95d0ad9 100644 --- a/jax_galsim/wcs.py +++ b/jax_galsim/wcs.py @@ -1631,8 +1631,8 @@ def readFromFitsHeader(header, suppress_warning=True): a tuple (wcs, origin) of the wcs from the header and the image origin. """ from . import fits - from .fitswcs import FitsWCS + if not isinstance(header, fits.FitsHeader): header = fits.FitsHeader(header) xmin = header.get("GS_XMIN", 1) diff --git a/tests/conftest.py b/tests/conftest.py index 50a3f527..a4edcff9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -82,19 +82,14 @@ def pytest_collection_modifyitems(config, items): # if this is a galsim test we check if it is requested or not if ( ( - ( - not any( - [t in item.nodeid for t in test_config["enabled_tests"]["galsim"]] - ) + not any( + [t in item.nodeid for t in test_config["enabled_tests"]["galsim"]] ) - and "*" not in test_config["enabled_tests"]["galsim"] - ) and ( - ( - not any( - [t in item.nodeid for t in test_config["enabled_tests"]["coord"]] - ) - ) and "*" not in test_config["enabled_tests"]["coord"] ) + and "*" not in test_config["enabled_tests"]["galsim"] + ) and ( + (not any([t in item.nodeid for t in test_config["enabled_tests"]["coord"]])) + and "*" not in test_config["enabled_tests"]["coord"] ): item.add_marker(skip) diff --git a/tests/jax/test_image_wrapping.py b/tests/jax/test_image_wrapping.py index 10fc4067..f1f7ada4 100644 --- a/tests/jax/test_image_wrapping.py +++ b/tests/jax/test_image_wrapping.py @@ -1,6 +1,6 @@ import jax import numpy as np -from galsim_test_helpers import timer, assert_raises +from galsim_test_helpers import assert_raises, timer import jax_galsim as galsim from jax_galsim.core.wrap_image import ( From 19357b3eefe193b6d1207194ecb8ac045032b816 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 27 Nov 2023 12:28:01 -0600 Subject: [PATCH 38/85] TST next round of test fixes --- jax_galsim/image.py | 25 +++++++-- jax_galsim/photon_array.py | 104 +++++++++++++++++++++++++++++++++++-- tests/GalSim | 2 +- tests/conftest.py | 2 + 4 files changed, 125 insertions(+), 8 deletions(-) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 06a32a18..d249fbf6 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -345,6 +345,16 @@ def array(self): """The underlying numpy array.""" return self._array + @property + def nrow(self): + """The number of rows in the image""" + return self._array.shape[0] + + @property + def ncol(self): + """The number of columns in the image""" + return self._array.shape[1] + @property def isconst(self): """Whether the `Image` is constant. I.e. modifying its values is an error.""" @@ -438,7 +448,9 @@ def real(self): This works for real or complex. For real images, it acts the same as `view`. """ - return self.__class__(self.array.real, bounds=self.bounds, wcs=self.wcs) + return self.__class__( + self.array.real, bounds=self.bounds, wcs=self.wcs, make_const=self._is_const + ) @property def imag(self): @@ -449,7 +461,12 @@ def imag(self): This works for real or complex. For real images, the returned array is read-only and all elements are 0. """ - return self.__class__(self.array.imag, bounds=self.bounds, wcs=self.wcs) + return self.__class__( + self.array.imag, + bounds=self.bounds, + wcs=self.wcs, + make_const=self._is_const or (not self.iscomplex), + ) @property def conjugate(self): @@ -895,7 +912,7 @@ def view( # we use the static bounds check set at construction # since the dynamic one in JAX would change array shape if not self.bounds.isDefined(_static=True): - return Image(wcs=wcs, dtype=dtype) + return Image(wcs=wcs, dtype=dtype, make_const=make_const) # Recast the array type if necessary if dtype != self.array.dtype: @@ -906,7 +923,7 @@ def view( array = self.array # Make the return Image - ret = self.__class__(array, bounds=self.bounds, wcs=wcs) + ret = self.__class__(array, bounds=self.bounds, wcs=wcs, make_const=make_const) # Update the origin if requested if origin is not None: diff --git a/jax_galsim/photon_array.py b/jax_galsim/photon_array.py index 4024f337..041c30b7 100644 --- a/jax_galsim/photon_array.py +++ b/jax_galsim/photon_array.py @@ -59,7 +59,6 @@ def __init__( time=None, _nokeep=None, ): - # self._N = N self._Ntot = _JAX_GALSIM_PHOTON_ARRAY_SIZE or N # if ( # _JAX_GALSIM_PHOTON_ARRAY_SIZE is not None @@ -155,7 +154,6 @@ def _fromArrays( ) ret = cls.__new__(cls) - # ret._N = x.shape[0] ret._Ntot = _JAX_GALSIM_PHOTON_ARRAY_SIZE or x.shape[0] ret._x = x.copy() ret._y = y.copy() @@ -191,7 +189,7 @@ def _fromArrays( if time is not None else jnp.full(ret._Ntot, jnp.nan, dtype=float) ) - ret._is_corr = jnp.array(is_corr) + ret.setCorrelated(is_corr) return ret def tree_flatten(self): @@ -400,10 +398,28 @@ def allocateTimes(self): def isCorrelated(self): """Returns whether the photons are correlated""" + from .deprecated import depr + + depr( + "isCorrelated", + 2.5, + "", + "We don't think this is necessary anymore. If you have a use case that " + "requires it, please open an issue.", + ) return self._is_corr def setCorrelated(self, is_corr=True): """Set whether the photons are correlated""" + from .deprecated import depr + + depr( + "setCorrelated", + 2.5, + "", + "We don't think this is necessary anymore. If you have a use case that " + "requires it, please open an issue.", + ) self._is_corr = jnp.array(is_corr, dtype=bool) def getTotalFlux(self): @@ -459,6 +475,13 @@ def _sort_by_nokeep(self): def assignAt(self, istart, rhs): """Assign the contents of another `PhotonArray` to this one starting at istart.""" + from .deprecated import depr + + depr( + "PhotonArray.assignAt", + 2.5, + "copyFrom(rhs, slice(istart, istart+rhs.size()))", + ) if istart + rhs.size() > self.size(): raise GalSimValueError( "The given rhs does not fit into this array starting at %d" % istart, @@ -477,6 +500,81 @@ def assignAt(self, istart, rhs): return self._sort_by_nokeep() + @_wraps( + _galsim.PhotonArray.copyFrom, + lax_description="The JAX version of PhotonArray.copyFrom does not raise for out of bounds indices.", + ) + def copyFrom( + self, + rhs, + target_indices=slice(None), + source_indices=slice(None), + do_xy=True, + do_flux=True, + do_other=True, + ): + return self._copyFrom( + rhs, target_indices, source_indices, do_xy, do_flux, do_other + ) + + def _copyFrom( + self, + rhs, + target_indices, + source_indices, + do_xy=True, + do_flux=True, + do_other=True, + ): + """Equivalent to self.copyFrom(rhs, target_indices, source_indices), but without any + checks that the indices are valid. + """ + # Aliases for notational convenience. + s1 = target_indices + s2 = source_indices + + @jax.jit + def _cond_set_indices(arr1, arr2, cond_val): + return jax.lax.cond( + cond_val, + lambda arr1, arr2: arr1.at[s1].set(arr2.at[s2].get()), + lambda arr1, arr2: arr1, + arr1, + arr2, + ) + + if do_xy: + self._x = self._x.at[s1].set(rhs.x.at[s2].get()) + self._y = self._y.at[s1].set(rhs.y.at[s2].get()) + + if do_flux: + self._flux = self._flux.at[s1].set(rhs.flux.at[s2].get()) + + if do_other: + self._dxdz = _cond_set_indices( + self._dxdz, rhs.dxdz, rhs.hasAllocatedAngles() + ) + self._dydz = _cond_set_indices( + self._dydz, rhs.dydz, rhs.hasAllocatedAngles() + ) + self._wave = _cond_set_indices( + self._wave, rhs.wavelength, rhs.hasAllocatedWavelengths() + ) + self._pupil_u = _cond_set_indices( + self._pupil_u, rhs.pupil_u, rhs.hasAllocatedPupil() + ) + self._pupil_v = _cond_set_indices( + self._pupil_v, rhs.pupil_v, rhs.hasAllocatedPupil() + ) + self._time = _cond_set_indices( + self._time, rhs.time, rhs.hasAllocatedTimes() + ) + + if do_xy or do_flux or do_other: + self._nokeep = self._nokeep.at[s1].set(rhs._nokeep.at[s2].get()) + + return self._sort_by_nokeep() + def _assign_from_categorical_index(self, cat_inds, cat_ind_to_assign, rhs): """Assign the contents of another `PhotonArray` to this one at locations where cat_ind == cat_ind_to_assign. diff --git a/tests/GalSim b/tests/GalSim index 7d6923e2..eecef976 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 7d6923e2ade53506ce972dd336e25f777ca4ba3c +Subproject commit eecef976143704e0e017aa74b50abd7967e1c520 diff --git a/tests/conftest.py b/tests/conftest.py index a4edcff9..f12f9758 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,6 +15,8 @@ import jax_galsim # noqa: E402 +os.environ["JAX_GALSIM_TESTING"] = "1" + # Identify the path to this current file test_directory = os.path.dirname(os.path.abspath(__file__)) From 68b2f7ecf399cfe4b677afad47f79bf0fc7c315a Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 27 Nov 2023 13:10:38 -0600 Subject: [PATCH 39/85] REF get rid of warning --- jax_galsim/deprecated.py | 22 +++++++++++++++++++++- jax_galsim/photon_array.py | 6 ++---- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/jax_galsim/deprecated.py b/jax_galsim/deprecated.py index d537605a..62345a3e 100644 --- a/jax_galsim/deprecated.py +++ b/jax_galsim/deprecated.py @@ -1 +1,21 @@ -from galsim.deprecated import depr # noqa: F401 +import warnings + +import galsim as _galsim +from jax._src.numpy.util import _wraps + +from jax_galsim.errors import GalSimDeprecationWarning + + +@_wraps( + _galsim.deprecated.depr, + lax_description="""\ +The JAX version of this function uses `stacklevel=3` to show where the +warning is generated.""", +) +def depr(f, v, s1, s2=None): + s = str(f) + " has been deprecated since GalSim version " + str(v) + "." + if s1: + s += " Use " + s1 + " instead." + if s2: + s += " " + s2 + warnings.warn(s, GalSimDeprecationWarning, stacklevel=3) diff --git a/jax_galsim/photon_array.py b/jax_galsim/photon_array.py index 041c30b7..1d6b136a 100644 --- a/jax_galsim/photon_array.py +++ b/jax_galsim/photon_array.py @@ -617,7 +617,7 @@ def convolve(self, rhs, rng=None): nrsinds = jnp.arange(self.size()) sinds = jax.lax.cond( - jnp.array(self.isCorrelated()) & jnp.array(rhs.isCorrelated()), + self._is_corr & rhs._is_corr, lambda nrsinds, rsinds: rsinds.at[ jnp.argsort(rhs._nokeep.at[rsinds].get()) ].get(), @@ -678,9 +678,7 @@ def convolve(self, rhs, rng=None): sinds, ) - self.setCorrelated( - jnp.array(self.isCorrelated()) | jnp.array(rhs.isCorrelated()) - ) + self._is_corr = self._is_corr | rhs._is_corr self._x = self._x + rhs._x.at[sinds].get() self._y = self._y + rhs._y.at[sinds].get() From 9b9f030d9f2096bd20514611cc653b2c456d2ae7 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 27 Nov 2023 13:26:44 -0600 Subject: [PATCH 40/85] TST new test suite --- tests/GalSim | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/GalSim b/tests/GalSim index eecef976..fc4e74fe 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit eecef976143704e0e017aa74b50abd7967e1c520 +Subproject commit fc4e74fe3beb18ec9912eefd28768464973aa369 From 40167b9544669c7f751e4fa595cfd14d75eeaa0c Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 27 Nov 2023 13:35:40 -0600 Subject: [PATCH 41/85] TST new test suite --- tests/GalSim | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/GalSim b/tests/GalSim index fc4e74fe..75dfbcbd 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit fc4e74fe3beb18ec9912eefd28768464973aa369 +Subproject commit 75dfbcbd4b269e71fd25beb5195f6b84e8090e64 From b222992981ff3f16dbae08c66902fcae03d2706d Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 27 Nov 2023 14:58:53 -0600 Subject: [PATCH 42/85] TST new test suite and better deprecation warnings --- jax_galsim/wcs.py | 11 +++++++++-- tests/GalSim | 2 +- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/jax_galsim/wcs.py b/jax_galsim/wcs.py index b95d0ad9..c55b3e2e 100644 --- a/jax_galsim/wcs.py +++ b/jax_galsim/wcs.py @@ -147,11 +147,11 @@ def affine(self, image_pos=None, world_pos=None, color=None): image_pos = PositionD(0, 0) if self._isCelestial: - return jac.withOrigin(image_pos) + return jac.shiftOrigin(image_pos) else: if world_pos is None: world_pos = self.toWorld(image_pos, color=color) - return jac.withOrigin(image_pos, world_pos, color=color) + return jac.shiftOrigin(image_pos, world_pos, color=color) @_wraps(_galsim.BaseWCS.shiftOrigin) def shiftOrigin(self, origin, world_origin=None, color=None): @@ -161,6 +161,13 @@ def shiftOrigin(self, origin, world_origin=None, color=None): raise TypeError("origin must be a PositionD or PositionI argument") return self._shiftOrigin(origin, world_origin, color) + @_wraps(_galsim.BaseWCS.withOrigin) + def withOrigin(self, origin, world_origin=None, color=None): + from .deprecated import depr + + depr("withOrigin", 2.3, "shiftOrigin") + return self.shiftOrigin(origin, world_origin, color) + # A lot of classes will need these checks, so consolidate them here def _set_origin(self, origin, world_origin=None): if origin is None: diff --git a/tests/GalSim b/tests/GalSim index 75dfbcbd..e003b5aa 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 75dfbcbd4b269e71fd25beb5195f6b84e8090e64 +Subproject commit e003b5aa2b00629cba7e6c48758475bcece872bb From a87846ed2d6eeb8990e12f953c8beded9f5770a4 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 27 Nov 2023 17:50:06 -0600 Subject: [PATCH 43/85] TST new test suite --- tests/GalSim | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/GalSim b/tests/GalSim index e003b5aa..5c716c8a 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit e003b5aa2b00629cba7e6c48758475bcece872bb +Subproject commit 5c716c8accd5172a52b50ce3ef170902a9078196 From fd00e94f59349ca5ed2188020344e66513dd4c48 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 27 Nov 2023 17:55:14 -0600 Subject: [PATCH 44/85] TST update test suite --- tests/GalSim | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/GalSim b/tests/GalSim index 5c716c8a..88103b53 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 5c716c8accd5172a52b50ce3ef170902a9078196 +Subproject commit 88103b53dcdc60606775e7199da9ec9b1849f49a From 05ece3a92961f71ddd83d73914c3eef54b2c8489 Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 28 Nov 2023 00:18:55 -0600 Subject: [PATCH 45/85] REF put bounds back --- jax_galsim/bounds.py | 394 +++++++++----------------------- jax_galsim/gsobject.py | 22 +- jax_galsim/image.py | 53 ++--- jax_galsim/interpolatedimage.py | 2 +- jax_galsim/photon_array.py | 2 +- jax_galsim/sensor.py | 2 +- 6 files changed, 140 insertions(+), 335 deletions(-) diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index f89a495f..3f99f34d 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -1,37 +1,24 @@ import galsim as _galsim import jax import jax.numpy as jnp -import numpy as np from jax._src.numpy.util import _wraps from jax.tree_util import register_pytree_node_class -from jax_galsim.core.utils import ( - cast_to_float, - cast_to_int, - ensure_hashable, - has_tracers, -) -from jax_galsim.errors import GalSimUndefinedBoundsError +from jax_galsim.core.utils import cast_to_float, cast_to_int, ensure_hashable from jax_galsim.position import Position, PositionD, PositionI +BOUNDS_LAX_DESCR = """\ +The JAX implementation + + - will not always test whether the bounds are valid + - will not always test whether BoundsI is initialized with integers +""" + # The reason for avoid these tests is that they are not easy to do for jitted code. -@_wraps( - _galsim.Bounds, - lax_description="""\ -"The JAX implementation of galsim.Bounds - - - will not always test for properly defined bounds, especially in jitted code - - will not test whether BoundsI is indeed initialized with integers during vmap/jit/grad transforms -""", -) +@_wraps(_galsim.Bounds, lax_description=BOUNDS_LAX_DESCR) @register_pytree_node_class -class Bounds(object): - def __init__(self): - raise NotImplementedError( - "Cannot instantiate the base class. " "Use either BoundsD or BoundsI." - ) - +class Bounds(_galsim.Bounds): def _parse_args(self, *args, **kwargs): if len(kwargs) == 0: if len(args) == 4: @@ -39,7 +26,7 @@ def _parse_args(self, *args, **kwargs): self.xmin, self.xmax, self.ymin, self.ymax = args elif len(args) == 0: self._isdefined = False - self.xmin = self.xmax = self.ymin = self.ymax = jnp.nan + self.xmin = self.xmax = self.ymin = self.ymax = 0 elif len(args) == 1: if isinstance(args[0], Bounds): self._isdefined = True @@ -94,109 +81,57 @@ def _parse_args(self, *args, **kwargs): if kwargs: raise TypeError("Got unexpected keyword arguments %s" % kwargs.keys()) - @_wraps(_galsim.Bounds.isDefined) - def isDefined(self, _static=False): - if _static: - return ( - self._isdefined - and np.all(self.xmin <= self.xmax) - and np.all(self.ymin <= self.ymax) - ) - else: - return ( - jnp.isfinite(self.xmin) - & jnp.isfinite(self.xmax) - & jnp.isfinite(self.ymin) - & jnp.isfinite(self.ymax) - & (self.xmin <= self.xmax) - & (self.ymin <= self.ymax) - ) - - def area(self): - """Return the area of the enclosed region. - - The area is a bit different for integer-type `BoundsI` and float-type `BoundsD` instances. - For floating point types, it is simply ``(xmax-xmin)*(ymax-ymin)``. However, for integer - types, we add 1 to each size to correctly count the number of pixels being described by the - bounding box. - """ - return self._area() - - def withBorder(self, dx, dy=None): - """Return a new `Bounds` object that expands the current bounds by the specified width. - - If two arguments are given, then these are separate dx and dy borders. - """ - self._check_scalar(dx, "dx") - if dy is None: - dy = dx - else: - self._check_scalar(dy, "dy") - return self.__class__( - self.xmin - dx, self.xmax + dx, self.ymin - dy, self.ymax + dy - ) + # for simple inputs, we can check if the bounds are valid + if ( + isinstance(self.xmin, (float, int)) + and isinstance(self.xmax, (float, int)) + and isinstance(self.ymin, (float, int)) + and isinstance(self.ymax, (float, int)) + and ((self.xmin > self.xmax) or (self.ymin > self.ymax)) + ): + self._isdefined = False @property - def origin(self): - "The lower left position of the `Bounds`." - return self._pos_class(self.xmin, self.ymin) - - @property - @_wraps( - _galsim.Bounds.center, - lax_description="The JAX implementation of galsim.Bounds.center does not raise for undefined bounds.", - ) - def center(self): - if not self.isDefined(_static=True): - raise GalSimUndefinedBoundsError( - "center is invalid for an undefined Bounds" - ) - return self._center - - @property - @_wraps( - _galsim.Bounds.true_center, - lax_description="The JAX implementation of galsim.Bounds.true_center does not raise for undefined bounds.", - ) def true_center(self): - if not self.isDefined(_static=True): - raise GalSimUndefinedBoundsError( + """The central position of the `Bounds` as a `PositionD`. + + This is always (xmax + xmin)/2., (ymax + ymin)/2., even for integer `BoundsI`, where + this may not necessarily be an integer `PositionI`. + """ + if not self.isDefined(): + raise _galsim.GalSimUndefinedBoundsError( "true_center is invalid for an undefined Bounds" ) return PositionD((self.xmax + self.xmin) / 2.0, (self.ymax + self.ymin) / 2.0) @_wraps(_galsim.Bounds.includes) - def includes(self, *args, _static=False): + def includes(self, *args): if len(args) == 1: if isinstance(args[0], Bounds): b = args[0] return ( - self.isDefined(_static=_static) - & b.isDefined(_static=_static) - & (self.xmin <= b.xmin) - & (self.xmax >= b.xmax) - & (self.ymin <= b.ymin) - & (self.ymax >= b.ymax) + self.isDefined() + and b.isDefined() + and self.xmin <= b.xmin + and self.xmax >= b.xmax + and self.ymin <= b.ymin + and self.ymax >= b.ymax ) elif isinstance(args[0], Position): p = args[0] return ( - self.isDefined(_static=_static) - & (self.xmin <= p.x) - & (p.x <= self.xmax) - & (self.ymin <= p.y) - & (p.y <= self.ymax) + self.isDefined() + and self.xmin <= p.x <= self.xmax + and self.ymin <= p.y <= self.ymax ) else: raise TypeError("Invalid argument %s" % args[0]) elif len(args) == 2: x, y = args return ( - self.isDefined(_static=_static) - & (self.xmin <= x) - & (x <= self.xmax) - & (self.ymin <= y) - & (y <= self.ymax) + self.isDefined() + and self.xmin <= float(x) <= self.xmax + and self.ymin <= float(y) <= self.ymax ) elif len(args) == 0: raise TypeError("include takes at least 1 argument (0 given)") @@ -214,108 +149,42 @@ def expand(self, factor_x, factor_y=None): dy = jnp.ceil(dy) return self.withBorder(dx, dy) - def getXMin(self): - "Get the value of xmin." - return self.xmin - - def getXMax(self): - "Get the value of xmax." - return self.xmax - - def getYMin(self): - "Get the value of ymin." - return self.ymin - - def getYMax(self): - "Get the value of ymax." - return self.ymax - - def shift(self, delta): - """Shift the `Bounds` instance by a supplied `Position`. - - Examples: - - The shift method takes either a `PositionI` or `PositionD` instance, which must match - the type of the `Bounds` instance:: - - >>> bounds = BoundsI(1,32,1,32) - >>> bounds = bounds.shift(galsim.PositionI(3, 2)) - >>> bounds = BoundsD(0, 37.4, 0, 49.9) - >>> bounds = bounds.shift(galsim.PositionD(3.9, 2.1)) - """ - if not isinstance(delta, self._pos_class): - raise TypeError("delta must be a %s instance" % self._pos_class) - return self.__class__( - self.xmin + delta.x, - self.xmax + delta.x, - self.ymin + delta.y, - self.ymax + delta.y, - ) - def __and__(self, other): if not isinstance(other, self.__class__): raise TypeError("other must be a %s instance" % self.__class__.__name__) - # NaNs always propagate, so if either is undefined, the result is undefined - return self.__class__( - jnp.maximum(self.xmin, other.xmin), - jnp.minimum(self.xmax, other.xmax), - jnp.maximum(self.ymin, other.ymin), - jnp.minimum(self.ymax, other.ymax), - ) + if not self.isDefined() or not other.isDefined(): + return self.__class__() + else: + xmin = jnp.maximum(self.xmin, other.xmin) + xmax = jnp.minimum(self.xmax, other.xmax) + ymin = jnp.maximum(self.ymin, other.ymin) + ymax = jnp.minimum(self.ymax, other.ymax) + if xmin > xmax or ymin > ymax: + return self.__class__() + else: + return self.__class__(xmin, xmax, ymin, ymax) def __add__(self, other): if isinstance(other, self.__class__): - # galsim logic is - # if not other.isDefined(): - # return self - # elif self.isDefined(): - # xmin = jnp.minimum(self.xmin, other.xmin) - # xmax = jnp.maximum(self.xmax, other.xmax) - # ymin = jnp.minimum(self.ymin, other.ymin) - # ymax = jnp.maximum(self.ymax, other.ymax) - # return self.__class__(xmin, xmax, ymin, ymax) - # else: - # return other - return self.__class__( - jax.lax.cond( - ~jnp.any(other.isDefined()), - lambda: BoundsD(self), - lambda: BoundsD( - jax.lax.cond( - jnp.any(self.isDefined()), - lambda: BoundsD( - jnp.minimum(self.xmin, other.xmin), - jnp.maximum(self.xmax, other.xmax), - jnp.minimum(self.ymin, other.ymin), - jnp.maximum(self.ymax, other.ymax), - ), - lambda: BoundsD(other), - ) - ), - ) - ) + if not other.isDefined(): + return self + elif self.isDefined(): + xmin = jnp.minimum(self.xmin, other.xmin) + xmax = jnp.maximum(self.xmax, other.xmax) + ymin = jnp.minimum(self.ymin, other.ymin) + ymax = jnp.maximum(self.ymax, other.ymax) + return self.__class__(xmin, xmax, ymin, ymax) + else: + return other elif isinstance(other, self._pos_class): - # the galsim logic is - # if self.isDefined(): - # xmin = jnp.minimum(self.xmin, other.x) - # xmax = jnp.maximum(self.xmax, other.x) - # ymin = jnp.minimum(self.ymin, other.y) - # ymax = jnp.maximum(self.ymax, other.y) - # return self.__class__(xmin, xmax, ymin, ymax) - # else: - # return self.__class__(other) - return self.__class__( - jax.lax.cond( - jnp.any(self.isDefined()), - lambda: BoundsD( - jnp.minimum(self.xmin, other.x), - jnp.maximum(self.xmax, other.x), - jnp.minimum(self.ymin, other.y), - jnp.maximum(self.ymax, other.y), - ), - lambda: BoundsD(other), - ) - ) + if self.isDefined(): + xmin = jnp.minimum(self.xmin, other.x) + xmax = jnp.maximum(self.xmax, other.x) + ymin = jnp.minimum(self.ymin, other.y) + ymax = jnp.maximum(self.ymax, other.y) + return self.__class__(xmin, xmax, ymin, ymax) + else: + return self.__class__(other) else: raise TypeError( "other must be either a %s or a %s" @@ -323,7 +192,7 @@ def __add__(self, other): ) def __repr__(self): - if self.isDefined(_static=True): + if self.isDefined(): return "galsim.%s(xmin=%r, xmax=%r, ymin=%r, ymax=%r)" % ( self.__class__.__name__, ensure_hashable(self.xmin), @@ -335,7 +204,7 @@ def __repr__(self): return "galsim.%s()" % (self.__class__.__name__) def __str__(self): - if self.isDefined(_static=True): + if self.isDefined(): return "galsim.%s(%s,%s,%s,%s)" % ( self.__class__.__name__, ensure_hashable(self.xmin), @@ -357,37 +226,14 @@ def __hash__(self): ) ) - def _getinitargs(self): - if self.isDefined(_static=True): - return (self.xmin, self.xmax, self.ymin, self.ymax) - else: - return () - - def __eq__(self, other): - return self is other or ( - isinstance(other, self.__class__) - and ( - ( - np.array_equal(self.xmin, other.xmin, equal_nan=True) - and np.array_equal(self.xmax, other.xmax, equal_nan=True) - and np.array_equal(self.ymin, other.ymin, equal_nan=True) - and np.array_equal(self.ymax, other.ymax, equal_nan=True) - ) - or ( - (not self.isDefined(_static=True)) - and (not other.isDefined(_static=True)) - ) - ) - ) - - def __ne__(self, other): - return not self.__eq__(other) - def tree_flatten(self): """This function flattens the Bounds into a list of children nodes that will be traced by JAX and auxiliary static data.""" # Define the children nodes of the PyTree that need tracing - children = (self.xmin, self.xmax, self.ymin, self.ymax) + if self.isDefined(): + children = (self.xmin, self.xmax, self.ymin, self.ymax) + else: + children = tuple() # Define auxiliary static data that doesn’t need to be traced aux_data = None return (children, aux_data) @@ -409,21 +255,18 @@ def from_galsim(cls, galsim_bounds): "galsim_bounds must be either a %s or a %s" % (_galsim.BoundsD.__name__, _galsim.BoundsI.__name__) ) - if not galsim_bounds.isDefined(): - return _cls() - else: + if galsim_bounds.isDefined(): return _cls( galsim_bounds.xmin, galsim_bounds.xmax, galsim_bounds.ymin, galsim_bounds.ymax, ) + else: + return _cls() -@_wraps( - _galsim.BoundsD, - lax_description="The JAX implementation of galsim.BoundsD does not always check for float values.", -) +@_wraps(_galsim.BoundsD, lax_description=BOUNDS_LAX_DESCR) @register_pytree_node_class class BoundsD(Bounds): _pos_class = PositionD @@ -450,47 +293,33 @@ def _check_scalar(self, x, name): raise TypeError("%s must be a float value" % name) def _area(self): - return jax.lax.cond( - jnp.any(self.isDefined()), - lambda xmin, xmax, ymin, ymax: (xmax - xmin) * (ymax - ymin), - lambda xmin, xmax, ymin, ymax: jnp.zeros_like(xmin), - self.xmin, - self.xmax, - self.ymin, - self.ymax, - ) + return (self.xmax - self.xmin) * (self.ymax - self.ymin) @property def _center(self): return PositionD((self.xmax + self.xmin) / 2.0, (self.ymax + self.ymin) / 2.0) -@_wraps( - _galsim.BoundsI, - lax_description="The JAX implementation of galsim.BoundsI does not always check for integer values.", -) +@_wraps(_galsim.BoundsI, lax_description=BOUNDS_LAX_DESCR) @register_pytree_node_class class BoundsI(Bounds): _pos_class = PositionI def __init__(self, *args, **kwargs): self._parse_args(*args, **kwargs) - - # best-effort error checking - raise_notint = False - try: - bnds = (self.xmin, self.xmax, self.ymin, self.ymax) - if not has_tracers(bnds) and np.all(np.isfinite(bnds)) & np.any( - (self.xmin != np.floor(self.xmin)) - | (self.xmax != np.floor(self.xmax)) - | (self.ymin != np.floor(self.ymin)) - | (self.ymax != np.floor(self.ymax)) - ): - raise_notint = True - except Exception: - pass - - if raise_notint: + # for simple inputs, we can check if the bounds are valid ints + if ( + isinstance(self.xmin, (float, int)) + and isinstance(self.xmax, (float, int)) + and isinstance(self.ymin, (float, int)) + and isinstance(self.ymax, (float, int)) + and ( + self.xmin != int(self.xmin) + or self.xmax != int(self.xmax) + or self.ymin != int(self.ymin) + or self.ymax != int(self.ymax) + ) + ): raise TypeError("BoundsI must be initialized with integer values") self.xmin = cast_to_int(self.xmin) @@ -512,38 +341,19 @@ def _check_scalar(self, x, name): pass raise TypeError("%s must be an integer value" % name) - def numpyShape(self, _static=False): + def numpyShape(self): "A simple utility function to get the numpy shape that corresponds to this `Bounds` object." - if _static: - if self.isDefined(_static=True): - return (self.ymax - self.ymin + 1, self.xmax - self.xmin + 1) - else: - return (0, 0) + if self.isDefined(): + return self.ymax - self.ymin + 1, self.xmax - self.xmin + 1 else: - return jax.lax.cond( - jnp.any(self.isDefined()), - lambda xmin, xmax, ymin, ymax: (ymax - ymin + 1, xmax - xmin + 1), - lambda xmin, xmax, ymin, ymax: ( - jnp.zeros_like(xmin), - jnp.zeros_like(xmin), - ), - self.xmin, - self.xmax, - self.ymin, - self.ymax, - ) + return 0, 0 def _area(self): # Remember the + 1 this time to include the pixels on both edges of the bounds. - return jax.lax.cond( - jnp.any(self.isDefined()), - lambda xmin, xmax, ymin, ymax: (xmax - xmin + 1) * (ymax - ymin + 1), - lambda xmin, xmax, ymin, ymax: jnp.zeros_like(xmin), - self.xmin, - self.xmax, - self.ymin, - self.ymax, - ) + if not self.isDefined(): + return 0 + else: + return (self.xmax - self.xmin + 1) * (self.ymax - self.ymin + 1) @property def _center(self): diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index 6dafc034..624abefb 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -445,7 +445,7 @@ def _setup_image( ) # Resize the given image if necessary - if not image.bounds.isDefined(_static=True): + if not image.bounds.isDefined(): # Can't add to image if need to resize if add_to_image: raise _galsim.GalSimIncompatibleValuesError( @@ -478,7 +478,7 @@ def _setup_image( ny=ny, bounds=bounds, ) - if not bounds.isDefined(_static=True): + if not bounds.isDefined(): raise _galsim.GalSimValueError( "Cannot use undefined bounds", bounds ) @@ -515,7 +515,7 @@ def _local_wcs(self, wcs, image, offset, center, use_true_center, new_bounds): bounds = new_bounds else: bounds = image.bounds - if not bounds.isDefined(_static=True): + if not bounds.isDefined(): raise _galsim.GalSimIncompatibleValuesError( "Cannot provide non-local wcs with automatically sized image", wcs=wcs, @@ -556,7 +556,7 @@ def _parse_center(self, center): def _get_new_bounds(self, image, nx, ny, bounds, center): from jax_galsim.bounds import BoundsI - if image is not None and image.bounds.isDefined(_static=True): + if image is not None and image.bounds.isDefined(): return image.bounds elif nx is not None and ny is not None: b = BoundsI(1, nx, 1, ny) @@ -568,7 +568,7 @@ def _get_new_bounds(self, image, nx, ny, bounds, center): ) ) return b - elif bounds is not None and bounds.isDefined(_static=True): + elif bounds is not None and bounds.isDefined(): return bounds else: return BoundsI() @@ -576,7 +576,7 @@ def _get_new_bounds(self, image, nx, ny, bounds, center): def _adjust_offset(self, new_bounds, offset, center, use_true_center): # Note: this assumes self is in terms of image coordinates. if center is not None: - if new_bounds.isDefined(_static=True): + if new_bounds.isDefined(): offset += center - new_bounds.center else: # Then will be created as even sized image. @@ -590,7 +590,7 @@ def _adjust_offset(self, new_bounds, offset, center, use_true_center): # Also, remember that numpy's shape is ordered as [y,x] dx = offset.x dy = offset.y - shape = new_bounds.numpyShape(_static=True) + shape = new_bounds.numpyShape() dx -= 0.5 * ((shape[1] + 1) % 2) dy -= 0.5 * ((shape[0] + 1) % 2) @@ -918,7 +918,7 @@ def drawFFT_makeKImage(self, image): jnp.array( [ jnp.max(jnp.abs(jnp.array(image.bounds._getinitargs()))) * 2, - jnp.max(jnp.array(image.bounds.numpyShape(_static=True))), + jnp.max(jnp.array(image.bounds.numpyShape())), ] ) ) @@ -984,7 +984,7 @@ def drawFFT_finish(self, image, kimage, wrap_size, add_to_image): ) kimg_shift = jnp.fft.ifftshift(kimage_wrap.array, axes=(-2,)) real_image_arr = jnp.fft.fftshift( - jnp.fft.irfft2(kimg_shift, breal.numpyShape(_static=True)) + jnp.fft.irfft2(kimg_shift, breal.numpyShape()) ) real_image = Image( bounds=breal, array=real_image_arr, dtype=image.dtype, wcs=image.wcs @@ -1069,7 +1069,7 @@ def drawKImage( dk = self.stepk else: dk = scale - if image is not None and image.bounds.isDefined(_static=True): + if image is not None and image.bounds.isDefined(): dx = np.pi / (max(image.array.shape) // 2 * dk) elif scale is None or scale <= 0: dx = self.nyquist_scale @@ -1081,7 +1081,7 @@ def drawKImage( # If the profile needs to be constructed from scratch, the _setup_image function will # do that, but only if the profile is in image coordinates for the real space image. # So make that profile. - if image is None or not image.bounds.isDefined(_static=True): + if image is None or not image.bounds.isDefined(): real_prof = PixelScale(dx).profileToImage(self) dtype = np.complex128 if image is None else image.dtype image = real_prof._setup_image( diff --git a/jax_galsim/image.py b/jax_galsim/image.py index d249fbf6..cfdad1a5 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -188,9 +188,7 @@ def __init__(self, *args, **kwargs): elif bounds is not None: if not isinstance(bounds, BoundsI): raise TypeError("bounds must be a galsim.BoundsI instance") - self._array = self._make_empty( - bounds.numpyShape(_static=True), dtype=self._dtype - ) + self._array = self._make_empty(bounds.numpyShape(), dtype=self._dtype) self._bounds = bounds if init_value: self._array = self._array + init_value @@ -263,7 +261,7 @@ def _get_xmin_ymin(array, kwargs, check_bounds=True): raise TypeError("bounds must be a galsim.BoundsI instance") # we use the static bounds check here since we cannot raise errors in jitted # code anyways - if check_bounds and b.isDefined(_static=True): + if check_bounds and b.isDefined(): # We need to disable this when jitting if b.xmax - b.xmin + 1 != array.shape[1]: raise _galsim.GalSimIncompatibleValuesError( @@ -281,7 +279,7 @@ def _get_xmin_ymin(array, kwargs, check_bounds=True): # this statement in JAX would change array sizes and so is not supported # so instead we check the static property set on construction for whether # the bounds are defined - if b.isDefined(_static=True): + if b.isDefined(): xmin = b.xmin ymin = b.ymin else: @@ -302,7 +300,7 @@ def _get_xmin_ymin(array, kwargs, check_bounds=True): def __repr__(self): s = "galsim.Image(bounds=%r" % self.bounds - if self.bounds.isDefined(_static=True): + if self.bounds.isDefined(): s += ", array=\n%r" % np.array(self.array) s += ", wcs=%r" % self.wcs if self.isconst: @@ -522,9 +520,7 @@ def resize(self, bounds, wcs=None): raise GalSimImmutableError("Cannot modify an immutable Image", self) if not isinstance(bounds, BoundsI): raise TypeError("bounds must be a galsim.BoundsI instance") - self._array = self._make_empty( - shape=bounds.numpyShape(_static=True), dtype=self.dtype - ) + self._array = self._make_empty(shape=bounds.numpyShape(), dtype=self.dtype) self._bounds = bounds if wcs is not None: self.wcs = wcs @@ -536,11 +532,11 @@ def subImage(self, bounds): """ if not isinstance(bounds, BoundsI): raise TypeError("bounds must be a galsim.BoundsI instance") - if not self.bounds.isDefined(_static=True): + if not self.bounds.isDefined(): raise _galsim.GalSimUndefinedBoundsError( "Attempt to access subImage of undefined image" ) - if not self.bounds.includes(bounds, _static=True): + if not self.bounds.includes(bounds): raise _galsim.GalSimBoundsError( "Attempt to access subImage not (fully) in image", bounds, self.bounds ) @@ -564,17 +560,17 @@ def setSubImage(self, bounds, rhs): raise GalSimImmutableError("Cannot modify an immutable Image", self) if not isinstance(bounds, BoundsI): raise TypeError("bounds must be a galsim.BoundsI instance") - if not self.bounds.isDefined(_static=True): + if not self.bounds.isDefined(): raise _galsim.GalSimUndefinedBoundsError( "Attempt to access values of an undefined image" ) - if not self.bounds.includes(bounds, _static=True): + if not self.bounds.includes(bounds): raise _galsim.GalSimBoundsError( "Attempt to access subImage not (fully) in image", bounds, self.bounds ) if not isinstance(rhs, Image): raise TypeError("Trying to copyFrom a non-image") - if bounds.numpyShape(_static=True) != rhs.bounds.numpyShape(_static=True): + if bounds.numpyShape() != rhs.bounds.numpyShape(): raise _galsim.GalSimIncompatibleValuesError( "Trying to copy images that are not the same shape", self_image=self, @@ -726,7 +722,7 @@ def _wrap(self, bounds, hermx, hermy): lax_description="JAX-GalSim does not support forward FFTs of complex dtypes.", ) def calculate_fft(self): - if not self.bounds.isDefined(_static=True): + if not self.bounds.isDefined(): raise _galsim.GalSimUndefinedBoundsError( "calculate_fft requires that the image have defined bounds." ) @@ -778,7 +774,7 @@ def calculate_fft(self): @_wraps(_galsim.Image.calculate_inverse_fft) def calculate_inverse_fft(self): - if not self.bounds.isDefined(_static=True): + if not self.bounds.isDefined(): raise _galsim.GalSimUndefinedBoundsError( "calculate_fft requires that the image have defined bounds." ) @@ -790,7 +786,7 @@ def calculate_inverse_fft(self): raise _galsim.GalSimError( "calculate_inverse_fft requires that the image has a PixelScale wcs." ) - if not self.bounds.includes(0, 0, _static=True): + if not self.bounds.includes(0, 0): raise _galsim.GalSimBoundsError( "calculate_inverse_fft requires that the image includes (0,0)", PositionI(0, 0), @@ -861,7 +857,7 @@ def copyFrom(self, rhs): raise GalSimImmutableError("Cannot modify an immutable Image", self) if not isinstance(rhs, Image): raise TypeError("Trying to copyFrom a non-image") - if self.bounds.numpyShape(_static=True) != rhs.bounds.numpyShape(_static=True): + if self.bounds.numpyShape() != rhs.bounds.numpyShape(): raise _galsim.GalSimIncompatibleValuesError( "Trying to copy images that are not the same shape", self_image=self, @@ -911,7 +907,7 @@ def view( # If currently empty, just return a new empty image. # we use the static bounds check set at construction # since the dynamic one in JAX would change array shape - if not self.bounds.isDefined(_static=True): + if not self.bounds.isDefined(): return Image(wcs=wcs, dtype=dtype, make_const=make_const) # Recast the array type if necessary @@ -985,11 +981,11 @@ def __call__(self, *args, **kwargs): @_wraps(_galsim.Image.getValue) def getValue(self, x, y): - if not self.bounds.isDefined(_static=True): + if not self.bounds.isDefined(): raise _galsim.GalSimUndefinedBoundsError( "Attempt to access values of an undefined image" ) - if not self.bounds.includes(x, y, _static=True): + if not self.bounds.includes(x, y): raise _galsim.GalSimBoundsError( "Attempt to access position not in bounds of image.", PositionI(x, y), @@ -1007,14 +1003,14 @@ def _getValue(self, x, y): def setValue(self, *args, **kwargs): if self.isconst: raise GalSimImmutableError("Cannot modify an immutable Image", self) - if not self.bounds.isDefined(_static=True): + if not self.bounds.isDefined(): raise _galsim.GalSimUndefinedBoundsError( "Attempt to set value of an undefined image" ) pos, value = parse_pos_args( args, kwargs, "x", "y", integer=True, others=["value"] ) - if not self.bounds.includes(pos, _static=True): + if not self.bounds.includes(pos): raise _galsim.GalSimBoundsError( "Attempt to set position not in bounds of image", pos, self.bounds ) @@ -1035,14 +1031,14 @@ def _setValue(self, x, y, value): def addValue(self, *args, **kwargs): if self.isconst: raise GalSimImmutableError("Cannot modify an immutable Image", self) - if not self.bounds.isDefined(_static=True): + if not self.bounds.isDefined(): raise _galsim.GalSimUndefinedBoundsError( "Attempt to set value of an undefined image" ) pos, value = parse_pos_args( args, kwargs, "x", "y", integer=True, others=["value"] ) - if not self.bounds.includes(pos, _static=True): + if not self.bounds.includes(pos): raise _galsim.GalSimBoundsError( "Attempt to set position not in bounds of image", pos, self.bounds ) @@ -1067,7 +1063,7 @@ def fill(self, value): """ if self.isconst: raise GalSimImmutableError("Cannot modify an immutable Image", self) - if not self.bounds.isDefined(_static=True): + if not self.bounds.isDefined(): raise _galsim.GalSimUndefinedBoundsError( "Attempt to set values of an undefined image" ) @@ -1091,7 +1087,7 @@ def invertSelf(self): """ if self.isconst: raise GalSimImmutableError("Cannot modify an immutable Image", self) - if not self.bounds.isDefined(_static=True): + if not self.bounds.isDefined(): raise _galsim.GalSimUndefinedBoundsError( "Attempt to set values of an undefined image" ) @@ -1130,8 +1126,7 @@ def __eq__(self, other): and self.bounds == other.bounds and self.wcs == other.wcs and ( - not self.bounds.isDefined(_static=True) - or jnp.array_equal(self.array, other.array) + not self.bounds.isDefined() or jnp.array_equal(self.array, other.array) ) and self.isconst == other.isconst ) diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index 3be49772..3215daf0 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -427,7 +427,7 @@ def __init__( ) # it must have well-defined bounds, otherwise seg fault in SBInterpolatedImage constructor - if not image.bounds.isDefined(_static=True): + if not image.bounds.isDefined(): raise GalSimUndefinedBoundsError( "Supplied image does not have bounds defined." ) diff --git a/jax_galsim/photon_array.py b/jax_galsim/photon_array.py index 1d6b136a..1a2b7f5a 100644 --- a/jax_galsim/photon_array.py +++ b/jax_galsim/photon_array.py @@ -741,7 +741,7 @@ def __ne__(self, other): lax_description="The JAX equivalent of galsim.PhotonArray.addTo may not raise for undefined bounds.", ) def addTo(self, image): - if not image.bounds.isDefined(_static=True): + if not image.bounds.isDefined(): raise GalSimUndefinedBoundsError( "Attempting to PhotonArray::addTo an Image with undefined Bounds" ) diff --git a/jax_galsim/sensor.py b/jax_galsim/sensor.py index 27f3552e..f588f1d2 100644 --- a/jax_galsim/sensor.py +++ b/jax_galsim/sensor.py @@ -17,7 +17,7 @@ def __init__(self): lax_description="The JAX equivalent of galsim.Sensor.accumulate does not raise for undefined bounds.", ) def accumulate(self, photons, image, orig_center=None, resume=False): - if not image.bounds.isDefined(_static=True): + if not image.bounds.isDefined(): raise GalSimUndefinedBoundsError( "Calling accumulate on image with undefined bounds" ) From f4573c6319f087814fa7d77ed4a3f71973f48066 Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 28 Nov 2023 00:27:11 -0600 Subject: [PATCH 46/85] REF centralize float and int casts --- jax_galsim/angle.py | 4 +--- jax_galsim/core/utils.py | 35 ++++++++++++++++++++++++++++------- jax_galsim/position.py | 4 ++-- 3 files changed, 31 insertions(+), 12 deletions(-) diff --git a/jax_galsim/angle.py b/jax_galsim/angle.py index 8bfdc8af..cbb04c54 100644 --- a/jax_galsim/angle.py +++ b/jax_galsim/angle.py @@ -36,9 +36,7 @@ def __init__(self, value): """ if isinstance(value, AngleUnit): raise TypeError("Cannot construct AngleUnit from another AngleUnit") - self._value = 1.0 * cast_to_float( - value - ) # this will cause an exception if things are not numeric + self._value = cast_to_float(value) @property def value(self): diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index db0d7fe2..6dd252b7 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -54,11 +54,17 @@ def cast_to_python_float(x): try: return float(x) except TypeError: - return x + # this will return the same value for anything float-like that + # cannot be cast to float + # however, it will raise an error if something is not float-like + return 1.0 * x except ValueError as e: # we let NaNs through if " NaN " in str(e): - return x + # this will return the same value for anything float-like that + # cannot be cast to float + # however, it will raise an error if something is not float-like + return 1.0 * x else: raise e @@ -72,11 +78,17 @@ def cast_to_python_int(x): try: return int(x) except TypeError: - return x + # this will return the same value for anything int-like that + # cannot be cast to int + # however, it will raise an error if something is not int-like + return 1 * x except ValueError as e: # we let NaNs through if " NaN " in str(e): - return x + # this will return the same value for anything int-like that + # cannot be cast to int + # however, it will raise an error if something is not int-like + return 1 * x else: raise e @@ -89,7 +101,10 @@ def cast_to_float(x): try: return jnp.asarray(x, dtype=float) except Exception: - return x + # this will return the same value for anything float-like that + # cannot be cast to float + # however, it will raise an error if something is not float-like + return 1.0 * x def cast_to_int(x): @@ -101,9 +116,15 @@ def cast_to_int(x): if not jnp.any(jnp.isnan(x)): return jnp.asarray(x, dtype=int) else: - return x + # this will return the same value for anything int-like that + # cannot be cast to int + # however, it will raise an error if something is not int-like + return 1 * x except Exception: - return x + # this will return the same value for anything int-like that + # cannot be cast to int + # however, it will raise an error if something is not int-like + return 1 * x def is_equal_with_arrays(x, y): diff --git a/jax_galsim/position.py b/jax_galsim/position.py index 01b264f5..8700d764 100644 --- a/jax_galsim/position.py +++ b/jax_galsim/position.py @@ -181,8 +181,8 @@ def __init__(self, *args, **kwargs): self._parse_args(*args, **kwargs) # Force conversion to float type in this case - self.x = 1.0 * cast_to_float(self.x) - self.y = 1.0 * cast_to_float(self.y) + self.x = cast_to_float(self.x) + self.y = cast_to_float(self.y) def _check_scalar(self, other, op): try: From f562a135ea630b81b83e179e5708a84c63fd187f Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 28 Nov 2023 00:29:22 -0600 Subject: [PATCH 47/85] REF remove dead code --- jax_galsim/core/utils.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index 6dd252b7..44f47271 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -7,14 +7,9 @@ def has_tracers(x): - """Return True if the input item is a JAX tracer, False otherwise.""" + """Return True if the input item is a JAX tracer or object, False otherwise.""" for item in tree_flatten(x)[0]: - if ( - isinstance(item, jax.core.Tracer) - or type(item) is object - # or isinstance(item, jax.core.ShapedArray) - # or isinstance(item, str) - ): + if isinstance(item, jax.core.Tracer) or type(item) is object: return True return False From 6c74671a903b7677ef57741f7152f4809057961e Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 28 Nov 2023 00:31:15 -0600 Subject: [PATCH 48/85] DOC fix typo in doc string --- jax_galsim/image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index cfdad1a5..0230e49f 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -55,7 +55,7 @@ class Image(object): valid_dtypes = _valid_dtypes def __init__(self, *args, **kwargs): - # this one is pecific to jax-galsim and is used to disable bounds checking + # this one is specific to jax-galsim and is used to disable bounds checking _check_bounds = kwargs.pop("_check_bounds", True) # Parse the args, kwargs From 785b02697219b681c9d1f945c29bf65aaafec5f9 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Tue, 28 Nov 2023 00:34:06 -0600 Subject: [PATCH 49/85] Apply suggestions from code review --- jax_galsim/image.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 0230e49f..54c5d228 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -259,8 +259,6 @@ def _get_xmin_ymin(array, kwargs, check_bounds=True): b = kwargs.pop("bounds") if not isinstance(b, BoundsI): raise TypeError("bounds must be a galsim.BoundsI instance") - # we use the static bounds check here since we cannot raise errors in jitted - # code anyways if check_bounds and b.isDefined(): # We need to disable this when jitting if b.xmax - b.xmin + 1 != array.shape[1]: @@ -276,9 +274,6 @@ def _get_xmin_ymin(array, kwargs, check_bounds=True): bounds=b, ) - # this statement in JAX would change array sizes and so is not supported - # so instead we check the static property set on construction for whether - # the bounds are defined if b.isDefined(): xmin = b.xmin ymin = b.ymin @@ -867,12 +862,7 @@ def copyFrom(self, rhs): @_wraps( _galsim.Image.view, - lax_description="""\ -Contrary to GalSim, view - - - will create a copy of the orginal image - - will not check for undefined bounds -""", + lax_description="Contrary to GalSim, this will create a copy of the orginal image.", ) def view( self, From 1bc758c7af248db8282f8c8c385a701ccf1d04a6 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Tue, 28 Nov 2023 00:35:42 -0600 Subject: [PATCH 50/85] Apply suggestions from code review --- jax_galsim/image.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 54c5d228..a63c6450 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -895,8 +895,6 @@ def view( dtype = dtype if dtype else self.dtype # If currently empty, just return a new empty image. - # we use the static bounds check set at construction - # since the dynamic one in JAX would change array shape if not self.bounds.isDefined(): return Image(wcs=wcs, dtype=dtype, make_const=make_const) @@ -1149,11 +1147,6 @@ def rot_ccw(self): @_wraps(_galsim.Image.rot_180) def rot_180(self): - """Return a version of the image rotated 180 degrees. - - Note: The returned image will have an undefined wcs. - If you care about the wcs, you will need to set it yourself. - """ return _Image(self.array.at[::-1, ::-1].get(), self._bounds, None) def tree_flatten(self): From a2eddb2e7e8cecb68223c03f862182a084e0f7a2 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Tue, 28 Nov 2023 00:38:29 -0600 Subject: [PATCH 51/85] Apply suggestions from code review --- jax_galsim/interpolatedimage.py | 3 +-- jax_galsim/position.py | 4 ++-- jax_galsim/sensor.py | 5 +---- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index 3215daf0..f213e0ea 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -60,8 +60,7 @@ def __dir__(cls): - noise padding - the pad_image options - depixelize - - most of the type checks and dtype casts done by galsim - - the image bounds are defined + - most of the bounds checks, type checks, and dtype casts done by galsim """ ), ) diff --git a/jax_galsim/position.py b/jax_galsim/position.py index 8700d764..f89d9c4f 100644 --- a/jax_galsim/position.py +++ b/jax_galsim/position.py @@ -206,8 +206,8 @@ def __init__(self, *args, **kwargs): self._parse_args(*args, **kwargs) # inputs must be ints - self.x = 1 * cast_to_int(self.x) - self.y = 1 * cast_to_int(self.y) + self.x = cast_to_int(self.x) + self.y = cast_to_int(self.y) def _check_scalar(self, other, op): try: diff --git a/jax_galsim/sensor.py b/jax_galsim/sensor.py index f588f1d2..80ffe1fa 100644 --- a/jax_galsim/sensor.py +++ b/jax_galsim/sensor.py @@ -12,10 +12,7 @@ class Sensor: def __init__(self): pass - @_wraps( - _galsim.Sensor.accumulate, - lax_description="The JAX equivalent of galsim.Sensor.accumulate does not raise for undefined bounds.", - ) + @_wraps(_galsim.Sensor.accumulate) def accumulate(self, photons, image, orig_center=None, resume=False): if not image.bounds.isDefined(): raise GalSimUndefinedBoundsError( From 17892b807ade03d1f9997e2549c8a0a792017cca Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 28 Nov 2023 00:42:03 -0600 Subject: [PATCH 52/85] REF make sure tracing does not raise --- jax_galsim/core/utils.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index 44f47271..240a20f5 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -99,7 +99,11 @@ def cast_to_float(x): # this will return the same value for anything float-like that # cannot be cast to float # however, it will raise an error if something is not float-like - return 1.0 * x + # we exclude object types since they are used in JAX tracing + if type(x) is object: + return x + else: + return 1.0 * x def cast_to_int(x): @@ -114,12 +118,18 @@ def cast_to_int(x): # this will return the same value for anything int-like that # cannot be cast to int # however, it will raise an error if something is not int-like - return 1 * x + if type(x) is object: + return x + else: + return 1 * x except Exception: # this will return the same value for anything int-like that # cannot be cast to int # however, it will raise an error if something is not int-like - return 1 * x + if type(x) is object: + return x + else: + return 1 * x def is_equal_with_arrays(x, y): From 093ef3f1cd909e4659629a0894d50f3fe6f3f261 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Tue, 28 Nov 2023 00:46:50 -0600 Subject: [PATCH 53/85] Apply suggestions from code review --- jax_galsim/exponential.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax_galsim/exponential.py b/jax_galsim/exponential.py index 6743550b..584960c1 100644 --- a/jax_galsim/exponential.py +++ b/jax_galsim/exponential.py @@ -154,14 +154,14 @@ def _shoot_cdf(self): # # We are looking to draw from a distribution that is r * exp(-r). # This distribution is the radial PDF of an Exponential profile. - # The fact of r comes from the area element r * dr. + # The factor of r comes from the area element r * dr. # # We can compute the CDF of this distribution analytically, but we cannot # invert the CDF in closed form. Thus we invert it numerically using a table. # # One final detail is that we want the inversion to be accurate and are using # linear interpolation. Thus we use a change of variables r = -ln(1 - u) - # to make the CDF more linear. + # to make the CDF more linear and map it's domain to [0, 1) instead of [0, inf). # # Putting this all together, we get # From 86bee6202122d62c46277caaf1923ea35e0d9e47 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Tue, 28 Nov 2023 06:49:35 -0600 Subject: [PATCH 54/85] Apply suggestions from code review --- jax_galsim/interpolant.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax_galsim/interpolant.py b/jax_galsim/interpolant.py index fb220232..a5881cec 100644 --- a/jax_galsim/interpolant.py +++ b/jax_galsim/interpolant.py @@ -1460,7 +1460,7 @@ def __str__(self): # this is a pure function and we apply JIT ahead of time since this # one is pretty slow @jax.jit - def _xval(n, conserve_dc, _K, x): + def _xval(x, n, conserve_dc, _K): x = jnp.abs(x) def _low(x, n): @@ -1533,7 +1533,7 @@ def _no_dcval(val, x, n, _K): ) def _xval_noraise(self, x): - return Lanczos._xval(self._n, self._conserve_dc, self._K_arr, x) + return Lanczos._xval(x, self._n, self._conserve_dc, self._K_arr) def _raw_uval(u, n): # this function is used in the init and so was causing a recursion depth error From 44fd8da94eb5f83a54c100d62f9786df85c2a56e Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Tue, 28 Nov 2023 06:51:46 -0600 Subject: [PATCH 55/85] Apply suggestions from code review --- jax_galsim/interpolatedimage.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index f213e0ea..12521dab 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -886,8 +886,8 @@ def _shoot(self, photons, rng): ) # accounnt for offset - we add the offset to get to - # image pixels in xValue - # here we generate photons from the image and thus + # image pixels in the xValue method + # here we generate photons from the image and # so we need to subtract it to get back to get to x as # it would be input in xVal photons.x -= self._offset.x From d88f994c211608e57bb2321c6a0e749f4677b76b Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Tue, 28 Nov 2023 06:54:25 -0600 Subject: [PATCH 56/85] Apply suggestions from code review --- jax_galsim/sum.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jax_galsim/sum.py b/jax_galsim/sum.py index a5f9133b..855f8590 100644 --- a/jax_galsim/sum.py +++ b/jax_galsim/sum.py @@ -193,7 +193,8 @@ def _shoot(self, photons, rng): fluxes = jnp.array( [obj.positive_flux + obj.negative_flux for obj in self.obj_list] ) - # for a sum of objects, we use a slightly different approach than galsim + # for a sum of objects, we use a slightly different approach than galsim did + # as of version 2.5 # galsim uses a binomial distribution to compute the number of photons per object # we take an equivalent but different approach in order to use fixed size arrays # of photons. it means we draw more photons but the code is JIT compilable and a bit simpler From a2b4dd73f9a154fb194b9568508083c1ef695669 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Tue, 28 Nov 2023 06:58:04 -0600 Subject: [PATCH 57/85] Apply suggestions from code review --- jax_galsim/utilities.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/jax_galsim/utilities.py b/jax_galsim/utilities.py index 1c9fedb6..4ddb4dee 100644 --- a/jax_galsim/utilities.py +++ b/jax_galsim/utilities.py @@ -18,10 +18,13 @@ "The LAX version of this decorator uses an `_workspace` attribute " "attached to the object so that the cache can easily be discarded " "for certain operations. It also will not cache jax.core.Tracer objects " - "in order to avoid side-effects in jit/grad/vmap transformations." + "in order to avoid side-effects in jit/grad/vmap transformations " + "unless `cache_jax_tracers=True` is given." ), ) def lazy_property(func_=None, cache_jax_tracers=False): + # the extra layer of indirection here allows the decorator to + # take keyword arguments and also be used without them. # see https://stackoverflow.com/a/57268935 def _decorator(func): attname = func.__name__ + "_cached" From 328df692db71cb6ea393d9abd03398f30b33d14b Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Tue, 28 Nov 2023 07:04:49 -0600 Subject: [PATCH 58/85] Apply suggestions from code review --- tests/conftest.py | 4 ++++ tests/jax/test_api.py | 1 - 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index f12f9758..c23ebd89 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,6 +15,10 @@ import jax_galsim # noqa: E402 +# this environment variable is used in the +# JAX-specific modifications to the GalSim +# test suite to change tests where +# jax-galsim is not compatible. os.environ["JAX_GALSIM_TESTING"] = "1" # Identify the path to this current file diff --git a/tests/jax/test_api.py b/tests/jax/test_api.py index 77c7df07..836ce720 100644 --- a/tests/jax/test_api.py +++ b/tests/jax/test_api.py @@ -940,7 +940,6 @@ def test_api_noise(): ], ) def test_api_pickling_eval_repr_basic(obj1): - """This test is here until we run all of the galsim tests which cover this one.""" # test copied from galsim import copy import pickle From 0d25200f987cf77621d2eba5d79badf7151b5b4a Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 28 Nov 2023 07:05:04 -0600 Subject: [PATCH 59/85] TST comment out old code --- tests/jax/test_photon_shooting_jax.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/jax/test_photon_shooting_jax.py b/tests/jax/test_photon_shooting_jax.py index bb2dd3b2..1965c358 100644 --- a/tests/jax/test_photon_shooting_jax.py +++ b/tests/jax/test_photon_shooting_jax.py @@ -151,18 +151,18 @@ def test_photon_shooting_jax_offset(offset): ) # code for testing - if not np.allclose(img_fft.array, img_phot.array, rtol=rtol, atol=atol): - import proplot as pplt + # if not np.allclose(img_fft.array, img_phot.array, rtol=rtol, atol=atol): + # import proplot as pplt - fig, axs = pplt.subplots(nrows=1, ncols=3) - axs[0].imshow(img_fft.array, origin="lower") - axs[1].imshow(img_phot.array, origin="lower") - axs[2].imshow(img_fft.array - img_phot.array, origin="lower") - fig.show() + # fig, axs = pplt.subplots(nrows=1, ncols=3) + # axs[0].imshow(img_fft.array, origin="lower") + # axs[1].imshow(img_phot.array, origin="lower") + # axs[2].imshow(img_fft.array - img_phot.array, origin="lower") + # fig.show() - import pdb + # import pdb - pdb.set_trace() + # pdb.set_trace() np.testing.assert_almost_equal( jnp.argmax(img_fft.array), From 560add87c163910e162168378f33f0dabc076b0e Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Tue, 28 Nov 2023 09:37:23 -0600 Subject: [PATCH 60/85] Apply suggestions from code review --- tests/jax/test_photon_shooting_jax.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/jax/test_photon_shooting_jax.py b/tests/jax/test_photon_shooting_jax.py index 1965c358..3c6192c5 100644 --- a/tests/jax/test_photon_shooting_jax.py +++ b/tests/jax/test_photon_shooting_jax.py @@ -104,6 +104,7 @@ def test_photon_shooting_jax_offset(offset): flux_max = iobj.max_sb * 0.2**2 atol = flux_max * rtol * 3 nphot = int((flux_tot / flux_max / rtol**2).item()) + rtol *= 3 with time_code_block(): img_phot = iobj.drawImage( @@ -207,14 +208,12 @@ def _draw(hlr, fwhm, shift, flux, seed): img = _draw(hlrs[0], fwhms[0], shifts[0], fluxes[0], seeds[0]) with time_code_block("one"): img = _draw(hlrs[0], fwhms[0], shifts[0], fluxes[0], seeds[0]) - print(img.array.shape, img.bounds, img.array.sum(), fluxes[0]) _vmap_draw = jax.jit(jax.vmap(_draw, in_axes=(0, 0, 0, 0, 0))) with time_code_block("vmap warmup"): imgs = _vmap_draw(hlrs, fwhms, shifts, fluxes, seeds) with time_code_block("vmap"): imgs = _vmap_draw(hlrs, fwhms, shifts, fluxes, seeds) - print(imgs.array.shape) np.testing.assert_allclose(img.array.sum(), imgs.array[0].sum()) From 8ac318d0ed01f26a476c1d7b830cb961960165b8 Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 28 Nov 2023 09:38:10 -0600 Subject: [PATCH 61/85] TST remove duplicate test --- jax_galsim/utilities.py | 2 +- tests/jax/test_image_wrapping.py | 191 +------------------------------ 2 files changed, 2 insertions(+), 191 deletions(-) diff --git a/jax_galsim/utilities.py b/jax_galsim/utilities.py index 4ddb4dee..95a7a985 100644 --- a/jax_galsim/utilities.py +++ b/jax_galsim/utilities.py @@ -23,7 +23,7 @@ ), ) def lazy_property(func_=None, cache_jax_tracers=False): - # the extra layer of indirection here allows the decorator to + # the extra layer of indirection here allows the decorator to # take keyword arguments and also be used without them. # see https://stackoverflow.com/a/57268935 def _decorator(func): diff --git a/tests/jax/test_image_wrapping.py b/tests/jax/test_image_wrapping.py index f1f7ada4..14bfe100 100644 --- a/tests/jax/test_image_wrapping.py +++ b/tests/jax/test_image_wrapping.py @@ -1,6 +1,6 @@ import jax import numpy as np -from galsim_test_helpers import assert_raises, timer +from galsim_test_helpers import timer import jax_galsim as galsim from jax_galsim.core.wrap_image import ( @@ -142,192 +142,3 @@ def _wrapit(im): p, grad = jax.vjp(_wrapit, im3) grad(p) jax.jvp(_wrapit, (im3,), (im3 * 2,)) - - -@timer -def test_wrap_jax_simple_real(): - """Test the image.wrap() function.""" - # Start with a fairly simple test where the image is 4 copies of the same data: - im_orig = galsim.Image( - [ - [11.0, 12.0, 13.0, 14.0, 11.0, 12.0, 13.0, 14.0], - [21.0, 22.0, 23.0, 24.0, 21.0, 22.0, 23.0, 24.0], - [31.0, 32.0, 33.0, 34.0, 31.0, 32.0, 33.0, 34.0], - [41.0, 42.0, 43.0, 44.0, 41.0, 42.0, 43.0, 44.0], - [11.0, 12.0, 13.0, 14.0, 11.0, 12.0, 13.0, 14.0], - [21.0, 22.0, 23.0, 24.0, 21.0, 22.0, 23.0, 24.0], - [31.0, 32.0, 33.0, 34.0, 31.0, 32.0, 33.0, 34.0], - [41.0, 42.0, 43.0, 44.0, 41.0, 42.0, 43.0, 44.0], - ] - ) - im = im_orig.copy() - b = galsim.BoundsI(1, 4, 1, 4) - im_quad = im_orig[b] - im_wrap = im.wrap(b) - np.testing.assert_allclose(im_wrap.array, 4.0 * im_quad.array) - - # The same thing should work no matter where the lower left corner is: - for xmin, ymin in ((1, 5), (5, 1), (5, 5), (2, 3), (4, 1)): - b = galsim.BoundsI(xmin, xmin + 3, ymin, ymin + 3) - im_quad = im_orig[b] - im = im_orig.copy() - im_wrap = im.wrap(b) - np.testing.assert_allclose( - im_wrap.array, - 4.0 * im_quad.array, - err_msg="image.wrap(%s) did not match expectation" % b, - ) - np.testing.assert_allclose( - im_wrap.array, - im[b].array, - err_msg="image.wrap(%s) did not return the right subimage" % b, - ) - # this test passes even though we do not get a view - im[b].fill(0) - np.testing.assert_allclose( - im_wrap.array, - im[b].array, - err_msg="image.wrap(%s) did not return a view of the original" % b, - ) - - -@timer -def test_wrap_jax_weird_real(): - # Now test where the subimage is not a simple fraction of the original, and all the - # sizes are different. - im = galsim.ImageD(17, 23, xmin=0, ymin=0) - b = galsim.BoundsI(7, 9, 11, 18) - im_test = galsim.ImageD(b, init_value=0) - for i in range(17): - for j in range(23): - val = np.exp(i / 7.3) + (j / 12.9) ** 3 # Something randomly complicated... - im[i, j] = val - # Find the location in the sub-image for this point. - ii = (i - b.xmin) % (b.xmax - b.xmin + 1) + b.xmin - jj = (j - b.ymin) % (b.ymax - b.ymin + 1) + b.ymin - im_test.addValue(ii, jj, val) - im_wrap = im.wrap(b) - np.testing.assert_allclose( - im_wrap.array, - im_test.array, - err_msg="image.wrap(%s) did not match expectation" % b, - ) - np.testing.assert_array_equal( - im_wrap.array, im[b].array, "image.wrap(%s) did not return the right subimage" - ) - np.testing.assert_equal( - im_wrap.bounds, b, "image.wrap(%s) does not have the correct bounds" - ) - - -@timer -def test_wrap_jax_complex(): - # For complex images (in particular k-space images), we often want the image to be implicitly - # Hermitian, so we only need to keep around half of it. - M = 38 - N = 25 - K = 8 - L = 5 - im = galsim.ImageCD(2 * M + 1, 2 * N + 1, xmin=-M, ymin=-N) # Explicitly Hermitian - im2 = galsim.ImageCD( - 2 * M + 1, N + 1, xmin=-M, ymin=0 - ) # Implicitly Hermitian across y axis - im3 = galsim.ImageCD( - M + 1, 2 * N + 1, xmin=0, ymin=-N - ) # Implicitly Hermitian across x axis - # print('im = ',im) - # print('im2 = ',im2) - # print('im3 = ',im3) - b = galsim.BoundsI(-K + 1, K, -L + 1, L) - b2 = galsim.BoundsI(-K + 1, K, 0, L) - b3 = galsim.BoundsI(0, K, -L + 1, L) - im_test = galsim.ImageCD(b, init_value=0) - for i in range(-M, M + 1): - for j in range(-N, N + 1): - # An arbitrary, complicated Hermitian function. - val = ( - np.exp((i / (2.3 * M)) ** 2 + 1j * (2.8 * i - 1.3 * j)) - + ((2 + 3j * j) / (1.9 * N)) ** 3 - ) - # val = 2*(i-j)**2 + 3j*(i+j) - - im[i, j] = val - if j >= 0: - im2[i, j] = val - if i >= 0: - im3[i, j] = val - - ii = (i - b.xmin) % (b.xmax - b.xmin + 1) + b.xmin - jj = (j - b.ymin) % (b.ymax - b.ymin + 1) + b.ymin - im_test.addValue(ii, jj, val) - # print("im = ",im.array) - - # Confirm that the image is Hermitian. - for i in range(-M, M + 1): - for j in range(-N, N + 1): - assert im(i, j) == im(-i, -j).conjugate() - - im_wrap = im.wrap(b) - # print("im_wrap = ",im_wrap.array) - np.testing.assert_allclose( - im_wrap.array, - im_test.array, - err_msg="image.wrap(%s) did not match expectation" % b, - ) - np.testing.assert_array_equal( - im_wrap.array, im[b].array, "image.wrap(%s) did not return the right subimage" - ) - np.testing.assert_equal( - im_wrap.bounds, b, "image.wrap(%s) does not have the correct bounds" - ) - - im2_wrap = im2.wrap(b2, hermitian="y") - # print('im_test = ',im_test[b2].array) - # print('im2_wrap = ',im2_wrap.array) - # print('diff = ',im2_wrap.array-im_test[b2].array) - np.testing.assert_allclose( - im2_wrap.array, - im_test[b2].array, - err_msg="image.wrap(%s) did not match expectation" % b, - ) - np.testing.assert_array_equal( - im2_wrap.array, - im2[b2].array, - "image.wrap(%s) did not return the right subimage", - ) - np.testing.assert_equal( - im2_wrap.bounds, b2, "image.wrap(%s) does not have the correct bounds" - ) - - im3_wrap = im3.wrap(b3, hermitian="x") - # print('im_test = ',im_test[b3].array) - # print('im3_wrap = ',im3_wrap.array) - # print('diff = ',im3_wrap.array-im_test[b3].array) - np.testing.assert_allclose( - im3_wrap.array, - im_test[b3].array, - err_msg="image.wrap(%s) did not match expectation" % b, - ) - np.testing.assert_array_equal( - im3_wrap.array, - im3[b3].array, - "image.wrap(%s) did not return the right subimage", - ) - np.testing.assert_equal( - im3_wrap.bounds, b3, "image.wrap(%s) does not have the correct bounds" - ) - - b = galsim.BoundsI(-K + 1, K, -L + 1, L) - b2 = galsim.BoundsI(-K + 1, K, 0, L) - b3 = galsim.BoundsI(0, K, -L + 1, L) - assert_raises(TypeError, im.wrap, bounds=None) - assert_raises(ValueError, im.wrap, b2, hermitian="y") - assert_raises(ValueError, im.wrap, b, hermitian="invalid") - assert_raises(ValueError, im.wrap, b3, hermitian="x") - - assert_raises(ValueError, im3.wrap, b, hermitian="x") - assert_raises(ValueError, im3.wrap, b2, hermitian="x") - assert_raises(ValueError, im2.wrap, b, hermitian="y") - assert_raises(ValueError, im2.wrap, b3, hermitian="y") - assert_raises(ValueError, im2.wrap, b2, hermitian="invalid") - assert_raises(ValueError, im3.wrap, b3, hermitian="invalid") From c9f4c8c175e1df615d6c41391c29d5832a3b0cb0 Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 28 Nov 2023 17:00:34 -0600 Subject: [PATCH 62/85] TST added tests of seeding --- tests/jax/test_jitting.py | 106 +++++++++++++++++++++----- tests/jax/test_photon_shooting_jax.py | 80 +++++++++++++++++++ 2 files changed, 165 insertions(+), 21 deletions(-) diff --git a/tests/jax/test_jitting.py b/tests/jax/test_jitting.py index 6a427f94..77d03f65 100644 --- a/tests/jax/test_jitting.py +++ b/tests/jax/test_jitting.py @@ -225,7 +225,7 @@ def _draw_it_jit(obj, n, nfft): def test_jitting_draw_phot(): - def _build_and_draw(hlr, fwhm, jit=True): + def _build_and_draw(hlr, fwhm, jit=True, maxn=False): gal = galsim.Exponential( half_light_radius=hlr, flux=1000.0 ) + galsim.Exponential(half_light_radius=hlr * 2.0, flux=100.0) @@ -240,10 +240,14 @@ def _build_and_draw(hlr, fwhm, jit=True): final._flux_per_photon, final.max_sb, poisson_flux=False, + rng=galsim.BaseDeviate(1234), )[0].item() gain = 1.0 if jit: - return _draw_it_jit(final, n, n_photons, gain) + if maxn: + return _draw_it_jit_maxn(final, n, n_photons, gain) + else: + return _draw_it_jit(final, n, n_photons, gain) else: return final.drawImage( nx=n, @@ -253,10 +257,11 @@ def _build_and_draw(hlr, fwhm, jit=True): n_photons=n_photons, poisson_flux=False, gain=gain, + rng=galsim.BaseDeviate(42), ) @partial(jax.jit, static_argnums=(1, 2)) - def _draw_it_jit(obj, n, nphotons, gain): + def _draw_it_jit_maxn(obj, n, nphotons, gain): return obj.drawImage( nx=n, ny=n, @@ -266,28 +271,55 @@ def _draw_it_jit(obj, n, nphotons, gain): poisson_flux=False, gain=gain, maxN=101, + rng=galsim.BaseDeviate(2), + ) + + @partial(jax.jit, static_argnums=(1, 2)) + def _draw_it_jit(obj, n, nphotons, gain): + return obj.drawImage( + nx=n, + ny=n, + scale=0.2, + n_photons=nphotons, + method="phot", + poisson_flux=False, + gain=gain, + maxN=None, + rng=galsim.BaseDeviate(42), ) with time_code_block("warmup no-jit"): - img = _build_and_draw(0.5, 1.0, jit=False) - np.testing.assert_allclose(img.array.sum(), 1100.0) + img1 = _build_and_draw(0.5, 1.0, jit=False) with time_code_block("no-jit"): - img = _build_and_draw(0.5, 1.0, jit=False) - np.testing.assert_allclose(img.array.sum(), 1100.0) + img2 = _build_and_draw(0.5, 1.0, jit=False) with time_code_block("warmup jit"): - img = _build_and_draw(0.5, 1.0) - np.testing.assert_allclose(img.array.sum(), 1100.0) + img3 = _build_and_draw(0.5, 1.0) with time_code_block("jit"): - img = _build_and_draw(0.5, 1.0) + img4 = _build_and_draw(0.5, 1.0) + + with time_code_block("warmup jit"): + img5 = _build_and_draw(0.5, 1.0, maxn=True) + + with time_code_block("jit"): + img6 = _build_and_draw(0.5, 1.0, maxn=True) + + np.testing.assert_allclose(img1.array, img2.array) + np.testing.assert_allclose(img1.array, img3.array) + np.testing.assert_allclose(img1.array, img4.array) + + assert not np.allclose(img1.array, img5.array) + + np.testing.assert_allclose(img5.array, img6.array) - np.testing.assert_allclose(img.array.sum(), 1100.0) + np.testing.assert_allclose(img1.array.sum(), 1100.0) + np.testing.assert_allclose(img5.array.sum(), 1100.0) def test_jitting_draw_phot_fixed(): - def _build_and_draw(hlr, fwhm, jit=True): + def _build_and_draw(hlr, fwhm, jit=True, maxn=False): gal = galsim.Exponential( half_light_radius=hlr, flux=1000.0 ) + galsim.Exponential(half_light_radius=hlr * 2.0, flux=100.0) @@ -302,10 +334,14 @@ def _build_and_draw(hlr, fwhm, jit=True): final._flux_per_photon, final.max_sb, poisson_flux=False, + rng=galsim.BaseDeviate(1234), )[0] gain = 1.0 if jit: - return _draw_it_jit(final, n, n_photons, gain) + if maxn: + return _draw_it_jit_maxn(final, n, n_photons, gain) + else: + return _draw_it_jit(final, n, n_photons, gain) else: with fixed_photon_array_size(2048): return final.drawImage( @@ -316,10 +352,25 @@ def _build_and_draw(hlr, fwhm, jit=True): n_photons=n_photons, poisson_flux=False, gain=gain, + rng=galsim.BaseDeviate(42), ) @partial(jax.jit, static_argnums=(1, 2)) def _draw_it_jit(obj, n, nphotons, gain): + with fixed_photon_array_size(2048): + return obj.drawImage( + nx=n, + ny=n, + scale=0.2, + n_photons=nphotons, + method="phot", + poisson_flux=False, + gain=gain, + rng=galsim.BaseDeviate(42), + ) + + @partial(jax.jit, static_argnums=(1, 2)) + def _draw_it_jit_maxn(obj, n, nphotons, gain): with fixed_photon_array_size(2048): return obj.drawImage( nx=n, @@ -330,21 +381,34 @@ def _draw_it_jit(obj, n, nphotons, gain): poisson_flux=False, gain=gain, maxN=101, + rng=galsim.BaseDeviate(42), ) with time_code_block("warmup no-jit"): - img = _build_and_draw(0.5, 1.0, jit=False) - np.testing.assert_allclose(img.array.sum(), 1100.0) + img1 = _build_and_draw(0.5, 1.0, jit=False) with time_code_block("no-jit"): - img = _build_and_draw(0.5, 1.0, jit=False) - np.testing.assert_allclose(img.array.sum(), 1100.0) + img2 = _build_and_draw(0.5, 1.0, jit=False) with time_code_block("warmup jit"): - img = _build_and_draw(0.5, 1.0) - np.testing.assert_allclose(img.array.sum(), 1100.0) + img3 = _build_and_draw(0.5, 1.0) with time_code_block("jit"): - img = _build_and_draw(0.5, 1.0) + img4 = _build_and_draw(0.5, 1.0) + + with time_code_block("warmup jit"): + img5 = _build_and_draw(0.5, 1.0, maxn=True) + + with time_code_block("jit"): + img6 = _build_and_draw(0.5, 1.0, maxn=True) + + np.testing.assert_allclose(img1.array, img2.array) + np.testing.assert_allclose(img1.array, img3.array) + np.testing.assert_allclose(img1.array, img4.array) + + assert not np.allclose(img1.array, img5.array) + + np.testing.assert_allclose(img5.array, img6.array) - np.testing.assert_allclose(img.array.sum(), 1100.0) + np.testing.assert_allclose(img1.array.sum(), 1100.0) + np.testing.assert_allclose(img5.array.sum(), 1100.0) diff --git a/tests/jax/test_photon_shooting_jax.py b/tests/jax/test_photon_shooting_jax.py index 3c6192c5..b61f0f0b 100644 --- a/tests/jax/test_photon_shooting_jax.py +++ b/tests/jax/test_photon_shooting_jax.py @@ -1,3 +1,5 @@ +from functools import partial + import galsim as _galsim import jax import jax.numpy as jnp @@ -6,6 +8,7 @@ from jax.tree_util import register_pytree_node_class import jax_galsim +from jax_galsim.core.draw import calculate_n_photons from jax_galsim.core.testing import time_code_block from jax_galsim.photon_array import fixed_photon_array_size @@ -235,3 +238,80 @@ def _draw_galsim(hlr, fwhm, shift, flux, seed): with time_code_block("galsim"): for i in range(n_stamps): _draw_galsim(hlrs[i], fwhms[i], shifts[i], fluxes[i], i + 1) + + +def test_photon_shooting_jax_rng_seed(): + def _build_and_draw(hlr, fwhm, jit=True, new_seed=False): + gal = jax_galsim.Exponential( + half_light_radius=hlr, flux=1000.0 + ) + jax_galsim.Exponential(half_light_radius=hlr * 2.0, flux=100.0) + psf = jax_galsim.Gaussian(fwhm=fwhm, flux=1.0) + final = jax_galsim.Convolve( + [gal, psf], + ) + n = final.getGoodImageSize(0.2).item() + n += 1 + n_photons = calculate_n_photons( + final.flux, + final._flux_per_photon, + final.max_sb, + poisson_flux=True, + rng=jax_galsim.BaseDeviate(1234), + )[0].item() + if jit: + if new_seed: + return _draw_it_jit_new_seed(final, n_photons) + else: + return _draw_it_jit(final, n_photons) + else: + return final.makePhot( + local_wcs=jax_galsim.PixelScale(0.2), + n_photons=n_photons, + poisson_flux=False, + rng=jax_galsim.BaseDeviate(42), + ) + + @partial(jax.jit, static_argnums=(1,)) + def _draw_it_jit(obj, n_photons): + return obj.makePhot( + local_wcs=jax_galsim.PixelScale(0.2), + n_photons=n_photons, + poisson_flux=False, + rng=jax_galsim.BaseDeviate(42), + ) + + @partial(jax.jit, static_argnums=(1,)) + def _draw_it_jit_new_seed(obj, n_photons): + return obj.makePhot( + local_wcs=jax_galsim.PixelScale(0.2), + n_photons=n_photons, + poisson_flux=False, + rng=jax_galsim.BaseDeviate(2), + ) + + with time_code_block("warmup no-jit"): + pa1 = _build_and_draw(0.5, 1.0, jit=False) + + with time_code_block("no-jit"): + pa2 = _build_and_draw(0.5, 1.0, jit=False) + + with time_code_block("warmup jit"): + pa3 = _build_and_draw(0.5, 1.0) + + with time_code_block("jit"): + pa4 = _build_and_draw(0.5, 1.0) + + with time_code_block("warmup jit + new seed"): + pa5 = _build_and_draw(0.5, 1.0, new_seed=True, jit=True) + + with time_code_block("jit + new seed"): + pa6 = _build_and_draw(0.5, 1.0, new_seed=True, jit=True) + + for attr in ["x", "y"]: + np.testing.assert_allclose(getattr(pa1, attr), getattr(pa2, attr)) + np.testing.assert_allclose(getattr(pa1, attr), getattr(pa3, attr)) + np.testing.assert_allclose(getattr(pa1, attr), getattr(pa4, attr)) + + assert not np.allclose(getattr(pa1, attr), getattr(pa5, attr)) + + np.testing.assert_allclose(getattr(pa5, attr), getattr(pa6, attr)) From 03084fb537f224b4e729be73b99023d1730bdabc Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 29 Nov 2023 06:26:56 -0600 Subject: [PATCH 63/85] BUG make sure we can return photons and added flux --- jax_galsim/image.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index a63c6450..c378573e 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -1154,6 +1154,13 @@ def tree_flatten(self): # Define the children nodes of the PyTree that need tracing children = (self.array, self.wcs) aux_data = {"dtype": self.dtype, "bounds": self.bounds, "isconst": self.isconst} + if hasattr(self, "added_flux"): + children += (self.added_flux,) + if hasattr(self, "header"): + aux_data["header"] = self.header + if hasattr(self, "photons"): + children += (self.photons,) + return (children, aux_data) @classmethod @@ -1165,6 +1172,12 @@ def tree_unflatten(cls, aux_data, children): obj._bounds = aux_data["bounds"] obj._dtype = aux_data["dtype"] obj._is_const = aux_data["isconst"] + if len(children) > 2: + obj.added_flux = children[2] + if "header" in aux_data: + obj.header = aux_data["header"] + if len(children) > 3: + obj.photons = children[3] return obj @classmethod From 4c24af7c30aaadd114bd346f8c8caa5ca40fe806 Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 29 Nov 2023 06:57:26 -0600 Subject: [PATCH 64/85] BUG remove workspace when pickling --- jax_galsim/interpolant.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/jax_galsim/interpolant.py b/jax_galsim/interpolant.py index a5881cec..efaf1e73 100644 --- a/jax_galsim/interpolant.py +++ b/jax_galsim/interpolant.py @@ -30,6 +30,17 @@ def __init__(self): "Use one of the subclasses instead, or use the `from_name` factory function." ) + def __getstate__(self): + d = self.__dict__.copy() + d["had_workspace"] = "_workspace" in d + d.pop("_workspace", None) + return d + + def __setstate__(self, d): + if d.pop("had_workspace", False): + d["_workspace"] = {} + self.__dict__ = d + @staticmethod def from_name(name, tol=None, gsparams=None): """A factory function to create an `Interpolant` of the correct type according to From effaa1d7607114bc0c0b6f3ed634d29905491ea5 Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 29 Nov 2023 07:10:42 -0600 Subject: [PATCH 65/85] BUG handle maxN with fixed array sizes correctly --- jax_galsim/gsobject.py | 2 ++ tests/jax/test_photon_shooting_jax.py | 12 ++++++------ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index 624abefb..081a93ee 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -1288,6 +1288,8 @@ def drawPhot( if not add_to_image: image.setZero() + maxN = pa._JAX_GALSIM_PHOTON_ARRAY_SIZE or maxN + if maxN is None: ( added_flux, diff --git a/tests/jax/test_photon_shooting_jax.py b/tests/jax/test_photon_shooting_jax.py index b61f0f0b..94aa14ec 100644 --- a/tests/jax/test_photon_shooting_jax.py +++ b/tests/jax/test_photon_shooting_jax.py @@ -176,7 +176,8 @@ def test_photon_shooting_jax_offset(offset): np.testing.assert_allclose(img_fft.array, img_phot.array, rtol=rtol, atol=atol) -def test_photon_shooting_jax_vmapping(): +@pytest.mark.parametrize("max_n_phot", [10, 100, 1000]) +def test_photon_shooting_jax_vmapping(max_n_phot): n_stamps = 100 rng = np.random.RandomState(1234) shifts = jnp.array(rng.uniform(-1, 1, size=(n_stamps, 2))) @@ -187,7 +188,6 @@ def test_photon_shooting_jax_vmapping(): seeds = [] for i in range(n_stamps): seeds.append(jax.random.key(i + 1)) - max_n_phot = 2048 seeds = jnp.array(seeds) @jax.jit @@ -200,8 +200,8 @@ def _draw(hlr, fwhm, shift, flux, seed): ) with fixed_photon_array_size(max_n_phot): return obj.drawImage( - nx=33, - ny=33, + nx=53, + ny=53, scale=0.2, method="phot", rng=jax_galsim.BaseDeviate(seed), @@ -228,8 +228,8 @@ def _draw_galsim(hlr, fwhm, shift, flux, seed): ] ) return obj.drawImage( - nx=33, - ny=33, + nx=53, + ny=53, scale=0.2, method="phot", rng=_galsim.BaseDeviate(seed), From bcc0e2911b53f72725270b1d5b4960c5e66c6f9d Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 29 Nov 2023 11:45:48 -0600 Subject: [PATCH 66/85] BUG compute minN properly --- jax_galsim/gsobject.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index 081a93ee..3ac1559c 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -1288,7 +1288,10 @@ def drawPhot( if not add_to_image: image.setZero() - maxN = pa._JAX_GALSIM_PHOTON_ARRAY_SIZE or maxN + if maxN is not None and pa._JAX_GALSIM_PHOTON_ARRAY_SIZE is not None: + maxN = min(maxN, pa._JAX_GALSIM_PHOTON_ARRAY_SIZE) + else: + maxN = pa._JAX_GALSIM_PHOTON_ARRAY_SIZE or maxN if maxN is None: ( From 0d16f9955f23489f4ef10b31210fb7760a7d8b32 Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 29 Nov 2023 14:15:03 -0600 Subject: [PATCH 67/85] TST added tests of APIs --- jax_galsim/photon_array.py | 6 +++--- tests/jax/test_api.py | 23 +++++++++++++++++++++-- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/jax_galsim/photon_array.py b/jax_galsim/photon_array.py index 1a2b7f5a..583f7753 100644 --- a/jax_galsim/photon_array.py +++ b/jax_galsim/photon_array.py @@ -106,11 +106,11 @@ def __init__( if time is not None: self.time = time + @classmethod @_wraps( _galsim.PhotonArray.fromArrays, lax_description="JAX-GalSim does not do input type/size checking.", ) - @classmethod def fromArrays( cls, x, @@ -128,8 +128,8 @@ def fromArrays( x, y, flux, dxdz, dydz, wavelength, pupil_u, pupil_v, time, is_corr ) - @_wraps(_galsim.PhotonArray._fromArrays) @classmethod + @_wraps(_galsim.PhotonArray._fromArrays) def _fromArrays( cls, x, @@ -473,8 +473,8 @@ def _sort_by_nokeep(self): return self + @_wraps(_galsim.PhotonArray.assignAt) def assignAt(self, istart, rhs): - """Assign the contents of another `PhotonArray` to this one starting at istart.""" from .deprecated import depr depr( diff --git a/tests/jax/test_api.py b/tests/jax/test_api.py index 836ce720..8c958434 100644 --- a/tests/jax/test_api.py +++ b/tests/jax/test_api.py @@ -125,7 +125,7 @@ def _run_object_checks(obj, cls, kind): # check that we can hash the object hash(obj) - elif kind == "pickle-eval-repr-img": + elif kind == "pickle-eval-repr-img" or kind == "pickle-eval-repr-nohash": from numpy import array # noqa: F401 # eval repr is identity mapping @@ -147,6 +147,9 @@ def _run_object_checks(obj, cls, kind): # check that we cannot hash the object hash(obj) + elif kind == "jax-compatible": + # JAX tracing should be an identity + assert cls.tree_unflatten(*((obj.tree_flatten())[::-1])) == obj elif kind == "vmap-jit-grad": # JAX tracing should be an identity assert cls.tree_unflatten(*((obj.tree_flatten())[::-1])) == obj @@ -362,7 +365,15 @@ def _reg_fun(p): line.strip() and line not in _galsim.utilities.lazy_property.__doc__ ): - assert line.strip() in getattr(cls, method).__doc__ + assert line.strip() in getattr(cls, method).__doc__, ( + cls.__name__ + + "." + + method + + " doc string does not match galsim." + + gscls.__name__ + + "." + + method + ) else: assert method not in dir(gscls), cls.__name__ + "." + method else: @@ -1024,3 +1035,11 @@ def func(x): else: # Even if we're not actually doing the test, still make the repr to check for syntax errors. repr(obj1) + + +def test_api_photon_array(): + pa = jax_galsim.PhotonArray(101) + + _run_object_checks(pa, pa.__class__, "docs-methods") + _run_object_checks(pa, pa.__class__, "pickle-eval-repr-nohash") + _run_object_checks(pa, pa.__class__, "jax-compatible") From d9da5759a982bad83a192fcf9eb9b6b39466f315 Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 29 Nov 2023 14:18:32 -0600 Subject: [PATCH 68/85] TST add API tests for sensors --- tests/jax/test_api.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/jax/test_api.py b/tests/jax/test_api.py index 8c958434..92846953 100644 --- a/tests/jax/test_api.py +++ b/tests/jax/test_api.py @@ -1043,3 +1043,24 @@ def test_api_photon_array(): _run_object_checks(pa, pa.__class__, "docs-methods") _run_object_checks(pa, pa.__class__, "pickle-eval-repr-nohash") _run_object_checks(pa, pa.__class__, "jax-compatible") + + +def test_api_sensor(): + classes = [] + for item in sorted(dir(jax_galsim)): + cls = getattr(jax_galsim, item) + if inspect.isclass(cls) and issubclass(cls, jax_galsim.sensor.Sensor): + classes.append(getattr(jax_galsim.sensor, item)) + + tested = set() + for cls in classes: + obj = cls() + print(obj) + tested.add(cls.__name__) + _run_object_checks(obj, obj.__class__, "docs-methods") + _run_object_checks(obj, obj.__class__, "pickle-eval-repr") + _run_object_checks(obj, obj.__class__, "jax-compatible") + + assert { + "Sensor", + } <= tested From c85fe946c51d9bbfea03a70ac093248c1b5f3772 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 1 Dec 2023 16:10:57 -0600 Subject: [PATCH 69/85] ENH comments and tests --- jax_galsim/core/draw.py | 3 + jax_galsim/gsobject.py | 87 +++++++++++--- jax_galsim/image.py | 26 +++-- jax_galsim/interpolatedimage.py | 17 +++ jax_galsim/noise.py | 1 + jax_galsim/photon_array.py | 153 +++++++++++++++++++------ jax_galsim/random.py | 2 +- tests/jax/test_photon_array_masking.py | 109 ++++++++++++++++++ 8 files changed, 340 insertions(+), 58 deletions(-) create mode 100644 tests/jax/test_photon_array_masking.py diff --git a/jax_galsim/core/draw.py b/jax_galsim/core/draw.py index f54fa08c..e665fa74 100644 --- a/jax_galsim/core/draw.py +++ b/jax_galsim/core/draw.py @@ -11,6 +11,7 @@ def draw_by_xValue( gsobject, image, jacobian=jnp.eye(2), offset=jnp.zeros(2), flux_scaling=1.0 ): """Utility function to draw a real-space GSObject into an Image.""" + # putting the import here to avoid circular imports from jax_galsim import Image, PositionD # Applies flux scaling to compensate for pixel scale @@ -41,6 +42,7 @@ def draw_by_xValue( def draw_by_kValue(gsobject, image, jacobian=jnp.eye(2)): + # putting the import here to avoid circular imports from jax_galsim import Image, PositionD # Create an array of coordinates @@ -59,6 +61,7 @@ def draw_by_kValue(gsobject, image, jacobian=jnp.eye(2)): def apply_kImage_phases(offset, image, jacobian=jnp.eye(2)): + # putting the import here to avoid circular imports from jax_galsim import Image, PositionD # Create an array of coordinates diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index 3ac1559c..81f61e3d 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -642,10 +642,13 @@ def _determine_wcs(self, scale, wcs, image, default_wcs=None): @_wraps( _galsim.GSObject.drawImage, lax_description="""\ -The JAX-GalSim version of `drawImage` does not +The JAX-GalSim version of `drawImage` - - do extensive (any?) checking of the input settings. - - does not support the deprecated `surface_ops` argument + - does not do extensive (any?) checking of the input settings. + - uses a default of ``n_photons=None`` instead of ``n_photons=0`` + to indicate that the number of photons should be determined + from the flux and gain + - requires that the maxN option must be a constant """, ) def drawImage( @@ -676,12 +679,19 @@ def drawImage( save_photons=False, bandpass=None, setup_only=False, + surface_ops=None, ): from jax_galsim.box import Pixel from jax_galsim.convolve import Convolution, Convolve from jax_galsim.image import Image from jax_galsim.wcs import PixelScale + if surface_ops is not None: + from .deprecated import depr + + depr("surface_ops", 2.3, "photon_ops") + photon_ops = surface_ops + if image is not None and not isinstance(image, Image): raise TypeError("image is not an Image instance", image) @@ -1168,7 +1178,6 @@ def _calculate_nphotons(self, n_photons, poisson_flux, max_extra_noise, rng): lax_description="""\ The JAX-GalSim version of `makePhot` - - does not support the deprecated surface_ops argument - does little to no error checking on the inputs - uses a default of ``n_photons=None`` instead of ``n_photons=0`` to indicate that the number of photons should be determined @@ -1196,11 +1205,16 @@ def makePhot( poisson_flux = n_photons is None if n_photons is not None: + # n_photons is the length of an array so it is a python int and + # and thus a constant wrt to JIT Ntot = int(n_photons + 0.5) _, g = self._calculate_nphotons( n_photons, poisson_flux, max_extra_noise, rng ) else: + # here Ntot can be a traced value + # one thus must use the fixed_photon_array_size context manager + # to ensure that the size of the photon array is fixed if using JIT Ntot, g = self._calculate_nphotons(0.0, poisson_flux, max_extra_noise, rng) try: @@ -1230,13 +1244,11 @@ def makePhot( lax_description="""\ The JAX-GalSim version of `drawPhot` - - does not support the deprecated surface_ops argument - does little to no error checking on the inputs - uses a default of ``n_photons=None`` instead of ``n_photons=0`` to indicate that the number of photons should be determined from the flux and gain - - the maxN option requires the use of fixed photon array sizes or a fixed - number of photons + - requires that the maxN option must be a constant """, ) def drawPhot( @@ -1253,7 +1265,14 @@ def drawPhot( maxN=None, orig_center=PositionI(0, 0), local_wcs=None, + surface_ops=None, ): + if surface_ops is not None: + from .deprecated import depr + + depr("surface_ops", 2.3, "photon_ops") + photon_ops = surface_ops + # If n_photons is given and poisson_flux is None, poisson_flux = False if poisson_flux is None: poisson_flux = n_photons is None @@ -1270,11 +1289,17 @@ def drawPhot( raise TypeError("The sensor provided is not a Sensor instance") if n_photons is not None: + # n_photons is the length of an array so it is a python int and + # and thus a constant wrt to JIT Ntot = int(n_photons + 0.5) _, g = self._calculate_nphotons( n_photons, poisson_flux, max_extra_noise, rng ) else: + # here Ntot can be a traced value + # one thus must use the fixed_photon_array_size context manager + # or the maxN option to ensure that the size of the photon array is fixed if using JIT + Ntot, g = self._calculate_nphotons(0.0, poisson_flux, max_extra_noise, rng) g = jax.lax.cond( @@ -1288,12 +1313,19 @@ def drawPhot( if not add_to_image: image.setZero() + # both maxN and _JAX_GALSIM_PHOTON_ARRAY_SIZE can be used to fix the sizes + # of the photon arrays for use with JIT if maxN is not None and pa._JAX_GALSIM_PHOTON_ARRAY_SIZE is not None: + # if both maxN and _JAX_GALSIM_PHOTON_ARRAY_SIZE are set, we use the smaller + # of the two maxN = min(maxN, pa._JAX_GALSIM_PHOTON_ARRAY_SIZE) else: + # otherwise we use the one that is set maxN = pa._JAX_GALSIM_PHOTON_ARRAY_SIZE or maxN if maxN is None: + # if neither maxN nor _JAX_GALSIM_PHOTON_ARRAY_SIZE are set + # we drae Ntot photons all at once ( added_flux, _image, @@ -1318,6 +1350,9 @@ def drawPhot( 0.0, ) else: + # if maxN or _JAX_GALSIM_PHOTON_ARRAY_SIZE is set + # we draw a fixed number of photons at a time in a while + # loop until we have drawn Ntot photons ( photons, _rng, @@ -1354,13 +1389,13 @@ def drawPhot( "Non-default sensors that carry state are not yet supported in jax-galsim." ) - return added_flux, photons or None # Just in case Nleft is already 0. + return added_flux, photons @_wraps(_galsim.GSObject.shoot) def shoot(self, n_photons, rng=None): photons = pa.PhotonArray(n_photons) - if photons._x.shape[0] > 0: + if photons.x.shape[0] > 0: _rng = BaseDeviate(rng) self._shoot(photons, _rng) if rng is not None: @@ -1376,14 +1411,31 @@ def _shoot(self, photons, rng): @_wraps(_galsim.GSObject.applyTo) def applyTo(self, photon_array, local_wcs=None, rng=None): + # galsim does not deal with dxdz and dydz here - IDK why p1 = pa.PhotonArray(len(photon_array)) - if photon_array.hasAllocatedWavelengths(): - p1._wave = photon_array._wave - if photon_array.hasAllocatedPupil(): - p1._pupil_u = photon_array._pupil_u - p1._pupil_v = photon_array._pupil_v - if photon_array.hasAllocatedTimes(): - p1._time = photon_array._time + p1._wave = jax.lax.cond( + photon_array.hasAllocatedWavelengths(), + lambda pa_wave, p1_wave: pa_wave, + lambda pa_wave, p1_wave: p1_wave, + photon_array._wave, + p1._wave, + ) + p1._pupil_u, p1._pupil_v = jax.lax.cond( + photon_array.hasAllocatedPupil(), + lambda pa_u, pa_v, p1_u, p1_v: (pa_u, pa_v), + lambda pa_u, pa_v, p1_u, p1_v: (p1_u, p1_v), + photon_array._pupil_u, + photon_array._pupil_v, + p1._pupil_u, + p1._pupil_v, + ) + p1._time = jax.lax.cond( + photon_array.hasAllocatedTimes(), + lambda pa_time, p1_time: pa_time, + lambda pa_time, p1_time: p1_time, + photon_array._time, + p1._time, + ) obj = local_wcs.toImage(self) if local_wcs is not None else self obj._shoot(p1, rng) photon_array.convolve(p1, rng) @@ -1418,6 +1470,7 @@ def _draw_phot_while_loop_shoot( resume, added_flux, ): + """This helper function shoots thisN photons and accumulates them into the image.""" try: photons = obj.shoot(maxN, rng) except (GalSimError, NotImplementedError) as e: @@ -1482,6 +1535,8 @@ def _draw_phot_while_loop( sensor, orig_center, ): + """This helper function shoots photons until Ntot is reached.""" + def _cond_fun(args): ( photons, diff --git a/jax_galsim/image.py b/jax_galsim/image.py index c378573e..612cef10 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -11,20 +11,21 @@ from jax_galsim.utilities import parse_pos_args from jax_galsim.wcs import BaseWCS, PixelScale - -@_wraps( - _galsim.Image, - lax_description=""" +IMAGE_LAX_DOCS = """\ Contrary to GalSim native Image, this implementation does not support sharing of the underlying numpy array between different Images or Views. This is due to the fact that in JAX numpy arrays are immutable, so any operation applied to this Image will create a new jnp.ndarray. - In particular the followong methods will create a copy of the Image: - - Image.view() - - Image.subImage() +In particular the followong methods will create a copy of the Image: + - Image.view() + - Image.subImage() +""" + -""", +@_wraps( + _galsim.Image, + lax_description=IMAGE_LAX_DOCS, ) @register_pytree_node_class class Image(object): @@ -56,6 +57,7 @@ class Image(object): def __init__(self, *args, **kwargs): # this one is specific to jax-galsim and is used to disable bounds checking + # we use an underscore to denote that it is a private argument _check_bounds = kwargs.pop("_check_bounds", True) # Parse the args, kwargs @@ -458,6 +460,7 @@ def imag(self): self.array.imag, bounds=self.bounds, wcs=self.wcs, + # for real images, the imaginary part is always zero and immutable make_const=self._is_const or (not self.iscomplex), ) @@ -1154,6 +1157,8 @@ def tree_flatten(self): # Define the children nodes of the PyTree that need tracing children = (self.array, self.wcs) aux_data = {"dtype": self.dtype, "bounds": self.bounds, "isconst": self.isconst} + # other routines may add these attributes to images on the fly + # we have to include them here so that JAX knows how to handle them in jitting etc. if hasattr(self, "added_flux"): children += (self.added_flux,) if hasattr(self, "header"): @@ -1193,7 +1198,10 @@ def from_galsim(cls, galsim_image): return im -@_wraps(_galsim._Image) +@_wraps( + _galsim._Image, + lax_description=IMAGE_LAX_DOCS, +) def _Image(array, bounds, wcs): ret = Image.__new__(Image) ret.wcs = wcs diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index 12521dab..95eb8d68 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -878,6 +878,23 @@ def _shoot(self, photons, rng): ud = UniformDeviate(rng) photons.x = ud.generate(photons.x) + xedges[xinds] photons.y = ud.generate(photons.y) + yedges[yinds] + # this magic set of factors comes from the galsim C++ code in + # a few spots it is + # + # - the sign of the photon flux + # - the flux per photon = 1 - 2 neg / (pos + neg) + # - the total absolute flux in the image = (pos + neg) + # - the number of photons to draw = photons.size() + # + # If you inpack it all, then you get + # + # sign * (1 - 2 neg / (pos + neg)) * (pos + neg) / photons.size() + # = sign * (pos + neg - 2 neg) / (pos + neg) * (pos + neg) / photons.size() + # = sign * (pos - neg) / photons.size() + # + # So what we have is a sign that oscillates between -1 and 1 with each photon getting + # the flux of the object divided by the number of photons (which is inflated to get the total flux + # correct by other bits of the code) photons.flux = ( jnp.sign(img.array.ravel())[inds] * self._flux_per_photon() diff --git a/jax_galsim/noise.py b/jax_galsim/noise.py index 5d517d19..f04acd94 100644 --- a/jax_galsim/noise.py +++ b/jax_galsim/noise.py @@ -47,6 +47,7 @@ def __init__(self, rng=None): else: if not isinstance(rng, BaseDeviate): raise TypeError("rng must be a galsim.BaseDeviate instance.") + # we link the noise fields to the RNG state self._rng = rng @property diff --git a/jax_galsim/photon_array.py b/jax_galsim/photon_array.py index 583f7753..2b398091 100644 --- a/jax_galsim/photon_array.py +++ b/jax_galsim/photon_array.py @@ -41,6 +41,33 @@ def fixed_photon_array_size(size): - They always copy input data and operations on them always copy. - They (usually) do not do any type/size checking on input data. - They do not support indexed assignement directly on the attributes. + - The additional properties `dxdz`, `dydz`, `wavelength`, `pupil_u`, `pupil_v`, + and `time` are set to arrays of NaNs by default. They are thus always allocated. + However, the methods like `hasAllocatedAngles` etc. return false if the arrays + are all NaNs. + +Further, a context manager `fixed_photon_array_size` is provided to temporarily +set a fixed size for photon arrays. + + - This functionality is useful when apply JIT to operations that vary the + number of photons drawn using Poisson statistics. + - When using this context manager, the attribute `_nokeep` stores a boolean mask + indicating which photons are to be kept. + - The attribute `_num_keep` stores the number of photons to be kept. If you set + this attribute, the `_nokeep` mask is updated by sorting _nokeep so that things + to be kept are at the start, the first `_num_keep` photons are marked to be kept, + and finally the array is sorted back to its original order. + - You may get an error if you ask for more photons than the fixed size, but not always, + especially in JITed code. + - Operations on photon arrays with fixed sizes but different `_num_keep` values are not + defined and will not raise an error. + - The `.flux` property scales `._flux` by the ratio of the fixed size to the number of kept photons + and sets non-kept photons to zero flux. Setting `.flux` to `._flux` will break things badly. + - Profiles should always draw the full number of photons given by `.size()` or `len()` + so that they use fixed array sizes and things are JIT compatible. + +**The `_nokeep`, `_num_keep`, and associated methods are private and should not be set by hand +unless you know what you are doing!** """, ) @register_pytree_node_class @@ -60,15 +87,19 @@ def __init__( _nokeep=None, ): self._Ntot = _JAX_GALSIM_PHOTON_ARRAY_SIZE or N - # if ( - # _JAX_GALSIM_PHOTON_ARRAY_SIZE is not None - # and isinstance(N, int) - # and N > _JAX_GALSIM_PHOTON_ARRAY_SIZE - # ): - # raise GalSimValueError( - # f"The given photon array size {N} is larger than " - # f"the allowed total size {_JAX_GALSIM_PHOTON_ARRAY_SIZE}." - # ) + if _JAX_GALSIM_PHOTON_ARRAY_SIZE is not None: + try: + # this will raise a boolean conversion error in JAX + # which we swallow + err_cond = (N > _JAX_GALSIM_PHOTON_ARRAY_SIZE) or False + except Exception: + err_cond = False + + if err_cond: + raise GalSimValueError( + f"The given photon array size {N} is larger than " + f"the allowed total size {_JAX_GALSIM_PHOTON_ARRAY_SIZE}." + ) if _nokeep is not None: self._nokeep = _nokeep else: @@ -241,7 +272,10 @@ def _num_keep(self): @_num_keep.setter def _num_keep(self, num_keep): """Set the number of actual photons in the array.""" + sinds = jnp.argsort(self._nokeep) + self._sort_by_nokeep(sinds=sinds) self._nokeep = jnp.arange(self._Ntot) >= num_keep + self._set_self_at_inds(sinds) @property def x(self): @@ -278,7 +312,13 @@ def flux(self): @flux.setter def flux(self, value): - self._flux = self._flux.at[:].set(value) + self._flux = self._flux.at[:].set( + value + # scale it down to account for scaling in flux getter above + # this factor has to be computed after _nokeep is set above + # so that _num_keep is the right value + / (self._Ntot / self._num_keep) + ) @property def dxdz(self): @@ -457,9 +497,10 @@ def scaleXY(self, scale): return self - def _sort_by_nokeep(self): + def _sort_by_nokeep(self, sinds=None): # now sort things to keep to the left - sinds = jnp.argsort(self._nokeep) + if sinds is None: + sinds = jnp.argsort(self._nokeep) self._x = self._x.at[sinds].get() self._y = self._y.at[sinds].get() self._flux = self._flux.at[sinds].get() @@ -473,6 +514,20 @@ def _sort_by_nokeep(self): return self + def _set_self_at_inds(self, sinds): + self._x = self._x.at[sinds].set(self._x) + self._y = self._y.at[sinds].set(self._y) + self._flux = self._flux.at[sinds].set(self._flux) + self._nokeep = self._nokeep.at[sinds].set(self._nokeep) + self._dxdz = self._dxdz.at[sinds].set(self._dxdz) + self._dydz = self._dydz.at[sinds].set(self._dydz) + self._wave = self._wave.at[sinds].set(self._wave) + self._pupil_u = self._pupil_u.at[sinds].set(self._pupil_u) + self._pupil_v = self._pupil_v.at[sinds].set(self._pupil_v) + self._time = self._time.at[sinds].set(self._time) + + return self + @_wraps(_galsim.PhotonArray.assignAt) def assignAt(self, istart, rhs): from .deprecated import depr @@ -487,18 +542,8 @@ def assignAt(self, istart, rhs): "The given rhs does not fit into this array starting at %d" % istart, rhs, ) - self._x = self._x.at[istart : istart + rhs.size()].set(rhs.x) - self._y = self._y.at[istart : istart + rhs.size()].set(rhs.y) - self._flux = self._flux.at[istart : istart + rhs.size()].set(rhs.flux) - self._nokeep = self._nokeep.at[istart : istart + rhs.size()].set(rhs._nokeep) - self._dxdz = self._dxdz.at[istart : istart + rhs.size()].set(rhs.dxdz) - self._dydz = self._dydz.at[istart : istart + rhs.size()].set(rhs.dydz) - self._wave = self._wave.at[istart : istart + rhs.size()].set(rhs.wavelength) - self._pupil_u = self._pupil_u.at[istart : istart + rhs.size()].set(rhs.pupil_u) - self._pupil_v = self._pupil_v.at[istart : istart + rhs.size()].set(rhs.pupil_v) - self._time = self._time.at[istart : istart + rhs.size()].set(rhs.time) - - return self._sort_by_nokeep() + s = slice(istart, istart + rhs.size()) + return self._copyFrom(rhs, s, slice(None)) @_wraps( _galsim.PhotonArray.copyFrom, @@ -543,12 +588,39 @@ def _cond_set_indices(arr1, arr2, cond_val): arr2, ) + old_flux_ratio = self._Ntot / self._num_keep + + if do_xy or do_flux or do_other: + self._nokeep = self._nokeep.at[s1].set(rhs._nokeep.at[s2].get()) + + new_flux_ratio = self._Ntot / self._num_keep + if do_xy: self._x = self._x.at[s1].set(rhs.x.at[s2].get()) self._y = self._y.at[s1].set(rhs.y.at[s2].get()) if do_flux: - self._flux = self._flux.at[s1].set(rhs.flux.at[s2].get()) + # we first scale the existing fluxes to account for the change in num_keep + self._flux = ( + self._flux + # this factor gets us back to true flux + * old_flux_ratio + # this factor gets us back to the internal units + / new_flux_ratio + ) + + # next we assign the RHS fluxes accounting for the change in num_keep from the + # RHS to the new flux_ratio + self._flux = self._flux.at[s1].set( + rhs._flux.at[s2].get() + # these factors conserve the flux of the assigned photons + # gets us to the true flux of the photon + * (rhs._Ntot / rhs._num_keep) + # scale it back down to account for scaling later + # this factor has to be computed after _nokeep is set above + # so that _num_keep is the right value + / (self._Ntot / self._num_keep) + ) if do_other: self._dxdz = _cond_set_indices( @@ -570,20 +642,26 @@ def _cond_set_indices(arr1, arr2, cond_val): self._time, rhs.time, rhs.hasAllocatedTimes() ) - if do_xy or do_flux or do_other: - self._nokeep = self._nokeep.at[s1].set(rhs._nokeep.at[s2].get()) - - return self._sort_by_nokeep() + return self def _assign_from_categorical_index(self, cat_inds, cat_ind_to_assign, rhs): """Assign the contents of another `PhotonArray` to this one at locations where cat_ind == cat_ind_to_assign. """ msk = cat_ind_to_assign == cat_inds + old_flux_ratio = self._Ntot / self._num_keep + self._nokeep = jnp.where(msk, rhs._nokeep, self._nokeep) + new_flux_ratio = self._Ntot / self._num_keep + + rhs_flux_ratio = rhs._Ntot / rhs._num_keep + self._x = jnp.where(msk, rhs._x, self._x) self._y = jnp.where(msk, rhs._y, self._y) - self._flux = jnp.where(msk, rhs._flux, self._flux) - self._nokeep = jnp.where(msk, rhs._nokeep, self._nokeep) + self._flux = jnp.where( + msk, + rhs._flux * rhs_flux_ratio / new_flux_ratio, + self._flux * old_flux_ratio / new_flux_ratio, + ) self._dxdz = jnp.where(msk, rhs._dxdz, self._dxdz) self._dydz = jnp.where(msk, rhs._dydz, self._dydz) @@ -592,7 +670,7 @@ def _assign_from_categorical_index(self, cat_inds, cat_ind_to_assign, rhs): self._pupil_v = jnp.where(msk, rhs._pupil_v, self._pupil_v) self._time = jnp.where(msk, rhs._time, self._time) - return self._sort_by_nokeep() + return self def convolve(self, rhs, rng=None): """Convolve this `PhotonArray` with another. @@ -607,6 +685,13 @@ def convolve(self, rhs, rng=None): "PhotonArray.convolve with unequal size arrays", self_pa=self, rhs=rhs ) + # We need to make sure that the arrays are sorted by _nokeep before convolving + # we sort them back to their original order after convolving + self_sinds = jnp.argsort(self._nokeep) + rhs_sinds = jnp.argsort(rhs._nokeep) + self._sort_by_nokeep(sinds=self_sinds) + rhs._sort_by_nokeep(sinds=rhs_sinds) + rng = BaseDeviate(rng) rsinds = jrng.choice( rng._state.split_one(), @@ -684,6 +769,10 @@ def convolve(self, rhs, rng=None): self._y = self._y + rhs._y.at[sinds].get() self._flux = self._flux * rhs._flux.at[sinds].get() * self.size() + # sort the arrays back to their original order + self._set_self_at_inds(self_sinds) + rhs._set_self_at_inds(rhs_sinds) + return self def __repr__(self): diff --git a/jax_galsim/random.py b/jax_galsim/random.py index d45a9169..03735294 100644 --- a/jax_galsim/random.py +++ b/jax_galsim/random.py @@ -25,7 +25,7 @@ Seeding the JAX-GalSim PRNG can be done in a few ways: - - pass seed=None (This is equivalent to passing seed=0) + - pass seed=None (This is equivalent to passing seed=0.) - pass an integer seed (This method will throw errors if the integer is traced by JAX.) - pass another JAX-GalSim PRNG - pass a JAX PRNG key made via `jax.random.key`. diff --git a/tests/jax/test_photon_array_masking.py b/tests/jax/test_photon_array_masking.py new file mode 100644 index 00000000..cd293965 --- /dev/null +++ b/tests/jax/test_photon_array_masking.py @@ -0,0 +1,109 @@ +import jax_galsim +import jax.numpy as jnp +import numpy as np + + +def _gen_photon_array(n_photons, rng): + pa = jax_galsim.PhotonArray(n_photons) + pa.x = rng.uniform(size=n_photons) + pa.y = rng.uniform(size=n_photons) + pa.flux = rng.uniform(size=n_photons) + pa.wavelength = rng.uniform(size=n_photons) + pa.dxdz = rng.uniform(size=n_photons) + pa.dydz = rng.uniform(size=n_photons) + pa.pupil_u = rng.uniform(size=n_photons) + pa.pupil_v = rng.uniform(size=n_photons) + pa.time = rng.uniform(size=n_photons) + pa._nokeep = jnp.array( + rng.uniform(size=n_photons) > 0.5, + dtype=bool, + ) + return pa + + +def test_photon_array_masking_sort(): + rng = np.random.RandomState(seed=42) + pa = _gen_photon_array(10, rng) + pas = jax_galsim.PhotonArray(10) + pas = pas.copyFrom(pa) + + sinds = jnp.argsort(pas._nokeep) + pas._sort_by_nokeep() + np.testing.assert_allclose(pas.x, pa.x[sinds]) + np.testing.assert_allclose(pas.y, pa.y[sinds]) + np.testing.assert_allclose(pas.flux, pa.flux[sinds]) + np.testing.assert_allclose(pas.wavelength, pa.wavelength[sinds]) + np.testing.assert_allclose(pas.dxdz, pa.dxdz[sinds]) + np.testing.assert_allclose(pas.dydz, pa.dydz[sinds]) + np.testing.assert_allclose(pas.pupil_u, pa.pupil_u[sinds]) + np.testing.assert_allclose(pas.pupil_v, pa.pupil_v[sinds]) + np.testing.assert_allclose(pas.time, pa.time[sinds]) + np.testing.assert_allclose(pas._nokeep, pa._nokeep[sinds]) + + pas._set_self_at_inds(sinds) + np.testing.assert_allclose(pas.x, pa.x) + np.testing.assert_allclose(pas.y, pa.y) + np.testing.assert_allclose(pas.flux, pa.flux) + np.testing.assert_allclose(pas.wavelength, pa.wavelength) + np.testing.assert_allclose(pas.dxdz, pa.dxdz) + np.testing.assert_allclose(pas.dydz, pa.dydz) + np.testing.assert_allclose(pas.pupil_u, pa.pupil_u) + np.testing.assert_allclose(pas.pupil_v, pa.pupil_v) + np.testing.assert_allclose(pas.time, pa.time) + np.testing.assert_allclose(pas._nokeep, pa._nokeep) + + +def test_photon_array_masking_set_num_keep(): + rng = np.random.RandomState(seed=42) + pa = _gen_photon_array(10, rng) + pas = jax_galsim.PhotonArray(10) + pas = pas.copyFrom(pa) + + pas._num_keep = 2 + np.testing.assert_allclose(pas.x, pa.x) + np.testing.assert_allclose(pas.y, pa.y) + assert not np.allclose(pas.flux, pa.flux) + np.testing.assert_allclose(pas.wavelength, pa.wavelength) + np.testing.assert_allclose(pas.dxdz, pa.dxdz) + np.testing.assert_allclose(pas.dydz, pa.dydz) + np.testing.assert_allclose(pas.pupil_u, pa.pupil_u) + np.testing.assert_allclose(pas.pupil_v, pa.pupil_v) + np.testing.assert_allclose(pas.time, pa.time) + assert not np.allclose(pas._nokeep, pa._nokeep) + + assert pas._num_keep == 2 + assert pas._Ntot == 10 + assert pa._num_keep == pa._Ntot - np.sum(pa._nokeep) + assert pa._num_keep != pas._num_keep + + +def test_photon_array_masking_copyFrom_flux_handling(): + rng = np.random.RandomState(seed=42) + pal = _gen_photon_array(6, rng) + pal._nokeep = jnp.array( + [True] * 2 + [False] * 4, + dtype=bool, + ) + + par = _gen_photon_array(4, rng) + par._nokeep = jnp.array( + [False] * 1 + [True] * 3, + dtype=bool, + ) + + pa = jax_galsim.PhotonArray(10) + pa.copyFrom(pal, slice(0, 6)) + pa.copyFrom(par, slice(6, 10)) + + np.testing.assert_allclose(pa.x, np.hstack([pal.x, par.x])) + np.testing.assert_allclose(pa.y, np.hstack([pal.y, par.y])) + np.testing.assert_allclose(pa.flux, np.hstack([pal.flux, par.flux])) + np.testing.assert_allclose(pa.wavelength, np.hstack([pal.wavelength, par.wavelength])) + np.testing.assert_allclose(pa.dxdz, np.hstack([pal.dxdz, par.dxdz])) + np.testing.assert_allclose(pa.dydz, np.hstack([pal.dydz, par.dydz])) + np.testing.assert_allclose(pa.pupil_u, np.hstack([pal.pupil_u, par.pupil_u])) + np.testing.assert_allclose(pa.pupil_v, np.hstack([pal.pupil_v, par.pupil_v])) + np.testing.assert_allclose(pa.time, np.hstack([pal.time, par.time])) + np.testing.assert_allclose(pa._nokeep, np.hstack([pal._nokeep, par._nokeep])) + + assert np.sum(pal.flux) + np.sum(par.flux) == np.sum(pa.flux) From 05b61da5202c7ac1c99a893e363b1c0795f27480 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 1 Dec 2023 16:18:00 -0600 Subject: [PATCH 70/85] STY blacken --- tests/jax/test_photon_array_masking.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/jax/test_photon_array_masking.py b/tests/jax/test_photon_array_masking.py index cd293965..7ab3428e 100644 --- a/tests/jax/test_photon_array_masking.py +++ b/tests/jax/test_photon_array_masking.py @@ -1,7 +1,8 @@ -import jax_galsim import jax.numpy as jnp import numpy as np +import jax_galsim + def _gen_photon_array(n_photons, rng): pa = jax_galsim.PhotonArray(n_photons) @@ -98,7 +99,9 @@ def test_photon_array_masking_copyFrom_flux_handling(): np.testing.assert_allclose(pa.x, np.hstack([pal.x, par.x])) np.testing.assert_allclose(pa.y, np.hstack([pal.y, par.y])) np.testing.assert_allclose(pa.flux, np.hstack([pal.flux, par.flux])) - np.testing.assert_allclose(pa.wavelength, np.hstack([pal.wavelength, par.wavelength])) + np.testing.assert_allclose( + pa.wavelength, np.hstack([pal.wavelength, par.wavelength]) + ) np.testing.assert_allclose(pa.dxdz, np.hstack([pal.dxdz, par.dxdz])) np.testing.assert_allclose(pa.dydz, np.hstack([pal.dydz, par.dydz])) np.testing.assert_allclose(pa.pupil_u, np.hstack([pal.pupil_u, par.pupil_u])) From 888ba98981ca38addbbf5e30c7905bc0489ae0f1 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 1 Dec 2023 18:24:32 -0600 Subject: [PATCH 71/85] DOC add docs for routines for n photons --- jax_galsim/core/draw.py | 311 ++++++++++++++++++++++++++++------------ jax_galsim/gsobject.py | 15 +- 2 files changed, 229 insertions(+), 97 deletions(-) diff --git a/jax_galsim/core/draw.py b/jax_galsim/core/draw.py index e665fa74..8cddfc37 100644 --- a/jax_galsim/core/draw.py +++ b/jax_galsim/core/draw.py @@ -2,7 +2,6 @@ import jax import jax.numpy as jnp -import numpy as np from jax_galsim.random import PoissonDeviate @@ -88,12 +87,12 @@ def phase(kpos): ) -NPhotonsData = namedtuple( +_NPhotonsData = namedtuple( "NPhotonsData", [ "n_photons", "flux", - "flux_per_photon", + "flux_per_photon", # also called eta_factor below "max_sb", "rng", "poisson_flux", @@ -102,72 +101,250 @@ def phase(kpos): ) -def calculate_n_photons( +def calculate_mean_n_photons( flux, - eta_factor, + flux_per_photon, max_sb, - rng=None, - max_extra_noise=0, - poisson_flux=True, ): - """ - Calculate the number of photons to shoot for photon shooting. + """Calculate the mean number of photons to shoot for photon shooting. - This routine is pure Python and is not JAX-compatible. + This routine can be used to group objects together by the typical number of photons + they will shoot when drawing objects in bulk. Parameters: - flux: The flux of the GSObject (e.g., ``obj.flux``). - eta_factor: The flux per photon (e.g., ``obj._flux_per_photon``). - max_sb: The maximum surface brightness of the object (e.g., ``obj.max_sb``). - rng: If provided, a random number generator to use for photon shooting, - which may be any kind of `BaseDeviate` object. If ``rng`` is None, one - will be automatically created, using the time as a seed. - [default: None] - max_extra_noise: If provided, the allowed extra noise in each pixel when photon - shooting. This is only relevant if ``n_photons=0``, so the number of - photons is being automatically calculated. In that case, if the image - noise is dominated by the sky background, then you can get away with - using fewer shot photons than the full ``n_photons = flux``. - Essentially each shot photon can have a ``flux > 1``, which increases - the noise in each pixel. The ``max_extra_noise`` parameter specifies - how much extra noise per pixel is allowed because of this approximation. - A typical value for this might be ``max_extra_noise = sky_level / 100`` - where ``sky_level`` is the flux per pixel due to the sky. Note that - this uses a "variance" definition of noise, not a "sigma" definition. - [default: 0.] - poisson_flux: Whether to allow total object flux scaling to vary according to - Poisson statistics for ``n_photons`` samples when photon shooting. - [default: True, unless ``n_photons`` is given, in which case the default - is False] + flux: The flux of the GSObject (e.g., ``obj.flux``). + flux_per_photon: The flux per photon (e.g., ``obj._flux_per_photon``). + max_sb: The maximum surface brightness of the object (e.g., ``obj.max_sb``). Returns: n_photons: The number of photons. - g: The gain to use when shooting the photons. """ - n_photons, g, _ = _calculate_n_photons( - flux, - eta_factor, - max_sb, - rng, - max_extra_noise, - poisson_flux, + npd = _NPhotonsData( + n_photons=0.0, + poisson_flux=False, + max_extra_noise=0.0, + rng=None, + flux=flux, + flux_per_photon=flux_per_photon, + max_sb=max_sb, ) - return np.atleast_1d(n_photons).ravel()[0], np.atleast_1d(g).ravel()[0] + return _sample_zero(npd)[0] @jax.jit -def get_n_photons(n_photons_data): +def calculate_n_photons( + flux, + flux_per_photon, + max_sb, + n_photons=0, + rng=None, + max_extra_noise=0.0, + poisson_flux=True, +): + """Calculate the number of photons to shoot for an object when photon shooting according to the + code in ``galsim.GSObject._calculate_n_photons``. See the notes section below for more details. + + Parameters: + flux: The flux of the GSObject (e.g., ``obj.flux``). + flux_per_photon: The flux per photon (e.g., ``obj._flux_per_photon``). + max_sb: The maximum surface brightness of the object (e.g., ``obj.max_sb``). + n_photons: If provided, the number of photons to use for photon shooting. + If not provided (i.e. ``n_photons = 0``), use as many photons as + necessary to result in an image with the correct Poisson shot + noise for the object's flux. For positive definite profiles, this + is equivalent to ``n_photons = flux``. However, some profiles need + more than this because some of the shot photons are negative + (usually due to interpolants). [default: 0] + rng: If provided, a random number generator to use for photon shooting, + which may be any kind of `BaseDeviate` object. If ``rng`` is None, one + will be automatically created, using the time as a seed. + [default: None] + max_extra_noise: If provided, the allowed extra noise in each pixel when photon + shooting. This is only relevant if ``n_photons=0``, so the number of + photons is being automatically calculated. In that case, if the image + noise is dominated by the sky background, then you can get away with + using fewer shot photons than the full ``n_photons = flux``. + Essentially each shot photon can have a ``flux > 1``, which increases + the noise in each pixel. The ``max_extra_noise`` parameter specifies + how much extra noise per pixel is allowed because of this approximation. + A typical value for this might be ``max_extra_noise = sky_level / 100`` + where ``sky_level`` is the flux per pixel due to the sky. Note that + this uses a "variance" definition of noise, not a "sigma" definition. + [default: 0.] + poisson_flux: Whether to allow total object flux scaling to vary according to + Poisson statistics for ``n_photons`` samples when photon shooting. + [default: True, unless ``n_photons`` is given, in which case the default + is False] + + Returns: + n_photons: The number of photons. + g: The flux ratio to use. Combine with a pre-existing gain via ``g /= gain`` and then multiply + the flux per photon by ``g``. + rng: The final random number generator used. + + + Notes: + + It is easiest to simply copy the original code from GSObject._calculate_nphotons + into the doc string here in order to document what this function does. + + # the old doc string: + Calculate how many photons to shoot and what flux_ratio (called g) each one should + have in order to produce an image with the right S/N and total flux. + + This routine is normally called by `drawPhot`. + + Returns: + n_photons, g + + # For profiles that are positive definite, then N = flux. Easy. + # + # However, some profiles shoot some of their photons with negative flux. This means that + # we need a few more photons to get the right S/N = sqrt(flux). Take eta to be the + # fraction of shot photons that have negative flux. + # + # S^2 = (N+ - N-)^2 = (N+ + N- - 2N-)^2 = (Ntot - 2N-)^2 = Ntot^2(1 - 2 eta)^2 + # N^2 = Var(S) = (N+ + N-) = Ntot + # + # So flux = (S/N)^2 = Ntot (1-2eta)^2 + # Ntot = flux / (1-2eta)^2 + # + # However, if each photon has a flux of 1, then S = (1-2eta) Ntot = flux / (1-2eta). + # So in fact, each photon needs to carry a flux of g = 1-2eta to get the right + # total flux. + # + # That's all the easy case. The trickier case is when we are sky-background dominated. + # Then we can usually get away with fewer shot photons than the above. In particular, + # if the noise from the photon shooting is much less than the sky noise, then we can + # use fewer shot photons and essentially have each photon have a flux > 1. This is ok + # as long as the additional noise due to this approximation is "much less than" the + # noise we'll be adding to the image for the sky noise. + # + # Let's still have Ntot photons, but now each with a flux of g. And let's look at the + # noise we get in the brightest pixel that has a nominal total flux of Imax. + # + # The number of photons hitting this pixel will be Imax/flux * Ntot. + # The variance of this number is the same thing (Poisson counting). + # So the noise in that pixel is: + # + # N^2 = Imax/flux * Ntot * g^2 + # + # And the signal in that pixel will be: + # + # S = Imax/flux * (N+ - N-) * g which has to equal Imax, so + # g = flux / Ntot(1-2eta) + # N^2 = Imax/Ntot * flux / (1-2eta)^2 + # + # As expected, we see that lowering Ntot will increase the noise in that (and every + # other) pixel. + # The input max_extra_noise parameter is the maximum value of spurious noise we want + # to allow. + # + # So setting N^2 = Imax + nu, we get + # + # Ntot = flux / (1-2eta)^2 / (1 + nu/Imax) + # g = (1 - 2eta) * (1 + nu/Imax) + # + # Returns the total flux placed inside the image bounds by photon shooting. + # + + flux = self.flux + if flux == 0.0: + return 0, 1.0 + + # The _flux_per_photon property is (1-2eta) + # This factor will already be accounted for by the shoot function, so don't include + # that as part of our scaling here. There may be other adjustments though, so g=1 here. + eta_factor = self._flux_per_photon + mod_flux = flux / (eta_factor * eta_factor) + g = 1. + + # If requested, let the target flux value vary as a Poisson deviate + if poisson_flux: + # If we have both positive and negative photons, then the mix of these + # already gives us some variation in the flux value from the variance + # of how many are positive and how many are negative. + # The number of negative photons varies as a binomial distribution. + # = eta * Ntot * g + # = (1-eta) * Ntot * g + # = (1-2eta) * Ntot * g = flux + # Var(F-) = eta * (1-eta) * Ntot * g^2 + # F+ = Ntot * g - F- is not an independent variable, so + # Var(F+ - F-) = Var(Ntot*g - 2*F-) + # = 4 * Var(F-) + # = 4 * eta * (1-eta) * Ntot * g^2 + # = 4 * eta * (1-eta) * flux + # We want the variance to be equal to flux, so we need an extra: + # delta Var = (1 - 4*eta + 4*eta^2) * flux + # = (1-2eta)^2 * flux + absflux = abs(flux) + mean = eta_factor*eta_factor * absflux + pd = PoissonDeviate(rng, mean) + pd_val = pd() - mean + absflux + ratio = pd_val / absflux + g *= ratio + mod_flux *= ratio + + if n_photons == 0.: + n_photons = abs(mod_flux) + if max_extra_noise > 0.: + gfactor = 1. + max_extra_noise / abs(self.max_sb) + n_photons /= gfactor + g *= gfactor + + # Make n_photons an integer. + iN = int(n_photons + 0.5) + + return iN, g + """ + + n_photons_data = _NPhotonsData( + n_photons=n_photons, + poisson_flux=poisson_flux, + max_extra_noise=max_extra_noise, + rng=rng, + flux=flux, + flux_per_photon=flux_per_photon, + max_sb=max_sb, + ) + _n_photons, g, _rng = jax.lax.cond( n_photons_data.n_photons == 0.0, _sample_zero, _sample_nonzero, n_photons_data, ) + if rng is not None: + rng._state = _rng._state + return _n_photons, g, rng + + +@jax.jit +def _sample_zero(n_photons_data): + _n_photons, _g, _rng = jax.lax.cond( + n_photons_data.flux == 0.0, + lambda flux, eta_factor, max_sb, poisson_flux, max_extra_noise, rng: ( + 0, + 1.0, + rng, + ), + lambda flux, eta_factor, max_sb, poisson_flux, max_extra_noise, rng: _calculate_n_photons_flux_nonzero( + flux, eta_factor, max_sb, poisson_flux, max_extra_noise, rng + ), + n_photons_data.flux, + n_photons_data.flux_per_photon, + n_photons_data.max_sb, + n_photons_data.poisson_flux, + n_photons_data.max_extra_noise, + n_photons_data.rng, + ) if n_photons_data.rng is not None: n_photons_data.rng._state = _rng._state - return _n_photons, g, n_photons_data.rng + return _n_photons, _g, n_photons_data.rng + +@jax.jit def _sample_nonzero(n_photons_data): g, _rng = jax.lax.cond( n_photons_data.poisson_flux, @@ -183,50 +360,6 @@ def _sample_nonzero(n_photons_data): return vals -@jax.jit -def _sample_zero(n_photons_data): - Ntot, g, _rng = _calculate_n_photons( - n_photons_data.flux, - n_photons_data.flux_per_photon, - n_photons_data.max_sb, - rng=n_photons_data.rng, - max_extra_noise=n_photons_data.max_extra_noise, - poisson_flux=n_photons_data.poisson_flux, - ) - return Ntot, g, _rng - - -@jax.jit -def _calculate_n_photons( - flux, - eta_factor, - max_sb, - rng, - max_extra_noise, - poisson_flux, -): - _n_photons, _g, _rng = jax.lax.cond( - flux == 0.0, - lambda flux, eta_factor, max_sb, poisson_flux, max_extra_noise, rng: ( - 0, - 1.0, - rng, - ), - lambda flux, eta_factor, max_sb, poisson_flux, max_extra_noise, rng: _calculate_n_photons_flux_nonzero( - flux, eta_factor, max_sb, poisson_flux, max_extra_noise, rng - ), - flux, - eta_factor, - max_sb, - poisson_flux, - max_extra_noise, - rng, - ) - if rng is not None: - rng._state = _rng._state - return _n_photons, _g, rng - - @jax.jit def _sample_poisson_flux(flux, eta_factor, rng): # If we have both positive and negative photons, then the mix of these diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index 81f61e3d..b007f0fd 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -7,7 +7,7 @@ from jax._src.numpy.util import _wraps import jax_galsim.photon_array as pa -from jax_galsim.core.draw import NPhotonsData, get_n_photons +from jax_galsim.core.draw import calculate_n_photons from jax_galsim.core.utils import is_equal_with_arrays from jax_galsim.errors import ( GalSimError, @@ -1159,16 +1159,15 @@ def _drawKImage( @_wraps(_galsim.GSObject._calculate_nphotons) def _calculate_nphotons(self, n_photons, poisson_flux, max_extra_noise, rng): - npd = NPhotonsData( + n_photons, g, _rng = calculate_n_photons( + self.flux, + self._flux_per_photon, + self.max_sb, n_photons=n_photons, - poisson_flux=poisson_flux, - max_extra_noise=max_extra_noise, rng=rng, - flux=self.flux, - flux_per_photon=self._flux_per_photon, - max_sb=self.max_sb, + max_extra_noise=max_extra_noise, + poisson_flux=poisson_flux, ) - n_photons, g, _rng = get_n_photons(npd) if rng is not None: rng._state = _rng._state return n_photons, g From e3f37b9a4e4eb502a3aedc522dcd4bd894b25de8 Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 2 Dec 2023 05:36:37 -0600 Subject: [PATCH 72/85] BUG hashable type for jit --- tests/jax/test_jitting.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/jax/test_jitting.py b/tests/jax/test_jitting.py index 77d03f65..e4955e30 100644 --- a/tests/jax/test_jitting.py +++ b/tests/jax/test_jitting.py @@ -335,7 +335,7 @@ def _build_and_draw(hlr, fwhm, jit=True, maxn=False): final.max_sb, poisson_flux=False, rng=galsim.BaseDeviate(1234), - )[0] + )[0].item() gain = 1.0 if jit: if maxn: @@ -396,10 +396,10 @@ def _draw_it_jit_maxn(obj, n, nphotons, gain): with time_code_block("jit"): img4 = _build_and_draw(0.5, 1.0) - with time_code_block("warmup jit"): + with time_code_block("warmup jit+maxn"): img5 = _build_and_draw(0.5, 1.0, maxn=True) - with time_code_block("jit"): + with time_code_block("jit+maxn"): img6 = _build_and_draw(0.5, 1.0, maxn=True) np.testing.assert_allclose(img1.array, img2.array) From 5aaf4351297fbe16291437673566af0956d4178b Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 2 Dec 2023 05:50:04 -0600 Subject: [PATCH 73/85] DOC update change log --- CHANGELOG.md | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8b3a6818..844574b1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,19 +11,22 @@ * `Shear` * `Convolve` * `InterpolatedImage` and `Interpolant` + * `PhotonArray` + * `Sensor` * Added implementation of fundamental operations: * `drawImage` * `drawReal` * `drawFFT` * `drawKImage` + * `makePhot` + * `drawPhot` * Added implementation of simple light profiles: - * `Gaussian`, `Exponential`, `Pixel`, `Box`, `Moffat` + * `Gaussian`, `Exponential`, `Pixel`, `Box`, `Moffat`, `DeltaFunction` * Added implementation of simple WCS: - * `PixelScale`, `OffsetWCS`, `JacobianWCS`, `AffineTransform`, `ShearWCS`, `OffsetShearWCS` + * `PixelScale`, `OffsetWCS`, `JacobianWCS`, `AffineTransform`, `ShearWCS`, `OffsetShearWCS`, `GSFitsWCS`, `FitsWCS`, `TanWCS` * Added automated suite of tests against reference GalSim * Added support for the `galsim.fits` module * Added a `from_galsim` method to convert from GalSim objects to JAX-GalSim objects * Caveats - * Real space convolution and photon shooting methods are not - yet implemented in drawImage. + * Real space convolution are not yet implemented in `drawImage``. From c8dfa26d9ad33409efc24f3cdfe3dd2f32da72cc Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 2 Dec 2023 05:52:14 -0600 Subject: [PATCH 74/85] DOC update change log --- CHANGELOG.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 844574b1..7f447e6d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,9 @@ * `InterpolatedImage` and `Interpolant` * `PhotonArray` * `Sensor` + * `AngleUnit`, `Angle`, and `CelestialCoord` + * `BaseDeviate` and child classes + * `BaseNoise` and child classes * Added implementation of fundamental operations: * `drawImage` * `drawReal` @@ -24,7 +27,7 @@ * `Gaussian`, `Exponential`, `Pixel`, `Box`, `Moffat`, `DeltaFunction` * Added implementation of simple WCS: * `PixelScale`, `OffsetWCS`, `JacobianWCS`, `AffineTransform`, `ShearWCS`, `OffsetShearWCS`, `GSFitsWCS`, `FitsWCS`, `TanWCS` - * Added automated suite of tests against reference GalSim + * Added automated suite of tests using the reference GalSim and LSSTDESC-Coord test suites * Added support for the `galsim.fits` module * Added a `from_galsim` method to convert from GalSim objects to JAX-GalSim objects From 595482e879b88f1855bcf3898a6903fd4ad16d29 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Wed, 13 Dec 2023 12:34:23 -0600 Subject: [PATCH 75/85] ENH first pass at code review response Co-authored-by: Ismael Mendoza <11745764+ismael-mendoza@users.noreply.github.com> --- jax_galsim/convolve.py | 2 +- jax_galsim/core/draw.py | 16 ---------------- jax_galsim/gsobject.py | 4 ++++ jax_galsim/image.py | 4 ++-- jax_galsim/interpolant.py | 6 +++++- jax_galsim/interpolatedimage.py | 4 ++-- jax_galsim/photon_array.py | 4 ++-- tests/galsim_tests_config.yaml | 4 ++-- 8 files changed, 18 insertions(+), 26 deletions(-) diff --git a/jax_galsim/convolve.py b/jax_galsim/convolve.py index 6902168f..855807a3 100644 --- a/jax_galsim/convolve.py +++ b/jax_galsim/convolve.py @@ -311,7 +311,7 @@ def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): def _shoot(self, photons, rng): self.obj_list[0]._shoot(photons, rng) # It may be necessary to shuffle when convolving because we do not have a - # gaurantee that the convolvee's photons are uncorrelated, e.g., they might + # guarantee that the convolvee's photons are uncorrelated, e.g., they might # both have their negative ones at the end. # However, this decision is now made by the convolve method. for obj in self.obj_list[1:]: diff --git a/jax_galsim/core/draw.py b/jax_galsim/core/draw.py index 8cddfc37..86eb2cf4 100644 --- a/jax_galsim/core/draw.py +++ b/jax_galsim/core/draw.py @@ -362,22 +362,6 @@ def _sample_nonzero(n_photons_data): @jax.jit def _sample_poisson_flux(flux, eta_factor, rng): - # If we have both positive and negative photons, then the mix of these - # already gives us some variation in the flux value from the variance - # of how many are positive and how many are negative. - # The number of negative photons varies as a binomial distribution. - # = eta * Ntot * g - # = (1-eta) * Ntot * g - # = (1-2eta) * Ntot * g = flux - # Var(F-) = eta * (1-eta) * Ntot * g^2 - # F+ = Ntot * g - F- is not an independent variable, so - # Var(F+ - F-) = Var(Ntot*g - 2*F-) - # = 4 * Var(F-) - # = 4 * eta * (1-eta) * Ntot * g^2 - # = 4 * eta * (1-eta) * flux - # We want the variance to be equal to flux, so we need an extra: - # delta Var = (1 - 4*eta + 4*eta^2) * flux - # = (1-2eta)^2 * flux absflux = jnp.abs(flux) mean = eta_factor * eta_factor * absflux pd = PoissonDeviate(rng, mean) diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index b007f0fd..2bdce318 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -1225,6 +1225,8 @@ def makePhot( "Deconvolve objects.\nOriginal error: %r" % (e) ) + # jax.lax.cond doesn't evaluate both of the branches + # and this call can save computations for common cases. photons = jax.lax.cond( g == 1.0, lambda photons, g: photons, @@ -1301,6 +1303,8 @@ def drawPhot( Ntot, g = self._calculate_nphotons(0.0, poisson_flux, max_extra_noise, rng) + # this call can save computations for the + # common case of gain == 1.0 g = jax.lax.cond( gain != 1.0, lambda g, gain: g / gain, diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 612cef10..541cddc9 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -17,7 +17,7 @@ This is due to the fact that in JAX numpy arrays are immutable, so any operation applied to this Image will create a new jnp.ndarray. -In particular the followong methods will create a copy of the Image: +In particular the following methods will create a copy of the Image: - Image.view() - Image.subImage() """ @@ -496,7 +496,7 @@ def get_pixel_centers(self): def _make_empty(self, shape, dtype): """Helper function to make an empty numpy array of the given shape.""" if np.prod(shape) == 0: - # galsim forces degenrate images to have at least 1 pixel + # galsim forces degenerate images to have at least 1 pixel return jnp.zeros(shape=(1, 1), dtype=dtype) else: return jnp.zeros(shape=shape, dtype=dtype) diff --git a/jax_galsim/interpolant.py b/jax_galsim/interpolant.py index efaf1e73..a9c063bd 100644 --- a/jax_galsim/interpolant.py +++ b/jax_galsim/interpolant.py @@ -277,7 +277,7 @@ def _shoot_cdf(self): px = jnp.abs(self._xval_noraise(jnp.abs(x))) dx = x[1] - x[0] # cumulative trapezoidal rule - # see scipy.integrate.cumulative_trapezoidal + # see scipy.integrate.cumulative_trapezoid cdfx = jnp.concatenate( [jnp.array([0]), jnp.cumsum((px[1:] + px[:-1]) * 0.5 * dx)] ) @@ -285,6 +285,8 @@ def _shoot_cdf(self): return x, cdfx def _shoot(self, photons, rng): + # this is a generic method used for kernels without easy + # analytic ways to draw from them (currently Cubic, Quintic, and Lanczos) x, cdfx = self._shoot_cdf ud = UniformDeviate(rng) ux = ud.generate(photons.x) @@ -352,6 +354,8 @@ def ixrange(self): return 0 def _shoot(self, photons, rng): + # PhotonArray class does the correct thing here, setting + # the whole array to zero photons.x = 0.0 photons.y = 0.0 photons.flux = 1.0 / photons.size() diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index 95eb8d68..3ca1cdde 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -886,7 +886,7 @@ def _shoot(self, photons, rng): # - the total absolute flux in the image = (pos + neg) # - the number of photons to draw = photons.size() # - # If you inpack it all, then you get + # If you unpack it all, then you get # # sign * (1 - 2 neg / (pos + neg)) * (pos + neg) / photons.size() # = sign * (pos + neg - 2 neg) / (pos + neg) * (pos + neg) / photons.size() @@ -902,7 +902,7 @@ def _shoot(self, photons, rng): / photons.size() ) - # accounnt for offset - we add the offset to get to + # account for offset - we add the offset to get to # image pixels in the xValue method # here we generate photons from the image and # so we need to subtract it to get back to get to x as diff --git a/jax_galsim/photon_array.py b/jax_galsim/photon_array.py index 2b398091..59493c06 100644 --- a/jax_galsim/photon_array.py +++ b/jax_galsim/photon_array.py @@ -619,7 +619,7 @@ def _cond_set_indices(arr1, arr2, cond_val): # scale it back down to account for scaling later # this factor has to be computed after _nokeep is set above # so that _num_keep is the right value - / (self._Ntot / self._num_keep) + / new_flux_ratio ) if do_other: @@ -994,7 +994,7 @@ def _add_photons_to_image(x, y, flux, xmin, ymin, arr): xinds = jnp.floor(x - xmin + 0.5).astype(int) yinds = jnp.floor(y - ymin + 0.5).astype(int) # the jax documentation says that they drop out of bounds indices, - # but the galsim unit tests reveal that withoout the check below, + # but the galsim unit tests reveal that without the check below, # the indices are not dropped. # I think maybe it is only indices beyond the end of the array that are # dropped and negative indices wrap around diff --git a/tests/galsim_tests_config.yaml b/tests/galsim_tests_config.yaml index 9860c962..2deca931 100644 --- a/tests/galsim_tests_config.yaml +++ b/tests/galsim_tests_config.yaml @@ -19,8 +19,8 @@ enabled_tests: - test_random.py - test_noise.py - test_image.py - - test_phpton_array.py - - "*" + - test_photon_array.py + - "*" # means all tests from galsim coord: - test_angle.py - test_angleunit.py From 2dd3ee3e405f19a64c5b86d04066e9b694e9270d Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Fri, 15 Dec 2023 15:11:35 -0600 Subject: [PATCH 76/85] Update tests/jax/test_ref_impl.py Co-authored-by: Ismael Mendoza <11745764+ismael-mendoza@users.noreply.github.com> --- tests/jax/test_ref_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/jax/test_ref_impl.py b/tests/jax/test_ref_impl.py index 7af83b8a..e6cea66f 100644 --- a/tests/jax/test_ref_impl.py +++ b/tests/jax/test_ref_impl.py @@ -38,7 +38,7 @@ def conv2(galsim): conv2(ref_galsim), conv2(jax_galsim), 5, - err_msg=" GSObject Convolve(psf,pixel) disagrees with expected result", + err_msg="GSObject Convolve(psf,pixel) disagrees with expected result", ) From abb9e381fb4afd105e64cc5368abb84e8d157348 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Fri, 15 Dec 2023 15:12:20 -0600 Subject: [PATCH 77/85] Apply suggestions from code review Co-authored-by: Ismael Mendoza <11745764+ismael-mendoza@users.noreply.github.com> --- tests/jax/test_photon_shooting_jax.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/tests/jax/test_photon_shooting_jax.py b/tests/jax/test_photon_shooting_jax.py index 94aa14ec..676e70f4 100644 --- a/tests/jax/test_photon_shooting_jax.py +++ b/tests/jax/test_photon_shooting_jax.py @@ -154,19 +154,6 @@ def test_photon_shooting_jax_offset(offset): ), ) - # code for testing - # if not np.allclose(img_fft.array, img_phot.array, rtol=rtol, atol=atol): - # import proplot as pplt - - # fig, axs = pplt.subplots(nrows=1, ncols=3) - # axs[0].imshow(img_fft.array, origin="lower") - # axs[1].imshow(img_phot.array, origin="lower") - # axs[2].imshow(img_fft.array - img_phot.array, origin="lower") - # fig.show() - - # import pdb - - # pdb.set_trace() np.testing.assert_almost_equal( jnp.argmax(img_fft.array), From abda85ca504df8c86a86d5105e8ba9caf8b374d9 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Fri, 15 Dec 2023 15:12:32 -0600 Subject: [PATCH 78/85] Update tests/jax/test_photon_shooting_jax.py Co-authored-by: Ismael Mendoza <11745764+ismael-mendoza@users.noreply.github.com> --- tests/jax/test_photon_shooting_jax.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/tests/jax/test_photon_shooting_jax.py b/tests/jax/test_photon_shooting_jax.py index 676e70f4..0c2f89ac 100644 --- a/tests/jax/test_photon_shooting_jax.py +++ b/tests/jax/test_photon_shooting_jax.py @@ -32,18 +32,6 @@ def test_photon_shooting_jax_make_from_image_notranspose(): image2 = jax_galsim.Image(np.zeros_like(ref_array)) photons.addTo(image2) - # code for testing - # if not np.allclose(image2.array, ref_array) and False: - # import proplot as pplt - - # fig, axs = pplt.subplots(nrows=1, ncols=3) - # axs[0].imshow(ref_array) - # axs[1].imshow(image2.array) - # axs[2].imshow(image2.array - ref_array) - - # import pdb - - # pdb.set_trace() np.testing.assert_allclose(image2.array, ref_array) From 89ac8fb69e22ea7d28b2b4471a69826b3df7ac0f Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 16 Dec 2023 06:20:07 -0600 Subject: [PATCH 79/85] STY blacken --- jax_galsim/gsobject.py | 2 +- tests/jax/test_photon_shooting_jax.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index 2bdce318..9a9adad8 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -1303,7 +1303,7 @@ def drawPhot( Ntot, g = self._calculate_nphotons(0.0, poisson_flux, max_extra_noise, rng) - # this call can save computations for the + # this call can save computations for the # common case of gain == 1.0 g = jax.lax.cond( gain != 1.0, diff --git a/tests/jax/test_photon_shooting_jax.py b/tests/jax/test_photon_shooting_jax.py index 0c2f89ac..a406e8c3 100644 --- a/tests/jax/test_photon_shooting_jax.py +++ b/tests/jax/test_photon_shooting_jax.py @@ -32,7 +32,6 @@ def test_photon_shooting_jax_make_from_image_notranspose(): image2 = jax_galsim.Image(np.zeros_like(ref_array)) photons.addTo(image2) - np.testing.assert_allclose(image2.array, ref_array) @@ -142,7 +141,6 @@ def test_photon_shooting_jax_offset(offset): ), ) - np.testing.assert_almost_equal( jnp.argmax(img_fft.array), jnp.argmax(img_phot.array), From 4d366a652293075909c27e5235e8ed4165907d76 Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 16 Dec 2023 06:41:43 -0600 Subject: [PATCH 80/85] ENH respond to more CR --- jax_galsim/gsobject.py | 4 +++- jax_galsim/interpolant.py | 1 + jax_galsim/photon_array.py | 4 ++++ tests/jax/test_interpolatedimage_utils.py | 11 +---------- tests/jax/test_jitting.py | 6 +++++- 5 files changed, 14 insertions(+), 12 deletions(-) diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index 9a9adad8..11c62ea2 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -648,7 +648,8 @@ def _determine_wcs(self, scale, wcs, image, default_wcs=None): - uses a default of ``n_photons=None`` instead of ``n_photons=0`` to indicate that the number of photons should be determined from the flux and gain - - requires that the maxN option must be a constant + - requires that the maxN option be a constant since PhotonArrays are allocated + with `maxN` photons when this option is used and arrays in JAX must have static sizes. """, ) def drawImage( @@ -1387,6 +1388,7 @@ def drawPhot( image._array = _image._array # TODO: how to update the sensor? + # https://github.com/GalSim-developers/JAX-GalSim/issues/85 if sensor.__class__ is not Sensor: raise GalSimNotImplementedError( "Non-default sensors that carry state are not yet supported in jax-galsim." diff --git a/jax_galsim/interpolant.py b/jax_galsim/interpolant.py index a9c063bd..fd8b560c 100644 --- a/jax_galsim/interpolant.py +++ b/jax_galsim/interpolant.py @@ -468,6 +468,7 @@ def _comp_fluxes(self): narr = jnp.arange(n) val = (si(jnp.pi * (narr + 1)) - si(jnp.pi * (narr))) / jnp.pi + # this computed flux is a constant and so we do not propagate gradients self._positive_flux = ( jax.lax.stop_gradient(jnp.sum(jnp.where(val > 0, val, 0.0))) * 2.0 ) diff --git a/jax_galsim/photon_array.py b/jax_galsim/photon_array.py index 59493c06..66874fd0 100644 --- a/jax_galsim/photon_array.py +++ b/jax_galsim/photon_array.py @@ -302,6 +302,8 @@ def y(self, value): @property def flux(self): """The flux of the photons.""" + # we use jax.lax.cond to save some multiplications when + # there are no masked photos. return jax.lax.cond( self._Ntot == self._num_keep, lambda flux, ratio: flux, @@ -838,6 +840,8 @@ def addTo(self, image): _arr, _flux_sum = _add_photons_to_image( self._x, self._y, + # this computation is the same as self.flux, but we've left it duplicated here + # so that we don't change this line to self._flux only by accident in the future jnp.where(self._nokeep, 0.0, self._flux) * self._Ntot / self._num_keep, image.bounds.xmin, image.bounds.ymin, diff --git a/tests/jax/test_interpolatedimage_utils.py b/tests/jax/test_interpolatedimage_utils.py index c76278ef..7043615f 100644 --- a/tests/jax/test_interpolatedimage_utils.py +++ b/tests/jax/test_interpolatedimage_utils.py @@ -330,6 +330,7 @@ def test_interpolatedimage_utils_jax_galsim_fft_vs_galsim_fft(n): ], ) def test_interpolatedimage_interpolant_sample(interp): + """Sample from the interpolation kernel and compare a histogram of the samples to the expected fistribution.""" from jax_galsim.photon_array import PhotonArray from jax_galsim.random import BaseDeviate @@ -353,13 +354,3 @@ def test_interpolatedimage_interpolant_sample(interp): fdev = np.abs(h - yv) / np.abs(np.sqrt(yv)) np.testing.assert_allclose(fdev[msk], 0, rtol=0, atol=5.0, err_msg=f"{interp}") np.testing.assert_allclose(fdev[~msk], 0, rtol=0, atol=15.0, err_msg=f"{interp}") - - if interp.__class__.__name__ in ["Quintic", "Lanczos"] and False: - import proplot as pplt - - fig, axs = pplt.subplots(figsize=(6, 6)) - axs.hist(photons.x, bins=500, log=True) - axs.plot(mid, yv, color="k") - axs.format(title=interp.__class__.__name__) - fig.show() - breakpoint() diff --git a/tests/jax/test_jitting.py b/tests/jax/test_jitting.py index e4955e30..1b2cc337 100644 --- a/tests/jax/test_jitting.py +++ b/tests/jax/test_jitting.py @@ -213,14 +213,18 @@ def _draw_it_jit(obj, n, nfft): with time_code_block("warmup no-jit"): img = _build_and_draw(0.5, 1.0, jit=False) + np.testing.assert_array_almost_equal(img.array.sum(), 1000.0, 0) + with time_code_block("no-jit"): img = _build_and_draw(0.5, 1.0, jit=False) + np.testing.assert_array_almost_equal(img.array.sum(), 1000.0, 0) with time_code_block("warmup jit"): img = _build_and_draw(0.5, 1.0) + np.testing.assert_array_almost_equal(img.array.sum(), 1000.0, 0) + with time_code_block("jit"): img = _build_and_draw(0.5, 1.0) - np.testing.assert_array_almost_equal(img.array.sum(), 1000.0, 0) From 7f3432432a4a8c08fe1cfa85f85f9bf779981e11 Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 16 Dec 2023 07:03:07 -0600 Subject: [PATCH 81/85] ENH respond to more CR --- jax_galsim/photon_array.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/jax_galsim/photon_array.py b/jax_galsim/photon_array.py index 66874fd0..fda8bda8 100644 --- a/jax_galsim/photon_array.py +++ b/jax_galsim/photon_array.py @@ -694,6 +694,28 @@ def convolve(self, rhs, rng=None): self._sort_by_nokeep(sinds=self_sinds) rhs._sort_by_nokeep(sinds=rhs_sinds) + # When two photon arrays are convolved, you basically perturb the positions of one + # by adding the positions of the other. For example, if you have a delta function + # and want to convolve with a Gaussian, then the photon arrays are an array of zeros + # for the delta function and an array of Gaussian draws for the Gaussian. The convolution + # is then implemented by adding the positions of the two arrays. + + # The edge case here is if the photons in anb array are correlated. for example, if + # you draw photons from a sum of two profiles, you could have the photons from one + # of the components only at the start of the array and the photons from the other + # component only at the end of the array like this + # + # [A, A, A, ..., A, B, B, B. ..., B] + # + # where A and B represent which component the photon came from. If you convolve two + # photon arrays where both arrays have intenral correlations in the ordering of the + # photons, then you need to randomly sort one of the arrays before the convolution. + # Otherwise you won't properly be adding a random draew from one profile to the other. + + # the indexing and PRNG code snippets below handle this case of convolving two internally + # correlated photon arrays. + + # these are indicies that randomly sort the RHS's photons. rng = BaseDeviate(rng) rsinds = jrng.choice( rng._state.split_one(), @@ -701,8 +723,14 @@ def convolve(self, rhs, rng=None): shape=(self.size(),), replace=False, ) + # these indices do not randomly sort the RHS's photons nrsinds = jnp.arange(self.size()) + # now we randomly sort if both arrays are internally correlated + # however there is a catch. The RHS may not be keeping all of its photons + # (i.e., rhs._nokeep is True for some photons). In this case, we additionally + # sort the random indices by the value of rhs._nokeep so that the photons to be + # kept are still at the front of the array but are in a new random order. sinds = jax.lax.cond( self._is_corr & rhs._is_corr, lambda nrsinds, rsinds: rsinds.at[ From 7b60ab9bf36a6b75216333d2b4eefc991ee7a2e2 Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 16 Dec 2023 08:11:42 -0600 Subject: [PATCH 82/85] refactor repeated logicx --- jax_galsim/photon_array.py | 39 +++++++++++++------------------------- 1 file changed, 13 insertions(+), 26 deletions(-) diff --git a/jax_galsim/photon_array.py b/jax_galsim/photon_array.py index fda8bda8..c0eb1136 100644 --- a/jax_galsim/photon_array.py +++ b/jax_galsim/photon_array.py @@ -330,12 +330,7 @@ def dxdz(self): @dxdz.setter def dxdz(self, value): self._dxdz = self._dxdz.at[:].set(value) - self._dydz = jax.lax.cond( - jnp.any(jnp.isfinite(self._dxdz)) & jnp.all(~jnp.isfinite(self._dydz)), - lambda dydz: jnp.zeros_like(dydz), - lambda dydz: dydz, - self._dydz, - ) + self._dydz = _zero_if_needed_on_set(self._dxdz, self._dydz) @property def dydz(self): @@ -345,12 +340,7 @@ def dydz(self): @dydz.setter def dydz(self, value): self._dydz = self._dydz.at[:].set(value) - self._dxdz = jax.lax.cond( - jnp.any(jnp.isfinite(self._dydz)) & jnp.all(~jnp.isfinite(self._dxdz)), - lambda dxdz: jnp.zeros_like(dxdz), - lambda dxdz: dxdz, - self._dxdz, - ) + self._dxdz = _zero_if_needed_on_set(self._dydz, self._dxdz) @property def wavelength(self): @@ -369,13 +359,7 @@ def pupil_u(self): @pupil_u.setter def pupil_u(self, value): self._pupil_u = self._pupil_u.at[:].set(value) - self._pupil_v = jax.lax.cond( - jnp.any(jnp.isfinite(self._pupil_u)) - & jnp.all(~jnp.isfinite(self._pupil_v)), - lambda pupil_v: jnp.zeros_like(pupil_v), - lambda pupil_v: pupil_v, - self._pupil_v, - ) + self._pupil_v = _zero_if_needed_on_set(self._pupil_u, self._pupil_v) @property def pupil_v(self): @@ -385,13 +369,7 @@ def pupil_v(self): @pupil_v.setter def pupil_v(self, value): self._pupil_v = self._pupil_v.at[:].set(value) - self._pupil_u = jax.lax.cond( - jnp.any(jnp.isfinite(self._pupil_v)) - & jnp.all(~jnp.isfinite(self._pupil_u)), - lambda pupil_u: jnp.zeros_like(pupil_u), - lambda pupil_u: pupil_u, - self._pupil_u, - ) + self._pupil_u = _zero_if_needed_on_set(self._pupil_v, self._pupil_u) @property def time(self): @@ -1035,3 +1013,12 @@ def _add_photons_to_image(x, y, flux, xmin, ymin, arr): _arr = arr.at[yinds, xinds].add(_flux.astype(arr.dtype)) return _arr, _flux.sum() + + +def _zero_if_needed_on_set(arr_to_test, arr_to_zero): + return jax.lax.cond( + jnp.any(jnp.isfinite(arr_to_test)) & jnp.all(~jnp.isfinite(arr_to_zero)), + lambda atz: jnp.zeros_like(atz), + lambda atz: atz, + arr_to_zero, + ) From db36427ca687a21c7e6325e1356c725643bf3de8 Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 16 Dec 2023 17:06:10 -0600 Subject: [PATCH 83/85] REF refactor loops to make code easier to read --- jax_galsim/gsobject.py | 252 ++++++++++++++++------------------------- 1 file changed, 99 insertions(+), 153 deletions(-) diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index 11c62ea2..85c1bbc2 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -1,3 +1,4 @@ +from collections import namedtuple from functools import partial import galsim as _galsim @@ -1330,62 +1331,46 @@ def drawPhot( if maxN is None: # if neither maxN nor _JAX_GALSIM_PHOTON_ARRAY_SIZE are set # we drae Ntot photons all at once - ( - added_flux, - _image, - _sensor, - _photon_ops, - _rng, - _, - photons, - ) = _draw_phot_while_loop_shoot( - Ntot, - Ntot, - Ntot, - self, - rng, - g, - image, - photon_ops, - sensor, - orig_center, - local_wcs, - False, - 0.0, + _dfret = _draw_phot_while_loop_shoot( + maxN=Ntot, + thisN=Ntot, + Ntot=Ntot, + obj=self, + rng=rng, + g=g, + image=image, + photon_ops=photon_ops, + sensor=sensor, + orig_center=orig_center, + local_wcs=local_wcs, + resume=False, + added_flux=0.0, ) else: # if maxN or _JAX_GALSIM_PHOTON_ARRAY_SIZE is set # we draw a fixed number of photons at a time in a while # loop until we have drawn Ntot photons - ( - photons, - _rng, - added_flux, - _Nleft, - _image, - _photon_ops, - _sensor, - ) = _draw_phot_while_loop( - PhotonArray(maxN), - rng, - self, - image, - g, - Ntot, - maxN, - photon_ops, - local_wcs, - sensor, - orig_center, + _dfret = _draw_phot_while_loop( + photons=PhotonArray(maxN), + rng=rng, + obj=self, + image=image, + g=g, + Ntot=Ntot, + maxN=maxN, + photon_ops=photon_ops, + local_wcs=local_wcs, + sensor=sensor, + orig_center=orig_center, ) if rng is not None: - rng._state = _rng._state + rng._state = _dfret.rng._state else: - rng = _rng + rng = _dfret.rng for i in range(len(photon_ops)): - photon_ops[i] = _photon_ops[i] + photon_ops[i] = _dfret.photon_ops[i] - image._array = _image._array + image._array = _dfret.image._array # TODO: how to update the sensor? # https://github.com/GalSim-developers/JAX-GalSim/issues/85 @@ -1394,7 +1379,7 @@ def drawPhot( "Non-default sensors that carry state are not yet supported in jax-galsim." ) - return added_flux, photons + return _dfret.added_flux, _dfret.photons @_wraps(_galsim.GSObject.shoot) def shoot(self, n_photons, rng=None): @@ -1460,7 +1445,22 @@ def tree_unflatten(cls, aux_data, children): return cls(**(children[0]), **aux_data) +_DrawPhotReturnTuple = namedtuple( + "_DrawPhotReturnTuple", + [ + "photons", + "rng", + "added_flux", + "image", + "photon_ops", + "sensor", + "resume", + ], +) + + def _draw_phot_while_loop_shoot( + *, maxN, thisN, Ntot, @@ -1474,6 +1474,8 @@ def _draw_phot_while_loop_shoot( local_wcs, resume, added_flux, + Nleft=0, + photons=None, ): """This helper function shoots thisN photons and accumulates them into the image.""" try: @@ -1523,11 +1525,14 @@ def _draw_phot_while_loop_shoot( added_flux += sensor.accumulate(photons, im1, orig_center) image += im1 - return added_flux, image, sensor, photon_ops, rng, resume, photons + return _DrawPhotReturnTuple( + photons, rng, added_flux, image, photon_ops, sensor, resume + ) @partial(jax.jit, static_argnames=("maxN",)) def _draw_phot_while_loop( + *, photons, rng, obj, @@ -1542,116 +1547,57 @@ def _draw_phot_while_loop( ): """This helper function shoots photons until Ntot is reached.""" - def _cond_fun(args): - ( - photons, - rng, - added_flux, - obj, - Nleft, - resume, - image, - g, - photon_ops, - local_wcs, - sensor, - orig_center, - ) = args - return Nleft > 0 - - def _body_fun(args): - ( - photons, - rng, - added_flux, - obj, - Nleft, - resume, - image, - g, - photon_ops, - local_wcs, - sensor, - orig_center, - ) = args - # Shoot at most maxN at a time - thisN = jnp.minimum(maxN, Nleft) - - ( - _added_flux, - _image, - _sensor, - _photon_ops, - _rng, - _resume, - _photons, - ) = _draw_phot_while_loop_shoot( - maxN, - thisN, - Ntot, - obj, - rng, - g, - image, - photon_ops, - sensor, - orig_center, - local_wcs, - resume, - added_flux, - ) - - Nleft -= thisN + def _cond_fun(kwargs): + return kwargs["Nleft"] > 0 - return ( - _photons, - _rng, - _added_flux, - obj, - Nleft, - _resume, - _image, - g, - _photon_ops, - local_wcs, - _sensor, - orig_center, + def _body_fun(kwargs): + # Shoot at most maxN at a time + thisN = jnp.minimum(maxN, kwargs["Nleft"]) + + _dfret = _draw_phot_while_loop_shoot(maxN=maxN, thisN=thisN, **kwargs) + + return dict( + photons=_dfret.photons, + rng=_dfret.rng, + added_flux=_dfret.added_flux, + obj=kwargs["obj"], + Nleft=kwargs["Nleft"] - thisN, + resume=_dfret.resume, + image=_dfret.image, + g=kwargs["g"], + photon_ops=_dfret.photon_ops, + local_wcs=kwargs["local_wcs"], + sensor=_dfret.sensor, + orig_center=kwargs["orig_center"], + Ntot=kwargs["Ntot"], ) - added_flux = jnp.array(0) - Nleft = jnp.array(Ntot) - resume = jnp.array(False) - rng = BaseDeviate(rng) - ( - photons, - rng, - added_flux, - obj, - Nleft, - resume, - image, - g, - photon_ops, - local_wcs, - sensor, - orig_center, - ) = jax.lax.while_loop( + ret_kwargs = jax.lax.while_loop( _cond_fun, _body_fun, - ( - photons, - rng, - added_flux, - obj, - Nleft, - resume, - image, - g, - photon_ops, - local_wcs, - sensor, - orig_center, + dict( + photons=photons, + rng=BaseDeviate(rng), + added_flux=jnp.array(0), + obj=obj, + Nleft=jnp.array(Ntot), + resume=jnp.array(False), + image=image, + g=g, + photon_ops=photon_ops, + local_wcs=local_wcs, + sensor=sensor, + orig_center=orig_center, + Ntot=Ntot, ), ) - return photons, rng, added_flux, Nleft, image, photon_ops, sensor + return _DrawPhotReturnTuple( + ret_kwargs["photons"], + ret_kwargs["rng"], + ret_kwargs["added_flux"], + ret_kwargs["image"], + ret_kwargs["photon_ops"], + ret_kwargs["sensor"], + False, + ) From bc7726ceb64c4e98fa0292fd80350b49448595f8 Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 16 Dec 2023 17:16:43 -0600 Subject: [PATCH 84/85] ENH finish code review response --- jax_galsim/interpolant.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/jax_galsim/interpolant.py b/jax_galsim/interpolant.py index fd8b560c..7b1b3b16 100644 --- a/jax_galsim/interpolant.py +++ b/jax_galsim/interpolant.py @@ -294,6 +294,10 @@ def _shoot(self, photons, rng): photons.x = jnp.interp(ux, cdfx, x) photons.y = jnp.interp(uy, cdfx, x) if photons.size() > 0: + # remember we are using a product of 1D interpolants in eahc direction + # thus the total flux is an integral over x and y of the product of the + # 1D interpolants. Thus we square the total absolute flux of the 1D interpolants + # to get the total absolute flux of the 2D interpolant. flux_per_photon = ( self.positive_flux + self.negative_flux ) ** 2 / photons.size() From d06205aad08a09b866c5f75dacbba6f539c8e3a0 Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 16 Dec 2023 17:17:09 -0600 Subject: [PATCH 85/85] ENH finish code review response --- jax_galsim/interpolant.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_galsim/interpolant.py b/jax_galsim/interpolant.py index 7b1b3b16..50d472c1 100644 --- a/jax_galsim/interpolant.py +++ b/jax_galsim/interpolant.py @@ -294,7 +294,7 @@ def _shoot(self, photons, rng): photons.x = jnp.interp(ux, cdfx, x) photons.y = jnp.interp(uy, cdfx, x) if photons.size() > 0: - # remember we are using a product of 1D interpolants in eahc direction + # remember we are using a product of 1D interpolants in each direction # thus the total flux is an integral over x and y of the product of the # 1D interpolants. Thus we square the total absolute flux of the 1D interpolants # to get the total absolute flux of the 2D interpolant.