Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH add InterpolatedImage #60

Merged
merged 84 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
c622679
WIP start adding interpolated images
beckermr Sep 17, 2023
4c07bc9
Merge branch 'main' into iimage
beckermr Sep 19, 2023
5b8638e
Merge branch 'main' into iimage
beckermr Sep 20, 2023
c32c8fa
ENH k-space wrapping for hermitian images
beckermr Sep 22, 2023
c1e653b
Update tests/jax/galsim/test_shear_position_jax.py
beckermr Sep 22, 2023
8225500
Update tests/jax/galsim/test_shear_position_jax.py
beckermr Sep 22, 2023
1c8e3e2
TST clean up the test suite a bit
beckermr Sep 22, 2023
35b3164
Update tests/jax/test_image_wrapping.py
beckermr Sep 22, 2023
991db78
STY blacken
beckermr Sep 22, 2023
ca38852
TST add test of rev-mode autodiff
beckermr Sep 22, 2023
4940d77
Merge branch 'main' into iimage
beckermr Sep 22, 2023
48729e8
Merge branch 'main' into iimage
ismael-mendoza Sep 22, 2023
bb261f3
Merge branch 'main' into iimage
ismael-mendoza Sep 24, 2023
579c151
WIP getting closer
beckermr Sep 25, 2023
ce080f0
merged
beckermr Sep 25, 2023
f537ace
STY black, sort, and flake
beckermr Sep 25, 2023
cad9dc3
WIP what exists so far
beckermr Sep 27, 2023
37c8f25
WIP it works
beckermr Sep 27, 2023
8438433
WIP it works maybe
beckermr Sep 29, 2023
47c06a7
new submodule
beckermr Oct 3, 2023
bdbf688
merged
beckermr Oct 3, 2023
a2c54db
TST fix the tests
beckermr Oct 3, 2023
cbb8659
ENH enable all tests
beckermr Oct 3, 2023
9a0e5ab
Merge branch 'main' into iimage
beckermr Oct 5, 2023
8138f08
merged
beckermr Oct 6, 2023
f5098a6
merged
beckermr Oct 6, 2023
ac8129d
Merge branch 'angle' into iimage
beckermr Oct 6, 2023
df356ab
Merge branch 'main' into iimage
beckermr Oct 14, 2023
58a0416
ENH update submodules
beckermr Oct 14, 2023
92cb2c4
ENH update submodule
beckermr Oct 16, 2023
395708e
try this for ci
beckermr Oct 16, 2023
eae7ca5
try this
beckermr Oct 16, 2023
b17e85e
try this
beckermr Oct 16, 2023
48c29fa
try this
beckermr Oct 16, 2023
13585ec
updte tests
beckermr Oct 16, 2023
1596518
try thids
beckermr Oct 16, 2023
ec81ca6
update submodule
beckermr Oct 16, 2023
8a33a70
WIP tryign to get stuff to pass
beckermr Oct 17, 2023
fde774f
merge
beckermr Oct 20, 2023
48a409d
update submodule
beckermr Oct 20, 2023
21db995
reset transofrm to current main
beckermr Oct 20, 2023
6a000e7
TST do not ignore these errors
beckermr Oct 20, 2023
8bac999
back out changes to test file
beckermr Oct 20, 2023
7b97e9c
Update tests/galsim_tests_config.yaml
beckermr Oct 20, 2023
c2cff10
STY blacken
beckermr Oct 20, 2023
0bb6e70
TST use more specific error
beckermr Oct 20, 2023
f9cb0e7
TST use more specific error
beckermr Oct 20, 2023
42344f9
TST make tests pass
beckermr Oct 20, 2023
9e0df20
Merge branch 'main' into iimage
beckermr Oct 20, 2023
d52c79c
TST make tests pass and refactor a bit
beckermr Oct 20, 2023
186f620
Update jax_galsim/core/utils.py
beckermr Oct 21, 2023
e44a658
ENH it works omg
beckermr Oct 23, 2023
9fd9865
Merge branch 'iimage' of https://github.com/beckermr/JAX-GalSim into …
beckermr Oct 23, 2023
619ea10
Update jax_galsim/transform.py
beckermr Oct 23, 2023
4014d5f
Update jax_galsim/interpolatedimage.py
beckermr Oct 23, 2023
47b0280
BUG cache kids
beckermr Oct 23, 2023
caa5dbb
Merge branch 'iimage' of https://github.com/beckermr/JAX-GalSim into …
beckermr Oct 23, 2023
898ca57
remove extra function
beckermr Oct 23, 2023
7ca8f1e
BUG make sure to have cached nz_bounds
beckermr Oct 23, 2023
6b359f7
fix another bug
beckermr Oct 24, 2023
d5ec4ef
TST add test of metacal
beckermr Oct 24, 2023
c4b1b14
TST add metacal tests
beckermr Oct 24, 2023
1581523
remove cache
beckermr Oct 25, 2023
20bc706
TST add metacal tests
beckermr Oct 25, 2023
5037a50
REF use vectorized drawing directly without vmap
beckermr Oct 25, 2023
8e0669f
REF use lazy property
beckermr Oct 26, 2023
b9c9f08
TST add extra jit for testing
beckermr Oct 26, 2023
75c332b
REF do not use lazy property here
beckermr Oct 26, 2023
69be657
REF do not cache items which have gradients
beckermr Oct 26, 2023
d38086d
ENH add lazy property decorator with explicit workspace
beckermr Oct 27, 2023
326de37
ENH add lazy property decorator with explicit workspace
beckermr Oct 27, 2023
13edd60
PERF faster tests
beckermr Oct 28, 2023
4138b03
TST make tests faster by skipping some at random
beckermr Oct 29, 2023
b478eb5
TST faster tests
beckermr Oct 29, 2023
32d49d2
Update test_metacal.py
beckermr Oct 29, 2023
0410565
Update test_metacal.py
beckermr Oct 29, 2023
cfb6770
TST metacal passes
beckermr Nov 1, 2023
6129483
merged
beckermr Nov 1, 2023
bf2c17d
STY isort
beckermr Nov 1, 2023
c48ba92
TST add test of fwd mode too
beckermr Nov 1, 2023
cd444a8
ENH respond to code review
beckermr Nov 9, 2023
cb94ce6
ENH respond to CR
beckermr Nov 9, 2023
0e86571
ENH respond to CR
beckermr Nov 9, 2023
e0fd4ce
DOC typo in comment
beckermr Nov 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
* `Transformation`
* `Shear`
* `Convolve`
* `InterpolatedImage` and `Interpolant`
* Added implementation of fundamental operations:
* `drawImage`
* `drawReal`
Expand All @@ -24,10 +25,5 @@
* Added a `from_galsim` method to convert from GalSim objects to JAX-GalSim objects

* Caveats
* Currently the FFT convolution does not perform kwrapping of hermitian images,
so it will lead to erroneous results on underesolved images that need k-space wrapping.
Wrapping for real images is implemented. K-space images arise from doing convolutions
via FFTs and so one would expect that underresolved images with convolutions may not be
rendered as accurately.
* Real space convolution and photon shooting methods are not
yet implemented in drawImage.
8 changes: 5 additions & 3 deletions jax_galsim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,14 @@
from .gaussian import Gaussian
from .box import Box, Pixel
from .gsobject import GSObject

# Interpolation
from .moffat import Moffat

from .sum import Add, Sum
from .transform import Transform, Transformation
from .convolve import Convolve, Convolution, Deconvolution, Deconvolve

# WCS
from .wcs import (
BaseWCS,
AffineTransform,
JacobianWCS,
OffsetWCS,
Expand All @@ -77,7 +75,11 @@
Quintic,
Lanczos,
)
from .interpolatedimage import InterpolatedImage, _InterpolatedImage

# packages kept separate
from . import bessel
from . import fits

# this one is specific to jax_galsim
from . import core
8 changes: 4 additions & 4 deletions jax_galsim/convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def __init__(self, *args, **kwargs):
# Save the construction parameters (as they are at this point) as attributes so they
# can be inspected later if necessary.
if bool(real_space):
raise NotImplementedError("Real space convolutions are not implemented")
raise NotImplementedError("Real-space convolutions are not implemented")
self._real_space = bool(real_space)

# Figure out what gsparams to use
Expand Down Expand Up @@ -296,7 +296,7 @@ def _max_sb(self):
return self.flux / jnp.sum(jnp.array(area_list))

def _xValue(self, pos):
raise NotImplementedError("Not implemented")
raise NotImplementedError("Real-space convolutions are not implemented")

def _kValue(self, kpos):
kv_list = [
Expand All @@ -305,10 +305,10 @@ def _kValue(self, kpos):
return jnp.prod(jnp.array(kv_list))

def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0):
raise NotImplementedError("Not implemented")
raise NotImplementedError("Real-space convolutions are not implemented")

def _shoot(self, photons, rng):
raise NotImplementedError("Not implemented")
raise NotImplementedError("Photon shooting convolutions are not implemented")

def _drawKImage(self, image, jac=None):
image = self.obj_list[0]._drawKImage(image, jac)
Expand Down
58 changes: 58 additions & 0 deletions jax_galsim/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
from functools import partial

import jax
import jax.numpy as jnp


@jax.jit
def compute_major_minor_from_jacobian(jac):
beckermr marked this conversation as resolved.
Show resolved Hide resolved
h1 = jnp.hypot(jac[0, 0] + jac[1, 1], jac[0, 1] - jac[1, 0])
h2 = jnp.hypot(jac[0, 0] - jac[1, 1], jac[0, 1] + jac[1, 0])
major = 0.5 * jnp.abs(h1 + h2)
minor = 0.5 * jnp.abs(h1 - h2)
return major, minor


def convert_to_float(x):
Expand Down Expand Up @@ -43,6 +53,54 @@ def cast_scalar_to_int(x):
return x


def is_equal_with_arrays(x, y):
"""Return True if the data is equal, False otherwise. Handles jax.Array types."""
if isinstance(x, list):
if isinstance(y, list) and len(x) == len(y):
for vx, vy in zip(x, y):
if not is_equal_with_arrays(vx, vy):
return False
return True
else:
return False
elif isinstance(x, tuple):
if isinstance(y, tuple) and len(x) == len(y):
for vx, vy in zip(x, y):
if not is_equal_with_arrays(vx, vy):
return False
return True
else:
return False
elif isinstance(x, set):
if isinstance(y, set) and len(x) == len(y):
for vx, vy in zip(sorted(x), sorted(y)):
if not is_equal_with_arrays(vx, vy):
return False
return True
else:
return False
elif isinstance(x, dict):
if isinstance(y, dict) and len(x) == len(y):
for kx, vx in x.items():
if kx not in y or (not is_equal_with_arrays(vx, y[kx])):
return False
return True
else:
return False
elif isinstance(x, jax.Array) and jnp.ndim(x) > 0:
if isinstance(y, jax.Array) and y.shape == x.shape:
return jnp.array_equal(x, y)
else:
return False
elif (isinstance(x, jax.Array) and jnp.ndim(x) == 0) or (
isinstance(y, jax.Array) and jnp.ndim(y) == 0
):
# this case covers comparing an array scalar to a python scalar or vice versa
return jnp.array_equal(x, y)
else:
return x == y


def _recurse_list_to_tuple(x):
if isinstance(x, list):
return tuple(_recurse_list_to_tuple(v) for v in x)
Expand Down
106 changes: 105 additions & 1 deletion jax_galsim/core/wrap_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


@jax.jit
def wrap_nonhermition(im, xmin, ymin, nxwrap, nywrap):
def wrap_nonhermitian(im, xmin, ymin, nxwrap, nywrap):
def _body_j(j, vals):
i, im = vals

Expand Down Expand Up @@ -33,3 +33,107 @@ def _body_i(i, vals):

im = jax.lax.fori_loop(0, im.shape[0], _body_i, im)
return im


@jax.jit
def expand_hermitian_x(im):
return jnp.concatenate([im[:, 1:][::-1, ::-1].conjugate(), im], axis=1)


@jax.jit
def contract_hermitian_x(im):
return im[:, im.shape[1] // 2 :]


@jax.jit
def wrap_hermitian_x(im, im_xmin, im_ymin, wrap_xmin, wrap_ymin, wrap_nx, wrap_ny):
im_exp = expand_hermitian_x(im)
im_exp = wrap_nonhermitian(
im_exp, wrap_xmin - im_xmin, wrap_ymin - im_ymin, wrap_nx, wrap_ny
)
return contract_hermitian_x(im_exp)


@jax.jit
def expand_hermitian_y(im):
return jnp.concatenate([im[1:, :][::-1, ::-1].conjugate(), im], axis=0)


@jax.jit
def contract_hermitian_y(im):
return im[im.shape[0] // 2 :, :]


@jax.jit
def wrap_hermitian_y(im, im_xmin, im_ymin, wrap_xmin, wrap_ymin, wrap_nx, wrap_ny):
im_exp = expand_hermitian_y(im)
im_exp = wrap_nonhermitian(
im_exp, wrap_xmin - im_xmin, wrap_ymin - im_ymin, wrap_nx, wrap_ny
)
return contract_hermitian_y(im_exp)


# I am leaving this code here for posterity. It has a bug that I cannot find.
# It tries to be more clever instead of simply expanding the hermitian image to
# it's full shape, wrapping everything, and then contracting. -MRB
# @jax.jit
# def wrap_hermitian_x(im, im_xmin, im_ymin, wrap_xmin, wrap_ymin, wrap_nx, wrap_ny):
# def _body_j(j, vals):
# i, im = vals

# # first do zero or positive x freq
# im_y = i + im_ymin
# im_x = j + im_xmin
# wrap_y = (im_y - wrap_ymin) % wrap_ny + wrap_ymin
# wrap_x = (im_x - wrap_xmin) % wrap_nx + wrap_xmin
# wrap_yind = wrap_y - im_ymin
# wrap_xind = wrap_x - im_xmin
# im = jax.lax.cond(
# wrap_xind >= 0,
# lambda wrap_x, im_x, wrap_y, im_y, im, wrap_yind, wrap_xind: jax.lax.cond(
# jnp.abs(wrap_x - im_x) + jnp.abs(wrap_y - im_y) != 0,
# lambda im, wrap_yind, wrap_xind: im.at[wrap_yind, wrap_xind].add(im[i, j]),
# lambda im, wrap_yind, wrap_xind: im,
# im,
# wrap_yind,
# wrap_xind,
# ),
# lambda wrap_x, im_x, wrap_y, im_y, im, wrap_yind, wrap_xind: im,
# wrap_x, im_x, wrap_y, im_y, im, wrap_yind, wrap_xind,
# )

# # now do neg x freq
# im_y = -im_y
# im_x = -im_x
# wrap_y = (im_y - wrap_ymin) % wrap_ny + wrap_ymin
# wrap_x = (im_x - wrap_xmin) % wrap_nx + wrap_xmin
# wrap_yind = wrap_y - im_ymin
# wrap_xind = wrap_x - im_xmin
# im = jax.lax.cond(
# im_x != 0,
# lambda wrap_x, im_x, wrap_y, im_y, im, wrap_yind, wrap_xind: jax.lax.cond(
# wrap_xind >= 0,
# lambda wrap_x, im_x, wrap_y, im_y, im, wrap_yind, wrap_xind: jax.lax.cond(
# (jnp.abs(wrap_x - im_x) + jnp.abs(wrap_y - im_y)) != 0,
# lambda im, wrap_yind, wrap_xind: im.at[wrap_yind, wrap_xind].add(im[i, j].conjugate()),
# lambda im, wrap_yind, wrap_xind: im,
# im,
# wrap_yind,
# wrap_xind,
# ),
# lambda wrap_x, im_x, wrap_y, im_y, im, wrap_yind, wrap_xind: im,
# wrap_x, im_x, wrap_y, im_y, im, wrap_yind, wrap_xind,
# ),
# lambda wrap_x, im_x, wrap_y, im_y, im, wrap_yind, wrap_xind: im,
# wrap_x, im_x, wrap_y, im_y, im, wrap_yind, wrap_xind,
# )

# return [i, im]

# def _body_i(i, vals):
# im = vals
# _, im = jax.lax.fori_loop(0, im.shape[1], _body_j, [i, im])
# return im

# im = jax.lax.fori_loop(0, im.shape[0], _body_i, im)
# return im
5 changes: 3 additions & 2 deletions jax_galsim/gsobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
from jax._src.numpy.util import _wraps

from jax_galsim.core.utils import is_equal_with_arrays
from jax_galsim.gsparams import GSParams
from jax_galsim.position import Position, PositionD, PositionI
from jax_galsim.utilities import parse_pos_args
Expand Down Expand Up @@ -178,7 +179,7 @@ def __neg__(self):
def __eq__(self, other):
return (self is other) or (
(type(other) is self.__class__)
and (self.tree_flatten() == other.tree_flatten())
and is_equal_with_arrays(self.tree_flatten(), other.tree_flatten())
)

@_wraps(_galsim.GSObject.xValue)
Expand Down Expand Up @@ -771,7 +772,7 @@ def drawFFT_makeKImage(self, image):
with jax.ensure_compile_time_eval():
Nk = self.gsparams.maximum_fft_size
N = Nk
dk = 2.0 * np.pi / (N * image.scale)
dk = 2.0 * np.pi / (N * image.scale)
beckermr marked this conversation as resolved.
Show resolved Hide resolved
else:
# Start with what this profile thinks a good size would be given the image's pixel scale.
N = self.getGoodImageSize(image.scale)
Expand Down
Loading