Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH add Interpolant classes #56

Merged
merged 44 commits into from
Oct 20, 2023
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
1421167
WIP add interpolant classes
beckermr Sep 13, 2023
c500962
STY blacken
beckermr Sep 13, 2023
a5b0f99
ENH add cubic
beckermr Sep 13, 2023
3adb74d
WIP start Lanczos
beckermr Sep 13, 2023
f6ef7ca
Merge branch 'main' into interp
beckermr Sep 14, 2023
4d3bc8b
ENH finish lanczos
beckermr Sep 14, 2023
afe7d8a
Merge branch 'interp' of https://github.com/beckermr/JAX-GalSim into …
beckermr Sep 14, 2023
e823cc8
STY make code a bit nicer
beckermr Sep 14, 2023
af65350
Update tests/galsim_tests_config.yaml
beckermr Sep 14, 2023
2f5a611
remove this
beckermr Sep 14, 2023
8c04db0
Update tests/conftest.py
beckermr Sep 14, 2023
c523ecd
Update jax_galsim/gsparams.py
beckermr Sep 14, 2023
8e56143
Update jax_galsim/gsobject.py
beckermr Sep 14, 2023
77d50f8
still need this
beckermr Sep 14, 2023
cab7290
revert this
beckermr Sep 14, 2023
0913bd7
Merge branch 'eval-repr' into interp
beckermr Sep 14, 2023
b8cf765
Merge branch 'main' into interp
beckermr Sep 14, 2023
a5bdf87
REF properly use pytree stuff
beckermr Sep 15, 2023
8e4d8e5
Merge branch 'interp' of https://github.com/beckermr/JAX-GalSim into …
beckermr Sep 15, 2023
6e1763e
Update jax_galsim/interpolant.py
beckermr Sep 15, 2023
baf1bab
BUG get sinc fluxes right
beckermr Sep 15, 2023
c4e22a3
Merge branch 'interp' of https://github.com/beckermr/JAX-GalSim into …
beckermr Sep 15, 2023
5e85f0f
BUG cannot return a pure python scalar
beckermr Sep 15, 2023
8c56481
TST a few more tests
beckermr Sep 15, 2023
a77c55d
ENH add optional interpolation
beckermr Sep 16, 2023
cbb8b4f
PERF fine tune JIT for performance
beckermr Sep 16, 2023
f9791bc
PERF tune things a bit
beckermr Sep 16, 2023
63c36f5
Update jax_galsim/interpolant.py
beckermr Sep 16, 2023
93433bb
REF no arrays in pytree
beckermr Sep 16, 2023
3f87885
REF no arrays in pytree
beckermr Sep 16, 2023
f196ab0
Merge branch 'interp' of https://github.com/beckermr/JAX-GalSim into …
beckermr Sep 16, 2023
0bfa625
PERF always jit pure functions
beckermr Sep 17, 2023
02dfd2c
Update jax_galsim/interpolant.py
beckermr Sep 17, 2023
c0767a5
Merge branch 'main' into interp
beckermr Sep 19, 2023
5ae7403
Merge branch 'main' into interp
beckermr Sep 20, 2023
333cc57
Merge branch 'main' into interp
ismael-mendoza Sep 22, 2023
8cb2d0b
Merge branch 'main' into interp
ismael-mendoza Sep 22, 2023
9e805ab
Merge branch 'main' into interp
ismael-mendoza Sep 24, 2023
5c52ba9
Merge branch 'main' into interp
beckermr Oct 5, 2023
8357d32
latest tests
beckermr Oct 6, 2023
2573eba
ENH respond to CR
beckermr Oct 6, 2023
1ea3e5b
STY blacken, sorten, etc
beckermr Oct 6, 2023
78f5c66
Merge branch 'main' into interp
beckermr Oct 14, 2023
6e42af2
Merge branch 'main' into interp
ismael-mendoza Oct 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions jax_galsim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,19 @@

# Shear
from .shear import Shear, _Shear

# Interpolations
from .interpolant import (
Interpolant,
Delta,
Nearest,
SincInterpolant,
Linear,
Cubic,
Quintic,
Lanczos,
)

# packages kept separate
from . import bessel
from . import fits
104 changes: 104 additions & 0 deletions jax_galsim/bessel.py
ismael-mendoza marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import jax
import jax.numpy as jnp
from jax._src.numpy.util import _wraps

import galsim as _galsim


# the code here for Si, f, g and _si_small_pade is taken from galsim/src/math/Sinc.cpp
@jax.jit
def _f_pade(x, x2):
# fmt: off
y = 1. / x2
f = (
(1. + # noqa: W504, E126, E226
y*(7.44437068161936700618e2 + # noqa: W504, E126, E226
y*(1.96396372895146869801e5 + # noqa: W504, E126, E226
y*(2.37750310125431834034e7 + # noqa: W504, E126, E226
y*(1.43073403821274636888e9 + # noqa: W504, E126, E226
y*(4.33736238870432522765e10 + # noqa: W504, E126, E226
y*(6.40533830574022022911e11 + # noqa: W504, E126, E226
y*(4.20968180571076940208e12 + # noqa: W504, E126, E226
y*(1.00795182980368574617e13 + # noqa: W504, E126, E226
y*(4.94816688199951963482e12 + # noqa: W504, E126, E226
y*(-4.94701168645415959931e11))))))))))) # noqa: W504, E126, E226
/ (x*(1. + # noqa: W504, E126, E226
y*(7.46437068161927678031e2 + # noqa: W504, E126, E226
y*(1.97865247031583951450e5 + # noqa: W504, E126, E226
y*(2.41535670165126845144e7 + # noqa: W504, E126, E226
y*(1.47478952192985464958e9 + # noqa: W504, E126, E226
y*(4.58595115847765779830e10 + # noqa: W504, E126, E226
y*(7.08501308149515401563e11 + # noqa: W504, E126, E226
y*(5.06084464593475076774e12 + # noqa: W504, E126, E226
y*(1.43468549171581016479e13 + # noqa: W504, E126, E226
y*(1.11535493509914254097e13))))))))))) # noqa: W504, E126, E226
)
# fmt: on
return f


@jax.jit
def _g_pade(x, x2):
# fmt: off
y = 1. / x2
g = (
y*(1. + # noqa: W504, E126, E226
y*(8.1359520115168615e2 + # noqa: W504, E126, E226
y*(2.35239181626478200e5 + # noqa: W504, E126, E226
y*(3.12557570795778731e7 + # noqa: W504, E126, E226
y*(2.06297595146763354e9 + # noqa: W504, E126, E226
y*(6.83052205423625007e10 + # noqa: W504, E126, E226
y*(1.09049528450362786e12 + # noqa: W504, E126, E226
y*(7.57664583257834349e12 + # noqa: W504, E126, E226
y*(1.81004487464664575e13 + # noqa: W504, E126, E226
y*(6.43291613143049485e12 + # noqa: W504, E126, E226
y*(-1.36517137670871689e12))))))))))) # noqa: W504, E126, E226
/ (1. + # noqa: W504, E126, E226
y*(8.19595201151451564e2 + # noqa: W504, E126, E226
y*(2.40036752835578777e5 + # noqa: W504, E126, E226
y*(3.26026661647090822e7 + # noqa: W504, E126, E226
y*(2.23355543278099360e9 + # noqa: W504, E126, E226
y*(7.87465017341829930e10 + # noqa: W504, E126, E226
y*(1.39866710696414565e12 + # noqa: W504, E126, E226
y*(1.17164723371736605e13 + # noqa: W504, E126, E226
y*(4.01839087307656620e13 + # noqa: W504, E126, E226
y*(3.99653257887490811e13)))))))))) # noqa: W504, E126, E226
)
# fmt: on
return g


@jax.jit
def _si_small_pade(x, x2):
# fmt: off
return (
x*(1. + # noqa: W504, E126, E226
x2*(-4.54393409816329991e-2 + # noqa: W504, E126, E226
x2*(1.15457225751016682e-3 + # noqa: W504, E126, E226
x2*(-1.41018536821330254e-5 + # noqa: W504, E126, E226
x2*(9.43280809438713025e-8 + # noqa: W504, E126, E226
x2*(-3.53201978997168357e-10 + # noqa: W504, E126, E226
x2*(7.08240282274875911e-13 + # noqa: W504, E126, E226
x2*(-6.05338212010422477e-16)))))))) # noqa: W504, E126, E226
/ (1. + # noqa: W504, E126, E226
x2*(1.01162145739225565e-2 + # noqa: W504, E126, E226
x2*(4.99175116169755106e-5 + # noqa: W504, E126, E226
x2*(1.55654986308745614e-7 + # noqa: W504, E126, E226
x2*(3.28067571055789734e-10 + # noqa: W504, E126, E226
x2*(4.5049097575386581e-13 + # noqa: W504, E126, E226
x2*(3.21107051193712168e-16))))))) # noqa: W504, E126, E226
)
# fmt: on


@_wraps(_galsim.bessel.si)
@jax.jit
def si(x):
x2 = x * x
return jnp.where(
x2 > 16.0,
jnp.sign(x) * (jnp.pi / 2)
- _f_pade(x, x2) * jnp.cos(x)
- _g_pade(x, x2) * jnp.sin(x),
_si_small_pade(x, x2),
)
Loading
Loading