Skip to content

Commit

Permalink
Merge branch 'main' into interp
Browse files Browse the repository at this point in the history
  • Loading branch information
ismael-mendoza authored Oct 19, 2023
2 parents 78f5c66 + 13754a6 commit 6e42af2
Show file tree
Hide file tree
Showing 23 changed files with 3,551 additions and 331 deletions.
10 changes: 5 additions & 5 deletions .github/workflows/python_package.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@ name: Python package

on:
push:
branches: ["main"]
branches:
- main
pull_request:
branches: ["main"]

jobs:
build:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9"]
python-version: ["3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v3
Expand All @@ -27,7 +27,7 @@ jobs:
python -m pip install .
- name: Ensure black formatting
run: |
black --check jax_galsim/ tests/ --exclude tests/GalSim/
black --check jax_galsim/ tests/ --exclude "tests/GalSim/|tests/Coord/|tests/jax/galsim/"
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
Expand All @@ -42,4 +42,4 @@ jobs:
- name: Test with pytest
run: |
git submodule update --init --recursive
pytest
pytest --durations=0
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@ repos:
hooks:
- id: black
language: python
exclude: tests/GalSim/|tests/Coord/
exclude: tests/GalSim/|tests/Coord/|tests/jax/galsim/
- repo: https://github.com/pycqa/flake8
rev: 6.1.0
hooks:
- id: flake8
entry: pflake8
additional_dependencies: [pyproject-flake8]
exclude: tests/
exclude: tests/Galsim/|tests/Coord/|tests/jax/galsim/
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
exclude: tests/Galsim|tests/Coord/
exclude: tests/Galsim/|tests/Coord/|tests/jax/galsim/
38 changes: 19 additions & 19 deletions jax_galsim/__init__.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
# Inherit all Exception and Warning classes from galsim
from galsim.errors import (
GalSimBoundsError,
GalSimConfigError,
GalSimConfigValueError,
GalSimDeprecationWarning,
GalSimError,
GalSimFFTSizeError,
GalSimHSMError,
GalSimImmutableError,
GalSimIncompatibleValuesError,
GalSimIndexError,
GalSimKeyError,
GalSimNotImplementedError,
GalSimRangeError,
GalSimSEDError,
GalSimUndefinedBoundsError,
GalSimValueError,
GalSimWarning,
# Exception and Warning classes
from .errors import GalSimError, GalSimRangeError, GalSimValueError
from .errors import GalSimKeyError, GalSimIndexError, GalSimNotImplementedError
from .errors import GalSimBoundsError, GalSimUndefinedBoundsError, GalSimImmutableError
from .errors import GalSimIncompatibleValuesError, GalSimSEDError, GalSimHSMError
from .errors import GalSimFFTSizeError
from .errors import GalSimConfigError, GalSimConfigValueError
from .errors import GalSimWarning, GalSimDeprecationWarning

# noise
from .random import (
BaseDeviate,
UniformDeviate,
GaussianDeviate,
PoissonDeviate,
Chi2Deviate,
GammaDeviate,
WeibullDeviate,
BinomialDeviate,
)

# Basic building blocks
Expand Down
10 changes: 8 additions & 2 deletions jax_galsim/convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,14 +367,20 @@ def __init__(self, obj, gsparams=None, propagate_gsparams=True):

# Save the original object as an attribute, so it can be inspected later if necessary.
self._gsparams = GSParams.check(gsparams, obj.gsparams)
self._min_acc_kvalue = obj.flux * self.gsparams.kvalue_accuracy
self._inv_min_acc_kvalue = 1.0 / self._min_acc_kvalue
self._propagate_gsparams = propagate_gsparams
if self._propagate_gsparams:
self._orig_obj = obj.withGSParams(self._gsparams)
else:
self._orig_obj = obj

@property
def _min_acc_kvalue(self):
return self._orig_obj.flux * self.gsparams.kvalue_accuracy

@property
def _inv_min_acc_kvalue(self):
return 1.0 / self._min_acc_kvalue

@property
def orig_obj(self):
"""The original object that is being deconvolved."""
Expand Down
4 changes: 2 additions & 2 deletions jax_galsim/core/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ def draw_by_kValue(gsobject, image, jacobian=jnp.eye(2)):
return Image(array=im, bounds=image.bounds, wcs=image.wcs, check_bounds=False)


def apply_kImage_phases(gsobject, image, jacobian=jnp.eye(2)):
def apply_kImage_phases(offset, image, jacobian=jnp.eye(2)):
# Create an array of coordinates
kcoords = jnp.stack(image.get_pixel_centers(), axis=-1)
kcoords = kcoords * image.scale # Scale by the image pixel scale
kcoords = jnp.dot(kcoords, jacobian)
cenx, ceny = gsobject._offset.x, gsobject._offset.y
cenx, ceny = offset.x, offset.y

#
# flux Exp(-i (kx cx + kxy cx + kyx cy + ky cy ) )
Expand Down
53 changes: 51 additions & 2 deletions jax_galsim/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,22 @@
from functools import partial

import jax


def convert_to_float(x):
if isinstance(x, jax.Array):
if x.shape == ():
return x.item()
else:
return x[0].astype(float).item()
else:
return float(x)


def cast_scalar_to_float(x):
"""Cast the input to a float. Works on python floats and jax arrays."""
if isinstance(x, float):
return float(x)
if isinstance(x, jax.Array):
return x.astype(float)
elif hasattr(x, "astype"):
return x.astype(float)
else:
Expand Down Expand Up @@ -51,3 +63,40 @@ def ensure_hashable(v):
return v
else:
return v


@partial(jax.jit, static_argnames=("niter",))
def bisect_for_root(func, low, high, niter=75):
def _func(i, args):
func, low, flow, high, fhigh = args
mid = (low + high) / 2.0
fmid = func(mid)
return jax.lax.cond(
fmid * fhigh < 0,
lambda func, low, flow, mid, fmid, high, fhigh: (
func,
mid,
fmid,
high,
fhigh,
),
lambda func, low, flow, mid, fmid, high, fhigh: (
func,
low,
flow,
mid,
fmid,
),
func,
low,
flow,
mid,
fmid,
high,
fhigh,
)

flow = func(low)
fhigh = func(high)
args = (func, low, flow, high, fhigh)
return jax.lax.fori_loop(0, niter, _func, args)[-2]
19 changes: 19 additions & 0 deletions jax_galsim/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from galsim.errors import ( # noqa: F401
GalSimBoundsError,
GalSimConfigError,
GalSimConfigValueError,
GalSimDeprecationWarning,
GalSimError,
GalSimFFTSizeError,
GalSimHSMError,
GalSimImmutableError,
GalSimIncompatibleValuesError,
GalSimIndexError,
GalSimKeyError,
GalSimNotImplementedError,
GalSimRangeError,
GalSimSEDError,
GalSimUndefinedBoundsError,
GalSimValueError,
GalSimWarning,
)
78 changes: 11 additions & 67 deletions jax_galsim/moffat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import galsim as _galsim
import jax
import jax.numpy as jnp
import jax.scipy as jsc
import tensorflow_probability as tfp
from jax._src.numpy.util import _wraps
from jax.tree_util import Partial as partial
Expand All @@ -10,8 +9,9 @@
from jax_galsim.core.bessel import j0
from jax_galsim.core.draw import draw_by_kValue, draw_by_xValue
from jax_galsim.core.integrate import ClenshawCurtisQuad, quad_integral
from jax_galsim.core.utils import ensure_hashable
from jax_galsim.core.utils import bisect_for_root, ensure_hashable
from jax_galsim.gsobject import GSObject
from jax_galsim.position import PositionD


@jax.jit
Expand Down Expand Up @@ -269,75 +269,21 @@ def __str__(self):
s += ")"
return s

@property
def _maxk_untrunc(self):
"""untruncated Moffat maxK
The 2D Fourier Transform of f(r)=C (1+(r/rd)^2)^(-beta) leads
C rd^2 = Flux (beta-1)/pi (no truc)
and
f(k) = C rd^2 int_0^infty (1+x^2)^(-beta) J_0(krd x) x dx
= 2 F (k rd /2)^(\beta-1) K[beta-1, k rd]/Gamma[beta-1]
with k->infty asymptotic behavior
f(k)/F \approx sqrt(pi)/Gamma(beta-1) e^(-k') (k'/2)^(beta -3/2) with k' = k rd
So we solve f(maxk)/F = thr (aka maxk_threshold in gsparams.py)
leading to the iterative search of
let alpha = -log(thr Gamma(beta-1)/sqrt(pi))
k = (\beta -3/2)log(k/2) + alpha
starting with k = alpha
note : in the code "alternative code" is related to issue #1208 in GalSim github
"""

def body(i, val):
kcur, alpha = val
knew = (self.beta - 0.5) * jnp.log(kcur) + alpha
# knew = (self.beta -1.5)* jnp.log(kcur/2) + alpha # alternative code
return knew, alpha

# alpha = -jnp.log(self.gsparams.maxk_threshold
# * jnp.exp(jsc.special.gammaln(self._beta-1))/jnp.sqrt(jnp.pi) ) # alternative code

alpha = -jnp.log(
self.gsparams.maxk_threshold
* jnp.power(2.0, self.beta - 0.5)
* jnp.exp(jsc.special.gammaln(self.beta - 1))
/ (2 * jnp.sqrt(jnp.pi))
)

val_init = (
alpha,
alpha,
)
val = jax.lax.fori_loop(0, 5, body, val_init)
maxk, alpha = val
return maxk / self._r0

@property
def _prefactor(self):
return 2.0 * (self.beta - 1.0) / (self._fluxFactor)

@property
def _maxk_trunc(self):
"""truncated Moffat maxK"""
# a for gaussian profile... this is f(k_max)/Flux = maxk_threshold
maxk_val = self.gsparams.maxk_threshold
dk = self.gsparams.table_spacing * jnp.sqrt(
jnp.sqrt(self.gsparams.kvalue_accuracy / 10.0)
@jax.jit
def _maxk_func(self, k):
return (
jnp.abs(self._kValue(PositionD(x=k, y=0)).real / self.flux)
- self.gsparams.maxk_threshold
)
# 50 is a max (GalSim) but it may be lowered if necessary
ki = jnp.arange(0.0, 50.0, dk)
quad = ClenshawCurtisQuad.init(150)
g = partial(_xMoffatIntegrant, beta=self.beta, rmax=self._maxRrD, quad=quad)
fki_1 = jax.jit(jax.vmap(g))(ki)
fki = fki_1 * self._prefactor
cond = jnp.abs(fki) > maxk_val
maxk = ki[cond][-1]
return maxk / self._r0

@property
@jax.jit
def _maxk(self):
return jax.lax.select(self.trunc > 0, self._maxk_trunc, self._maxk_untrunc)
return bisect_for_root(partial(self._maxk_func), 0.0, 1e5, niter=75)

@property
def _stepk_lowbeta(self):
Expand All @@ -353,12 +299,10 @@ def _stepk_highbeta(self):
jnp.power(self.gsparams.folding_threshold, 0.5 / (1.0 - self.beta))
* self._r0
)
if R > self._maxR:
R = self._maxR
R = jnp.minimum(R, self._maxR)
# at least R should be 5 HLR
R5hlr = self.gsparams.stepk_minimum_hlr * self.half_light_radius
if R < R5hlr:
R = R5hlr
R = jnp.maximum(R, R5hlr)
return jnp.pi / R

@property
Expand Down
Loading

0 comments on commit 6e42af2

Please sign in to comment.