Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH add photon shooting #82

Merged
merged 95 commits into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
95 commits
Select commit Hold shift + click to select a range
cc4d715
ENH add photon shooting
beckermr Nov 13, 2023
d7c2969
Merge branch 'main' into photon-array
beckermr Nov 13, 2023
bd703e2
Update jax_galsim/random.py
beckermr Nov 13, 2023
a2d0cfb
Update jax_galsim/random.py
beckermr Nov 13, 2023
bcb1eca
Apply suggestions from code review
beckermr Nov 13, 2023
9c311ff
Update jax_galsim/random.py
beckermr Nov 13, 2023
2e61945
STY blacken
beckermr Nov 13, 2023
a791f13
ENH finished photon shooting for interpolated images
beckermr Nov 14, 2023
c51c144
STY blacken
beckermr Nov 14, 2023
dfbf59f
tests of jit
beckermr Nov 15, 2023
36ccd9f
STY please the flake8
beckermr Nov 15, 2023
2d45b6e
Merge branch 'main' into photon-array
beckermr Nov 15, 2023
289001b
TST jit w/ photon shooting
beckermr Nov 17, 2023
f7ae8aa
REF change repr just a bit
beckermr Nov 17, 2023
42c4ddb
BUG wrong pixel location
beckermr Nov 17, 2023
fe73df2
REF more jit in photon ops
beckermr Nov 17, 2023
8694436
TST add sum to jit test
beckermr Nov 17, 2023
026a03d
STY blacken
beckermr Nov 17, 2023
d9a9d4c
Update jax_galsim/core/utils.py
beckermr Nov 17, 2023
7cc9384
Update jax_galsim/core/utils.py
beckermr Nov 17, 2023
110442f
Update jax_galsim/core/utils.py
beckermr Nov 17, 2023
0cb4de9
Update jax_galsim/noise.py
beckermr Nov 17, 2023
f2ea9b2
REF simpler
beckermr Nov 18, 2023
eada628
Merge branch 'photon-array' of https://github.com/beckermr/JAX-GalSim…
beckermr Nov 18, 2023
1416586
DOC added comments
beckermr Nov 18, 2023
ca3c02c
TST added test for raising
beckermr Nov 18, 2023
8a7a8f3
TST add test of shooting in jit
beckermr Nov 18, 2023
37047c3
TST simpler test code
beckermr Nov 18, 2023
65995b2
BUG make sure tests do not raise even if they fail
beckermr Nov 18, 2023
28a6174
DOC added comments
beckermr Nov 18, 2023
05558ae
DOC added comments
beckermr Nov 18, 2023
ba61a08
TST make test fail
beckermr Nov 18, 2023
8d4bc55
TST added tests of offsets
beckermr Nov 18, 2023
802c1c1
PERF extra jitting
beckermr Nov 18, 2023
454f692
WIP enable fixed size photon arrays
beckermr Nov 19, 2023
3fdd163
REF a lot of changes to enable vmap, clean up APIs, etc.
beckermr Nov 20, 2023
9be573f
ENH use higher order integration method
beckermr Nov 20, 2023
63e40d5
Merge branch 'main' into photon-array
beckermr Nov 21, 2023
de1aad9
Merge pull request #2 from beckermr/photon-fixed
beckermr Nov 21, 2023
8f3331e
TST enable rest of galsim test suit
beckermr Nov 26, 2023
2ed69af
update to latest test suite
beckermr Nov 26, 2023
4f19dd9
merged
beckermr Nov 26, 2023
a8e669a
STY blacken
beckermr Nov 26, 2023
19357b3
TST next round of test fixes
beckermr Nov 27, 2023
68b2f7e
REF get rid of warning
beckermr Nov 27, 2023
9b9f030
TST new test suite
beckermr Nov 27, 2023
40167b9
TST new test suite
beckermr Nov 27, 2023
b222992
TST new test suite and better deprecation warnings
beckermr Nov 27, 2023
a87846e
TST new test suite
beckermr Nov 27, 2023
fd00e94
TST update test suite
beckermr Nov 27, 2023
325c91e
Merge pull request #3 from beckermr/more-tests
beckermr Nov 28, 2023
05ece3a
REF put bounds back
beckermr Nov 28, 2023
309daa5
Merge pull request #4 from beckermr/more-tests
beckermr Nov 28, 2023
f4573c6
REF centralize float and int casts
beckermr Nov 28, 2023
f562a13
REF remove dead code
beckermr Nov 28, 2023
6c74671
DOC fix typo in doc string
beckermr Nov 28, 2023
785b026
Apply suggestions from code review
beckermr Nov 28, 2023
1bc758c
Apply suggestions from code review
beckermr Nov 28, 2023
a2eddb2
Apply suggestions from code review
beckermr Nov 28, 2023
17892b8
REF make sure tracing does not raise
beckermr Nov 28, 2023
093ef3f
Apply suggestions from code review
beckermr Nov 28, 2023
86bee62
Apply suggestions from code review
beckermr Nov 28, 2023
44fd8da
Apply suggestions from code review
beckermr Nov 28, 2023
d88f994
Apply suggestions from code review
beckermr Nov 28, 2023
a2b4dd7
Apply suggestions from code review
beckermr Nov 28, 2023
328df69
Apply suggestions from code review
beckermr Nov 28, 2023
0d25200
TST comment out old code
beckermr Nov 28, 2023
5e46f4d
Merge branch 'photon-array' of https://github.com/beckermr/JAX-GalSim…
beckermr Nov 28, 2023
560add8
Apply suggestions from code review
beckermr Nov 28, 2023
8ac318d
TST remove duplicate test
beckermr Nov 28, 2023
5336e79
Merge branch 'photon-array' of https://github.com/beckermr/JAX-GalSim…
beckermr Nov 28, 2023
c9f4c8c
TST added tests of seeding
beckermr Nov 28, 2023
03084fb
BUG make sure we can return photons and added flux
beckermr Nov 29, 2023
4c24af7
BUG remove workspace when pickling
beckermr Nov 29, 2023
effaa1d
BUG handle maxN with fixed array sizes correctly
beckermr Nov 29, 2023
bcc0e29
BUG compute minN properly
beckermr Nov 29, 2023
0d16f99
TST added tests of APIs
beckermr Nov 29, 2023
d9da575
TST add API tests for sensors
beckermr Nov 29, 2023
c85fe94
ENH comments and tests
beckermr Dec 1, 2023
05b61da
STY blacken
beckermr Dec 1, 2023
888ba98
DOC add docs for routines for n photons
beckermr Dec 2, 2023
e3f37b9
BUG hashable type for jit
beckermr Dec 2, 2023
5aaf435
DOC update change log
beckermr Dec 2, 2023
c8dfa26
DOC update change log
beckermr Dec 2, 2023
595482e
ENH first pass at code review response
beckermr Dec 13, 2023
2dd3ee3
Update tests/jax/test_ref_impl.py
beckermr Dec 15, 2023
abb9e38
Apply suggestions from code review
beckermr Dec 15, 2023
abda85c
Update tests/jax/test_photon_shooting_jax.py
beckermr Dec 15, 2023
89ac8fb
STY blacken
beckermr Dec 16, 2023
4d366a6
ENH respond to more CR
beckermr Dec 16, 2023
7f34324
ENH respond to more CR
beckermr Dec 16, 2023
7b60ab9
refactor repeated logicx
beckermr Dec 16, 2023
db36427
REF refactor loops to make code easier to read
beckermr Dec 16, 2023
bc7726c
ENH finish code review response
beckermr Dec 16, 2023
d06205a
ENH finish code review response
beckermr Dec 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions jax_galsim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from .sum import Add, Sum
from .transform import Transform, Transformation
from .convolve import Convolve, Convolution, Deconvolution, Deconvolve
from .deltafunction import DeltaFunction

# WCS
from .wcs import (
Expand All @@ -68,6 +69,7 @@
)
from .fits import FitsHeader
from .celestial import CelestialCoord
from .fitswcs import TanWCS, FitsWCS, GSFitsWCS

# Shear
from .shear import Shear, _Shear
Expand All @@ -85,6 +87,10 @@
)
from .interpolatedimage import InterpolatedImage, _InterpolatedImage

# Photon Shooting
from .photon_array import PhotonArray
from .sensor import Sensor

# packages kept separate
from . import bessel
from . import fits
Expand Down
6 changes: 3 additions & 3 deletions jax_galsim/angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from jax._src.numpy.util import _wraps
from jax.tree_util import register_pytree_node_class

from jax_galsim.core.utils import cast_scalar_to_float, ensure_hashable
from jax_galsim.core.utils import cast_to_array_scalar, ensure_hashable


@_wraps(_galsim.AngleUnit)
Expand All @@ -34,7 +34,7 @@ def __init__(self, value):
"""
:param value: The measure of the unit in radians.
"""
self._value = cast_scalar_to_float(value)
self._value = cast_to_array_scalar(value, dtype=float)

@property
def value(self):
Expand Down Expand Up @@ -142,7 +142,7 @@ def __init__(self, theta, unit=None):
raise TypeError("Invalid unit %s of type %s" % (unit, type(unit)))
else:
# Normal case
self._rad = cast_scalar_to_float(theta) * unit.value
self._rad = cast_to_array_scalar(theta, dtype=float) * unit.value

@property
def rad(self):
Expand Down
22 changes: 9 additions & 13 deletions jax_galsim/bounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,7 @@
from jax._src.numpy.util import _wraps
from jax.tree_util import register_pytree_node_class

from jax_galsim.core.utils import (
cast_scalar_to_float,
cast_scalar_to_int,
ensure_hashable,
)
from jax_galsim.core.utils import cast_to_float, cast_to_int, ensure_hashable
from jax_galsim.position import Position, PositionD, PositionI


Expand Down Expand Up @@ -264,10 +260,10 @@ class BoundsD(Bounds):

def __init__(self, *args, **kwargs):
self._parse_args(*args, **kwargs)
self.xmin = cast_scalar_to_float(self.xmin)
self.xmax = cast_scalar_to_float(self.xmax)
self.ymin = cast_scalar_to_float(self.ymin)
self.ymax = cast_scalar_to_float(self.ymax)
self.xmin = cast_to_float(self.xmin)
self.xmax = cast_to_float(self.xmax)
self.ymin = cast_to_float(self.ymin)
self.ymax = cast_to_float(self.ymax)

def _check_scalar(self, x, name):
try:
Expand Down Expand Up @@ -298,10 +294,10 @@ class BoundsI(Bounds):

def __init__(self, *args, **kwargs):
self._parse_args(*args, **kwargs)
self.xmin = cast_scalar_to_int(self.xmin)
self.xmax = cast_scalar_to_int(self.xmax)
self.ymin = cast_scalar_to_int(self.ymin)
self.ymax = cast_scalar_to_int(self.ymax)
self.xmin = cast_to_int(self.xmin)
self.xmax = cast_to_int(self.xmax)
self.ymin = cast_to_int(self.ymin)
self.ymax = cast_to_int(self.ymax)

def _check_scalar(self, x, name):
try:
Expand Down
10 changes: 10 additions & 0 deletions jax_galsim/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from jax_galsim.core.draw import draw_by_kValue, draw_by_xValue
from jax_galsim.core.utils import ensure_hashable
from jax_galsim.gsobject import GSObject
from jax_galsim.random import UniformDeviate


@_wraps(_galsim.Box)
Expand Down Expand Up @@ -115,6 +116,15 @@ def tree_unflatten(cls, aux_data, children):
**aux_data
)

@_wraps(_galsim.Box._shoot)
def _shoot(self, photons, rng):
ud = UniformDeviate(rng)

# this does not fill arrays like in galsim
photons.x = (ud.generate(photons.x) - 0.5) * self.width
photons.y = (ud.generate(photons.y) - 0.5) * self.height
photons.flux = self.flux / photons.size()


@_wraps(_galsim.Pixel)
@register_pytree_node_class
Expand Down
11 changes: 10 additions & 1 deletion jax_galsim/convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from jax_galsim.gsobject import GSObject
from jax_galsim.gsparams import GSParams
from jax_galsim.photon_array import PhotonArray


@_wraps(
Expand Down Expand Up @@ -308,7 +309,15 @@ def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0):
raise NotImplementedError("Real-space convolutions are not implemented")

def _shoot(self, photons, rng):
raise NotImplementedError("Photon shooting convolutions are not implemented")
self.obj_list[0]._shoot(photons, rng)
beckermr marked this conversation as resolved.
Show resolved Hide resolved
# It may be necessary to shuffle when convolving because we do not have a
# gaurantee that the convolvee's photons are uncorrelated, e.g., they might
beckermr marked this conversation as resolved.
Show resolved Hide resolved
# both have their negative ones at the end.
# However, this decision is now made by the convolve method.
for obj in self.obj_list[1:]:
p1 = PhotonArray(len(photons))
obj._shoot(p1, rng)
photons.convolve(p1, rng)

def _drawKImage(self, image, jac=None):
image = self.obj_list[0]._drawKImage(image, jac)
Expand Down
Loading