Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH enable noise fields and fix drawing bugz #79

Merged
merged 11 commits into from
Nov 15, 2023
8 changes: 8 additions & 0 deletions jax_galsim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@
WeibullDeviate,
BinomialDeviate,
)
from .noise import (
BaseNoise,
GaussianNoise,
DeviateNoise,
PoissonNoise,
VariableGaussianNoise,
CCDNoise,
)

# Basic building blocks
from .bounds import Bounds, BoundsD, BoundsI
Expand Down
62 changes: 43 additions & 19 deletions jax_galsim/gsobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,13 +594,23 @@ def _determine_wcs(self, scale, wcs, image, default_wcs=None):
wcs = image.wcs

# If the input scale <= 0, or wcs is still None at this point, then use the Nyquist scale:
# TODO: we will need to remove this test of scale for jitting
# if wcs is None or (wcs.isPixelScale and wcs.scale <= 0):
if wcs is None:
if default_wcs is None:
wcs = PixelScale(self.nyquist_scale)
else:
wcs = default_wcs

if wcs.isPixelScale() and wcs.isLocal():
wcs = jax.lax.cond(
wcs.scale <= 0,
lambda wcs, nqs: PixelScale(jnp.float_(nqs))
if default_wcs is None
else default_wcs,
lambda wcs, nqs: PixelScale(jnp.float_(wcs.scale)),
wcs,
self.nyquist_scale,
)

return wcs

@_wraps(_galsim.GSObject.drawImage)
Expand Down Expand Up @@ -636,8 +646,12 @@ def drawImage(
):
from jax_galsim.box import Pixel
from jax_galsim.convolve import Convolve
from jax_galsim.image import Image
from jax_galsim.wcs import PixelScale

if image is not None and not isinstance(image, Image):
raise TypeError("image is not an Image instance", image)

# Figure out what wcs we are going to use.
wcs = self._determine_wcs(scale, wcs, image)

Expand Down Expand Up @@ -710,9 +724,9 @@ def drawImage(
image.wcs = PixelScale(1.0)

if prof.is_analytic_x:
added_photons, image = prof.drawReal(image, add_to_image)
added_photons = prof.drawReal(image, add_to_image)
else:
added_photons, image = prof.drawFFT(image, add_to_image)
added_photons = prof.drawFFT(image, add_to_image)

image.added_flux = added_photons / flux_scale
# Restore the original center and wcs
Expand All @@ -722,21 +736,25 @@ def drawImage(
# Update image_in to satisfy GalSim API
image_in._array = image._array
image_in.added_flux = image.added_flux
image_in._bounds = image.bounds
image_in._bounds = image._bounds
image_in.wcs = image.wcs
image_in._dtype = image._dtype
return image

@_wraps(_galsim.GSObject.drawReal)
def drawReal(self, image, add_to_image=False):
if image.wcs is None or not image.wcs.isPixelScale:
if image.wcs is None or not image.wcs.isPixelScale():
raise _galsim.GalSimValueError(
"drawReal requires an image with a PixelScale wcs", image
)
im1 = self._drawReal(image)
temp = im1.subImage(image.bounds)
if add_to_image:
return im1.array.sum(dtype=float), image + im1
image._array = image._array + temp._array
else:
return im1.array.sum(dtype=float), im1
image._array = temp._array

return temp.array.sum(dtype=float)

def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0):
"""A version of `drawReal` without the sanity checks or some options.
Expand Down Expand Up @@ -856,11 +874,11 @@ def drawFFT_finish(self, image, kimage, wrap_size, add_to_image):
# Add (a portion of) this to the original image.
temp = real_image.subImage(image.bounds)
if add_to_image:
image += temp
image._array = image._array + temp._array
else:
image = temp
added_photons = temp.array.sum(dtype=float)
return added_photons, image
image._array = temp._array

return temp.array.sum(dtype=float)

def drawFFT(self, image, add_to_image=False):
"""
Expand Down Expand Up @@ -888,7 +906,7 @@ def drawFFT(self, image, add_to_image=False):
Returns:
The total flux drawn inside the image bounds.
"""
if image.wcs is None or not image.wcs.isPixelScale:
if image.wcs is None or not image.wcs.isPixelScale():
raise _galsim.GalSimValueError(
"drawFFT requires an image with a PixelScale wcs", image
)
Expand Down Expand Up @@ -985,16 +1003,22 @@ def drawKImage(

if setup_only:
return image

# For GalSim compatibility, we will attempt to update the input image
image_in = image
if not add_to_image and image.iscontiguous:
image = self._drawKImage(image)
im2 = Image(bounds=image.bounds, dtype=image.dtype, scale=image.scale)
im2 = self._drawKImage(im2)

if not add_to_image:
image._array = im2._array
else:
im2 = Image(bounds=image.bounds, dtype=image.dtype, scale=image.scale)
im2 = self._drawKImage(im2)
image += im2
image._array = im2._array + image._array

image_in._array = image._array
image_in._bounds = image.bounds
image_in._bounds = image._bounds
image_in.wcs = image.wcs
image_in._dtype = image._dtype

return image

@_wraps(_galsim.GSObject._drawKImage)
Expand Down
36 changes: 27 additions & 9 deletions jax_galsim/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,20 +110,23 @@ def __init__(self, *args, **kwargs):

# Check that we got them all
if kwargs:
if "copy" in kwargs.keys():
if "copy" in kwargs.keys() and not kwargs["copy"]:
raise TypeError(
"'copy' is not a valid keyword argument for the JAX-GalSim version of the Image constructor"
"'copy=False' is not a valid keyword argument for the JAX-GalSim version of the Image constructor"
)
else:
# remove it since we used it
kwargs.pop("copy", None)

if "make_const" in kwargs.keys():
raise TypeError(
"'make_const' is not a valid keyword argument for the JAX-GalSim version of the Image constructor"
)
raise TypeError(
"Image constructor got unexpected keyword arguments: %s", kwargs
)
raise TypeError(
"Image constructor got unexpected keyword arguments: %s", kwargs
)

if kwargs:
raise TypeError(
"Image constructor got unexpected keyword arguments: %s", kwargs
)

# Figure out what dtype we want:
dtype = self._alias_dtypes.get(dtype, dtype)
Expand All @@ -150,6 +153,14 @@ def __init__(self, *args, **kwargs):
if not array.dtype.isnative:
array = array.astype(array.dtype.newbyteorder("="))
self._dtype = array.dtype.type
elif image is not None:
if not isinstance(image, Image):
raise TypeError("image must be an Image")
# we do less checking here since we already have a valid image
if dtype is None:
self._dtype = image.dtype
else:
self._dtype = dtype
elif dtype is not None:
self._dtype = dtype
else:
Expand Down Expand Up @@ -669,7 +680,10 @@ def _wrap(self, bounds, hermx, hermy):

return self.subImage(bounds)

@_wraps(_galsim.Image.calculate_fft)
@_wraps(
_galsim.Image.calculate_fft,
lax_description="JAX-GalSim does not support forward FFTs of complex dtypes.",
)
def calculate_fft(self):
if not self.bounds.isDefined():
raise _galsim.GalSimUndefinedBoundsError(
Expand All @@ -681,6 +695,10 @@ def calculate_fft(self):
raise _galsim.GalSimError(
"calculate_fft requires that the image has a PixelScale wcs."
)
if self.dtype in [np.complex64, np.complex128, complex]:
raise _galsim.GalSimNotImplementedError(
"JAX-GalSim does not support forward FFTs of complex dtypes."
)

No2 = max(
max(
Expand Down
Loading