Skip to content

Commit

Permalink
Merge pull request #73 from beckermr/fix-tests-again
Browse files Browse the repository at this point in the history
TST fix CI tests after rebase on main
  • Loading branch information
beckermr authored Oct 19, 2023
2 parents e1c6c73 + ab8589e commit 13754a6
Show file tree
Hide file tree
Showing 14 changed files with 248 additions and 166 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/python_package.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ name: Python package

on:
push:
branches: ["main"]
branches:
- main
pull_request:
branches: ["main"]

jobs:
build:
Expand Down Expand Up @@ -42,4 +42,4 @@ jobs:
- name: Test with pytest
run: |
git submodule update --init --recursive
pytest
pytest --durations=0
4 changes: 2 additions & 2 deletions jax_galsim/core/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ def draw_by_kValue(gsobject, image, jacobian=jnp.eye(2)):
return Image(array=im, bounds=image.bounds, wcs=image.wcs, check_bounds=False)


def apply_kImage_phases(gsobject, image, jacobian=jnp.eye(2)):
def apply_kImage_phases(offset, image, jacobian=jnp.eye(2)):
# Create an array of coordinates
kcoords = jnp.stack(image.get_pixel_centers(), axis=-1)
kcoords = kcoords * image.scale # Scale by the image pixel scale
kcoords = jnp.dot(kcoords, jacobian)
cenx, ceny = gsobject._offset.x, gsobject._offset.y
cenx, ceny = offset.x, offset.y

#
# flux Exp(-i (kx cx + kxy cx + kyx cy + ky cy ) )
Expand Down
39 changes: 39 additions & 0 deletions jax_galsim/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import partial

import jax


Expand Down Expand Up @@ -61,3 +63,40 @@ def ensure_hashable(v):
return v
else:
return v


@partial(jax.jit, static_argnames=("niter",))
def bisect_for_root(func, low, high, niter=75):
def _func(i, args):
func, low, flow, high, fhigh = args
mid = (low + high) / 2.0
fmid = func(mid)
return jax.lax.cond(
fmid * fhigh < 0,
lambda func, low, flow, mid, fmid, high, fhigh: (
func,
mid,
fmid,
high,
fhigh,
),
lambda func, low, flow, mid, fmid, high, fhigh: (
func,
low,
flow,
mid,
fmid,
),
func,
low,
flow,
mid,
fmid,
high,
fhigh,
)

flow = func(low)
fhigh = func(high)
args = (func, low, flow, high, fhigh)
return jax.lax.fori_loop(0, niter, _func, args)[-2]
78 changes: 11 additions & 67 deletions jax_galsim/moffat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import galsim as _galsim
import jax
import jax.numpy as jnp
import jax.scipy as jsc
import tensorflow_probability as tfp
from jax._src.numpy.util import _wraps
from jax.tree_util import Partial as partial
Expand All @@ -10,8 +9,9 @@
from jax_galsim.core.bessel import j0
from jax_galsim.core.draw import draw_by_kValue, draw_by_xValue
from jax_galsim.core.integrate import ClenshawCurtisQuad, quad_integral
from jax_galsim.core.utils import ensure_hashable
from jax_galsim.core.utils import bisect_for_root, ensure_hashable
from jax_galsim.gsobject import GSObject
from jax_galsim.position import PositionD


@jax.jit
Expand Down Expand Up @@ -269,75 +269,21 @@ def __str__(self):
s += ")"
return s

@property
def _maxk_untrunc(self):
"""untruncated Moffat maxK
The 2D Fourier Transform of f(r)=C (1+(r/rd)^2)^(-beta) leads
C rd^2 = Flux (beta-1)/pi (no truc)
and
f(k) = C rd^2 int_0^infty (1+x^2)^(-beta) J_0(krd x) x dx
= 2 F (k rd /2)^(\beta-1) K[beta-1, k rd]/Gamma[beta-1]
with k->infty asymptotic behavior
f(k)/F \approx sqrt(pi)/Gamma(beta-1) e^(-k') (k'/2)^(beta -3/2) with k' = k rd
So we solve f(maxk)/F = thr (aka maxk_threshold in gsparams.py)
leading to the iterative search of
let alpha = -log(thr Gamma(beta-1)/sqrt(pi))
k = (\beta -3/2)log(k/2) + alpha
starting with k = alpha
note : in the code "alternative code" is related to issue #1208 in GalSim github
"""

def body(i, val):
kcur, alpha = val
knew = (self.beta - 0.5) * jnp.log(kcur) + alpha
# knew = (self.beta -1.5)* jnp.log(kcur/2) + alpha # alternative code
return knew, alpha

# alpha = -jnp.log(self.gsparams.maxk_threshold
# * jnp.exp(jsc.special.gammaln(self._beta-1))/jnp.sqrt(jnp.pi) ) # alternative code

alpha = -jnp.log(
self.gsparams.maxk_threshold
* jnp.power(2.0, self.beta - 0.5)
* jnp.exp(jsc.special.gammaln(self.beta - 1))
/ (2 * jnp.sqrt(jnp.pi))
)

val_init = (
alpha,
alpha,
)
val = jax.lax.fori_loop(0, 5, body, val_init)
maxk, alpha = val
return maxk / self._r0

@property
def _prefactor(self):
return 2.0 * (self.beta - 1.0) / (self._fluxFactor)

@property
def _maxk_trunc(self):
"""truncated Moffat maxK"""
# a for gaussian profile... this is f(k_max)/Flux = maxk_threshold
maxk_val = self.gsparams.maxk_threshold
dk = self.gsparams.table_spacing * jnp.sqrt(
jnp.sqrt(self.gsparams.kvalue_accuracy / 10.0)
@jax.jit
def _maxk_func(self, k):
return (
jnp.abs(self._kValue(PositionD(x=k, y=0)).real / self.flux)
- self.gsparams.maxk_threshold
)
# 50 is a max (GalSim) but it may be lowered if necessary
ki = jnp.arange(0.0, 50.0, dk)
quad = ClenshawCurtisQuad.init(150)
g = partial(_xMoffatIntegrant, beta=self.beta, rmax=self._maxRrD, quad=quad)
fki_1 = jax.jit(jax.vmap(g))(ki)
fki = fki_1 * self._prefactor
cond = jnp.abs(fki) > maxk_val
maxk = ki[cond][-1]
return maxk / self._r0

@property
@jax.jit
def _maxk(self):
return jax.lax.select(self.trunc > 0, self._maxk_trunc, self._maxk_untrunc)
return bisect_for_root(partial(self._maxk_func), 0.0, 1e5, niter=75)

@property
def _stepk_lowbeta(self):
Expand All @@ -353,12 +299,10 @@ def _stepk_highbeta(self):
jnp.power(self.gsparams.folding_threshold, 0.5 / (1.0 - self.beta))
* self._r0
)
if R > self._maxR:
R = self._maxR
R = jnp.minimum(R, self._maxR)
# at least R should be 5 HLR
R5hlr = self.gsparams.stepk_minimum_hlr * self.half_light_radius
if R < R5hlr:
R = R5hlr
R = jnp.maximum(R, R5hlr)
return jnp.pi / R

@property
Expand Down
2 changes: 1 addition & 1 deletion jax_galsim/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def _drawKImage(self, image, jac=None):
image = self._original._drawKImage(image, jac1)

_jac = jnp.eye(2) if jac is None else jac
image = apply_kImage_phases(self, image, _jac)
image = apply_kImage_phases(self.offset, image, _jac)

image = image * self._flux_scaling
return image
Expand Down
12 changes: 12 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,15 @@ skip = [
"tests/Galsim/",
"tests/Coord/",
]

[tool.pytest.ini_options]
minversion = "6.0"
addopts = "-ra -q"
testpaths = [
"tests/GalSim/tests/",
"tests/jax",
"tests/Coord/tests/",
]
filterwarnings = [
"ignore::DeprecationWarning",
]
10 changes: 0 additions & 10 deletions pytest.ini

This file was deleted.

2 changes: 1 addition & 1 deletion tests/GalSim
Submodule GalSim updated 1139 files
26 changes: 24 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,42 @@
import inspect
import os
import sys
from functools import lru_cache
from unittest.mock import patch

import galsim
import pytest
import yaml

# Define the accuracy for running the tests
from jax.config import config

import jax_galsim

config.update("jax_enable_x64", True)

# Identify the path to this current file
test_directory = os.path.dirname(os.path.abspath(__file__))


# Loading which tests to run
with open(os.path.join(test_directory, "galsim_tests_config.yaml"), "r") as f:
test_config = yaml.safe_load(f)

# we need to patch the galsim utilities check_pickle function
# to use jax_galsim. it has an import inside a function so
# we patch sys.modules.
# see https://stackoverflow.com/questions/34213088/mocking-a-module-imported-inside-of-a-function
orig_check_pickle = galsim.utilities.check_pickle
orig_check_pickle.__globals__["BaseDeviate"] = jax_galsim.BaseDeviate


def _check_pickle(*args, **kwargs):
with patch.dict(sys.modules, {"galsim": jax_galsim}):
return orig_check_pickle(*args, **kwargs)


galsim.utilities.check_pickle = _check_pickle


def pytest_ignore_collect(collection_path, path, config):
"""This hook will skip collecting tests that are not
Expand Down Expand Up @@ -104,7 +123,10 @@ def pytest_pycollect_makemodule(module_path, path, parent):
if (
callable(v)
and hasattr(v, "__globals__")
and inspect.getsourcefile(v).endswith("galsim_test_helpers.py")
and (
inspect.getsourcefile(v).endswith("galsim_test_helpers.py")
or inspect.getsourcefile(v).endswith("galsim/utilities.py")
)
and _infile("def " + k, inspect.getsourcefile(v))
and "galsim" in v.__globals__
):
Expand Down
Loading

0 comments on commit 13754a6

Please sign in to comment.