From e529ef7a4cd9061db501e55a1e4a01dd8ad90145 Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 17 Oct 2023 07:50:53 -0500 Subject: [PATCH 01/10] latest changes --- tests/GalSim | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/GalSim b/tests/GalSim index 66092bdf..ca90d938 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 66092bdf7215983bab4d2d953a700eb8a0ddcbe4 +Subproject commit ca90d938a3b16450b84452720068e0b558842bbb From 62d146e36e422bef9824925b5cc04bc378d5a8a9 Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 17 Oct 2023 14:43:46 -0500 Subject: [PATCH 02/10] TST fix more tests --- jax_galsim/core/utils.py | 55 ++++++- jax_galsim/moffat.py | 78 ++------- jax_galsim/transform.py | 67 +++++--- jax_galsim/wcs.py | 96 ++++++++++- pyproject.toml | 12 ++ pytest.ini | 10 -- tests/GalSim | 2 +- tests/conftest.py | 31 +++- tests/galsim_tests_config.yaml | 2 + tests/jax/galsim/test_image_jax.py | 80 ++++----- tests/jax/galsim/test_shear_jax.py | 2 +- tests/jax/galsim/test_wcs_jax.py | 238 ++++++++++++++------------- tests/jax/test_moffat_comp_galsim.py | 40 +++++ 13 files changed, 445 insertions(+), 268 deletions(-) delete mode 100644 pytest.ini create mode 100644 tests/jax/test_moffat_comp_galsim.py diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index f2e82df4..872912ec 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -1,10 +1,22 @@ +from functools import partial + import jax +def convert_to_float(x): + if isinstance(x, jax.Array): + if x.shape == (): + return x.item() + else: + return x[0].astype(float).item() + else: + return float(x) + + def cast_scalar_to_float(x): """Cast the input to a float. Works on python floats and jax arrays.""" - if isinstance(x, float): - return float(x) + if isinstance(x, jax.Array): + return x.astype(float) elif hasattr(x, "astype"): return x.astype(float) else: @@ -51,3 +63,42 @@ def ensure_hashable(v): return v else: return v + + +@partial(jax.jit, static_argnames=("niter",)) +def bisect_for_root(func, low, high, niter=75): + def _func(i, args): + func, low, flow, high, fhigh = args + mid = (low + high) / 2.0 + fmid = func(mid) + return jax.lax.cond( + fmid * fhigh < 0, + lambda func, low, flow, mid, fmid, high, fhigh: ( + func, + mid, + fmid, + high, + fhigh, + ), + lambda func, low, flow, mid, fmid, high, fhigh: ( + func, + low, + flow, + mid, + fmid, + ), + func, + low, + flow, + mid, + fmid, + high, + fhigh, + ) + + low = 0.0 + high = 1e5 + flow = func(low) + fhigh = func(high) + args = (func, low, flow, high, fhigh) + return jax.lax.fori_loop(0, niter, _func, args)[-2] diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index 09da2579..f9ec4839 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -1,7 +1,6 @@ import galsim as _galsim import jax import jax.numpy as jnp -import jax.scipy as jsc import tensorflow_probability as tfp from jax._src.numpy.util import _wraps from jax.tree_util import Partial as partial @@ -10,8 +9,9 @@ from jax_galsim.core.bessel import j0 from jax_galsim.core.draw import draw_by_kValue, draw_by_xValue from jax_galsim.core.integrate import ClenshawCurtisQuad, quad_integral -from jax_galsim.core.utils import ensure_hashable +from jax_galsim.core.utils import bisect_for_root, ensure_hashable from jax_galsim.gsobject import GSObject +from jax_galsim.position import PositionD @jax.jit @@ -269,75 +269,21 @@ def __str__(self): s += ")" return s - @property - def _maxk_untrunc(self): - """untruncated Moffat maxK - - The 2D Fourier Transform of f(r)=C (1+(r/rd)^2)^(-beta) leads - C rd^2 = Flux (beta-1)/pi (no truc) - and - f(k) = C rd^2 int_0^infty (1+x^2)^(-beta) J_0(krd x) x dx - = 2 F (k rd /2)^(\beta-1) K[beta-1, k rd]/Gamma[beta-1] - with k->infty asymptotic behavior - f(k)/F \approx sqrt(pi)/Gamma(beta-1) e^(-k') (k'/2)^(beta -3/2) with k' = k rd - So we solve f(maxk)/F = thr (aka maxk_threshold in gsparams.py) - leading to the iterative search of - let alpha = -log(thr Gamma(beta-1)/sqrt(pi)) - k = (\beta -3/2)log(k/2) + alpha - starting with k = alpha - - note : in the code "alternative code" is related to issue #1208 in GalSim github - """ - - def body(i, val): - kcur, alpha = val - knew = (self.beta - 0.5) * jnp.log(kcur) + alpha - # knew = (self.beta -1.5)* jnp.log(kcur/2) + alpha # alternative code - return knew, alpha - - # alpha = -jnp.log(self.gsparams.maxk_threshold - # * jnp.exp(jsc.special.gammaln(self._beta-1))/jnp.sqrt(jnp.pi) ) # alternative code - - alpha = -jnp.log( - self.gsparams.maxk_threshold - * jnp.power(2.0, self.beta - 0.5) - * jnp.exp(jsc.special.gammaln(self.beta - 1)) - / (2 * jnp.sqrt(jnp.pi)) - ) - - val_init = ( - alpha, - alpha, - ) - val = jax.lax.fori_loop(0, 5, body, val_init) - maxk, alpha = val - return maxk / self._r0 - @property def _prefactor(self): return 2.0 * (self.beta - 1.0) / (self._fluxFactor) - @property - def _maxk_trunc(self): - """truncated Moffat maxK""" - # a for gaussian profile... this is f(k_max)/Flux = maxk_threshold - maxk_val = self.gsparams.maxk_threshold - dk = self.gsparams.table_spacing * jnp.sqrt( - jnp.sqrt(self.gsparams.kvalue_accuracy / 10.0) + @jax.jit + def _maxk_func(self, k): + return ( + jnp.abs(self._kValue(PositionD(x=k, y=0)).real / self.flux) + - self.gsparams.maxk_threshold ) - # 50 is a max (GalSim) but it may be lowered if necessary - ki = jnp.arange(0.0, 50.0, dk) - quad = ClenshawCurtisQuad.init(150) - g = partial(_xMoffatIntegrant, beta=self.beta, rmax=self._maxRrD, quad=quad) - fki_1 = jax.jit(jax.vmap(g))(ki) - fki = fki_1 * self._prefactor - cond = jnp.abs(fki) > maxk_val - maxk = ki[cond][-1] - return maxk / self._r0 @property + @jax.jit def _maxk(self): - return jax.lax.select(self.trunc > 0, self._maxk_trunc, self._maxk_untrunc) + return bisect_for_root(partial(self._maxk_func), 0.0, 1e5, niter=75) @property def _stepk_lowbeta(self): @@ -353,12 +299,10 @@ def _stepk_highbeta(self): jnp.power(self.gsparams.folding_threshold, 0.5 / (1.0 - self.beta)) * self._r0 ) - if R > self._maxR: - R = self._maxR + R = jnp.minimum(R, self._maxR) # at least R should be 5 HLR R5hlr = self.gsparams.stepk_minimum_hlr * self.half_light_radius - if R < R5hlr: - R = R5hlr + R = jnp.maximum(R, R5hlr) return jnp.pi / R @property diff --git a/jax_galsim/transform.py b/jax_galsim/transform.py index 4dc2eac8..a0144018 100644 --- a/jax_galsim/transform.py +++ b/jax_galsim/transform.py @@ -1,4 +1,5 @@ import galsim as _galsim +import jax import jax.numpy as jnp from jax._src.numpy.util import _wraps from jax.tree_util import register_pytree_node_class @@ -47,34 +48,38 @@ def __init__( gsparams=None, propagate_gsparams=True, ): - self._offset = PositionD(offset) - self._flux_ratio = flux_ratio self._gsparams = GSParams.check(gsparams, obj.gsparams) self._propagate_gsparams = propagate_gsparams if self._propagate_gsparams: obj = obj.withGSParams(self._gsparams) self._params = { - "obj": obj, "jac": jac, - "offset": self._offset, - "flux_ratio": self._flux_ratio, + "offset": PositionD(offset), + "flux_ratio": flux_ratio, } if isinstance(obj, Transformation): # Combine the two affine transformations into one. dx, dy = self._fwd(obj.offset.x, obj.offset.y) - self._offset.x += dx - self._offset.y += dy + self._params["offset"].x += dx + self._params["offset"].y += dy self._params["jac"] = self._jac.dot(obj.jac) - self._flux_ratio *= obj._flux_ratio + self._params["flux_ratio"] *= obj._params["flux_ratio"] self._original = obj.original else: self._original = obj @property def _jac(self): - return jnp.asarray(self._params["jac"], dtype=float).reshape(2, 2) + jac = self._params["jac"] + jac = jax.lax.cond( + jac is not None, + lambda jac: jnp.broadcast_to(jnp.array(jac, dtype=float).ravel(), (4,)), + lambda jax: jnp.array([1.0, 0.0, 0.0, 1.0]), + jac, + ) + return jnp.asarray(jac, dtype=float).reshape(2, 2) @property def original(self): @@ -89,17 +94,21 @@ def jac(self): @property def offset(self): """The offset of the transformation.""" - return self._offset + return self._params["offset"] @property def flux_ratio(self): """The flux ratio of the transformation.""" - return self._flux_ratio + return self._params["flux_ratio"] @property def _flux(self): return self._flux_scaling * self._original.flux + @property + def _offset(self): + return self._params["offset"] + def withGSParams(self, gsparams=None, **kwargs): """Create a version of the current object with the given gsparams @@ -110,13 +119,14 @@ def withGSParams(self, gsparams=None, **kwargs): """ if gsparams == self.gsparams: return self - from copy import copy - ret = copy(self) - ret._gsparams = GSParams.check(gsparams, self.gsparams, **kwargs) + chld, aux = self.tree_flatten() + aux["gsparams"] = GSParams.check(gsparams, self.gsparams, **kwargs) if self._propagate_gsparams: - ret._original = self._original.withGSParams(ret._gsparams) - return ret + new_obj = chld[0].withGSParams(aux["gsparams"]) + chld = (new_obj,) + chld[1:] + + return self.tree_unflatten(aux, chld) def __eq__(self, other): return self is other or ( @@ -149,7 +159,7 @@ def __repr__(self): "propagate_gsparams=%r)" ) % ( self.original, - ensure_hashable(self._jac), + ensure_hashable(self._jac.ravel()), self.offset, ensure_hashable(self.flux_ratio), self.gsparams, @@ -221,11 +231,11 @@ def _invjac(self): # than flux_ratio, which is really an amplitude scaling. @property def _amp_scaling(self): - return self._flux_ratio + return self._params["flux_ratio"] @property def _flux_scaling(self): - return jnp.abs(self._det) * self._flux_ratio + return jnp.abs(self._det) * self._params["flux_ratio"] def _fwd(self, x, y): res = jnp.dot(self._jac, jnp.array([x, y])) @@ -240,8 +250,8 @@ def _inv(self, x, y): return res[0], res[1] def _kfactor(self, kx, ky): - kx *= -1j * self._offset.x - ky *= -1j * self._offset.y + kx *= -1j * self.offset.x + ky *= -1j * self.offset.y kx += ky return self._flux_scaling * jnp.exp(kx) @@ -269,7 +279,7 @@ def _stepk(self): # stepk = Pi/R # R <- R + |shift| # stepk <- Pi/(Pi/stepk + |shift|) - dr = jnp.hypot(self._offset.x, self._offset.y) + dr = jnp.hypot(self.offset.x, self.offset.y) stepk = jnp.pi / (jnp.pi / stepk + dr) return stepk @@ -283,7 +293,7 @@ def _is_axisymmetric(self): self._original.is_axisymmetric and self._jac[0, 0] == self._jac[1, 1] and self._jac[0, 1] == -self._jac[1, 0] - and self._offset == PositionD(0.0, 0.0) + and self.offset == PositionD(0.0, 0.0) ) @property @@ -314,7 +324,7 @@ def _max_sb(self): return self._amp_scaling * self._original.max_sb def _xValue(self, pos): - pos -= self._offset + pos -= self.offset inv_pos = PositionD(self._inv(pos.x, pos.y)) return self._original._xValue(inv_pos) * self._amp_scaling @@ -360,10 +370,10 @@ def _drawKImage(self, image, jac=None): return image def tree_flatten(self): - """This function flattens the GSObject into a list of children + """This function flattens the Transform into a list of children nodes that will be traced by JAX and auxiliary static data.""" # Define the children nodes of the PyTree that need tracing - children = (self.params,) + children = (self._original, self._params) # Define auxiliary static data that doesn’t need to be traced aux_data = { "gsparams": self.gsparams, @@ -371,6 +381,11 @@ def tree_flatten(self): } return (children, aux_data) + @classmethod + def tree_unflatten(cls, aux_data, children): + """Recreates an instance of the class from flatten representation""" + return cls(children[0], **(children[1]), **aux_data) + def _Transform( obj, diff --git a/jax_galsim/wcs.py b/jax_galsim/wcs.py index cb7320fa..3cebeeb0 100644 --- a/jax_galsim/wcs.py +++ b/jax_galsim/wcs.py @@ -3,7 +3,7 @@ from jax._src.numpy.util import _wraps from jax.tree_util import register_pytree_node_class -from jax_galsim.core.utils import ensure_hashable +from jax_galsim.core.utils import convert_to_float, ensure_hashable from jax_galsim.gsobject import GSObject from jax_galsim.position import Position, PositionD, PositionI from jax_galsim.shear import Shear @@ -18,6 +18,8 @@ def toWorld(self, *args, **kwargs): if len(args) == 1: if isinstance(args[0], GSObject): return self.profileToWorld(*args, **kwargs) + elif isinstance(args[0], Shear): + return self.shearToWorld(*args, **kwargs) else: return self.posToWorld(*args, **kwargs) elif len(args) == 2: @@ -52,11 +54,19 @@ def profileToWorld( image_profile, flux_ratio, PositionD(offset) ) + @_wraps(_galsim.BaseWCS.shearToWorld) + def shearToWorld(self, image_shear, image_pos=None, world_pos=None, color=None): + if color is None: + color = self._color + return self.local(image_pos, world_pos, color=color)._shearToWorld(image_shear) + @_wraps(_galsim.BaseWCS.toImage) def toImage(self, *args, **kwargs): if len(args) == 1: if isinstance(args[0], GSObject): return self.profileToImage(*args, **kwargs) + elif isinstance(args[0], Shear): + return self.shearToImage(*args, **kwargs) else: return self.posToImage(*args, **kwargs) elif len(args) == 2: @@ -94,6 +104,12 @@ def profileToImage( world_profile, flux_ratio, PositionD(offset) ) + @_wraps(_galsim.BaseWCS.shearToImage) + def shearToImage(self, world_shear, image_pos=None, world_pos=None, color=None): + if color is None: + color = self._color + return self.local(image_pos, world_pos, color=color)._shearToImage(world_shear) + @_wraps(_galsim.BaseWCS.local) def local(self, image_pos=None, world_pos=None, color=None): if color is None: @@ -622,6 +638,13 @@ def _profileToImage(self, world_profile, flux_ratio, offset): offset=offset, ) + def _shearToWorld(self, image_shear): + # These are trivial for PixelScale. + return image_shear + + def _shearToImage(self, world_shear): + return world_shear + def _pixelArea(self): return self._scale**2 @@ -728,6 +751,13 @@ def _profileToImage(self, world_profile, flux_ratio, offset): * flux_ratio ) + def _shearToWorld(self, image_shear): + # This isn't worth customizing. Just use the jacobian. + return self._toJacobian()._shearToWorld(image_shear) + + def _shearToImage(self, world_shear): + return self._toJacobian()._shearToImage(world_shear) + def _pixelArea(self): return self._scale**2 @@ -752,6 +782,13 @@ def _toJacobian(self): def _newOrigin(self, origin, world_origin): return OffsetShearWCS(self._scale, self._shear, origin, world_origin) + def _writeHeader(self, header, bounds): + header["GS_WCS"] = ("ShearWCS", "GalSim WCS name") + header["GS_SCALE"] = (self.scale, "GalSim image scale") + header["GS_G1"] = (self.shear.g1, "GalSim image shear g1") + header["GS_G2"] = (self.shear.g2, "GalSim image shear g2") + return self.affine()._writeLinearWCS(header, bounds) + def copy(self): return ShearWCS(self._scale, self._shear) @@ -846,6 +883,24 @@ def _profileToImage(self, world_profile, flux_ratio, offset): offset=offset, ) + def _shearToWorld(self, image_shear): + # Code from https://github.com/rmjarvis/DESWL/blob/y3a1-v23/psf/run_piff.py#L691 + e1 = image_shear.e1 + e2 = image_shear.e2 + + M = jnp.array([[1 + e1, e2], [e2, 1 - e1]]) + J = self.getMatrix() + M = J.dot(M).dot(J.T) + + e1 = (M[0, 0] - M[1, 1]) / (M[0, 0] + M[1, 1]) + e2 = (2.0 * M[0, 1]) / (M[0, 0] + M[1, 1]) + + return Shear(e1=e1, e2=e2) + + def _shearToImage(self, world_shear): + # Same as above but inverse J matrix. + return self._inverse()._shearToWorld(world_shear) + def _pixelArea(self): return abs(self._det) @@ -1096,6 +1151,17 @@ def world_origin(self): def _newOrigin(self, origin, world_origin): return OffsetShearWCS(self.scale, self.shear, origin, world_origin) + def _writeHeader(self, header, bounds): + header["GS_WCS"] = ("OffsetShearWCS", "GalSim WCS name") + header["GS_SCALE"] = (self.scale, "GalSim image scale") + header["GS_G1"] = (self.shear.g1, "GalSim image shear g1") + header["GS_G2"] = (self.shear.g2, "GalSim image shear g2") + header["GS_X0"] = (self.origin.x, "GalSim image origin x coordinate") + header["GS_Y0"] = (self.origin.y, "GalSim image origin y coordinate") + header["GS_U0"] = (self.world_origin.x, "GalSim world origin u coordinate") + header["GS_V0"] = (self.world_origin.y, "GalSim world origin v coordinate") + return self.affine()._writeLinearWCS(header, bounds) + def copy(self): return OffsetShearWCS(self.scale, self.shear, self.origin, self.world_origin) @@ -1173,14 +1239,26 @@ def _writeHeader(self, header, bounds): def _writeLinearWCS(self, header, bounds): header["CTYPE1"] = ("LINEAR", "name of the world coordinate axis") header["CTYPE2"] = ("LINEAR", "name of the world coordinate axis") - header["CRVAL1"] = (self.u0, "world coordinate at reference pixel = u0") - header["CRVAL2"] = (self.v0, "world coordinate at reference pixel = v0") - header["CRPIX1"] = (self.x0, "image coordinate of reference pixel = x0") - header["CRPIX2"] = (self.y0, "image coordinate of reference pixel = y0") - header["CD1_1"] = (self.dudx, "CD1_1 = dudx") - header["CD1_2"] = (self.dudy, "CD1_2 = dudy") - header["CD2_1"] = (self.dvdx, "CD2_1 = dvdx") - header["CD2_2"] = (self.dvdy, "CD2_2 = dvdy") + header["CRVAL1"] = ( + convert_to_float(self.u0), + "world coordinate at reference pixel = u0", + ) + header["CRVAL2"] = ( + convert_to_float(self.v0), + "world coordinate at reference pixel = v0", + ) + header["CRPIX1"] = ( + convert_to_float(self.x0), + "image coordinate of reference pixel = x0", + ) + header["CRPIX2"] = ( + convert_to_float(self.y0), + "image coordinate of reference pixel = y0", + ) + header["CD1_1"] = (convert_to_float(self.dudx), "CD1_1 = dudx") + header["CD1_2"] = (convert_to_float(self.dudy), "CD1_2 = dudy") + header["CD2_1"] = (convert_to_float(self.dvdx), "CD2_1 = dvdx") + header["CD2_2"] = (convert_to_float(self.dvdy), "CD2_2 = dvdy") return header @staticmethod diff --git a/pyproject.toml b/pyproject.toml index c6fe79ab..3f456117 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,3 +21,15 @@ skip = [ "tests/Galsim/", "tests/Coord/", ] + +[tool.pytest.ini_options] +minversion = "6.0" +addopts = "-ra -q" +testpaths = [ + "tests/GalSim/tests/", + "tests/jax", + "tests/Coord/tests/", +] +filterwarnings = [ + "ignore::DeprecationWarning", +] diff --git a/pytest.ini b/pytest.ini deleted file mode 100644 index d85bef3a..00000000 --- a/pytest.ini +++ /dev/null @@ -1,10 +0,0 @@ -# pytest.ini -[pytest] -minversion = 6.0 -addopts = -ra -q -testpaths = - tests/GalSim/tests/ - tests/jax - tests/Coord/tests/ -filterwarnings = - ignore::DeprecationWarning diff --git a/tests/GalSim b/tests/GalSim index ca90d938..81509041 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit ca90d938a3b16450b84452720068e0b558842bbb +Subproject commit 815090419d343d0e840bbc53e79c7bc4469ec79d diff --git a/tests/conftest.py b/tests/conftest.py index 8095105e..dfdc96cc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -68,6 +68,19 @@ def _infile(val, fname): return False +def _convert_galsim_to_jax_galsim(obj): + import galsim as _galsim # noqa: F401 + from numpy import array # noqa: F401 + + import jax_galsim as galsim # noqa: F401 + + if isinstance(obj, _galsim.GSObject): + ret_obj = eval(repr(obj)) + return ret_obj + else: + return obj + + def pytest_pycollect_makemodule(module_path, path, parent): """This hook is tasked with overriding the galsim import at the top of each test file. Replaces it by jax-galsim. @@ -91,7 +104,10 @@ def pytest_pycollect_makemodule(module_path, path, parent): if ( callable(v) and hasattr(v, "__globals__") - and inspect.getsourcefile(v).endswith("galsim_test_helpers.py") + and ( + inspect.getsourcefile(v).endswith("galsim_test_helpers.py") + or inspect.getsourcefile(v).endswith("galsim/utilities.py") + ) and _infile("def " + k, inspect.getsourcefile(v)) and "galsim" in v.__globals__ ): @@ -111,6 +127,19 @@ def pytest_pycollect_makemodule(module_path, path, parent): v.__globals__["coord"] = __import__("jax_galsim") v.__globals__["galsim"] = __import__("jax_galsim") + # the galsim WCS tests have some items that are galsim objects that need conversions + # to jax_galsim objects + if module.name.endswith("tests/GalSim/tests/test_wcs.py"): + for k, v in module.obj.__dict__.items(): + if isinstance(v, __import__("galsim").GSObject): + module.obj.__dict__[k] = _convert_galsim_to_jax_galsim(v) + elif isinstance(v, list): + module.obj.__dict__[k] = [ + _convert_galsim_to_jax_galsim(obj) for obj in v + ] + + module.obj._convert_galsim_to_jax_galsim = _convert_galsim_to_jax_galsim + return module diff --git a/tests/galsim_tests_config.yaml b/tests/galsim_tests_config.yaml index dd761691..25ab8225 100644 --- a/tests/galsim_tests_config.yaml +++ b/tests/galsim_tests_config.yaml @@ -58,3 +58,5 @@ allowed_failures: - "TypeError not raised by __mul__" - "ValueError not raised by CelestialCoord" - "module 'jax_galsim' has no attribute 'TanWCS'" + - "'Image' object has no attribute 'FindAdaptiveMom'" + - " module 'jax_galsim' has no attribute 'BaseCorrelatedNoise'" diff --git a/tests/jax/galsim/test_image_jax.py b/tests/jax/galsim/test_image_jax.py index 945d2da2..c14e87ca 100644 --- a/tests/jax/galsim/test_image_jax.py +++ b/tests/jax/galsim/test_image_jax.py @@ -509,22 +509,22 @@ def test_Image_basic(): # ------------------------- # We will not be doing pickles # Check picklability - # do_pickle(im1) - # do_pickle(im1_view) - # do_pickle(im2) - # do_pickle(im2_view) - # do_pickle(im2_cview) - # do_pickle(im3_view) - # do_pickle(im4_view) + # check_pickle(im1) + # check_pickle(im1_view) + # check_pickle(im2) + # check_pickle(im2_view) + # check_pickle(im2_cview) + # check_pickle(im3_view) + # check_pickle(im4_view) # JAX specific modification # ------------------------- # We will not be doing pickles # Also check picklability of Bounds, Position here. - # do_pickle(galsim.PositionI(2,3)) - # do_pickle(galsim.PositionD(2.2,3.3)) - # do_pickle(galsim.BoundsI(2,3,7,8)) - # do_pickle(galsim.BoundsD(2.1, 4.3, 6.5, 9.1)) + # check_pickle(galsim.PositionI(2,3)) + # check_pickle(galsim.PositionD(2.2,3.3)) + # check_pickle(galsim.BoundsI(2,3,7,8)) + # check_pickle(galsim.BoundsD(2.1, 4.3, 6.5, 9.1)) @timer @@ -632,10 +632,10 @@ def test_undefined_image(): # JAX specific modification # ------------------------- # We will not be doing pickles - # do_pickle(im1.bounds) - # do_pickle(im1) - # do_pickle(im1.view()) - # do_pickle(im1.view(make_const=True)) + # check_pickle(im1.bounds) + # check_pickle(im1) + # check_pickle(im1.view()) + # check_pickle(im1.view(make_const=True)) @timer @@ -2908,7 +2908,7 @@ def test_Image_subImage(): # JAX specific modification # ------------------------- # We won't do any pickling - # do_pickle(image) + # check_pickle(image) assert_raises(TypeError, image.subImage, bounds=None) assert_raises(TypeError, image.subImage, bounds=galsim.BoundsD(0, 4, 0, 4)) @@ -3035,9 +3035,9 @@ def test_Image_resize(): im3_full.array, 23, err_msg="im3_full changed" ) - do_pickle(im1) - do_pickle(im2) - do_pickle(im3) + check_pickle(im1) + check_pickle(im2) + check_pickle(im3) assert_raises(TypeError, im1.resize, bounds=None) assert_raises(TypeError, im1.resize, bounds=galsim.BoundsD(0, 5, 0, 5)) @@ -3083,7 +3083,7 @@ def test_Image_resize(): # assert_raises(galsim.GalSimImmutableError, image.setZero) # assert_raises(galsim.GalSimImmutableError, image.invertSelf) -# do_pickle(image) +# check_pickle(image) @timer @@ -3164,7 +3164,7 @@ def test_Image_constructor(): # JAX specific modification # ------------------------- # No picklibng for JAX images - # do_pickle(test_im) + # check_pickle(test_im) # Check that some invalid sets of construction args raise the appropriate errors # Invalid args @@ -3246,7 +3246,7 @@ def test_Image_view(): # JAX specific modification # ------------------------- # No picklibng for JAX images - # do_pickle(im) + # check_pickle(im) # Test view with no arguments imv = im.view() @@ -3262,8 +3262,8 @@ def test_Image_view(): # JAX specific modification # ------------------------- # No picklibng for JAX images - # do_pickle(im) - # do_pickle(imv) + # check_pickle(im) + # check_pickle(imv) # Test view with new origin imv = im.view(origin=(0, 0)) @@ -3288,8 +3288,8 @@ def test_Image_view(): # JAX specific modification # ------------------------- # No picklibng for JAX images - # do_pickle(imv) - # do_pickle(imv2) + # check_pickle(imv) + # check_pickle(imv2) # Test view with new center imv = im.view(center=(0, 0)) @@ -3316,8 +3316,8 @@ def test_Image_view(): # JAX specific modification # ------------------------- # No picklibng for JAX images - # do_pickle(imv) - # do_pickle(imv2) + # check_pickle(imv) + # check_pickle(imv2) # Test view with new scale imv = im.view(scale=0.17) @@ -3342,8 +3342,8 @@ def test_Image_view(): # JAX specific modification # ------------------------- # No picklibng for JAX images - # do_pickle(imv) - # do_pickle(imv2) + # check_pickle(imv) + # check_pickle(imv2) # Test view with new wcs imv = im.view(wcs=galsim.JacobianWCS(0.0, 0.23, -0.23, 0.0)) @@ -3365,8 +3365,8 @@ def test_Image_view(): # JAX specific modification # ------------------------- # No picklibng for JAX images - # do_pickle(imv) - # do_pickle(imv2) + # check_pickle(imv) + # check_pickle(imv2) # Go back to original value for that pixel and make sure all are still equal to 17 im.setValue(11, 19, 17) @@ -3427,7 +3427,7 @@ def test_ne(): galsim.ImageD(array1, wcs=galsim.PixelScale(0.2)), galsim.ImageD(array1, xmin=2), ] - all_obj_diff(objs) + check_all_diff(objs) @timer @@ -3653,13 +3653,13 @@ def test_complex_image(): # ------------------------- # No picklibng for JAX images # Check picklability - # do_pickle(im1) - # do_pickle(im1_view) - # do_pickle(im1_cview) - # do_pickle(im2) - # do_pickle(im2_view) - # do_pickle(im3_view) - # do_pickle(im4_view) + # check_pickle(im1) + # check_pickle(im1_view) + # check_pickle(im1_cview) + # check_pickle(im2) + # check_pickle(im2_view) + # check_pickle(im3_view) + # check_pickle(im4_view) @timer diff --git a/tests/jax/galsim/test_shear_jax.py b/tests/jax/galsim/test_shear_jax.py index 55a4fde1..890baea7 100644 --- a/tests/jax/galsim/test_shear_jax.py +++ b/tests/jax/galsim/test_shear_jax.py @@ -242,7 +242,7 @@ def test_shear_initialization(): # JAX specific modification # ------------------------- # We don't allow jax objects to be pickled. - # do_pickle(s) + # check_pickle(s) # finally check some examples of invalid initializations for Shear assert_raises(TypeError, galsim.Shear, 0.3) diff --git a/tests/jax/galsim/test_wcs_jax.py b/tests/jax/galsim/test_wcs_jax.py index 49f6d8d4..f2b7e791 100644 --- a/tests/jax/galsim/test_wcs_jax.py +++ b/tests/jax/galsim/test_wcs_jax.py @@ -5,9 +5,17 @@ import time import warnings -import galsim import numpy as np -from galsim_test_helpers import * +from galsim_test_helpers import ( + Profile, + assert_raises, + assert_warns, + check_pickle, + gsobject_compare, + timer, +) + +import jax_galsim as galsim # These positions will be used a few times below, so define them here. # One of the tests requires that the last pair are integers, so don't change that. @@ -476,7 +484,9 @@ def do_wcs_image(wcs, name, approx=False): # Use the "blank" image as our test image. It's not blank in the sense of having all # zeros. Rather, there are basically random values that we can use to test that # the shifted values are correct. And it is a conveniently small-ish, non-square image. - dir = "fits_files" + dir = os.path.join( + os.path.dirname(__file__), "..", "..", "GalSim", "tests", "fits_files" + ) file_name = "blankimg.fits" im = galsim.fits.read(file_name, dir=dir) np.testing.assert_equal(im.origin.x, 1, "initial origin is not 1,1 as expected") @@ -801,7 +811,7 @@ def do_local_wcs(wcs, ufunc, vfunc, name): do_wcs_pos(wcs, ufunc, vfunc, name) # Check picklability - do_pickle(wcs) + check_pickle(wcs) # Test the transformation of a GSObject # These only work for local WCS projections! @@ -910,7 +920,7 @@ def do_jac_decomp(wcs, name): M = scale * S.dot(R).dot(F) J = wcs.getMatrix() - np.testing.assert_almost_equal( + np.testing.assert_array_almost_equal( M, J, 8, "Decomposition was inconsistent with jacobian for " + name ) @@ -964,7 +974,6 @@ def do_nonlocal_wcs(wcs, ufunc, vfunc, name, test_pickle=True, color=None): wcs4 = wcs.local(wcs.origin, color=color) assert wcs != wcs4, name + " is not != wcs.local()" assert wcs4 != wcs, name + " is not != wcs.local() (reverse)" - world_origin = wcs.toWorld(wcs.origin, color=color) if wcs.isUniform(): if wcs.world_origin == galsim.PositionD(0, 0): wcs2 = wcs.local(wcs.origin, color=color).withOrigin(wcs.origin) @@ -1014,7 +1023,7 @@ def do_nonlocal_wcs(wcs, ufunc, vfunc, name, test_pickle=True, color=None): # Check picklability if test_pickle: - do_pickle(wcs) + check_pickle(wcs) # The GSObject transformation tests are only valid for a local WCS. # But it should work for wcs.local() @@ -1028,8 +1037,8 @@ def do_nonlocal_wcs(wcs, ufunc, vfunc, name, test_pickle=True, color=None): full_im2 = galsim.Image(galsim.BoundsI(-1023, 1024, -1023, 1024), scale=1.0) for x0, y0, u0, v0 in zip(far_x_list, far_y_list, far_u_list, far_v_list): - local_ufunc = lambda x, y: ufunc(x + x0, y + y0) - u0 - local_vfunc = lambda x, y: vfunc(x + x0, y + y0) - v0 + local_ufunc = lambda x, y: ufunc(x + x0, y + y0) - u0 # noqa: E731 + local_vfunc = lambda x, y: vfunc(x + x0, y + y0) - v0 # noqa: E731 image_pos = galsim.PositionD(x0, y0) world_pos = galsim.PositionD(u0, v0) do_wcs_pos( @@ -1204,8 +1213,6 @@ def do_celestial_wcs(wcs, name, test_pickle=True, approx=False): "shiftOrigin(new_origin) returned wrong world position", ) - world_origin = wcs.toWorld(wcs.origin) - full_im1 = galsim.Image(galsim.BoundsI(-1023, 1024, -1023, 1024), wcs=wcs) full_im2 = galsim.Image(galsim.BoundsI(-1023, 1024, -1023, 1024), scale=1.0) @@ -1223,7 +1230,7 @@ def do_celestial_wcs(wcs, name, test_pickle=True, approx=False): # Check picklability if test_pickle: - do_pickle(wcs) + check_pickle(wcs) near_ra_list = [] near_dec_list = [] @@ -1521,8 +1528,8 @@ def test_pixelscale(): # assert_raises(TypeError, galsim.PixelScale, scale=scale, origin=galsim.PositionD(0, 0)) # assert_raises(TypeError, galsim.PixelScale, scale=scale, world_origin=galsim.PositionD(0, 0)) - ufunc = lambda x, y: x * scale - vfunc = lambda x, y: y * scale + ufunc = lambda x, y: x * scale # noqa: E731 + vfunc = lambda x, y: y * scale # noqa: E731 # Do generic tests that apply to all WCS types do_local_wcs(wcs, ufunc, vfunc, "PixelScale") @@ -1593,8 +1600,8 @@ def test_pixelscale(): assert wcs != wcs3b, "OffsetWCS is not != a different one (origin)" assert wcs != wcs3c, "OffsetWCS is not != a different one (world_origin)" - ufunc = lambda x, y: scale * (x - x0) - vfunc = lambda x, y: scale * (y - y0) + ufunc = lambda x, y: scale * (x - x0) # noqa: E731 + vfunc = lambda x, y: scale * (y - y0) # noqa: E731 do_nonlocal_wcs(wcs, ufunc, vfunc, "OffsetWCS 1") # Add a world origin offset @@ -1602,8 +1609,8 @@ def test_pixelscale(): v0 = -141.9 world_origin = galsim.PositionD(u0, v0) wcs = galsim.OffsetWCS(scale, world_origin=world_origin) - ufunc = lambda x, y: scale * x + u0 - vfunc = lambda x, y: scale * y + v0 + ufunc = lambda x, y: scale * x + u0 # noqa: E731 + vfunc = lambda x, y: scale * y + v0 # noqa: E731 do_nonlocal_wcs(wcs, ufunc, vfunc, "OffsetWCS 2") # Add both kinds of offsets @@ -1614,8 +1621,8 @@ def test_pixelscale(): origin = galsim.PositionD(x0, y0) world_origin = galsim.PositionD(u0, v0) wcs = galsim.OffsetWCS(scale, origin=origin, world_origin=world_origin) - ufunc = lambda x, y: scale * (x - x0) + u0 - vfunc = lambda x, y: scale * (y - y0) + v0 + ufunc = lambda x, y: scale * (x - x0) + u0 # noqa: E731 + vfunc = lambda x, y: scale * (y - y0) + v0 # noqa: E731 do_nonlocal_wcs(wcs, ufunc, vfunc, "OffsetWCS 3") # Check that using a wcs in the context of an image works correctly @@ -1657,8 +1664,8 @@ def test_shearwcs(): assert wcs != wcs3b, "ShearWCS is not != a different one (shear)" factor = 1.0 / np.sqrt(1.0 - g1 * g1 - g2 * g2) - ufunc = lambda x, y: (x - g1 * x - g2 * y) * scale * factor - vfunc = lambda x, y: (y + g1 * y - g2 * x) * scale * factor + ufunc = lambda x, y: (x - g1 * x - g2 * y) * scale * factor # noqa: E731 + vfunc = lambda x, y: (y + g1 * y - g2 * x) * scale * factor # noqa: E731 # Do generic tests that apply to all WCS types do_local_wcs(wcs, ufunc, vfunc, "ShearWCS") @@ -1743,8 +1750,12 @@ def test_shearwcs(): assert wcs != wcs3c, "OffsetShearWCS is not != a different one (origin)" assert wcs != wcs3d, "OffsetShearWCS is not != a different one (world_origin)" - ufunc = lambda x, y: ((1 - g1) * (x - x0) - g2 * (y - y0)) * scale * factor - vfunc = lambda x, y: ((1 + g1) * (y - y0) - g2 * (x - x0)) * scale * factor + ufunc = ( + lambda x, y: ((1 - g1) * (x - x0) - g2 * (y - y0)) * scale * factor + ) # noqa: E731 + vfunc = ( + lambda x, y: ((1 + g1) * (y - y0) - g2 * (x - x0)) * scale * factor + ) # noqa: E731 do_nonlocal_wcs(wcs, ufunc, vfunc, "OffsetShearWCS 1") # Add a world origin offset @@ -1752,8 +1763,8 @@ def test_shearwcs(): v0 = -141.9 world_origin = galsim.PositionD(u0, v0) wcs = galsim.OffsetShearWCS(scale, shear, world_origin=world_origin) - ufunc = lambda x, y: ((1 - g1) * x - g2 * y) * scale * factor + u0 - vfunc = lambda x, y: ((1 + g1) * y - g2 * x) * scale * factor + v0 + ufunc = lambda x, y: ((1 - g1) * x - g2 * y) * scale * factor + u0 # noqa: E731 + vfunc = lambda x, y: ((1 + g1) * y - g2 * x) * scale * factor + v0 # noqa: E731 do_nonlocal_wcs(wcs, ufunc, vfunc, "OffsetShearWCS 2") # Add both kinds of offsets @@ -1764,8 +1775,12 @@ def test_shearwcs(): origin = galsim.PositionD(x0, y0) world_origin = galsim.PositionD(u0, v0) wcs = galsim.OffsetShearWCS(scale, shear, origin=origin, world_origin=world_origin) - ufunc = lambda x, y: ((1 - g1) * (x - x0) - g2 * (y - y0)) * scale * factor + u0 - vfunc = lambda x, y: ((1 + g1) * (y - y0) - g2 * (x - x0)) * scale * factor + v0 + ufunc = ( + lambda x, y: ((1 - g1) * (x - x0) - g2 * (y - y0)) * scale * factor + u0 + ) # noqa: E731 + vfunc = ( + lambda x, y: ((1 + g1) * (y - y0) - g2 * (x - x0)) * scale * factor + v0 + ) # noqa: E731 do_nonlocal_wcs(wcs, ufunc, vfunc, "OffsetShearWCS 3") # Check that using a wcs in the context of an image works correctly @@ -1825,8 +1840,8 @@ def test_affinetransform(): assert wcs != wcs3c, "JacobianWCS is not != a different one (dvdx)" assert wcs != wcs3d, "JacobianWCS is not != a different one (dvdy)" - ufunc = lambda x, y: dudx * x + dudy * y - vfunc = lambda x, y: dvdx * x + dvdy * y + ufunc = lambda x, y: dudx * x + dudy * y # noqa: E731 + vfunc = lambda x, y: dvdx * x + dvdy * y # noqa: E731 do_local_wcs(wcs, ufunc, vfunc, "JacobianWCS 1") # Check the decomposition: @@ -1882,8 +1897,8 @@ def test_affinetransform(): assert wcs != wcs3e, "AffineTransform is not != a different one (origin)" assert wcs != wcs3f, "AffineTransform is not != a different one (world_origin)" - ufunc = lambda x, y: dudx * (x - x0) + dudy * (y - y0) - vfunc = lambda x, y: dvdx * (x - x0) + dvdy * (y - y0) + ufunc = lambda x, y: dudx * (x - x0) + dudy * (y - y0) # noqa: E731 + vfunc = lambda x, y: dvdx * (x - x0) + dvdy * (y - y0) # noqa: E731 do_nonlocal_wcs(wcs, ufunc, vfunc, "AffineTransform 1") # Next one with a flip and significant rotation and a large (u,v) offset @@ -1893,8 +1908,8 @@ def test_affinetransform(): dvdy = 0.1409 wcs = galsim.JacobianWCS(dudx, dudy, dvdx, dvdy) - ufunc = lambda x, y: dudx * x + dudy * y - vfunc = lambda x, y: dvdx * x + dvdy * y + ufunc = lambda x, y: dudx * x + dudy * y # noqa: E731 + vfunc = lambda x, y: dvdx * x + dvdy * y # noqa: E731 do_local_wcs(wcs, ufunc, vfunc, "JacobianWCS 2") # Check the decomposition: @@ -1906,8 +1921,8 @@ def test_affinetransform(): wcs = galsim.AffineTransform( dudx, dudy, dvdx, dvdy, world_origin=galsim.PositionD(u0, v0) ) - ufunc = lambda x, y: dudx * x + dudy * y + u0 - vfunc = lambda x, y: dvdx * x + dvdy * y + v0 + ufunc = lambda x, y: dudx * x + dudy * y + u0 # noqa: E731 + vfunc = lambda x, y: dvdx * x + dvdy * y + v0 # noqa: E731 do_nonlocal_wcs(wcs, ufunc, vfunc, "AffineTransform 2") # Finally a really crazy one that isn't remotely regular @@ -1917,8 +1932,8 @@ def test_affinetransform(): dvdy = -0.3013 wcs = galsim.JacobianWCS(dudx, dudy, dvdx, dvdy) - ufunc = lambda x, y: dudx * x + dudy * y - vfunc = lambda x, y: dvdx * x + dvdy * y + ufunc = lambda x, y: dudx * x + dudy * y # noqa: E731 + vfunc = lambda x, y: dvdx * x + dvdy * y # noqa: E731 do_local_wcs(wcs, ufunc, vfunc, "Jacobian 3") # Check the decomposition: @@ -1937,8 +1952,8 @@ def test_affinetransform(): wcs = galsim.AffineTransform( dudx, dudy, dvdx, dvdy, origin=origin, world_origin=world_origin ) - ufunc = lambda x, y: dudx * (x - x0) + dudy * (y - y0) + u0 - vfunc = lambda x, y: dvdx * (x - x0) + dvdy * (y - y0) + v0 + ufunc = lambda x, y: dudx * (x - x0) + dudy * (y - y0) + u0 # noqa: E731 + vfunc = lambda x, y: dvdx * (x - x0) + dvdy * (y - y0) + v0 # noqa: E731 do_nonlocal_wcs(wcs, ufunc, vfunc, "AffineTransform 3") # Check that using a wcs in the context of an image works correctly @@ -2008,8 +2023,8 @@ def test_uvfunction(): # First make some that are identical to simpler WCS classes: # 1. Like PixelScale scale = 0.17 - ufunc = lambda x, y: x * scale - vfunc = lambda x, y: y * scale + ufunc = lambda x, y: x * scale # noqa: E731 + vfunc = lambda x, y: y * scale # noqa: E731 wcs = galsim.UVFunction(ufunc, vfunc) do_nonlocal_wcs(wcs, ufunc, vfunc, "UVFunction like PixelScale", test_pickle=False) assert wcs.ufunc(2.9, 3.7) == ufunc(2.9, 3.7) @@ -2022,8 +2037,8 @@ def test_uvfunction(): assert not wcs.isCelestial() # Also check with inverse functions. - xfunc = lambda u, v: u / scale - yfunc = lambda u, v: v / scale + xfunc = lambda u, v: u / scale # noqa: E731 + yfunc = lambda u, v: v / scale # noqa: E731 wcs = galsim.UVFunction(ufunc, vfunc, xfunc, yfunc) do_nonlocal_wcs( wcs, ufunc, vfunc, "UVFunction like PixelScale with inverse", test_pickle=False @@ -2057,14 +2072,14 @@ def test_uvfunction(): g1 = 0.14 g2 = -0.37 factor = 1.0 / np.sqrt(1.0 - g1 * g1 - g2 * g2) - ufunc = lambda x, y: (x - g1 * x - g2 * y) * scale * factor - vfunc = lambda x, y: (y + g1 * y - g2 * x) * scale * factor + ufunc = lambda x, y: (x - g1 * x - g2 * y) * scale * factor # noqa: E731 + vfunc = lambda x, y: (y + g1 * y - g2 * x) * scale * factor # noqa: E731 wcs = galsim.UVFunction(ufunc, vfunc) do_nonlocal_wcs(wcs, ufunc, vfunc, "UVFunction like ShearWCS", test_pickle=False) # Also check with inverse functions. - xfunc = lambda u, v: (u + g1 * u + g2 * v) / scale * factor - yfunc = lambda u, v: (v - g1 * v + g2 * u) / scale * factor + xfunc = lambda u, v: (u + g1 * u + g2 * v) / scale * factor # noqa: E731 + yfunc = lambda u, v: (v - g1 * v + g2 * u) / scale * factor # noqa: E731 wcs = galsim.UVFunction(ufunc, vfunc, xfunc, yfunc) do_nonlocal_wcs( wcs, ufunc, vfunc, "UVFunction like ShearWCS with inverse", test_pickle=False @@ -2076,8 +2091,8 @@ def test_uvfunction(): dvdx = 0.1409 dvdy = 0.2391 - ufunc = lambda x, y: dudx * x + dudy * y - vfunc = lambda x, y: dvdx * x + dvdy * y + ufunc = lambda x, y: dudx * x + dudy * y # noqa: E731 + vfunc = lambda x, y: dvdx * x + dvdy * y # noqa: E731 wcs = galsim.UVFunction(ufunc, vfunc) do_nonlocal_wcs( wcs, ufunc, vfunc, "UVFunction like AffineTransform", test_pickle=False @@ -2113,7 +2128,7 @@ def test_uvfunction(): uses_color=True, ) do_nonlocal_wcs( - wcs, ufunc, vfunc, "UVFunction with unused color term", test_pickle=True + wcsc, ufunc, vfunc, "UVFunction with unused color term", test_pickle=True ) # 4. Next some UVFunctions with non-trivial offsets @@ -2123,8 +2138,8 @@ def test_uvfunction(): v0 = -141.9 origin = galsim.PositionD(x0, y0) world_origin = galsim.PositionD(u0, v0) - ufunc2 = lambda x, y: dudx * (x - x0) + dudy * (y - y0) + u0 - vfunc2 = lambda x, y: dvdx * (x - x0) + dvdy * (y - y0) + v0 + ufunc2 = lambda x, y: dudx * (x - x0) + dudy * (y - y0) + u0 # noqa: E731 + vfunc2 = lambda x, y: dvdx * (x - x0) + dvdy * (y - y0) + v0 # noqa: E731 wcs = galsim.UVFunction(ufunc2, vfunc2) do_nonlocal_wcs( wcs, ufunc2, vfunc2, "UVFunction with origins in funcs", test_pickle=False @@ -2197,8 +2212,8 @@ def test_uvfunction(): "UVFunction dvdy does not match expected value.", ) - ufunc = lambda x, y: radial_u(x - x0, y - y0) - vfunc = lambda x, y: radial_v(x - x0, y - y0) + ufunc = lambda x, y: radial_u(x - x0, y - y0) # noqa: E731 + vfunc = lambda x, y: radial_v(x - x0, y - y0) # noqa: E731 do_nonlocal_wcs(wcs, ufunc, vfunc, "Cubic radial UVFunction", test_pickle=False) # Check that using a wcs in the context of an image works correctly @@ -2209,8 +2224,8 @@ def test_uvfunction(): cubic_u = Cubic(2.9e-5, 2000.0, "u") cubic_v = Cubic(-3.7e-5, 2000.0, "v") wcs = galsim.UVFunction(cubic_u, cubic_v, origin=galsim.PositionD(x0, y0)) - ufunc = lambda x, y: cubic_u(x - x0, y - y0) - vfunc = lambda x, y: cubic_v(x - x0, y - y0) + ufunc = lambda x, y: cubic_u(x - x0, y - y0) # noqa: E731 + vfunc = lambda x, y: cubic_v(x - x0, y - y0) # noqa: E731 do_nonlocal_wcs(wcs, ufunc, vfunc, "Cubic object UVFunction", test_pickle=False) # Check that using a wcs in the context of an image works correctly @@ -2218,8 +2233,8 @@ def test_uvfunction(): # 7. Test the UVFunction that is used in demo9 to confirm that I got the # inverse function correct! - ufunc = lambda x, y: 0.05 * x * (1.0 + 2.0e-6 * (x**2 + y**2)) - vfunc = lambda x, y: 0.05 * y * (1.0 + 2.0e-6 * (x**2 + y**2)) + ufunc = lambda x, y: 0.05 * x * (1.0 + 2.0e-6 * (x**2 + y**2)) # noqa: E731 + vfunc = lambda x, y: 0.05 * y * (1.0 + 2.0e-6 * (x**2 + y**2)) # noqa: E731 # w = 0.05 (r + 2.e-6 r^3) # 0 = r^3 + 5e5 r - 1e7 w # @@ -2231,7 +2246,7 @@ def test_uvfunction(): # ( 5 sqrt( w^2 + 5.e3/27 ) - 5 w )^1/3 ) import math - xfunc = lambda u, v: ( + xfunc = lambda u, v: ( # noqa: E731 lambda w: ( 0.0 if w == 0.0 @@ -2244,7 +2259,7 @@ def test_uvfunction(): ) ) )(math.sqrt(u**2 + v**2)) - yfunc = lambda u, v: ( + yfunc = lambda u, v: ( # noqa: E731 lambda w: ( 0.0 if w == 0.0 @@ -2281,19 +2296,23 @@ def test_uvfunction(): # This version doesn't work with numpy arrays because of the math functions. # This provides a test of that branch of the makeSkyImage function. - ufunc = lambda x, y: 0.17 * x * (1.0 + 1.0e-5 * math.sqrt(x**2 + y**2)) - vfunc = lambda x, y: 0.17 * y * (1.0 + 1.0e-5 * math.sqrt(x**2 + y**2)) + ufunc = ( + lambda x, y: 0.17 * x * (1.0 + 1.0e-5 * math.sqrt(x**2 + y**2)) + ) # noqa: E731 + vfunc = ( + lambda x, y: 0.17 * y * (1.0 + 1.0e-5 * math.sqrt(x**2 + y**2)) + ) # noqa: E731 wcs = galsim.UVFunction(ufunc, vfunc) do_nonlocal_wcs(wcs, ufunc, vfunc, "UVFunction with math funcs", test_pickle=False) do_wcs_image(wcs, "UVFunction_math") # 8. A non-trivial color example - ufunc = lambda x, y, c: (dudx + 0.1 * c) * x + dudy * y - vfunc = lambda x, y, c: dvdx * x + (dvdy - 0.2 * c) * y - xfunc = lambda u, v, c: ((dvdy - 0.2 * c) * u - dudy * v) / ( + ufunc = lambda x, y, c: (dudx + 0.1 * c) * x + dudy * y # noqa: E731 + vfunc = lambda x, y, c: dvdx * x + (dvdy - 0.2 * c) * y # noqa: E731 + xfunc = lambda u, v, c: ((dvdy - 0.2 * c) * u - dudy * v) / ( # noqa: E731 (dudx + 0.1 * c) * (dvdy - 0.2 * c) - dudy * dvdx ) - yfunc = lambda u, v, c: (-dvdx * u + (dudx + 0.1 * c) * v) / ( + yfunc = lambda u, v, c: (-dvdx * u + (dudx + 0.1 * c) * v) / ( # noqa: E731 (dudx + 0.1 * c) * (dvdy - 0.2 * c) - dudy * dvdx ) wcs = galsim.UVFunction(ufunc, vfunc, xfunc, yfunc, uses_color=True) @@ -2326,10 +2345,10 @@ def test_uvfunction(): ) # 9. A non-trivial color example that fails for arrays - ufunc = lambda x, y, c: math.exp(c * x) - vfunc = lambda x, y, c: math.exp(c * y / 2) - xfunc = lambda u, v, c: math.log(u) / c - yfunc = lambda u, v, c: math.log(v) * 2 / c + ufunc = lambda x, y, c: math.exp(c * x) # noqa: E731 + vfunc = lambda x, y, c: math.exp(c * y / 2) # noqa: E731 + xfunc = lambda u, v, c: math.log(u) / c # noqa: E731 + yfunc = lambda u, v, c: math.log(v) * 2 / c # noqa: E731 wcs = galsim.UVFunction(ufunc, vfunc, xfunc, yfunc, uses_color=True) do_nonlocal_wcs( wcs, @@ -2341,20 +2360,20 @@ def test_uvfunction(): ) # 10. One with invalid functions, which raise errors. (Just for coverage really.) - ufunc = lambda x, y: 0.05 * x * math.sqrt(x**2 - y**2) - vfunc = lambda x, y: 0.05 * y * math.sqrt(3 * y**2 - x**2) - xfunc = lambda u, v: 0.05 * u * math.sqrt(2 * u**2 - 7 * v**2) - yfunc = lambda u, v: 0.05 * v * math.sqrt(8 * v**2 - u**2) + ufunc = lambda x, y: 0.05 * x * math.sqrt(x**2 - y**2) # noqa: E731 + vfunc = lambda x, y: 0.05 * y * math.sqrt(3 * y**2 - x**2) # noqa: E731 + xfunc = lambda u, v: 0.05 * u * math.sqrt(2 * u**2 - 7 * v**2) # noqa: E731 + yfunc = lambda u, v: 0.05 * v * math.sqrt(8 * v**2 - u**2) # noqa: E731 wcs = galsim.UVFunction(ufunc, vfunc, xfunc, yfunc) assert_raises(ValueError, wcs.toWorld, galsim.PositionD(5, 6)) assert_raises(ValueError, wcs.toWorld, galsim.PositionD(8, 2)) assert_raises(ValueError, wcs.toImage, galsim.PositionD(3, 3)) assert_raises(ValueError, wcs.toImage, galsim.PositionD(6, 0)) # Repeat with color - ufunc = lambda x, y, c: 0.05 * c * math.sqrt(x**2 - y**2) - vfunc = lambda x, y, c: 0.05 * c * math.sqrt(3 * y**2 - x**2) - xfunc = lambda u, v, c: 0.05 * c * math.sqrt(2 * u**2 - 7 * v**2) - yfunc = lambda u, v, c: 0.05 * c * math.sqrt(8 * v**2 - u**2) + ufunc = lambda x, y, c: 0.05 * c * math.sqrt(x**2 - y**2) # noqa: E731 + vfunc = lambda x, y, c: 0.05 * c * math.sqrt(3 * y**2 - x**2) # noqa: E731 + xfunc = lambda u, v, c: 0.05 * c * math.sqrt(2 * u**2 - 7 * v**2) # noqa: E731 + yfunc = lambda u, v, c: 0.05 * c * math.sqrt(8 * v**2 - u**2) # noqa: E731 wcs = galsim.UVFunction(ufunc, vfunc, xfunc, yfunc, uses_color=True) assert_raises(ValueError, wcs.toWorld, galsim.PositionD(5, 6), color=0.2) assert_raises(ValueError, wcs.toWorld, galsim.PositionD(8, 2), color=0.2) @@ -2369,47 +2388,45 @@ def test_radecfunction(): funcs = [] scale = 0.17 - ufunc = lambda x, y: x * scale - vfunc = lambda x, y: y * scale + ufunc = lambda x, y: x * scale # noqa: E731 + vfunc = lambda x, y: y * scale # noqa: E731 funcs.append((ufunc, vfunc, "like PixelScale")) scale = 0.23 g1 = 0.14 g2 = -0.37 factor = 1.0 / np.sqrt(1.0 - g1 * g1 - g2 * g2) - ufunc = lambda x, y: (x - g1 * x - g2 * y) * scale * factor - vfunc = lambda x, y: (y + g1 * y - g2 * x) * scale * factor + ufunc = lambda x, y: (x - g1 * x - g2 * y) * scale * factor # noqa: E731 + vfunc = lambda x, y: (y + g1 * y - g2 * x) * scale * factor # noqa: E731 funcs.append((ufunc, vfunc, "like ShearWCS")) dudx = 0.2342 dudy = 0.1432 dvdx = 0.1409 dvdy = 0.2391 - ufunc = lambda x, y: dudx * x + dudy * y - vfunc = lambda x, y: dvdx * x + dvdy * y + ufunc = lambda x, y: dudx * x + dudy * y # noqa: E731 + vfunc = lambda x, y: dvdx * x + dvdy * y # noqa: E731 funcs.append((ufunc, vfunc, "like JacobianWCS")) x0 = 1.3 y0 = -0.9 u0 = 124.3 v0 = -141.9 - origin = galsim.PositionD(x0, y0) - world_origin = galsim.PositionD(u0, v0) - ufunc = lambda x, y: dudx * (x - x0) + dudy * (y - y0) + u0 - vfunc = lambda x, y: dvdx * (x - x0) + dvdy * (y - y0) + v0 + ufunc = lambda x, y: dudx * (x - x0) + dudy * (y - y0) + u0 # noqa: E731 + vfunc = lambda x, y: dvdx * (x - x0) + dvdy * (y - y0) + v0 # noqa: E731 funcs.append((ufunc, vfunc, "like AffineTransform")) funcs.append((radial_u, radial_v, "Cubic radial")) - ufunc = lambda x, y: radial_u(x - x0, y - y0) - vfunc = lambda x, y: radial_v(x - x0, y - y0) + ufunc = lambda x, y: radial_u(x - x0, y - y0) # noqa: E731 + vfunc = lambda x, y: radial_v(x - x0, y - y0) # noqa: E731 funcs.append((ufunc, vfunc, "offset Cubic radial")) cubic_u = Cubic(2.9e-5, 2000.0, "u") cubic_v = Cubic(-3.7e-5, 2000.0, "v") - ufunc = lambda x, y: cubic_u(x - x0, y - y0) - vfunc = lambda x, y: cubic_v(x - x0, y - y0) + ufunc = lambda x, y: cubic_u(x - x0, y - y0) # noqa: E731 + vfunc = lambda x, y: cubic_v(x - x0, y - y0) # noqa: E731 funcs.append((ufunc, vfunc, "offset Cubic object")) # The last one needs to not have a lambda, since we use it for the image test, which @@ -2434,7 +2451,7 @@ def test_radecfunction(): ) scale = galsim.arcsec / galsim.radians - radec_func = lambda x, y: center.deproject_rad( + radec_func = lambda x, y: center.deproject_rad( # noqa: E731 ufunc(x, y) * scale, vfunc(x, y) * scale, projection="lambert" ) wcs2 = galsim.RaDecFunction(radec_func) @@ -2447,12 +2464,12 @@ def test_radecfunction(): # code does the right thing in that case too, since local and makeSkyImage # try the numpy option first and do something else if it fails. # This also tests the alternate initialization using separate ra_func, dec_fun. - ra_func = lambda x, y: center.deproject( + ra_func = lambda x, y: center.deproject( # noqa: E731 ufunc(x, y) * galsim.arcsec, vfunc(x, y) * galsim.arcsec, projection="lambert", ).ra.rad - dec_func = lambda x, y: center.deproject( + dec_func = lambda x, y: center.deproject( # noqa: E731 ufunc(x, y) * galsim.arcsec, vfunc(x, y) * galsim.arcsec, projection="lambert", @@ -2521,7 +2538,6 @@ def test_radecfunction(): image_pos = galsim.PositionD(x, y) world_pos1 = wcs1.toWorld(image_pos) world_pos2 = test_wcs.toWorld(image_pos) - origin = test_wcs.toWorld(galsim.PositionD(0.0, 0.0)) d3 = np.sqrt(world_pos1.x**2 + world_pos1.y**2) d4 = center.distanceTo(world_pos2) d4 = 2.0 * np.sin(d4 / 2) * galsim.radians / galsim.arcsec @@ -2712,7 +2728,7 @@ def test_radecfunction(): do_wcs_image(wcs3, "RaDecFunction") # One with invalid functions, which raise errors. (Just for coverage really.) - radec_func = lambda x, y: center.deproject_rad( + radec_func = lambda x, y: center.deproject_rad( # noqa: E731 math.sqrt(x), math.sqrt(y), projection="lambert" ) wcs = galsim.RaDecFunction(radec_func) @@ -2780,8 +2796,8 @@ def test_astropywcs(): """Test the AstropyWCS class""" with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=RuntimeWarning) - import astropy.wcs - import scipy # AstropyWCS constructor will do this, so check now. + import astropy.wcs # noqa: F401 + import scipy # noqa: F401 - AstropyWCS constructor will do this, so check now. # These all work, but it is quite slow, so only test a few of them for the regular unit tests. # (1.8 seconds for 4 tags.) @@ -3130,9 +3146,9 @@ def test_inverseab_convergence(): [0.0003767412741890354, 0.00019733136932198898], ] ), - coord.CelestialCoord( - coord.Angle(2.171481673601117, coord.radians), - coord.Angle(-0.47508762601580773, coord.radians), + galsim.CelestialCoord( + galsim.Angle(2.171481673601117, galsim.radians), + galsim.Angle(-0.47508762601580773, galsim.radians), ), None, np.array( @@ -3320,13 +3336,13 @@ def test_fitswcs(): # mostly just tests the basic interface of the FitsWCS function. test_tags = ["TAN", "TPV"] try: - import starlink.Ast + import starlink.Ast # noqa: F401 # Useful also to test one that GSFitsWCS doesn't work on. This works on Travis at # least, and helps to cover some of the FitsWCS functionality where the first try # isn't successful. test_tags.append("HPX") - except: + except Exception: pass dir = "fits_files" @@ -3361,7 +3377,7 @@ def test_fitswcs(): # We don't really have any accuracy checks here. This really just checks that the # read function doesn't raise an exception. hdu, hdu_list, fin = galsim.fits.readFile(file_name, dir) - affine = galsim.AffineTransform._readHeader(hdu.header) + galsim.AffineTransform._readHeader(hdu.header) galsim.fits.closeHDUList(hdu_list, fin) # This does support LINEAR WCS types. @@ -3419,7 +3435,7 @@ def check_sphere(ra1, dec1, ra2, dec2, atol=1): w = dsq >= 3.99 if np.any(w): cross = np.cross(np.array([x1, y1, z1])[w], np.array([x2, y2, z2])[w]) - crosssq = cross[0] ** 2 + cross[1] ** 2 + cross[2] ** 2 + crossq = cross[0] ** 2 + cross[1] ** 2 + cross[2] ** 2 dist[w] = np.pi - np.arcsin(np.sqrt(crossq)) dist = np.rad2deg(dist) * 3600 np.testing.assert_allclose(dist, 0.0, rtol=0.0, atol=atol) @@ -3448,7 +3464,7 @@ def test_fittedsipwcs(): "ZTF": (0.1, 0.1), } - dir = "fits_files" + dir = os.path.join(os.path.dirname(__file__), "..", "..", "GalSim/tests/fits_files") if __name__ == "__main__": test_tags = all_tags @@ -3964,7 +3980,7 @@ def test_int_args(): # is unnecessary. dir = "des_data" file_name = "DECam_00158414_01.fits.fz" - with profile(): + with Profile(): t0 = time.time() wcs = galsim.FitsWCS(file_name, dir=dir) t1 = time.time() @@ -4008,8 +4024,8 @@ def test_razero(): # do this. with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=RuntimeWarning) - import astropy.wcs - import scipy # AstropyWCS constructor will do this, so check now. + import astropy.wcs # noqa: F401 + import scipy # noqa: F401 - AstropyWCS constructor will do this, so check now. dir = "fits_files" # This file is based in sipsample.fits, but with the CRVAL1 changed to 0.002322805429 diff --git a/tests/jax/test_moffat_comp_galsim.py b/tests/jax/test_moffat_comp_galsim.py new file mode 100644 index 00000000..60f8c831 --- /dev/null +++ b/tests/jax/test_moffat_comp_galsim.py @@ -0,0 +1,40 @@ +import galsim as _galsim +import jax_galsim as galsim +import numpy as np + + +def test_moffat_comp_galsim_maxk(): + psfs = [ + # Make sure to include all the specialized betas we have in C++ layer. + # The scale_radius and flux don't matter, but vary themm too. + # Note: We also specialize beta=1, but that seems to be impossible to realize, + # even when it is trunctatd. + galsim.Moffat(beta=1.5, scale_radius=1, flux=1), + galsim.Moffat(beta=1.5001, scale_radius=1, flux=1), + galsim.Moffat(beta=2, scale_radius=0.8, flux=23), + galsim.Moffat(beta=2.5, scale_radius=1.8e-3, flux=2), + galsim.Moffat(beta=3, scale_radius=1.8e3, flux=35), + galsim.Moffat(beta=3.5, scale_radius=1.3, flux=123), + galsim.Moffat(beta=4, scale_radius=4.9, flux=23), + galsim.Moffat(beta=1.22, scale_radius=23, flux=23), + galsim.Moffat(beta=3.6, scale_radius=2, flux=23), + galsim.Moffat(beta=12.9, scale_radius=5, flux=23), + galsim.Moffat(beta=1.22, scale_radius=7, flux=23, trunc=30), + galsim.Moffat(beta=3.6, scale_radius=9, flux=23, trunc=50), + galsim.Moffat(beta=12.9, scale_radius=11, flux=23, trunc=1000), + ] + threshs = [1.e-3, 1.e-4, 0.03] + print('\nbeta \t trunc \t thresh \t kValue(maxk) \t jgs-maxk \t gs-maxk') + for psf in psfs: + for thresh in threshs: + psf = psf.withGSParams(maxk_threshold=thresh) + gpsf = _galsim.Moffat(beta=psf.beta, scale_radius=psf.scale_radius, flux=psf.flux, trunc=psf.trunc) + gpsf = gpsf.withGSParams(maxk_threshold=thresh) + fk = psf.kValue(psf.maxk, 0).real / psf.flux + + print(f'{psf.beta} \t {int(psf.trunc)} \t {thresh:.1e} \t {fk:.3e} \t {psf.maxk:.3e} \t {gpsf.maxk:.3e}') + np.testing.assert_allclose(psf.kValue(0.0, 0.0), gpsf.kValue(0.0, 0.0), rtol=1e-5) + np.testing.assert_allclose(psf.kValue(0.0, 0.1), gpsf.kValue(0.0, 0.1), rtol=1e-5) + np.testing.assert_allclose(psf.kValue(-1.0, 0.0), gpsf.kValue(-1.0, 0.0), rtol=1e-5) + np.testing.assert_allclose(psf.kValue(1.0, 0.0), gpsf.kValue(1.0, 0.0), rtol=1e-5) + np.testing.assert_allclose(gpsf.maxk, psf.maxk, rtol=0.25, atol=0) From 425f3c3c33796b6e3747fb2394243a18efa31910 Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 17 Oct 2023 14:46:37 -0500 Subject: [PATCH 03/10] STY please the flake8 --- tests/jax/galsim/test_wcs_jax.py | 24 ++++++++++---------- tests/jax/test_moffat_comp_galsim.py | 34 ++++++++++++++++++++-------- 2 files changed, 37 insertions(+), 21 deletions(-) diff --git a/tests/jax/galsim/test_wcs_jax.py b/tests/jax/galsim/test_wcs_jax.py index f2b7e791..d5ab3628 100644 --- a/tests/jax/galsim/test_wcs_jax.py +++ b/tests/jax/galsim/test_wcs_jax.py @@ -1750,12 +1750,12 @@ def test_shearwcs(): assert wcs != wcs3c, "OffsetShearWCS is not != a different one (origin)" assert wcs != wcs3d, "OffsetShearWCS is not != a different one (world_origin)" - ufunc = ( + ufunc = ( # noqa: E731 lambda x, y: ((1 - g1) * (x - x0) - g2 * (y - y0)) * scale * factor - ) # noqa: E731 - vfunc = ( + ) + vfunc = ( # noqa: E731 lambda x, y: ((1 + g1) * (y - y0) - g2 * (x - x0)) * scale * factor - ) # noqa: E731 + ) do_nonlocal_wcs(wcs, ufunc, vfunc, "OffsetShearWCS 1") # Add a world origin offset @@ -1775,12 +1775,12 @@ def test_shearwcs(): origin = galsim.PositionD(x0, y0) world_origin = galsim.PositionD(u0, v0) wcs = galsim.OffsetShearWCS(scale, shear, origin=origin, world_origin=world_origin) - ufunc = ( + ufunc = ( # noqa: E731 lambda x, y: ((1 - g1) * (x - x0) - g2 * (y - y0)) * scale * factor + u0 - ) # noqa: E731 - vfunc = ( + ) + vfunc = ( # noqa: E731 lambda x, y: ((1 + g1) * (y - y0) - g2 * (x - x0)) * scale * factor + v0 - ) # noqa: E731 + ) do_nonlocal_wcs(wcs, ufunc, vfunc, "OffsetShearWCS 3") # Check that using a wcs in the context of an image works correctly @@ -2296,12 +2296,12 @@ def test_uvfunction(): # This version doesn't work with numpy arrays because of the math functions. # This provides a test of that branch of the makeSkyImage function. - ufunc = ( + ufunc = ( # noqa: E731 lambda x, y: 0.17 * x * (1.0 + 1.0e-5 * math.sqrt(x**2 + y**2)) - ) # noqa: E731 - vfunc = ( + ) + vfunc = ( # noqa: E731 lambda x, y: 0.17 * y * (1.0 + 1.0e-5 * math.sqrt(x**2 + y**2)) - ) # noqa: E731 + ) wcs = galsim.UVFunction(ufunc, vfunc) do_nonlocal_wcs(wcs, ufunc, vfunc, "UVFunction with math funcs", test_pickle=False) do_wcs_image(wcs, "UVFunction_math") diff --git a/tests/jax/test_moffat_comp_galsim.py b/tests/jax/test_moffat_comp_galsim.py index 60f8c831..d4549420 100644 --- a/tests/jax/test_moffat_comp_galsim.py +++ b/tests/jax/test_moffat_comp_galsim.py @@ -1,7 +1,8 @@ import galsim as _galsim -import jax_galsim as galsim import numpy as np +import jax_galsim as galsim + def test_moffat_comp_galsim_maxk(): psfs = [ @@ -23,18 +24,33 @@ def test_moffat_comp_galsim_maxk(): galsim.Moffat(beta=3.6, scale_radius=9, flux=23, trunc=50), galsim.Moffat(beta=12.9, scale_radius=11, flux=23, trunc=1000), ] - threshs = [1.e-3, 1.e-4, 0.03] - print('\nbeta \t trunc \t thresh \t kValue(maxk) \t jgs-maxk \t gs-maxk') + threshs = [1.0e-3, 1.0e-4, 0.03] + print("\nbeta \t trunc \t thresh \t kValue(maxk) \t jgs-maxk \t gs-maxk") for psf in psfs: for thresh in threshs: psf = psf.withGSParams(maxk_threshold=thresh) - gpsf = _galsim.Moffat(beta=psf.beta, scale_radius=psf.scale_radius, flux=psf.flux, trunc=psf.trunc) + gpsf = _galsim.Moffat( + beta=psf.beta, + scale_radius=psf.scale_radius, + flux=psf.flux, + trunc=psf.trunc, + ) gpsf = gpsf.withGSParams(maxk_threshold=thresh) fk = psf.kValue(psf.maxk, 0).real / psf.flux - print(f'{psf.beta} \t {int(psf.trunc)} \t {thresh:.1e} \t {fk:.3e} \t {psf.maxk:.3e} \t {gpsf.maxk:.3e}') - np.testing.assert_allclose(psf.kValue(0.0, 0.0), gpsf.kValue(0.0, 0.0), rtol=1e-5) - np.testing.assert_allclose(psf.kValue(0.0, 0.1), gpsf.kValue(0.0, 0.1), rtol=1e-5) - np.testing.assert_allclose(psf.kValue(-1.0, 0.0), gpsf.kValue(-1.0, 0.0), rtol=1e-5) - np.testing.assert_allclose(psf.kValue(1.0, 0.0), gpsf.kValue(1.0, 0.0), rtol=1e-5) + print( + f"{psf.beta} \t {int(psf.trunc)} \t {thresh:.1e} \t {fk:.3e} \t {psf.maxk:.3e} \t {gpsf.maxk:.3e}" + ) + np.testing.assert_allclose( + psf.kValue(0.0, 0.0), gpsf.kValue(0.0, 0.0), rtol=1e-5 + ) + np.testing.assert_allclose( + psf.kValue(0.0, 0.1), gpsf.kValue(0.0, 0.1), rtol=1e-5 + ) + np.testing.assert_allclose( + psf.kValue(-1.0, 0.0), gpsf.kValue(-1.0, 0.0), rtol=1e-5 + ) + np.testing.assert_allclose( + psf.kValue(1.0, 0.0), gpsf.kValue(1.0, 0.0), rtol=1e-5 + ) np.testing.assert_allclose(gpsf.maxk, psf.maxk, rtol=0.25, atol=0) From bc7f3550d8ae942fcad0649d14de358261505e7e Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Tue, 17 Oct 2023 14:47:18 -0500 Subject: [PATCH 04/10] Update jax_galsim/core/utils.py --- jax_galsim/core/utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index 872912ec..c41a8a40 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -96,8 +96,6 @@ def _func(i, args): fhigh, ) - low = 0.0 - high = 1e5 flow = func(low) fhigh = func(high) args = (func, low, flow, high, fhigh) From 851b4f74022e72586974968bf5cdfa408302c619 Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 17 Oct 2023 23:00:15 -0500 Subject: [PATCH 05/10] TST patch galsim in check_pickle --- tests/GalSim | 2 +- tests/conftest.py | 20 +++++++++++++++++++- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/tests/GalSim b/tests/GalSim index 81509041..b018d57f 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 815090419d343d0e840bbc53e79c7bc4469ec79d +Subproject commit b018d57fba88eabbaacf40d34d3029a77e7071f2 diff --git a/tests/conftest.py b/tests/conftest.py index dfdc96cc..4d8bab9a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,23 +1,41 @@ import inspect import os +import sys from functools import lru_cache +from unittest.mock import patch +import galsim import pytest import yaml # Define the accuracy for running the tests from jax.config import config +import jax_galsim + config.update("jax_enable_x64", True) # Identify the path to this current file test_directory = os.path.dirname(os.path.abspath(__file__)) - # Loading which tests to run with open(os.path.join(test_directory, "galsim_tests_config.yaml"), "r") as f: test_config = yaml.safe_load(f) +# we need to patch the galsim utilities check_pickle function +# to use jax_galsim. it has an import inside a function so +# we patch sys.modules. +# see https://stackoverflow.com/questions/34213088/mocking-a-module-imported-inside-of-a-function +orig_check_pickle = galsim.utilities.check_pickle + + +def _check_pickle(*args, **kwargs): + with patch.dict(sys.modules, {"galsim": jax_galsim}): + return orig_check_pickle(*args, **kwargs) + + +galsim.utilities.check_pickle = _check_pickle + def pytest_ignore_collect(collection_path, path, config): """This hook will skip collecting tests that are not From a1d5d9baf70627271b2485e5c437987020b1a389 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Wed, 18 Oct 2023 11:02:17 -0500 Subject: [PATCH 06/10] Update jax_galsim/transform.py --- jax_galsim/transform.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/jax_galsim/transform.py b/jax_galsim/transform.py index 56f3e1ee..f86e7b3e 100644 --- a/jax_galsim/transform.py +++ b/jax_galsim/transform.py @@ -105,10 +105,6 @@ def flux_ratio(self): def _flux(self): return self._flux_scaling * self._original.flux - @property - def _offset(self): - return self._params["offset"] - def withGSParams(self, gsparams=None, **kwargs): """Create a version of the current object with the given gsparams From 0e9a4b84892eeed982da448d88687ab859bfd984 Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 18 Oct 2023 11:04:52 -0500 Subject: [PATCH 07/10] TST fix tests --- tests/jax/galsim/test_random_jax.py | 56 ++++++++++++++--------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/tests/jax/galsim/test_random_jax.py b/tests/jax/galsim/test_random_jax.py index 151c11d4..879cdad2 100644 --- a/tests/jax/galsim/test_random_jax.py +++ b/tests/jax/galsim/test_random_jax.py @@ -3,7 +3,7 @@ import os import galsim from galsim.utilities import single_threaded -from galsim_test_helpers import timer, do_pickle # noqa: E402 +from galsim_test_helpers import timer, check_pickle # noqa: E402 precision = 10 # decimal point at which agreement is required for all double precision tests @@ -274,10 +274,10 @@ def test_uniform(): np.testing.assert_array_equal(v1, v2) # Check picklability - do_pickle(u, lambda x: x.serialize()) - do_pickle(u, lambda x: (x(), x(), x(), x())) - do_pickle(u) - do_pickle(rng) + check_pickle(u, lambda x: x.serialize()) + check_pickle(u, lambda x: (x(), x(), x(), x())) + check_pickle(u) + check_pickle(rng) assert "UniformDeviate" in repr(u) assert "UniformDeviate" in str(u) assert isinstance(eval(repr(u)), galsim.UniformDeviate) @@ -495,9 +495,9 @@ def test_gaussian(): np.testing.assert_array_equal(v1, v2) # Check picklability - do_pickle(g, lambda x: (x.serialize(), x.mean, x.sigma)) - do_pickle(g, lambda x: (x(), x(), x(), x())) - do_pickle(g) + check_pickle(g, lambda x: (x.serialize(), x.mean, x.sigma)) + check_pickle(g, lambda x: (x(), x(), x(), x())) + check_pickle(g) assert 'GaussianDeviate' in repr(g) assert 'GaussianDeviate' in str(g) assert isinstance(eval(repr(g)), galsim.GaussianDeviate) @@ -666,9 +666,9 @@ def test_binomial(): np.testing.assert_array_equal(v1, v2) # Check picklability - do_pickle(b, lambda x: (x.serialize(), x.n, x.p)) - do_pickle(b, lambda x: (x(), x(), x(), x())) - do_pickle(b) + check_pickle(b, lambda x: (x.serialize(), x.n, x.p)) + check_pickle(b, lambda x: (x(), x(), x(), x())) + check_pickle(b) assert 'BinomialDeviate' in repr(b) assert 'BinomialDeviate' in str(b) assert isinstance(eval(repr(b)), galsim.BinomialDeviate) @@ -869,9 +869,9 @@ def test_poisson(): np.testing.assert_array_equal(v1, v2) # Check picklability - do_pickle(p, lambda x: (x.serialize(), x.mean)) - do_pickle(p, lambda x: (x(), x(), x(), x())) - do_pickle(p) + check_pickle(p, lambda x: (x.serialize(), x.mean)) + check_pickle(p, lambda x: (x(), x(), x(), x())) + check_pickle(p) assert 'PoissonDeviate' in repr(p) assert 'PoissonDeviate' in str(p) assert isinstance(eval(repr(p)), galsim.PoissonDeviate) @@ -1000,7 +1000,7 @@ def test_poisson_zeromean(): p = galsim.PoissonDeviate(testseed, mean=0) p2 = p.duplicate() p3 = galsim.PoissonDeviate(p.serialize(), mean=0) - do_pickle(p) + check_pickle(p) # Test direct draws testResult = (p(), p(), p()) @@ -1184,9 +1184,9 @@ def test_weibull(): np.testing.assert_array_equal(v1, v2) # Check picklability - do_pickle(w, lambda x: (x.serialize(), x.a, x.b)) - do_pickle(w, lambda x: (x(), x(), x(), x())) - do_pickle(w) + check_pickle(w, lambda x: (x.serialize(), x.a, x.b)) + check_pickle(w, lambda x: (x(), x(), x(), x())) + check_pickle(w) assert 'WeibullDeviate' in repr(w) assert 'WeibullDeviate' in str(w) assert isinstance(eval(repr(w)), galsim.WeibullDeviate) @@ -1337,9 +1337,9 @@ def test_gamma(): err_msg='Wrong gamma random number sequence from generate.') # Check picklability - do_pickle(g, lambda x: (x.serialize(), x.k, x.theta)) - do_pickle(g, lambda x: (x(), x(), x(), x())) - do_pickle(g) + check_pickle(g, lambda x: (x.serialize(), x.k, x.theta)) + check_pickle(g, lambda x: (x(), x(), x(), x())) + check_pickle(g) assert 'GammaDeviate' in repr(g) assert 'GammaDeviate' in str(g) assert isinstance(eval(repr(g)), galsim.GammaDeviate) @@ -1489,9 +1489,9 @@ def test_chi2(): err_msg='Wrong Chi^2 random number sequence from generate.') # Check picklability - do_pickle(c, lambda x: (x.serialize(), x.n)) - do_pickle(c, lambda x: (x(), x(), x(), x())) - do_pickle(c) + check_pickle(c, lambda x: (x.serialize(), x.n)) + check_pickle(c, lambda x: (x(), x(), x(), x())) + check_pickle(c) assert 'Chi2Deviate' in repr(c) assert 'Chi2Deviate' in str(c) assert isinstance(eval(repr(c)), galsim.Chi2Deviate) @@ -1710,8 +1710,8 @@ def test_chi2(): # np.testing.assert_array_equal(v1, v2) # # Check picklability -# do_pickle(d, lambda x: (x(), x(), x(), x())) -# do_pickle(d) +# check_pickle(d, lambda x: (x(), x(), x(), x())) +# check_pickle(d) # assert 'DistDeviate' in repr(d) # assert 'DistDeviate' in str(d) # assert isinstance(eval(repr(d)), galsim.DistDeviate) @@ -1877,8 +1877,8 @@ def test_chi2(): # err_msg='Two DistDeviates with near-flat probabilities generated different values.') # # Check picklability -# do_pickle(d, lambda x: (x(), x(), x(), x())) -# do_pickle(d) +# check_pickle(d, lambda x: (x(), x(), x(), x())) +# check_pickle(d) # assert 'DistDeviate' in repr(d) # assert 'DistDeviate' in str(d) # assert isinstance(eval(repr(d)), galsim.DistDeviate) From e18bb4734989b5cf2a937435ac674942d3a084bb Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 18 Oct 2023 11:26:05 -0500 Subject: [PATCH 08/10] ENH change api to get tests to pass without using internal attributes --- jax_galsim/core/draw.py | 4 ++-- jax_galsim/transform.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/jax_galsim/core/draw.py b/jax_galsim/core/draw.py index 4ccf2789..a8edfe51 100644 --- a/jax_galsim/core/draw.py +++ b/jax_galsim/core/draw.py @@ -51,12 +51,12 @@ def draw_by_kValue(gsobject, image, jacobian=jnp.eye(2)): return Image(array=im, bounds=image.bounds, wcs=image.wcs, check_bounds=False) -def apply_kImage_phases(gsobject, image, jacobian=jnp.eye(2)): +def apply_kImage_phases(offset, image, jacobian=jnp.eye(2)): # Create an array of coordinates kcoords = jnp.stack(image.get_pixel_centers(), axis=-1) kcoords = kcoords * image.scale # Scale by the image pixel scale kcoords = jnp.dot(kcoords, jacobian) - cenx, ceny = gsobject._offset.x, gsobject._offset.y + cenx, ceny = offset.x, offset.y # # flux Exp(-i (kx cx + kxy cx + kyx cy + ky cy ) ) diff --git a/jax_galsim/transform.py b/jax_galsim/transform.py index f86e7b3e..c64863f8 100644 --- a/jax_galsim/transform.py +++ b/jax_galsim/transform.py @@ -360,7 +360,7 @@ def _drawKImage(self, image, jac=None): image = self._original._drawKImage(image, jac1) _jac = jnp.eye(2) if jac is None else jac - image = apply_kImage_phases(self, image, _jac) + image = apply_kImage_phases(self.offset, image, _jac) image = image * self._flux_scaling return image From 20cd1439983fbce08c8713a3d5fd104657eb21c2 Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 18 Oct 2023 12:34:26 -0500 Subject: [PATCH 09/10] TST make sure to patch BaseDeviate --- tests/conftest.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/conftest.py b/tests/conftest.py index 4d8bab9a..5db13ef4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,6 +27,7 @@ # we patch sys.modules. # see https://stackoverflow.com/questions/34213088/mocking-a-module-imported-inside-of-a-function orig_check_pickle = galsim.utilities.check_pickle +orig_check_pickle.__globals__["BaseDeviate"] = jax_galsim.BaseDeviate def _check_pickle(*args, **kwargs): From ab8589e440112c3b3fbbbb929cea7019950321e9 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Thu, 19 Oct 2023 06:43:45 -0500 Subject: [PATCH 10/10] Update python_package.yaml --- .github/workflows/python_package.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/python_package.yaml b/.github/workflows/python_package.yaml index b4b6ca67..dac1d792 100644 --- a/.github/workflows/python_package.yaml +++ b/.github/workflows/python_package.yaml @@ -2,9 +2,9 @@ name: Python package on: push: - branches: ["main"] + branches: + - main pull_request: - branches: ["main"] jobs: build: @@ -42,4 +42,4 @@ jobs: - name: Test with pytest run: | git submodule update --init --recursive - pytest + pytest --durations=0