From 35e8fb4a09446721752169e2c1fe00cd9f1dc8bb Mon Sep 17 00:00:00 2001 From: beckermr Date: Sun, 8 Oct 2023 16:50:11 -0400 Subject: [PATCH 01/33] WIP add random deviate classes --- .pre-commit-config.yaml | 6 +- jax_galsim/__init__.py | 31 +- jax_galsim/core/utils.py | 14 +- jax_galsim/errors.py | 19 + jax_galsim/random.py | 830 +++++++++++ jax_galsim/transform.py | 65 +- jax_galsim/utilities.py | 2 + jax_galsim/wcs.py | 96 +- tests/GalSim | 2 +- tests/conftest.py | 26 + tests/galsim_tests_config.yaml | 7 +- tests/jax/galsim/test_random_jax.py | 2019 +++++++++++++++++++++++++++ tests/jax/galsim/test_wcs_jax.py | 13 +- 13 files changed, 3062 insertions(+), 68 deletions(-) create mode 100644 jax_galsim/errors.py create mode 100644 jax_galsim/random.py create mode 100644 tests/jax/galsim/test_random_jax.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d9b1101e..1d952b2e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,16 +4,16 @@ repos: hooks: - id: black language: python - exclude: tests/GalSim/|tests/Coord/ + exclude: tests/GalSim/|tests/Coord/|tests/jax/galsim/ - repo: https://github.com/pycqa/flake8 rev: 6.1.0 hooks: - id: flake8 entry: pflake8 additional_dependencies: [pyproject-flake8] - exclude: tests/ + exclude: tests/Galsim|tests/Coord/|tests/jax/galsim/ - repo: https://github.com/pycqa/isort rev: 5.12.0 hooks: - id: isort - exclude: tests/Galsim|tests/Coord/ + exclude: tests/Galsim|tests/Coord/|tests/jax/galsim/ diff --git a/jax_galsim/__init__.py b/jax_galsim/__init__.py index 4a4d2ee8..e2d91e4d 100644 --- a/jax_galsim/__init__.py +++ b/jax_galsim/__init__.py @@ -1,23 +1,14 @@ -# Inherit all Exception and Warning classes from galsim -from galsim.errors import ( - GalSimBoundsError, - GalSimConfigError, - GalSimConfigValueError, - GalSimDeprecationWarning, - GalSimError, - GalSimFFTSizeError, - GalSimHSMError, - GalSimImmutableError, - GalSimIncompatibleValuesError, - GalSimIndexError, - GalSimKeyError, - GalSimNotImplementedError, - GalSimRangeError, - GalSimSEDError, - GalSimUndefinedBoundsError, - GalSimValueError, - GalSimWarning, -) +# Exception and Warning classes +from .errors import GalSimError, GalSimRangeError, GalSimValueError +from .errors import GalSimKeyError, GalSimIndexError, GalSimNotImplementedError +from .errors import GalSimBoundsError, GalSimUndefinedBoundsError, GalSimImmutableError +from .errors import GalSimIncompatibleValuesError, GalSimSEDError, GalSimHSMError +from .errors import GalSimFFTSizeError +from .errors import GalSimConfigError, GalSimConfigValueError +from .errors import GalSimWarning, GalSimDeprecationWarning + +# noise +from .random import BaseDeviate, UniformDeviate # Basic building blocks from .bounds import Bounds, BoundsD, BoundsI diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index f2e82df4..4753ff59 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -1,10 +1,20 @@ import jax +def convert_to_float(x): + if isinstance(x, jax.Array): + if x.shape == (): + return x.item() + else: + return x[0].astype(float).item() + else: + return float(x) + + def cast_scalar_to_float(x): """Cast the input to a float. Works on python floats and jax arrays.""" - if isinstance(x, float): - return float(x) + if isinstance(x, jax.Array): + return x.astype(float) elif hasattr(x, "astype"): return x.astype(float) else: diff --git a/jax_galsim/errors.py b/jax_galsim/errors.py new file mode 100644 index 00000000..10e81995 --- /dev/null +++ b/jax_galsim/errors.py @@ -0,0 +1,19 @@ +from galsim.errors import ( # noqa: F401 + GalSimBoundsError, + GalSimConfigError, + GalSimConfigValueError, + GalSimDeprecationWarning, + GalSimError, + GalSimFFTSizeError, + GalSimHSMError, + GalSimImmutableError, + GalSimIncompatibleValuesError, + GalSimIndexError, + GalSimKeyError, + GalSimNotImplementedError, + GalSimRangeError, + GalSimSEDError, + GalSimUndefinedBoundsError, + GalSimValueError, + GalSimWarning, +) diff --git a/jax_galsim/random.py b/jax_galsim/random.py new file mode 100644 index 00000000..5755f333 --- /dev/null +++ b/jax_galsim/random.py @@ -0,0 +1,830 @@ +import secrets + +import galsim as _galsim + +import jax +import jax.numpy as jnp +import jax.random as jrandom +from jax._src.numpy.util import _wraps +from jax.tree_util import register_pytree_node_class + +try: + from jax.extend.random import wrap_key_data +except ImportError: + from jax.random import wrap_key_data + +from jax_galsim.core.utils import ensure_hashable + + +LAX_FUNCTIONAL_RNG = ( + "The JAX version of the this class is purely function and thus cannot " + "share state with any other version of this class. Also no type checking is done on the inputs." +) + + +@_wraps( + _galsim.BaseDeviate, + lax_description=LAX_FUNCTIONAL_RNG, +) +@register_pytree_node_class +class BaseDeviate: + # always the case for JAX + has_reliable_discard = True + generates_in_pairs = False + + def __init__(self, seed=None): + self.reset(seed=seed) + + @_wraps(_galsim.BaseDeviate.seed, lax_description="The JAX version of this method does no type checking.") + def seed(self, seed=0): + self._seed(seed=seed) + + @_wraps(_galsim.BaseDeviate._seed) + def _seed(self, seed=0): + _initial_seed = seed or secrets.randbelow(2**31) + self._key = jrandom.PRNGKey(_initial_seed) + + @_wraps( + _galsim.BaseDeviate.reset, + lax_description=( + "The JAX version of this method does no type checking. Also, the JAX version of this " + "class cannot be linked to another JAX version of this class so ``reset`` is equivalent " + "to ``seed``. If another ``BaseDeviate`` is supplied, that deviates current state is used." + ) + ) + def reset(self, seed=None): + if isinstance(seed, BaseDeviate): + self._reset(seed) + elif isinstance(seed, jax.Array): + self._key = wrap_key_data(seed) + elif isinstance(seed, str): + self._key = wrap_key_data(jnp.array(eval(seed), dtype=jnp.uint32)) + elif isinstance(seed, tuple): + self._key = wrap_key_data(jnp.array(seed, dtype=jnp.uint32)) + else: + self._seed(seed=seed) + + @_wraps(_galsim.BaseDeviate._reset) + def _reset(self, rng): + self._key = rng._key + + @property + @_wraps(_galsim.BaseDeviate.np) + def np(self): + raise NotImplementedError("The JAX galsim.BaseDeviate does not support being used as a numpy PRNG.") + + @_wraps(_galsim.BaseDeviate.as_numpy_generator) + def as_numpy_generator(self): + raise NotImplementedError("The JAX galsim.BaseDeviate does not support being used as a numpy PRNG.") + + @_wraps(_galsim.BaseDeviate.duplicate) + def duplicate(self): + ret = BaseDeviate.__new__(self.__class__) + ret._key = self._key + return ret + + def __copy__(self): + return self.duplicate() + + @_wraps(_galsim.BaseDeviate.clearCache, lax_description="This method is a no-op for the JAX version of this class.") + def clearCache(self): + pass + + @_wraps( + _galsim.BaseDeviate.discard, + lax_description=( + "The JAX version of this class has reliable discarding and uses one key per value " + "so it never generates in pairs. Thus this method will never raise an error." + ) + ) + def discard(self, n, suppress_warnings=False): + def _discard(i, key): + key, subkey = jrandom.split(key) + return key + + self._key = jax.lax.fori_loop(0, n, _discard, self._key) + + @_wraps( + _galsim.BaseDeviate.raw, + lax_description=( + "The JAX version of this class does not use the raw value to " + "generate the next value of a specific kind." + ), + ) + def raw(self): + self._key, subkey = jrandom.split(self._key) + return jrandom.bits(subkey, dtype=jnp.uint32) + + @_wraps( + _galsim.BaseDeviate.generate, + lax_description=( + "JAX arrays cannot be changed in-place, so the JAX version of " + "this method returns a new array." + ) + ) + def generate(self, array): + self._key, array = self.__class__._generate(self._key, array) + return array + + @_wraps( + _galsim.BaseDeviate.generate, + lax_description=( + "JAX arrays cannot be changed in-place, so the JAX version of " + "this method returns a new array." + ) + ) + def add_generate(self, array): + return self.generate(array) + array + + def __call__(self): + self._key, val = self.__class__._generate_one(self._key, None) + return val + + def __eq__(self, other): + return ( + self is other + or ( + isinstance(other, self.__class__) + and jnp.array_equal(jrandom.key_data(self._key), jrandom.key_data(other._key)) + ) + ) + + def __ne__(self, other): + return not self.__eq__(other) + + __hash__ = None + + def serialize(self): + return repr(ensure_hashable(jrandom.key_data(self._key))) + + def __repr__(self): + return "galsim.%s(%r) " % ( + self.__class__.__name__, + ensure_hashable(jrandom.key_data(self._key)), + ) + + def __str__(self): + return self.__repr__() + + def tree_flatten(self): + """This function flattens the PRNG into a list of children + nodes that will be traced by JAX and auxiliary static data.""" + # Define the children nodes of the PyTree that need tracing + children = (jrandom.key_data(self._key),) + # Define auxiliary static data that doesn’t need to be traced + aux_data = {} + return (children, aux_data) + + @classmethod + def tree_unflatten(cls, aux_data, children): + """Recreates an instance of the class from flatten representation""" + return cls(children[0]) + + +@_wraps( + _galsim.UniformDeviate, + lax_description=LAX_FUNCTIONAL_RNG, +) +@register_pytree_node_class +class UniformDeviate(BaseDeviate): + def _generate(key, array): + # we do it this way so that the RNG appears to have a fixed state that is advanced per value drawn + key, res = jax.lax.scan(UniformDeviate._generate_one, key, None, length=array.ravel().shape[0]) + return key, res.reshape(array.shape) + + @jax.jit + def _generate_one(key, x): + _key, subkey = jrandom.split(key) + return _key, jrandom.uniform(subkey, dtype=float) + + +# class GaussianDeviate(BaseDeviate): +# """Pseudo-random number generator with Gaussian distribution. + +# See http://en.wikipedia.org/wiki/Gaussian_distribution for further details. + +# Successive calls to ``g()`` generate pseudo-random values distributed according to a Gaussian +# distribution with the provided ``mean``, ``sigma``:: + +# >>> g = galsim.GaussianDeviate(31415926) +# >>> g() +# 0.5533754000847082 +# >>> g() +# 1.0218588970190354 + +# Parameters: +# seed: Something that can seed a `BaseDeviate`: an integer seed or another +# `BaseDeviate`. Using 0 means to generate a seed from the system. +# [default: None] +# mean: Mean of Gaussian distribution. [default: 0.] +# sigma: Sigma of Gaussian distribution. [default: 1.; Must be > 0] +# """ +# def __init__(self, seed=None, mean=0., sigma=1.): +# if sigma < 0.: +# raise GalSimRangeError("GaussianDeviate sigma must be > 0.", sigma, 0.) +# self._rng_type = _galsim.GaussianDeviateImpl +# self._rng_args = (float(mean), float(sigma)) +# self.reset(seed) + +# @property +# def mean(self): +# """The mean of the Gaussian distribution. +# """ +# return self._rng_args[0] + +# @property +# def sigma(self): +# """The sigma of the Gaussian distribution. +# """ +# return self._rng_args[1] + +# @property +# def generates_in_pairs(self): +# return True + +# def __call__(self): +# """Draw a new random number from the distribution. + +# Returns a Gaussian deviate with the given mean and sigma. +# """ +# return self._rng.generate1() + +# def generate_from_variance(self, array): +# """Generate many Gaussian deviate values using the existing array values as the +# variance for each. +# """ +# array_1d = np.ascontiguousarray(array.ravel(), dtype=float) +# #assert(array_1d.strides[0] == array_1d.itemsize) +# _a = array_1d.__array_interface__['data'][0] +# self._rng.generate_from_variance(len(array_1d), _a) +# if array_1d.data != array.data: +# # array_1d is not a view into the original array. Need to copy back. +# np.copyto(array, array_1d.reshape(array.shape), casting='unsafe') + +# def __repr__(self): +# return 'galsim.GaussianDeviate(seed=%r, mean=%r, sigma=%r)'%( +# self._seed_repr(), self.mean, self.sigma) +# def __str__(self): +# return 'galsim.GaussianDeviate(mean=%r, sigma=%r)'%(self.mean, self.sigma) + + +# class BinomialDeviate(BaseDeviate): +# """Pseudo-random Binomial deviate for ``N`` trials each of probability ``p``. + +# ``N`` is number of 'coin flips,' ``p`` is probability of 'heads,' and each call returns an +# integer value where 0 <= value <= N gives the number of heads. See +# http://en.wikipedia.org/wiki/Binomial_distribution for more information. + +# Successive calls to ``b()`` generate pseudo-random integer values distributed according to a +# binomial distribution with the provided ``N``, ``p``:: + +# >>> b = galsim.BinomialDeviate(31415926, N=10, p=0.3) +# >>> b() +# 2 +# >>> b() +# 3 + +# Parameters: +# seed: Something that can seed a `BaseDeviate`: an integer seed or another +# `BaseDeviate`. Using 0 means to generate a seed from the system. +# [default: None] +# N: The number of 'coin flips' per trial. [default: 1; Must be > 0] +# p: The probability of success per coin flip. [default: 0.5; Must be > 0] +# """ +# def __init__(self, seed=None, N=1, p=0.5): +# self._rng_type = _galsim.BinomialDeviateImpl +# self._rng_args = (int(N), float(p)) +# self.reset(seed) + +# @property +# def n(self): +# """The number of 'coin flips'. +# """ +# return self._rng_args[0] + +# @property +# def p(self): +# """The probability of success per 'coin flip'. +# """ +# return self._rng_args[1] + +# def __call__(self): +# """Draw a new random number from the distribution. + +# Returns a Binomial deviate with the given n and p. +# """ +# return self._rng.generate1() + +# def __repr__(self): +# return 'galsim.BinomialDeviate(seed=%r, N=%r, p=%r)'%(self._seed_repr(), self.n, self.p) +# def __str__(self): +# return 'galsim.BinomialDeviate(N=%r, p=%r)'%(self.n, self.p) + + +# class PoissonDeviate(BaseDeviate): +# """Pseudo-random Poisson deviate with specified ``mean``. + +# The input ``mean`` sets the mean and variance of the Poisson deviate. An integer deviate with +# this distribution is returned after each call. +# See http://en.wikipedia.org/wiki/Poisson_distribution for more details. + +# Successive calls to ``p()`` generate pseudo-random integer values distributed according to a +# Poisson distribution with the specified ``mean``:: + +# >>> p = galsim.PoissonDeviate(31415926, mean=100) +# >>> p() +# 94 +# >>> p() +# 106 + +# Parameters: +# seed: Something that can seed a `BaseDeviate`: an integer seed or another +# `BaseDeviate`. Using 0 means to generate a seed from the system. +# [default: None] +# mean: Mean of the distribution. [default: 1; Must be > 0] +# """ +# def __init__(self, seed=None, mean=1.): +# if mean < 0: +# raise GalSimValueError("PoissonDeviate is only defined for mean >= 0.", mean) +# self._rng_type = _galsim.PoissonDeviateImpl +# self._rng_args = (float(mean),) +# self.reset(seed) + +# @property +# def mean(self): +# """The mean of the distribution. +# """ +# return self._rng_args[0] + +# @property +# def has_reliable_discard(self): +# return False + +# def __call__(self): +# """Draw a new random number from the distribution. + +# Returns a Poisson deviate with the given mean. +# """ +# return self._rng.generate1() + +# def generate_from_expectation(self, array): +# """Generate many Poisson deviate values using the existing array values as the +# expectation value (aka mean) for each. +# """ +# if np.any(array < 0): +# raise GalSimValueError("Expectation array may not have values < 0.", array) +# array_1d = np.ascontiguousarray(array.ravel(), dtype=float) +# #assert(array_1d.strides[0] == array_1d.itemsize) +# _a = array_1d.__array_interface__['data'][0] +# self._rng.generate_from_expectation(len(array_1d), _a) +# if array_1d.data != array.data: +# # array_1d is not a view into the original array. Need to copy back. +# np.copyto(array, array_1d.reshape(array.shape), casting='unsafe') + +# def __repr__(self): +# return 'galsim.PoissonDeviate(seed=%r, mean=%r)'%(self._seed_repr(), self.mean) +# def __str__(self): +# return 'galsim.PoissonDeviate(mean=%r)'%(self.mean) + + +# class WeibullDeviate(BaseDeviate): +# """Pseudo-random Weibull-distributed deviate for shape parameter ``a`` and scale parameter ``b``. + +# The Weibull distribution is related to a number of other probability distributions; in +# particular, it interpolates between the exponential distribution (a=1) and the Rayleigh +# distribution (a=2). +# See http://en.wikipedia.org/wiki/Weibull_distribution (a=k and b=lambda in the notation adopted +# in the Wikipedia article) for more details. The Weibull distribution is real valued and +# produces deviates >= 0. + +# Successive calls to ``w()`` generate pseudo-random values distributed according to a Weibull +# distribution with the specified shape and scale parameters ``a`` and ``b``:: + +# >>> w = galsim.WeibullDeviate(31415926, a=1.3, b=4) +# >>> w() +# 1.1038481241018219 +# >>> w() +# 2.957052966368049 + +# Parameters: +# seed: Something that can seed a `BaseDeviate`: an integer seed or another +# `BaseDeviate`. Using 0 means to generate a seed from the system. +# [default: None] +# a: Shape parameter of the distribution. [default: 1; Must be > 0] +# b: Scale parameter of the distribution. [default: 1; Must be > 0] +# """ +# def __init__(self, seed=None, a=1., b=1.): +# self._rng_type = _galsim.WeibullDeviateImpl +# self._rng_args = (float(a), float(b)) +# self.reset(seed) + +# @property +# def a(self): +# """The shape parameter, a. +# """ +# return self._rng_args[0] + +# @property +# def b(self): +# """The scale parameter, b. +# """ +# return self._rng_args[1] + +# def __call__(self): +# """Draw a new random number from the distribution. + +# Returns a Weibull-distributed deviate with the given shape parameters a and b. +# """ +# return self._rng.generate1() + +# def __repr__(self): +# return 'galsim.WeibullDeviate(seed=%r, a=%r, b=%r)'%(self._seed_repr(), self.a, self.b) +# def __str__(self): +# return 'galsim.WeibullDeviate(a=%r, b=%r)'%(self.a, self.b) + + +# class GammaDeviate(BaseDeviate): +# """A Gamma-distributed deviate with shape parameter ``k`` and scale parameter ``theta``. +# See http://en.wikipedia.org/wiki/Gamma_distribution. +# (Note: we use the k, theta notation. If you prefer alpha, beta, use k=alpha, theta=1/beta.) +# The Gamma distribution is a real valued distribution producing deviates >= 0. + +# Successive calls to ``g()`` generate pseudo-random values distributed according to a gamma +# distribution with the specified shape and scale parameters ``k`` and ``theta``:: + +# >>> gam = galsim.GammaDeviate(31415926, k=1, theta=2) +# >>> gam() +# 0.37508882726316 +# >>> gam() +# 1.3504199388358704 + +# Parameters: +# seed: Something that can seed a `BaseDeviate`: an integer seed or another +# `BaseDeviate`. Using 0 means to generate a seed from the system. +# [default: None] +# k: Shape parameter of the distribution. [default: 1; Must be > 0] +# theta: Scale parameter of the distribution. [default: 1; Must be > 0] +# """ +# def __init__(self, seed=None, k=1., theta=1.): +# self._rng_type = _galsim.GammaDeviateImpl +# self._rng_args = (float(k), float(theta)) +# self.reset(seed) + +# @property +# def k(self): +# """The shape parameter, k. +# """ +# return self._rng_args[0] + +# @property +# def theta(self): +# """The scale parameter, theta. +# """ +# return self._rng_args[1] + +# @property +# def has_reliable_discard(self): +# return False + +# def __call__(self): +# """Draw a new random number from the distribution. + +# Returns a Gamma-distributed deviate with the given k and theta. +# """ +# return self._rng.generate1() + +# def __repr__(self): +# return 'galsim.GammaDeviate(seed=%r, k=%r, theta=%r)'%( +# self._seed_repr(), self.k, self.theta) +# def __str__(self): +# return 'galsim.GammaDeviate(k=%r, theta=%r)'%(self.k, self.theta) + + +# class Chi2Deviate(BaseDeviate): +# """Pseudo-random Chi^2-distributed deviate for degrees-of-freedom parameter ``n``. + +# See http://en.wikipedia.org/wiki/Chi-squared_distribution (note that k=n in the notation +# adopted in the Boost.Random routine called by this class). The Chi^2 distribution is a +# real-valued distribution producing deviates >= 0. + +# Successive calls to ``chi2()`` generate pseudo-random values distributed according to a +# chi-square distribution with the specified degrees of freedom, ``n``:: + +# >>> chi2 = galsim.Chi2Deviate(31415926, n=7) +# >>> chi2() +# 7.9182211987712385 +# >>> chi2() +# 6.644121724269535 + +# Parameters: +# seed: Something that can seed a `BaseDeviate`: an integer seed or another +# `BaseDeviate`. Using 0 means to generate a seed from the system. +# [default: None] +# n: Number of degrees of freedom for the output distribution. [default: 1; +# Must be > 0] +# """ +# def __init__(self, seed=None, n=1.): +# self._rng_type = _galsim.Chi2DeviateImpl +# self._rng_args = (float(n),) +# self.reset(seed) + +# @property +# def n(self): +# """The number of degrees of freedom. +# """ +# return self._rng_args[0] + +# @property +# def has_reliable_discard(self): +# return False + +# def __call__(self): +# """Draw a new random number from the distribution. + +# Returns a Chi2-distributed deviate with the given number of degrees of freedom. +# """ +# return self._rng.generate1() + +# def __repr__(self): +# return 'galsim.Chi2Deviate(seed=%r, n=%r)'%(self._seed_repr(), self.n) +# def __str__(self): +# return 'galsim.Chi2Deviate(n=%r)'%(self.n) + + +# class DistDeviate(BaseDeviate): +# """A class to draw random numbers from a user-defined probability distribution. + +# DistDeviate is a `BaseDeviate` class that can be used to draw from an arbitrary probability +# distribution. The probability distribution passed to DistDeviate can be given one of three +# ways: as the name of a file containing a 2d ASCII array of x and P(x), as a `LookupTable` +# mapping x to P(x), or as a callable function. + +# Once given a probability, DistDeviate creates a table of the cumulative probability and draws +# from it using a `UniformDeviate`. The precision of its outputs can be controlled with the +# keyword ``npoints``, which sets the number of points DistDeviate creates for its internal table +# of CDF(x). To prevent errors due to non-monotonicity, the interpolant for this internal table +# is always linear. + +# Two keywords, ``x_min`` and ``x_max``, define the support of the function. They must be passed +# if a callable function is given to DistDeviate, unless the function is a `LookupTable`, which +# has its own defined endpoints. If a filename or `LookupTable` is passed to DistDeviate, the +# use of ``x_min`` or ``x_max`` will result in an error. + +# If given a table in a file, DistDeviate will construct an interpolated `LookupTable` to obtain +# more finely gridded probabilities for generating the cumulative probability table. The default +# ``interpolant`` is linear, but any interpolant understood by `LookupTable` may be used. We +# caution against the use of splines because they can cause non-monotonic behavior. Passing the +# ``interpolant`` keyword next to anything but a table in a file will result in an error. + +# **Examples**: + +# Some sample initialization calls:: + +# >>> d = galsim.DistDeviate(function=f, x_min=x_min, x_max=x_max) + +# Initializes d to be a DistDeviate instance with a distribution given by the callable function +# ``f(x)`` from ``x=x_min`` to ``x=x_max`` and seeds the PRNG using current time:: + +# >>> d = galsim.DistDeviate(1062533, function=file_name, interpolant='floor') + +# Initializes d to be a DistDeviate instance with a distribution given by the data in file +# ``file_name``, which must be a 2-column ASCII table, and seeds the PRNG using the integer +# seed 1062533. It generates probabilities from ``file_name`` using the interpolant 'floor':: + +# >>> d = galsim.DistDeviate(rng, function=galsim.LookupTable(x,p)) + +# Initializes d to be a DistDeviate instance with a distribution given by P(x), defined as two +# arrays ``x`` and ``p`` which are used to make a callable `LookupTable`, and links the +# DistDeviate PRNG to the already-existing random number generator ``rng``. + +# Successive calls to ``d()`` generate pseudo-random values with the given probability +# distribution:: + +# >>> d = galsim.DistDeviate(31415926, function=lambda x: 1-abs(x), x_min=-1, x_max=1) +# >>> d() +# -0.4151921102709466 +# >>> d() +# -0.00909781188974034 + +# Parameters: +# seed: Something that can seed a `BaseDeviate`: an integer seed or another +# `BaseDeviate`. Using 0 means to generate a seed from the system. +# [default: None] +# function: A callable function giving a probability distribution or the name of a +# file containing a probability distribution as a 2-column ASCII table. +# [required] +# x_min: The minimum desired return value (required for non-`LookupTable` +# callable functions; will raise an error if not passed in that case, or if +# passed in any other case) [default: None] +# x_max: The maximum desired return value (required for non-`LookupTable` +# callable functions; will raise an error if not passed in that case, or if +# passed in any other case) [default: None] +# interpolant: Type of interpolation used for interpolating a file (causes an error if +# passed alongside a callable function). Options are given in the +# documentation for `LookupTable`. [default: 'linear'] +# npoints: Number of points DistDeviate should create for its internal interpolation +# tables. [default: 256, unless the function is a non-log `LookupTable`, in +# which case it uses the table's x values] +# """ +# def __init__(self, seed=None, function=None, x_min=None, +# x_max=None, interpolant=None, npoints=None): +# from .table import LookupTable +# from . import utilities +# from . import integ + +# # Set up the PRNG +# self._rng_type = _galsim.UniformDeviateImpl +# self._rng_args = () +# self.reset(seed) + +# # Basic input checking and setups +# if function is None: +# raise TypeError('You must pass a function to DistDeviate!') + +# self._interpolant = interpolant +# self._npoints = npoints +# self._xmin = x_min +# self._xmax = x_max + +# # Figure out if a string is a filename or something we should be using in an eval call +# if isinstance(function, str): +# self._function = function # Save the inputs to be used in repr +# import os.path +# if os.path.isfile(function): +# if interpolant is None: +# interpolant='linear' +# if x_min or x_max: +# raise GalSimIncompatibleValuesError( +# "Cannot pass x_min or x_max with a filename argument", +# function=function, x_min=x_min, x_max=x_max) +# function = LookupTable.from_file(function, interpolant=interpolant) +# x_min = function.x_min +# x_max = function.x_max +# else: +# try: +# function = utilities.math_eval('lambda x : ' + function) +# if x_min is not None: # is not None in case x_min=0. +# function(x_min) +# else: +# # Somebody would be silly to pass a string for evaluation without x_min, +# # but we'd like to throw reasonable errors in that case anyway +# function(0.6) # A value unlikely to be a singular point of a function +# except Exception as e: +# raise GalSimValueError( +# "String function must either be a valid filename or something that " +# "can eval to a function of x.\n" +# "Caught error: {0}".format(e), self._function) +# else: +# # Check that the function is actually a function +# if not hasattr(function, '__call__'): +# raise TypeError('function must be a callable function or a string') +# if interpolant: +# raise GalSimIncompatibleValuesError( +# "Cannot provide an interpolant with a callable function argument", +# interpolant=interpolant, function=function) +# if isinstance(function, LookupTable): +# if x_min or x_max: +# raise GalSimIncompatibleValuesError( +# "Cannot provide x_min or x_max with a LookupTable function", +# function=function, x_min=x_min, x_max=x_max) +# x_min = function.x_min +# x_max = function.x_max +# else: +# if x_min is None or x_max is None: +# raise GalSimIncompatibleValuesError( +# "Must provide x_min and x_max when function argument is a regular " +# "python callable function", +# function=function, x_min=x_min, x_max=x_max) + +# self._function = function # Save the inputs to be used in repr + +# # Compute the probability distribution function, pdf(x) +# if (npoints is None and isinstance(function, LookupTable) and +# not function.x_log and not function.f_log): +# xarray = np.array(function.x, dtype=float) +# pdf = np.array(function.f, dtype=float) +# # Set up pdf, so cumsum basically does a cumulative trapz integral +# # On Python 3.4, doing pdf[1:] += pdf[:-1] the last value gets messed up. +# # Writing it this way works. (Maybe slightly slower though, so if we stop +# # supporting python 3.4, consider switching to the += version.) +# pdf[1:] = pdf[1:] + pdf[:-1] +# pdf[1:] *= np.diff(xarray) +# pdf[0] = 0. +# else: +# if npoints is None: npoints = 256 +# xarray = x_min+(1.*x_max-x_min)/(npoints-1)*np.array(range(npoints),float) +# # Integrate over the range of x in case the function is doing something weird here. +# pdf = [0.] + [integ.int1d(function, xarray[i], xarray[i+1]) +# for i in range(npoints - 1)] +# pdf = np.array(pdf) + +# # Check that the probability is nonnegative +# if not np.all(pdf >= 0.): +# raise GalSimValueError('Negative probability found in DistDeviate.',function) + +# # Compute the cumulative distribution function = int(pdf(x),x) +# cdf = np.cumsum(pdf) + +# # Quietly renormalize the probability if it wasn't already normalized +# totalprobability = cdf[-1] +# cdf /= totalprobability + +# self._inverse_cdf = LookupTable(cdf, xarray, interpolant='linear') +# self.x_min = x_min +# self.x_max = x_max + +# def val(self, p): +# r""" +# Return the value :math:`x` of the input function to `DistDeviate` such that ``p`` = +# :math:`F(x)`, where :math:`F` is the cumulattive probability distribution function: + +# .. math:: + +# F(x) = \int_{-\infty}^x \mathrm{pdf}(t) dt + +# This function is typically called by `__call__`, which generates a random p +# between 0 and 1 and calls ``self.val(p)``. + +# Parameters: +# p: The desired cumulative probabilty p. + +# Returns: +# the corresponding x such that :math:`p = F(x)`. +# """ +# if p<0 or p>1: +# raise GalSimRangeError('Invalid cumulative probability for DistDeviate', p, 0., 1.) +# return self._inverse_cdf(p) + +# def __call__(self): +# """Draw a new random number from the distribution. +# """ +# return self._inverse_cdf(self._rng.generate1()) + +# def generate(self, array): +# """Generate many pseudo-random values, filling in the values of a numpy array. +# """ +# p = np.empty_like(array) +# BaseDeviate.generate(self, p) # Fill with unform deviate values +# np.copyto(array, self._inverse_cdf(p)) # Convert from p -> x + +# def add_generate(self, array): +# """Generate many pseudo-random values, adding them to the values of a numpy array. +# """ +# p = np.empty_like(array) +# BaseDeviate.generate(self, p) +# array += self._inverse_cdf(p) + +# def __repr__(self): +# return ('galsim.DistDeviate(seed=%r, function=%r, x_min=%r, x_max=%r, interpolant=%r, ' +# 'npoints=%r)')%(self._seed_repr(), self._function, self._xmin, self._xmax, +# self._interpolant, self._npoints) +# def __str__(self): +# return 'galsim.DistDeviate(function="%s", x_min=%s, x_max=%s, interpolant=%s, npoints=%s)'%( +# self._function, self._xmin, self._xmax, self._interpolant, self._npoints) + +# def __eq__(self, other): +# return (self is other or +# (isinstance(other, DistDeviate) and +# self.serialize() == other.serialize() and +# self._function == other._function and +# self._xmin == other._xmin and +# self._xmax == other._xmax and +# self._interpolant == other._interpolant and +# self._npoints == other._npoints)) + + +# class GalSimBitGenerator(np.random.BitGenerator): +# """A numpy.random.BitGenerator that uses the GalSim C++-layer random number generator +# for the random bit generation. + +# Parameters: +# rng: The galsim.BaseDeviate object to use for the underlying bit generation. +# """ +# def __init__(self, rng): +# super().__init__(0) +# self.rng = rng +# self.rng._rng.setup_bitgen(self.capsule) + +# def permute(rng, *args): +# """Randomly permute one or more lists. + +# If more than one list is given, then all lists will have the same random permutation +# applied to it. + +# Parameters: +# rng: The random number generator to use. (This will be converted to a `UniformDeviate`.) +# args: Any number of lists to be permuted. +# """ +# from .random import UniformDeviate +# ud = UniformDeviate(rng) +# if len(args) == 0: +# raise TypeError("permute called with no lists to permute") + +# # We use an algorithm called the Knuth shuffle, which is based on the Fisher-Yates shuffle. +# # See http://en.wikipedia.org/wiki/Fisher-Yates_shuffle for more information. +# n = len(args[0]) +# for i in range(n-1,1,-1): +# j = int((i+1) * ud()) +# if j == i+1: j = i # I'm not sure if this is possible, but just in case... +# for lst in args: +# lst[i], lst[j] = lst[j], lst[i] diff --git a/jax_galsim/transform.py b/jax_galsim/transform.py index 4dc2eac8..34bd65c3 100644 --- a/jax_galsim/transform.py +++ b/jax_galsim/transform.py @@ -1,4 +1,5 @@ import galsim as _galsim +import jax import jax.numpy as jnp from jax._src.numpy.util import _wraps from jax.tree_util import register_pytree_node_class @@ -47,27 +48,29 @@ def __init__( gsparams=None, propagate_gsparams=True, ): - self._offset = PositionD(offset) - self._flux_ratio = flux_ratio self._gsparams = GSParams.check(gsparams, obj.gsparams) self._propagate_gsparams = propagate_gsparams if self._propagate_gsparams: obj = obj.withGSParams(self._gsparams) self._params = { - "obj": obj, - "jac": jac, - "offset": self._offset, - "flux_ratio": self._flux_ratio, + "jac": jax.lax.cond( + jac is not None, + lambda jac: jnp.broadcast_to(jnp.array(jac, dtype=float).ravel(), (4,)), + lambda jax: jnp.array([1.0, 0.0, 0.0, 1.0]), + jac, + ), + "offset": PositionD(offset), + "flux_ratio": flux_ratio, } if isinstance(obj, Transformation): # Combine the two affine transformations into one. dx, dy = self._fwd(obj.offset.x, obj.offset.y) - self._offset.x += dx - self._offset.y += dy + self._params["offset"].x += dx + self._params["offset"].y += dy self._params["jac"] = self._jac.dot(obj.jac) - self._flux_ratio *= obj._flux_ratio + self._params["flux_ratio"] *= obj._params["flux_ratio"] self._original = obj.original else: self._original = obj @@ -89,17 +92,21 @@ def jac(self): @property def offset(self): """The offset of the transformation.""" - return self._offset + return self._params["offset"] @property def flux_ratio(self): """The flux ratio of the transformation.""" - return self._flux_ratio + return self._params["flux_ratio"] @property def _flux(self): return self._flux_scaling * self._original.flux + @property + def _offset(self): + return self._params["offset"] + def withGSParams(self, gsparams=None, **kwargs): """Create a version of the current object with the given gsparams @@ -110,13 +117,14 @@ def withGSParams(self, gsparams=None, **kwargs): """ if gsparams == self.gsparams: return self - from copy import copy - ret = copy(self) - ret._gsparams = GSParams.check(gsparams, self.gsparams, **kwargs) + chld, aux = self.tree_flatten() + aux["gsparams"] = GSParams.check(gsparams, self.gsparams, **kwargs) if self._propagate_gsparams: - ret._original = self._original.withGSParams(ret._gsparams) - return ret + new_obj = chld[0].withGSParams(aux["gsparams"]) + chld = (new_obj,) + chld[1:] + + return self.tree_unflatten(aux, chld) def __eq__(self, other): return self is other or ( @@ -149,7 +157,7 @@ def __repr__(self): "propagate_gsparams=%r)" ) % ( self.original, - ensure_hashable(self._jac), + ensure_hashable(self._jac.ravel()), self.offset, ensure_hashable(self.flux_ratio), self.gsparams, @@ -221,11 +229,11 @@ def _invjac(self): # than flux_ratio, which is really an amplitude scaling. @property def _amp_scaling(self): - return self._flux_ratio + return self._params["flux_ratio"] @property def _flux_scaling(self): - return jnp.abs(self._det) * self._flux_ratio + return jnp.abs(self._det) * self._params["flux_ratio"] def _fwd(self, x, y): res = jnp.dot(self._jac, jnp.array([x, y])) @@ -240,8 +248,8 @@ def _inv(self, x, y): return res[0], res[1] def _kfactor(self, kx, ky): - kx *= -1j * self._offset.x - ky *= -1j * self._offset.y + kx *= -1j * self.offset.x + ky *= -1j * self.offset.y kx += ky return self._flux_scaling * jnp.exp(kx) @@ -269,7 +277,7 @@ def _stepk(self): # stepk = Pi/R # R <- R + |shift| # stepk <- Pi/(Pi/stepk + |shift|) - dr = jnp.hypot(self._offset.x, self._offset.y) + dr = jnp.hypot(self.offset.x, self.offset.y) stepk = jnp.pi / (jnp.pi / stepk + dr) return stepk @@ -283,7 +291,7 @@ def _is_axisymmetric(self): self._original.is_axisymmetric and self._jac[0, 0] == self._jac[1, 1] and self._jac[0, 1] == -self._jac[1, 0] - and self._offset == PositionD(0.0, 0.0) + and self.offset == PositionD(0.0, 0.0) ) @property @@ -314,7 +322,7 @@ def _max_sb(self): return self._amp_scaling * self._original.max_sb def _xValue(self, pos): - pos -= self._offset + pos -= self.offset inv_pos = PositionD(self._inv(pos.x, pos.y)) return self._original._xValue(inv_pos) * self._amp_scaling @@ -360,10 +368,10 @@ def _drawKImage(self, image, jac=None): return image def tree_flatten(self): - """This function flattens the GSObject into a list of children + """This function flattens the Transform into a list of children nodes that will be traced by JAX and auxiliary static data.""" # Define the children nodes of the PyTree that need tracing - children = (self.params,) + children = (self._original, self._params) # Define auxiliary static data that doesn’t need to be traced aux_data = { "gsparams": self.gsparams, @@ -371,6 +379,11 @@ def tree_flatten(self): } return (children, aux_data) + @classmethod + def tree_unflatten(cls, aux_data, children): + """Recreates an instance of the class from flatten representation""" + return cls(children[0], **(children[1]), **aux_data) + def _Transform( obj, diff --git a/jax_galsim/utilities.py b/jax_galsim/utilities.py index c8a8f155..b9401fb8 100644 --- a/jax_galsim/utilities.py +++ b/jax_galsim/utilities.py @@ -4,6 +4,8 @@ from jax_galsim.position import PositionD, PositionI +printoptions = _galsim.utilities.printoptions + @_wraps(_galsim.utilities.parse_pos_args) def parse_pos_args(args, kwargs, name1, name2, integer=False, others=[]): diff --git a/jax_galsim/wcs.py b/jax_galsim/wcs.py index cb7320fa..3cebeeb0 100644 --- a/jax_galsim/wcs.py +++ b/jax_galsim/wcs.py @@ -3,7 +3,7 @@ from jax._src.numpy.util import _wraps from jax.tree_util import register_pytree_node_class -from jax_galsim.core.utils import ensure_hashable +from jax_galsim.core.utils import convert_to_float, ensure_hashable from jax_galsim.gsobject import GSObject from jax_galsim.position import Position, PositionD, PositionI from jax_galsim.shear import Shear @@ -18,6 +18,8 @@ def toWorld(self, *args, **kwargs): if len(args) == 1: if isinstance(args[0], GSObject): return self.profileToWorld(*args, **kwargs) + elif isinstance(args[0], Shear): + return self.shearToWorld(*args, **kwargs) else: return self.posToWorld(*args, **kwargs) elif len(args) == 2: @@ -52,11 +54,19 @@ def profileToWorld( image_profile, flux_ratio, PositionD(offset) ) + @_wraps(_galsim.BaseWCS.shearToWorld) + def shearToWorld(self, image_shear, image_pos=None, world_pos=None, color=None): + if color is None: + color = self._color + return self.local(image_pos, world_pos, color=color)._shearToWorld(image_shear) + @_wraps(_galsim.BaseWCS.toImage) def toImage(self, *args, **kwargs): if len(args) == 1: if isinstance(args[0], GSObject): return self.profileToImage(*args, **kwargs) + elif isinstance(args[0], Shear): + return self.shearToImage(*args, **kwargs) else: return self.posToImage(*args, **kwargs) elif len(args) == 2: @@ -94,6 +104,12 @@ def profileToImage( world_profile, flux_ratio, PositionD(offset) ) + @_wraps(_galsim.BaseWCS.shearToImage) + def shearToImage(self, world_shear, image_pos=None, world_pos=None, color=None): + if color is None: + color = self._color + return self.local(image_pos, world_pos, color=color)._shearToImage(world_shear) + @_wraps(_galsim.BaseWCS.local) def local(self, image_pos=None, world_pos=None, color=None): if color is None: @@ -622,6 +638,13 @@ def _profileToImage(self, world_profile, flux_ratio, offset): offset=offset, ) + def _shearToWorld(self, image_shear): + # These are trivial for PixelScale. + return image_shear + + def _shearToImage(self, world_shear): + return world_shear + def _pixelArea(self): return self._scale**2 @@ -728,6 +751,13 @@ def _profileToImage(self, world_profile, flux_ratio, offset): * flux_ratio ) + def _shearToWorld(self, image_shear): + # This isn't worth customizing. Just use the jacobian. + return self._toJacobian()._shearToWorld(image_shear) + + def _shearToImage(self, world_shear): + return self._toJacobian()._shearToImage(world_shear) + def _pixelArea(self): return self._scale**2 @@ -752,6 +782,13 @@ def _toJacobian(self): def _newOrigin(self, origin, world_origin): return OffsetShearWCS(self._scale, self._shear, origin, world_origin) + def _writeHeader(self, header, bounds): + header["GS_WCS"] = ("ShearWCS", "GalSim WCS name") + header["GS_SCALE"] = (self.scale, "GalSim image scale") + header["GS_G1"] = (self.shear.g1, "GalSim image shear g1") + header["GS_G2"] = (self.shear.g2, "GalSim image shear g2") + return self.affine()._writeLinearWCS(header, bounds) + def copy(self): return ShearWCS(self._scale, self._shear) @@ -846,6 +883,24 @@ def _profileToImage(self, world_profile, flux_ratio, offset): offset=offset, ) + def _shearToWorld(self, image_shear): + # Code from https://github.com/rmjarvis/DESWL/blob/y3a1-v23/psf/run_piff.py#L691 + e1 = image_shear.e1 + e2 = image_shear.e2 + + M = jnp.array([[1 + e1, e2], [e2, 1 - e1]]) + J = self.getMatrix() + M = J.dot(M).dot(J.T) + + e1 = (M[0, 0] - M[1, 1]) / (M[0, 0] + M[1, 1]) + e2 = (2.0 * M[0, 1]) / (M[0, 0] + M[1, 1]) + + return Shear(e1=e1, e2=e2) + + def _shearToImage(self, world_shear): + # Same as above but inverse J matrix. + return self._inverse()._shearToWorld(world_shear) + def _pixelArea(self): return abs(self._det) @@ -1096,6 +1151,17 @@ def world_origin(self): def _newOrigin(self, origin, world_origin): return OffsetShearWCS(self.scale, self.shear, origin, world_origin) + def _writeHeader(self, header, bounds): + header["GS_WCS"] = ("OffsetShearWCS", "GalSim WCS name") + header["GS_SCALE"] = (self.scale, "GalSim image scale") + header["GS_G1"] = (self.shear.g1, "GalSim image shear g1") + header["GS_G2"] = (self.shear.g2, "GalSim image shear g2") + header["GS_X0"] = (self.origin.x, "GalSim image origin x coordinate") + header["GS_Y0"] = (self.origin.y, "GalSim image origin y coordinate") + header["GS_U0"] = (self.world_origin.x, "GalSim world origin u coordinate") + header["GS_V0"] = (self.world_origin.y, "GalSim world origin v coordinate") + return self.affine()._writeLinearWCS(header, bounds) + def copy(self): return OffsetShearWCS(self.scale, self.shear, self.origin, self.world_origin) @@ -1173,14 +1239,26 @@ def _writeHeader(self, header, bounds): def _writeLinearWCS(self, header, bounds): header["CTYPE1"] = ("LINEAR", "name of the world coordinate axis") header["CTYPE2"] = ("LINEAR", "name of the world coordinate axis") - header["CRVAL1"] = (self.u0, "world coordinate at reference pixel = u0") - header["CRVAL2"] = (self.v0, "world coordinate at reference pixel = v0") - header["CRPIX1"] = (self.x0, "image coordinate of reference pixel = x0") - header["CRPIX2"] = (self.y0, "image coordinate of reference pixel = y0") - header["CD1_1"] = (self.dudx, "CD1_1 = dudx") - header["CD1_2"] = (self.dudy, "CD1_2 = dudy") - header["CD2_1"] = (self.dvdx, "CD2_1 = dvdx") - header["CD2_2"] = (self.dvdy, "CD2_2 = dvdy") + header["CRVAL1"] = ( + convert_to_float(self.u0), + "world coordinate at reference pixel = u0", + ) + header["CRVAL2"] = ( + convert_to_float(self.v0), + "world coordinate at reference pixel = v0", + ) + header["CRPIX1"] = ( + convert_to_float(self.x0), + "image coordinate of reference pixel = x0", + ) + header["CRPIX2"] = ( + convert_to_float(self.y0), + "image coordinate of reference pixel = y0", + ) + header["CD1_1"] = (convert_to_float(self.dudx), "CD1_1 = dudx") + header["CD1_2"] = (convert_to_float(self.dudy), "CD1_2 = dudy") + header["CD2_1"] = (convert_to_float(self.dvdx), "CD2_1 = dvdx") + header["CD2_2"] = (convert_to_float(self.dvdy), "CD2_2 = dvdy") return header @staticmethod diff --git a/tests/GalSim b/tests/GalSim index 66092bdf..0281f764 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 66092bdf7215983bab4d2d953a700eb8a0ddcbe4 +Subproject commit 0281f764f2f8ad3af45bcee9d171d2b48fd79a20 diff --git a/tests/conftest.py b/tests/conftest.py index 8095105e..dc4ceff9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -68,6 +68,19 @@ def _infile(val, fname): return False +def _convert_galsim_to_jax_galsim(obj): + import galsim as _galsim # noqa: F401 + from numpy import array # noqa: F401 + + import jax_galsim as galsim # noqa: F401 + + if isinstance(obj, _galsim.GSObject): + ret_obj = eval(repr(obj)) + return ret_obj + else: + return obj + + def pytest_pycollect_makemodule(module_path, path, parent): """This hook is tasked with overriding the galsim import at the top of each test file. Replaces it by jax-galsim. @@ -111,6 +124,19 @@ def pytest_pycollect_makemodule(module_path, path, parent): v.__globals__["coord"] = __import__("jax_galsim") v.__globals__["galsim"] = __import__("jax_galsim") + # the galsim WCS tests have some items that are galsim objects that need conversions + # to jax_galsim objects + if module.name.endswith("tests/GalSim/tests/test_wcs.py"): + for k, v in module.obj.__dict__.items(): + if isinstance(v, __import__("galsim").GSObject): + module.obj.__dict__[k] = _convert_galsim_to_jax_galsim(v) + elif isinstance(v, list): + module.obj.__dict__[k] = [ + _convert_galsim_to_jax_galsim(obj) for obj in v + ] + + module.obj._convert_galsim_to_jax_galsim = _convert_galsim_to_jax_galsim + return module diff --git a/tests/galsim_tests_config.yaml b/tests/galsim_tests_config.yaml index dd761691..e9aa68ef 100644 --- a/tests/galsim_tests_config.yaml +++ b/tests/galsim_tests_config.yaml @@ -24,7 +24,6 @@ enabled_tests: # in jax_galsim allowed_failures: - "NotImplementedError" - - "module 'jax_galsim' has no attribute 'BaseDeviate" - "module 'jax_galsim' has no attribute 'Airy'" - "module 'jax_galsim' has no attribute 'Kolmogorov'" - "module 'jax_galsim' has no attribute 'Sersic'" @@ -35,7 +34,6 @@ allowed_failures: - "module 'jax_galsim' has no attribute 'UVFunction'" - "module 'jax_galsim' has no attribute 'FitsWCS'" - "module 'jax_galsim' has no attribute 'FitsHeader'" - - "module 'jax_galsim' has no attribute 'UniformDeviate'" - "module 'jax_galsim' has no attribute 'AstropyWCS'" - "module 'jax_galsim' has no attribute 'GSFitsWCS'" - "module 'jax_galsim' has no attribute 'WcsToolsWCS'" @@ -58,3 +56,8 @@ allowed_failures: - "TypeError not raised by __mul__" - "ValueError not raised by CelestialCoord" - "module 'jax_galsim' has no attribute 'TanWCS'" + - "has no attribute 'drawPhot'" + - "'jax_galsim.utilities' has no attribute 'horner'" + - "module 'jax_galsim.utilities' has no attribute 'horner2d'" + - "'Image' object has no attribute 'FindAdaptiveMom'" + - "module 'jax_galsim' has no attribute 'BaseCorrelatedNoise'" diff --git a/tests/jax/galsim/test_random_jax.py b/tests/jax/galsim/test_random_jax.py new file mode 100644 index 00000000..ed057b51 --- /dev/null +++ b/tests/jax/galsim/test_random_jax.py @@ -0,0 +1,2019 @@ +import numpy as np +import os +import galsim +from galsim.utilities import single_threaded +from galsim_test_helpers import timer, do_pickle # noqa: E402 + +precision = 10 +# decimal point at which agreement is required for all double precision tests + +precisionD = precision +precisionF = 5 # precision=10 does not make sense at single precision +precisionS = 1 # "precision" also a silly concept for ints, but allows all 4 tests to run in one go +precisionI = 1 + +# The number of values to generate when checking the mean and variance calculations. +# This is currently low enough to not dominate the time of the unit tests, but when changing +# something, it may be useful to add a couple zeros while testing. +nvals = 100000 + +testseed = 1000 # seed used for UniformDeviate for all tests +# Warning! If you change testseed, then all of the *Result variables below must change as well. + +# the right answer for the first three uniform deviates produced from testseed +uResult = (0.0160653916, 0.228817832, 0.1609966951) + +# mean, sigma to use for Gaussian tests +gMean = 4.7 +gSigma = 3.2 +# the right answer for the first three Gaussian deviates produced from testseed +gResult = (6.3344979808161215, 6.2082355273987861, -0.069894693358302007) + +# N, p to use for binomial tests +bN = 10 +bp = 0.7 +# the right answer for the first three binomial deviates produced from testseed +bResult = (9, 8, 7) + +# mean to use for Poisson tests +pMean = 7 +# the right answer for the first three Poisson deviates produced from testseed +pResult = (4, 5, 6) + +# a & b to use for Weibull tests +wA = 4.0 +wB = 9.0 +# Tabulated results for Weibull +wResult = (5.3648053017485591, 6.3093033550873878, 7.7982696798921074) + +# k & theta to use for Gamma tests +gammaK = 1.5 +gammaTheta = 4.5 +# Tabulated results for Gamma +gammaResult = (4.7375613139927157, 15.272973580418618, 21.485016362839747) + +# n to use for Chi2 tests +chi2N = 30 +# Tabulated results for Chi2 +chi2Result = (32.209933900954049, 50.040002656028513, 24.301442486313896) + +# function and min&max to use for DistDeviate function call tests +dmin = 0.0 +dmax = 2.0 + + +def dfunction(x): + return x * x + + +# Tabulated results for DistDeviate function call +dFunctionResult = (0.9826461346196363, 1.1973307331701328, 1.5105900949284945) + +# x and p arrays and interpolant to use for DistDeviate LookupTable tests +dx = [0.0, 1.0, 1.000000001, 2.999999999, 3.0, 4.0] +dp = [0.1, 0.1, 0.0, 0.0, 0.1, 0.1] +dLookupTable = galsim.LookupTable(x=dx, f=dp, interpolant="linear") +# Tabulated results for DistDeviate LookupTable call +dLookupTableResult = (0.23721845680847731, 0.42913599265739233, 0.86176396813243539) +# File with the same values +dLookupTableFile = os.path.join("random_data", "dLookupTable.dat") + + +@timer +def test_uniform(): + """Test uniform random number generator""" + u = galsim.UniformDeviate(testseed) + u2 = u.duplicate() + u3 = galsim.UniformDeviate(u.serialize()) + testResult = (u(), u(), u()) + np.testing.assert_array_almost_equal( + np.array(testResult), + np.array(uResult), + precision, + err_msg="Wrong uniform random number sequence generated", + ) + testResult = (u2(), u2(), u2()) + np.testing.assert_array_almost_equal( + np.array(testResult), + np.array(uResult), + precision, + err_msg="Wrong uniform random number sequence generated with duplicate", + ) + testResult = (u3(), u3(), u3()) + np.testing.assert_array_almost_equal( + np.array(testResult), + np.array(uResult), + precision, + err_msg="Wrong uniform random number sequence generated from serialize", + ) + + # Check that the mean and variance come out right + u = galsim.UniformDeviate(testseed) + vals = [u() for i in range(nvals)] + mean = np.mean(vals) + var = np.var(vals) + mu = 1.0 / 2.0 + v = 1.0 / 12.0 + print("mean = ", mean, " true mean = ", mu) + print("var = ", var, " true var = ", v) + np.testing.assert_almost_equal( + mean, mu, 1, err_msg="Wrong mean from UniformDeviate" + ) + np.testing.assert_almost_equal( + var, v, 1, err_msg="Wrong variance from UniformDeviate" + ) + + # Check discard + u2 = galsim.UniformDeviate(testseed) + u2.discard(nvals) + v1, v2 = u(), u2() + print("after %d vals, next one is %s, %s" % (nvals, v1, v2)) + assert v1 == v2 + assert u.has_reliable_discard + assert not u.generates_in_pairs + + # Check seed, reset + u.seed(testseed) + testResult2 = (u(), u(), u()) + np.testing.assert_array_equal( + np.array(testResult), + np.array(testResult2), + err_msg="Wrong uniform random number sequence generated after seed", + ) + + u.reset(testseed) + testResult2 = (u(), u(), u()) + np.testing.assert_array_equal( + np.array(testResult), + np.array(testResult2), + err_msg="Wrong uniform random number sequence generated after reset(seed)", + ) + + rng = galsim.BaseDeviate(testseed) + u.reset(rng) + testResult2 = (u(), u(), u()) + np.testing.assert_array_equal( + np.array(testResult), + np.array(testResult2), + err_msg="Wrong uniform random number sequence generated after reset(rng)", + ) + + # Check raw + u2.reset(testseed) + u2.discard(3) + np.testing.assert_equal( + u.raw(), u2.raw(), err_msg="Uniform deviates generate different raw values" + ) + + # NOTE: these tests differ from galsim since we cannot connect + # generators + rng2 = galsim.BaseDeviate(testseed) + rng2.discard(4) + rng.discard(4) # new line relative to galsim + np.testing.assert_equal( + rng.raw(), + rng2.raw(), + err_msg="BaseDeviates generate different raw values after discard", + ) + + # NOTE: these tests will never work in galsim since we cannot + # connect RNGs in JAX + # # Check that two connected uniform deviates work correctly together. + # u2 = galsim.UniformDeviate(testseed) + # u.reset(u2) + # testResult2 = (u(), u2(), u()) + # np.testing.assert_array_equal( + # np.array(testResult), np.array(testResult2), + # err_msg='Wrong uniform random number sequence generated using two uds') + # u.seed(testseed) + # testResult2 = (u2(), u(), u2()) + # np.testing.assert_array_equal( + # np.array(testResult), np.array(testResult2), + # err_msg='Wrong uniform random number sequence generated using two uds after seed') + + # Check that seeding with the time works (although we cannot check the output). + # We're mostly just checking that this doesn't raise an exception. + # The output could be anything. + u.seed() + testResult2 = (u(), u(), u()) + assert testResult2 != testResult + u.reset() + testResult3 = (u(), u(), u()) + assert testResult3 != testResult + assert testResult3 != testResult2 + u.reset() + testResult4 = (u(), u(), u()) + assert testResult4 != testResult + assert testResult4 != testResult2 + assert testResult4 != testResult3 + u = galsim.UniformDeviate() + testResult5 = (u(), u(), u()) + assert testResult5 != testResult + assert testResult5 != testResult2 + assert testResult5 != testResult3 + assert testResult5 != testResult4 + + # NOTE: these tests differ since we cannot edit arrays in-place in JAX + # Test generate + u.seed(testseed) + test_array = np.empty(3) + test_array = u.generate(test_array) + np.testing.assert_array_almost_equal( + test_array, + np.array(uResult), + precision, + err_msg="Wrong uniform random number sequence from generate.", + ) + + # Test add_generate + u.seed(testseed) + test_array = u.add_generate(test_array) + np.testing.assert_array_almost_equal( + test_array, + 2.0 * np.array(uResult), + precision, + err_msg="Wrong uniform random number sequence from generate.", + ) + + # Test generate with a float32 array + u.seed(testseed) + test_array = np.empty(3, dtype=np.float32) + test_array = u.generate(test_array) + np.testing.assert_array_almost_equal( + test_array, + np.array(uResult), + precisionF, + err_msg="Wrong uniform random number sequence from generate.", + ) + + # Test add_generate + u.seed(testseed) + test_array = u.add_generate(test_array) + np.testing.assert_array_almost_equal( + test_array, + 2.0 * np.array(uResult), + precisionF, + err_msg="Wrong uniform random number sequence from generate.", + ) + + # Check that generated values are independent of number of threads. + u1 = galsim.UniformDeviate(testseed) + u2 = galsim.UniformDeviate(testseed) + v1 = np.empty(555) + v2 = np.empty(555) + with single_threaded(): + v1 = u1.generate(v1) + with single_threaded(num_threads=10): + v2 = u2.generate(v2) + np.testing.assert_array_equal(v1, v2) + with single_threaded(): + v1 = u1.add_generate(v1) + with single_threaded(num_threads=10): + v2 = u2.add_generate(v2) + np.testing.assert_array_equal(v1, v2) + + # Check picklability + do_pickle(u, lambda x: x.serialize()) + do_pickle(u, lambda x: (x(), x(), x(), x())) + do_pickle(u) + do_pickle(rng) + assert "UniformDeviate" in repr(u) + assert "UniformDeviate" in str(u) + assert isinstance(eval(repr(u)), galsim.UniformDeviate) + assert isinstance(eval(str(u)), galsim.UniformDeviate) + assert isinstance(eval(repr(rng)), galsim.BaseDeviate) + assert isinstance(eval(str(rng)), galsim.BaseDeviate) + + # Check that we can construct a UniformDeviate from None, and that it depends on dev/random. + u1 = galsim.UniformDeviate(None) + u2 = galsim.UniformDeviate(None) + assert u1 != u2, "Consecutive UniformDeviate(None) compared equal!" + + # NOTE: We do not test for these since we do no type checking in JAX + # # We shouldn't be able to construct a UniformDeviate from anything but a BaseDeviate, int, str, + # # or None. + # assert_raises(TypeError, galsim.UniformDeviate, dict()) + # assert_raises(TypeError, galsim.UniformDeviate, list()) + # assert_raises(TypeError, galsim.UniformDeviate, set()) + + # assert_raises(TypeError, u.seed, '123') + # assert_raises(TypeError, u.seed, 12.3) + + +# @timer +# def test_gaussian(): +# """Test Gaussian random number generator +# """ +# g = galsim.GaussianDeviate(testseed, mean=gMean, sigma=gSigma) +# g2 = g.duplicate() +# g3 = galsim.GaussianDeviate(g.serialize(), mean=gMean, sigma=gSigma) +# testResult = (g(), g(), g()) +# np.testing.assert_array_almost_equal( +# np.array(testResult), np.array(gResult), precision, +# err_msg='Wrong Gaussian random number sequence generated') +# testResult = (g2(), g2(), g2()) +# np.testing.assert_array_almost_equal( +# np.array(testResult), np.array(gResult), precision, +# err_msg='Wrong Gaussian random number sequence generated with duplicate') +# testResult = (g3(), g3(), g3()) +# np.testing.assert_array_almost_equal( +# np.array(testResult), np.array(gResult), precision, +# err_msg='Wrong Gaussian random number sequence generated from serialize') + +# # Check that the mean and variance come out right +# g = galsim.GaussianDeviate(testseed, mean=gMean, sigma=gSigma) +# vals = [g() for i in range(nvals)] +# mean = np.mean(vals) +# var = np.var(vals) +# mu = gMean +# v = gSigma**2 +# print('mean = ',mean,' true mean = ',mu) +# print('var = ',var,' true var = ',v) +# np.testing.assert_almost_equal(mean, mu, 1, +# err_msg='Wrong mean from GaussianDeviate') +# np.testing.assert_almost_equal(var, v, 0, +# err_msg='Wrong variance from GaussianDeviate') + +# # Check discard +# g2 = galsim.GaussianDeviate(testseed, mean=gMean, sigma=gSigma) +# g2.discard(nvals) +# v1,v2 = g(),g2() +# print('after %d vals, next one is %s, %s'%(nvals,v1,v2)) +# assert v1 == v2 +# # Note: For Gaussian, this only works if nvals is even. +# g2 = galsim.GaussianDeviate(testseed, mean=gMean, sigma=gSigma) +# g2.discard(nvals+1, suppress_warnings=True) +# v1,v2 = g(),g2() +# print('after %d vals, next one is %s, %s'%(nvals+1,v1,v2)) +# assert v1 != v2 +# assert g.has_reliable_discard +# assert g.generates_in_pairs + +# # If don't explicitly suppress the warning, then a warning is emitted when n is odd. +# g2 = galsim.GaussianDeviate(testseed, mean=gMean, sigma=gSigma) +# with assert_warns(galsim.GalSimWarning): +# g2.discard(nvals+1) + +# # Check seed, reset +# g.seed(testseed) +# testResult2 = (g(), g(), g()) +# np.testing.assert_array_equal( +# np.array(testResult), np.array(testResult2), +# err_msg='Wrong Gaussian random number sequence generated after seed') + +# g.reset(testseed) +# testResult2 = (g(), g(), g()) +# np.testing.assert_array_equal( +# np.array(testResult), np.array(testResult2), +# err_msg='Wrong Gaussian random number sequence generated after reset(seed)') + +# rng = galsim.BaseDeviate(testseed) +# g.reset(rng) +# testResult2 = (g(), g(), g()) +# np.testing.assert_array_equal( +# np.array(testResult), np.array(testResult2), +# err_msg='Wrong Gaussian random number sequence generated after reset(rng)') + +# ud = galsim.UniformDeviate(testseed) +# g.reset(ud) +# testResult = (g(), g(), g()) +# np.testing.assert_array_equal( +# np.array(testResult), np.array(testResult2), +# err_msg='Wrong Gaussian random number sequence generated after reset(ud)') + +# # Check that two connected Gaussian deviates work correctly together. +# g2 = galsim.GaussianDeviate(testseed, mean=gMean, sigma=gSigma) +# g.reset(g2) +# # Note: GaussianDeviate generates two values at a time, so we have to compare them in pairs. +# testResult2 = (g(), g(), g2()) +# np.testing.assert_array_equal( +# np.array(testResult), np.array(testResult2), +# err_msg='Wrong Gaussian random number sequence generated using two gds') +# g.seed(testseed) +# # For the same reason, after seeding one, we need to manually clear the other's cache: +# g2.clearCache() +# testResult2 = (g2(), g2(), g()) +# np.testing.assert_array_equal( +# np.array(testResult), np.array(testResult2), +# err_msg='Wrong Gaussian random number sequence generated using two gds after seed') + +# # Check that seeding with the time works (although we cannot check the output). +# # We're mostly just checking that this doesn't raise an exception. +# # The output could be anything. +# g.seed() +# testResult2 = (g(), g(), g()) +# assert testResult2 != testResult +# g.reset() +# testResult3 = (g(), g(), g()) +# assert testResult3 != testResult +# assert testResult3 != testResult2 +# g.reset() +# testResult4 = (g(), g(), g()) +# assert testResult4 != testResult +# assert testResult4 != testResult2 +# assert testResult4 != testResult3 +# g = galsim.GaussianDeviate(mean=gMean, sigma=gSigma) +# testResult5 = (g(), g(), g()) +# assert testResult5 != testResult +# assert testResult5 != testResult2 +# assert testResult5 != testResult3 +# assert testResult5 != testResult4 + +# # Test generate +# g.seed(testseed) +# test_array = np.empty(3) +# g.generate(test_array) +# np.testing.assert_array_almost_equal( +# test_array, np.array(gResult), precision, +# err_msg='Wrong Gaussian random number sequence from generate.') + +# # Test generate_from_variance. +# g2 = galsim.GaussianDeviate(testseed, mean=5, sigma=0.3) +# g3 = galsim.GaussianDeviate(testseed, mean=5, sigma=0.3) +# test_array.fill(gSigma**2) +# g2.generate_from_variance(test_array) +# np.testing.assert_array_almost_equal( +# test_array, np.array(gResult)-gMean, precision, +# err_msg='Wrong Gaussian random number sequence from generate_from_variance.') +# # After running generate_from_variance, it should be back to using the specified mean, sigma. +# # Note: need to round up to even number for discard, since gd generates 2 at a time. +# g3.discard((len(test_array)+1)//2 * 2) +# print('g2,g3 = ',g2(),g3()) +# assert g2() == g3() + +# # Test generate with a float32 array. +# g.seed(testseed) +# test_array = np.empty(3, dtype=np.float32) +# g.generate(test_array) +# np.testing.assert_array_almost_equal( +# test_array, np.array(gResult), precisionF, +# err_msg='Wrong Gaussian random number sequence from generate.') + +# # Test generate_from_variance. +# g2.seed(testseed) +# test_array.fill(gSigma**2) +# g2.generate_from_variance(test_array) +# np.testing.assert_array_almost_equal( +# test_array, np.array(gResult)-gMean, precisionF, +# err_msg='Wrong Gaussian random number sequence from generate_from_variance.') + +# # Check that generated values are independent of number of threads. +# g1 = galsim.GaussianDeviate(testseed, mean=53, sigma=1.3) +# g2 = galsim.GaussianDeviate(testseed, mean=53, sigma=1.3) +# v1 = np.empty(555) +# v2 = np.empty(555) +# with single_threaded(): +# g1.generate(v1) +# with single_threaded(num_threads=10): +# g2.generate(v2) +# np.testing.assert_array_equal(v1, v2) +# with single_threaded(): +# g1.add_generate(v1) +# with single_threaded(num_threads=10): +# g2.add_generate(v2) +# np.testing.assert_array_equal(v1, v2) +# ud = galsim.UniformDeviate(testseed + 3) +# ud.generate(v1) +# v1 += 6.7 +# v2[:] = v1 +# with single_threaded(): +# g1.generate_from_variance(v1) +# with single_threaded(num_threads=10): +# g2.generate_from_variance(v2) +# np.testing.assert_array_equal(v1, v2) + +# # Check picklability +# do_pickle(g, lambda x: (x.serialize(), x.mean, x.sigma)) +# do_pickle(g, lambda x: (x(), x(), x(), x())) +# do_pickle(g) +# assert 'GaussianDeviate' in repr(g) +# assert 'GaussianDeviate' in str(g) +# assert isinstance(eval(repr(g)), galsim.GaussianDeviate) +# assert isinstance(eval(str(g)), galsim.GaussianDeviate) + +# # Check that we can construct a GaussianDeviate from None, and that it depends on dev/random. +# g1 = galsim.GaussianDeviate(None) +# g2 = galsim.GaussianDeviate(None) +# assert g1 != g2, "Consecutive GaussianDeviate(None) compared equal!" +# # We shouldn't be able to construct a GaussianDeviate from anything but a BaseDeviate, int, str, +# # or None. +# assert_raises(TypeError, galsim.GaussianDeviate, dict()) +# assert_raises(TypeError, galsim.GaussianDeviate, list()) +# assert_raises(TypeError, galsim.GaussianDeviate, set()) + +# assert_raises(ValueError, galsim.GaussianDeviate, testseed, mean=1, sigma=-1) + + +# @timer +# def test_binomial(): +# """Test binomial random number generator +# """ +# b = galsim.BinomialDeviate(testseed, N=bN, p=bp) +# b2 = b.duplicate() +# b3 = galsim.BinomialDeviate(b.serialize(), N=bN, p=bp) +# testResult = (b(), b(), b()) +# np.testing.assert_array_almost_equal( +# np.array(testResult), np.array(bResult), precision, +# err_msg='Wrong binomial random number sequence generated') +# testResult = (b2(), b2(), b2()) +# np.testing.assert_array_almost_equal( +# np.array(testResult), np.array(bResult), precision, +# err_msg='Wrong binomial random number sequence generated with duplicate') +# testResult = (b3(), b3(), b3()) +# np.testing.assert_array_almost_equal( +# np.array(testResult), np.array(bResult), precision, +# err_msg='Wrong binomial random number sequence generated from serialize') + +# # Check that the mean and variance come out right +# b = galsim.BinomialDeviate(testseed, N=bN, p=bp) +# vals = [b() for i in range(nvals)] +# mean = np.mean(vals) +# var = np.var(vals) +# mu = bN*bp +# v = bN*bp*(1.-bp) +# print('mean = ',mean,' true mean = ',mu) +# print('var = ',var,' true var = ',v) +# np.testing.assert_almost_equal(mean, mu, 1, +# err_msg='Wrong mean from BinomialDeviate') +# np.testing.assert_almost_equal(var, v, 1, +# err_msg='Wrong variance from BinomialDeviate') + +# # Check discard +# b2 = galsim.BinomialDeviate(testseed, N=bN, p=bp) +# b2.discard(nvals) +# v1,v2 = b(),b2() +# print('after %d vals, next one is %s, %s'%(nvals,v1,v2)) +# assert v1 == v2 +# assert b.has_reliable_discard +# assert not b.generates_in_pairs + +# # Check seed, reset +# b.seed(testseed) +# testResult2 = (b(), b(), b()) +# np.testing.assert_array_equal( +# np.array(testResult), np.array(testResult2), +# err_msg='Wrong binomial random number sequence generated after seed') + +# b.reset(testseed) +# testResult2 = (b(), b(), b()) +# np.testing.assert_array_equal( +# np.array(testResult), np.array(testResult2), +# err_msg='Wrong binomial random number sequence generated after reset(seed)') + +# rng = galsim.BaseDeviate(testseed) +# b.reset(rng) +# testResult2 = (b(), b(), b()) +# np.testing.assert_array_equal( +# np.array(testResult), np.array(testResult2), +# err_msg='Wrong binomial random number sequence generated after reset(rng)') + +# ud = galsim.UniformDeviate(testseed) +# b.reset(ud) +# testResult = (b(), b(), b()) +# np.testing.assert_array_equal( +# np.array(testResult), np.array(testResult2), +# err_msg='Wrong binomial random number sequence generated after reset(ud)') + +# # Check that two connected binomial deviates work correctly together. +# b2 = galsim.BinomialDeviate(testseed, N=bN, p=bp) +# b.reset(b2) +# testResult2 = (b(), b2(), b()) +# np.testing.assert_array_equal( +# np.array(testResult), np.array(testResult2), +# err_msg='Wrong binomial random number sequence generated using two bds') +# b.seed(testseed) +# testResult2 = (b2(), b(), b2()) +# np.testing.assert_array_equal( +# np.array(testResult), np.array(testResult2), +# err_msg='Wrong binomial random number sequence generated using two bds after seed') + +# # Check that seeding with the time works (although we cannot check the output). +# # We're mostly just checking that this doesn't raise an exception. +# # The output could be anything. However, in this case, there are few enough options +# # for the output that occasionally two of these match. So we don't do the normal +# # testResult2 != testResult, etc. +# b.seed() +# testResult2 = (b(), b(), b()) +# #assert testResult2 != testResult +# b.reset() +# testResult3 = (b(), b(), b()) +# #assert testResult3 != testResult +# #assert testResult3 != testResult2 +# b.reset() +# testResult4 = (b(), b(), b()) +# #assert testResult4 != testResult +# #assert testResult4 != testResult2 +# #assert testResult4 != testResult3 +# b = galsim.BinomialDeviate(N=bN, p=bp) +# testResult5 = (b(), b(), b()) +# #assert testResult5 != testResult +# #assert testResult5 != testResult2 +# #assert testResult5 != testResult3 +# #assert testResult5 != testResult4 + +# # Test generate +# b.seed(testseed) +# test_array = np.empty(3) +# b.generate(test_array) +# np.testing.assert_array_almost_equal( +# test_array, np.array(bResult), precision, +# err_msg='Wrong binomial random number sequence from generate.') + +# # Test generate with an int array +# b.seed(testseed) +# test_array = np.empty(3, dtype=int) +# b.generate(test_array) +# np.testing.assert_array_almost_equal( +# test_array, np.array(bResult), precisionI, +# err_msg='Wrong binomial random number sequence from generate.') + +# # Check that generated values are independent of number of threads. +# b1 = galsim.BinomialDeviate(testseed, N=17, p=0.7) +# b2 = galsim.BinomialDeviate(testseed, N=17, p=0.7) +# v1 = np.empty(555) +# v2 = np.empty(555) +# with single_threaded(): +# b1.generate(v1) +# with single_threaded(num_threads=10): +# b2.generate(v2) +# np.testing.assert_array_equal(v1, v2) +# with single_threaded(): +# b1.add_generate(v1) +# with single_threaded(num_threads=10): +# b2.add_generate(v2) +# np.testing.assert_array_equal(v1, v2) + +# # Check picklability +# do_pickle(b, lambda x: (x.serialize(), x.n, x.p)) +# do_pickle(b, lambda x: (x(), x(), x(), x())) +# do_pickle(b) +# assert 'BinomialDeviate' in repr(b) +# assert 'BinomialDeviate' in str(b) +# assert isinstance(eval(repr(b)), galsim.BinomialDeviate) +# assert isinstance(eval(str(b)), galsim.BinomialDeviate) + +# # Check that we can construct a BinomialDeviate from None, and that it depends on dev/random. +# b1 = galsim.BinomialDeviate(None) +# b2 = galsim.BinomialDeviate(None) +# assert b1 != b2, "Consecutive BinomialDeviate(None) compared equal!" +# # We shouldn't be able to construct a BinomialDeviate from anything but a BaseDeviate, int, str, +# # or None. +# assert_raises(TypeError, galsim.BinomialDeviate, dict()) +# assert_raises(TypeError, galsim.BinomialDeviate, list()) +# assert_raises(TypeError, galsim.BinomialDeviate, set()) + + +# @timer +# def test_poisson(): +# """Test Poisson random number generator +# """ +# p = galsim.PoissonDeviate(testseed, mean=pMean) +# p2 = p.duplicate() +# p3 = galsim.PoissonDeviate(p.serialize(), mean=pMean) +# testResult = (p(), p(), p()) +# np.testing.assert_array_almost_equal( +# np.array(testResult), np.array(pResult), precision, +# err_msg='Wrong Poisson random number sequence generated') +# testResult = (p2(), p2(), p2()) +# np.testing.assert_array_almost_equal( +# np.array(testResult), np.array(pResult), precision, +# err_msg='Wrong Poisson random number sequence generated with duplicate') +# testResult = (p3(), p3(), p3()) +# np.testing.assert_array_almost_equal( +# np.array(testResult), np.array(pResult), precision, +# err_msg='Wrong Poisson random number sequence generated from serialize') + +# # Check that the mean and variance come out right +# p = galsim.PoissonDeviate(testseed, mean=pMean) +# vals = [p() for i in range(nvals)] +# mean = np.mean(vals) +# var = np.var(vals) +# mu = pMean +# v = pMean +# print('mean = ',mean,' true mean = ',mu) +# print('var = ',var,' true var = ',v) +# np.testing.assert_almost_equal(mean, mu, 1, +# err_msg='Wrong mean from PoissonDeviate') +# np.testing.assert_almost_equal(var, v, 1, +# err_msg='Wrong variance from PoissonDeviate') + +# # Check discard +# p2 = galsim.PoissonDeviate(testseed, mean=pMean) +# p2.discard(nvals, suppress_warnings=True) +# v1,v2 = p(),p2() +# print('With mean = %d, after %d vals, next one is %s, %s'%(pMean,nvals,v1,v2)) +# assert v1 == v2 + +# # With a very small mean value, Poisson reliably only uses 1 rng per value. +# # But at only slightly larger means, it sometimes uses two rngs for a single value. +# # Basically anything >= 10 causes this next test to have v1 != v2 +# high_mean = 10 +# p = galsim.PoissonDeviate(testseed, mean=high_mean) +# p2 = galsim.PoissonDeviate(testseed, mean=high_mean) +# vals = [p() for i in range(nvals)] +# p2.discard(nvals, suppress_warnings=True) +# v1,v2 = p(),p2() +# print('With mean = %d, after %d vals, next one is %s, %s'%(high_mean,nvals,v1,v2)) +# assert v1 != v2 +# assert not p.has_reliable_discard +# assert not p.generates_in_pairs + +# # Discard normally emits a warning for Poisson +# p2 = galsim.PoissonDeviate(testseed, mean=pMean) +# with assert_warns(galsim.GalSimWarning): +# p2.discard(nvals) + +# # Check seed, reset +# p = galsim.PoissonDeviate(testseed, mean=pMean) +# p.seed(testseed) +# testResult2 = (p(), p(), p()) +# np.testing.assert_array_equal( +# np.array(testResult), np.array(testResult2), +# err_msg='Wrong poisson random number sequence generated after seed') + +# p.reset(testseed) +# testResult2 = (p(), p(), p()) +# np.testing.assert_array_equal( +# np.array(testResult), np.array(testResult2), +# err_msg='Wrong poisson random number sequence generated after reset(seed)') + +# rng = galsim.BaseDeviate(testseed) +# p.reset(rng) +# testResult2 = (p(), p(), p()) +# np.testing.assert_array_equal( +# np.array(testResult), np.array(testResult2), +# err_msg='Wrong poisson random number sequence generated after reset(rng)') + +# ud = galsim.UniformDeviate(testseed) +# p.reset(ud) +# testResult = (p(), p(), p()) +# np.testing.assert_array_equal( +# np.array(testResult), np.array(testResult2), +# err_msg='Wrong poisson random number sequence generated after reset(ud)') + +# # Check that two connected poisson deviates work correctly together. +# p2 = galsim.PoissonDeviate(testseed, mean=pMean) +# p.reset(p2) +# testResult2 = (p(), p2(), p()) +# np.testing.assert_array_equal( +# np.array(testResult), np.array(testResult2), +# err_msg='Wrong poisson random number sequence generated using two pds') +# p.seed(testseed) +# testResult2 = (p2(), p(), p2()) +# np.testing.assert_array_equal( +# np.array(testResult), np.array(testResult2), +# err_msg='Wrong poisson random number sequence generated using two pds after seed') + +# # Check that seeding with the time works (although we cannot check the output). +# # We're mostly just checking that this doesn't raise an exception. +# # The output could be anything. However, in this case, there are few enough options +# # for the output that occasionally two of these match. So we don't do the normal +# # testResult2 != testResult, etc. +# p.seed() +# testResult2 = (p(), p(), p()) +# #assert testResult2 != testResult +# p.reset() +# testResult3 = (p(), p(), p()) +# #assert testResult3 != testResult +# #assert testResult3 != testResult2 +# p.reset() +# testResult4 = (p(), p(), p()) +# #assert testResult4 != testResult +# #assert testResult4 != testResult2 +# #assert testResult4 != testResult3 +# p = galsim.PoissonDeviate(mean=pMean) +# testResult5 = (p(), p(), p()) +# #assert testResult5 != testResult +# #assert testResult5 != testResult2 +# #assert testResult5 != testResult3 +# #assert testResult5 != testResult4 + +# # Test generate +# p.seed(testseed) +# test_array = np.empty(3) +# p.generate(test_array) +# np.testing.assert_array_almost_equal( +# test_array, np.array(pResult), precision, +# err_msg='Wrong poisson random number sequence from generate.') + +# # Test generate with an int array +# p.seed(testseed) +# test_array = np.empty(3, dtype=int) +# p.generate(test_array) +# np.testing.assert_array_almost_equal( +# test_array, np.array(pResult), precisionI, +# err_msg='Wrong poisson random number sequence from generate.') + +# # Test generate_from_expectation +# p2 = galsim.PoissonDeviate(testseed, mean=77) +# test_array = np.array([pMean]*3, dtype=int) +# p2.generate_from_expectation(test_array) +# np.testing.assert_array_almost_equal( +# test_array, np.array(pResult), precisionI, +# err_msg='Wrong poisson random number sequence from generate_from_expectation.') +# # After generating, it should be back to mean=77 +# test_array2 = np.array([p2() for i in range(100)]) +# print('test_array2 = ',test_array2) +# print('mean = ',test_array2.mean()) +# assert np.isclose(test_array2.mean(), 77, atol=2) + +# # Check that generated values are independent of number of threads. +# # This should be trivial, since Poisson disables multi-threading, but check anyway. +# p1 = galsim.PoissonDeviate(testseed, mean=77) +# p2 = galsim.PoissonDeviate(testseed, mean=77) +# v1 = np.empty(555) +# v2 = np.empty(555) +# with single_threaded(): +# p1.generate(v1) +# with single_threaded(num_threads=10): +# p2.generate(v2) +# np.testing.assert_array_equal(v1, v2) +# with single_threaded(): +# p1.add_generate(v1) +# with single_threaded(num_threads=10): +# p2.add_generate(v2) +# np.testing.assert_array_equal(v1, v2) + +# # Check picklability +# do_pickle(p, lambda x: (x.serialize(), x.mean)) +# do_pickle(p, lambda x: (x(), x(), x(), x())) +# do_pickle(p) +# assert 'PoissonDeviate' in repr(p) +# assert 'PoissonDeviate' in str(p) +# assert isinstance(eval(repr(p)), galsim.PoissonDeviate) +# assert isinstance(eval(str(p)), galsim.PoissonDeviate) + +# # Check that we can construct a PoissonDeviate from None, and that it depends on dev/random. +# p1 = galsim.PoissonDeviate(None) +# p2 = galsim.PoissonDeviate(None) +# assert p1 != p2, "Consecutive PoissonDeviate(None) compared equal!" +# # We shouldn't be able to construct a PoissonDeviate from anything but a BaseDeviate, int, str, +# # or None. +# assert_raises(TypeError, galsim.PoissonDeviate, dict()) +# assert_raises(TypeError, galsim.PoissonDeviate, list()) +# assert_raises(TypeError, galsim.PoissonDeviate, set()) + + +# @timer +# def test_poisson_highmean(): +# """Test Poisson random number generator with high (>2^30) mean (cf. Issue #881) + +# It turns out that the boost poisson deviate class that we use maxes out at 2^31 and wraps +# around to -2^31. We have code to automatically switch over to using a Gaussian deviate +# instead if the mean > 2^30 (factor of 2 from the problem to be safe). Check that this +# works properly. +# """ +# mean_vals =[ 2**30 + 50, # Uses Gaussian +# 2**30 - 50, # Uses Poisson +# 2**30, # Uses Poisson (highest value of mean that does) +# 2**31, # This is where problems happen if not using Gaussian +# 5.e20, # Definitely would have problems with normal implementation. +# ] + +# if __name__ == '__main__': +# nvals = 10000000 +# rtol_var = 1.e-3 +# else: +# nvals = 100000 +# rtol_var = 1.e-2 + +# for mean in mean_vals: +# print('Test PoissonDeviate with mean = ',mean) +# p = galsim.PoissonDeviate(testseed, mean=mean) +# p2 = p.duplicate() +# p3 = galsim.PoissonDeviate(p.serialize(), mean=mean) +# testResult = (p(), p(), p()) +# testResult2 = (p2(), p2(), p2()) +# testResult3 = (p3(), p3(), p3()) +# np.testing.assert_allclose( +# testResult2, testResult, rtol=1.e-8, +# err_msg='PoissonDeviate.duplicate not equivalent for mean=%s'%mean) +# np.testing.assert_allclose( +# testResult3, testResult, rtol=1.e-8, +# err_msg='PoissonDeviate from serialize not equivalent for mean=%s'%mean) + +# # Check that the mean and variance come out right +# p = galsim.PoissonDeviate(testseed, mean=mean) +# vals = [p() for i in range(nvals)] +# mu = np.mean(vals) +# var = np.var(vals) +# print('mean = ',mu,' true mean = ',mean) +# print('var = ',var,' true var = ',mean) +# np.testing.assert_allclose(mu, mean, rtol=1.e-5, +# err_msg='Wrong mean from PoissonDeviate with mean=%s'%mean) +# np.testing.assert_allclose(var, mean, rtol=rtol_var, +# err_msg='Wrong variance from PoissonDeviate with mean=%s'%mean) + +# # Check discard +# p2 = galsim.PoissonDeviate(testseed, mean=mean) +# p2.discard(nvals, suppress_warnings=True) +# v1,v2 = p(),p2() +# print('after %d vals, next one is %s, %s'%(nvals,v1,v2)) +# if mean > 2**30: +# # Poisson doesn't have a reliable rng count (unless the mean is vv small). +# # But above 2**30 we're back to Gaussian, which is reliable. +# assert v1 == v2 + +# # Check that two connected poisson deviates work correctly together. +# p2 = galsim.PoissonDeviate(testseed, mean=mean) +# p.reset(p2) +# testResult2 = (p(), p(), p2()) +# np.testing.assert_array_equal( +# testResult2, testResult, +# err_msg='Wrong poisson random number sequence generated using two pds') +# p.seed(testseed) +# p2.clearCache() +# testResult2 = (p2(), p2(), p()) +# np.testing.assert_array_equal( +# testResult2, testResult, +# err_msg='Wrong poisson random number sequence generated using two pds after seed') + +# # Test filling an image +# p.seed(testseed) +# testimage = galsim.ImageD(np.zeros((3, 1))) +# testimage.addNoise(galsim.DeviateNoise(p)) +# np.testing.assert_array_equal( +# testimage.array.flatten(), testResult, +# err_msg='Wrong poisson random number sequence generated when applied to image.') + +# # The PoissonNoise version also subtracts off the mean value +# rng = galsim.BaseDeviate(testseed) +# pn = galsim.PoissonNoise(rng, sky_level=mean) +# testimage.fill(0) +# testimage.addNoise(pn) +# np.testing.assert_array_equal( +# testimage.array.flatten(), np.array(testResult)-mean, +# err_msg='Wrong poisson random number sequence generated using PoissonNoise') + +# # Check PoissonNoise variance: +# np.testing.assert_allclose( +# pn.getVariance(), mean, rtol=1.e-8, +# err_msg="PoissonNoise getVariance returns wrong variance") +# np.testing.assert_allclose( +# pn.sky_level, mean, rtol=1.e-8, +# err_msg="PoissonNoise sky_level returns wrong value") + +# # Check that the noise model really does produce this variance. +# big_im = galsim.Image(2048,2048,dtype=float) +# big_im.addNoise(pn) +# var = np.var(big_im.array) +# print('variance = ',var) +# print('getVar = ',pn.getVariance()) +# np.testing.assert_allclose( +# var, pn.getVariance(), rtol=rtol_var, +# err_msg='Realized variance for PoissonNoise did not match getVariance()') + + +# @timer +# def test_poisson_zeromean(): +# """Make sure Poisson Deviate behaves sensibly when mean=0. +# """ +# p = galsim.PoissonDeviate(testseed, mean=0) +# p2 = p.duplicate() +# p3 = galsim.PoissonDeviate(p.serialize(), mean=0) +# do_pickle(p) + +# # Test direct draws +# testResult = (p(), p(), p()) +# testResult2 = (p2(), p2(), p2()) +# testResult3 = (p3(), p3(), p3()) +# np.testing.assert_array_equal(testResult, 0) +# np.testing.assert_array_equal(testResult2, 0) +# np.testing.assert_array_equal(testResult3, 0) + +# # Test generate +# test_array = np.empty(3, dtype=int) +# p.generate(test_array) +# np.testing.assert_array_equal(test_array, 0) +# p2.generate(test_array) +# np.testing.assert_array_equal(test_array, 0) +# p3.generate(test_array) +# np.testing.assert_array_equal(test_array, 0) + +# # Test generate_from_expectation +# test_array = np.array([0,0,0]) +# np.testing.assert_allclose(test_array, 0) +# test_array = np.array([1,0,4]) +# assert test_array[0] != 0 +# assert test_array[1] == 0 +# assert test_array[2] != 0 + +# # Error raised if mean<0 +# with assert_raises(ValueError): +# p = galsim.PoissonDeviate(testseed, mean=-0.1) +# with assert_raises(ValueError): +# p = galsim.PoissonDeviate(testseed, mean=-10) +# test_array = np.array([-1,1,4]) +# with assert_raises(ValueError): +# p.generate_from_expectation(test_array) +# test_array = np.array([1,-1,-4]) +# with assert_raises(ValueError): +# p.generate_from_expectation(test_array) + +# @timer +# def test_weibull(): +# """Test Weibull random number generator +# """ +# w = galsim.WeibullDeviate(testseed, a=wA, b=wB) +# w2 = w.duplicate() +# w3 = galsim.WeibullDeviate(w.serialize(), a=wA, b=wB) +# testResult = (w(), w(), w()) +# np.testing.assert_array_almost_equal( +# np.array(testResult), np.array(wResult), precision, +# err_msg='Wrong Weibull random number sequence generated') +# testResult = (w2(), w2(), w2()) +# np.testing.assert_array_almost_equal( +# np.array(testResult), np.array(wResult), precision, +# err_msg='Wrong Weibull random number sequence generated with duplicate') +# testResult = (w3(), w3(), w3()) +# np.testing.assert_array_almost_equal( +# np.array(testResult), np.array(wResult), precision, +# err_msg='Wrong Weibull random number sequence generated from serialize') + +# # Check that the mean and variance come out right +# w = galsim.WeibullDeviate(testseed, a=wA, b=wB) +# vals = [w() for i in range(nvals)] +# mean = np.mean(vals) +# var = np.var(vals) +# gammaFactor1 = math.gamma(1.+1./wA) +# gammaFactor2 = math.gamma(1.+2./wA) +# mu = wB * gammaFactor1 +# v = wB**2 * gammaFactor2 - mu**2 +# print('mean = ',mean,' true mean = ',mu) +# print('var = ',var,' true var = ',v) +# np.testing.assert_almost_equal(mean, mu, 1, +# err_msg='Wrong mean from WeibullDeviate') +# np.testing.assert_almost_equal(var, v, 1, +# err_msg='Wrong variance from WeibullDeviate') + +# # Check discard +# w2 = galsim.WeibullDeviate(testseed, a=wA, b=wB) +# w2.discard(nvals) +# v1,v2 = w(),w2() +# print('after %d vals, next one is %s, %s'%(nvals,v1,v2)) +# assert v1 == v2 +# assert w.has_reliable_discard +# assert not w.generates_in_pairs + +# # Check seed, reset +# w.seed(testseed) +# testResult2 = (w(), w(), w()) +# np.testing.assert_array_equal( +# np.array(testResult), np.array(testResult2), +# err_msg='Wrong weibull random number sequence generated after seed') + +# w.reset(testseed) +# testResult2 = (w(), w(), w()) +# np.testing.assert_array_equal( +# np.array(testResult), np.array(testResult2), +# err_msg='Wrong weibull random number sequence generated after reset(seed)') + +# rng = galsim.BaseDeviate(testseed) +# w.reset(rng) +# testResult2 = (w(), w(), w()) +# np.testing.assert_array_equal( +# np.array(testResult), np.array(testResult2), +# err_msg='Wrong weibull random number sequence generated after reset(rng)') + +# ud = galsim.UniformDeviate(testseed) +# w.reset(ud) +# testResult = (w(), w(), w()) +# np.testing.assert_array_equal( +# np.array(testResult), np.array(testResult2), +# err_msg='Wrong weibull random number sequence generated after reset(ud)') + +# # Check that two connected weibull deviates work correctly together. +# w2 = galsim.WeibullDeviate(testseed, a=wA, b=wB) +# w.reset(w2) +# testResult2 = (w(), w2(), w()) +# np.testing.assert_array_equal( +# np.array(testResult), np.array(testResult2), +# err_msg='Wrong weibull random number sequence generated using two wds') +# w.seed(testseed) +# testResult2 = (w2(), w(), w2()) +# np.testing.assert_array_equal( +# np.array(testResult), np.array(testResult2), +# err_msg='Wrong weibull random number sequence generated using two wds after seed') + +# # Check that seeding with the time works (although we cannot check the output). +# # We're mostly just checking that this doesn't raise an exception. +# # The output could be anything. +# w.seed() +# testResult2 = (w(), w(), w()) +# assert testResult2 != testResult +# w.reset() +# testResult3 = (w(), w(), w()) +# assert testResult3 != testResult +# assert testResult3 != testResult2 +# w.reset() +# testResult4 = (w(), w(), w()) +# assert testResult4 != testResult +# assert testResult4 != testResult2 +# assert testResult4 != testResult3 +# w = galsim.WeibullDeviate(a=wA, b=wB) +# testResult5 = (w(), w(), w()) +# assert testResult5 != testResult +# assert testResult5 != testResult2 +# assert testResult5 != testResult3 +# assert testResult5 != testResult4 + +# # Test generate +# w.seed(testseed) +# test_array = np.empty(3) +# w.generate(test_array) +# np.testing.assert_array_almost_equal( +# test_array, np.array(wResult), precision, +# err_msg='Wrong weibull random number sequence from generate.') + +# # Test generate with a float32 array +# w.seed(testseed) +# test_array = np.empty(3, dtype=np.float32) +# w.generate(test_array) +# np.testing.assert_array_almost_equal( +# test_array, np.array(wResult), precisionF, +# err_msg='Wrong weibull random number sequence from generate.') + +# # Check that generated values are independent of number of threads. +# w1 = galsim.WeibullDeviate(testseed, a=3.1, b=7.3) +# w2 = galsim.WeibullDeviate(testseed, a=3.1, b=7.3) +# v1 = np.empty(555) +# v2 = np.empty(555) +# with single_threaded(): +# w1.generate(v1) +# with single_threaded(num_threads=10): +# w2.generate(v2) +# np.testing.assert_array_equal(v1, v2) +# with single_threaded(): +# w1.add_generate(v1) +# with single_threaded(num_threads=10): +# w2.add_generate(v2) +# np.testing.assert_array_equal(v1, v2) + +# # Check picklability +# do_pickle(w, lambda x: (x.serialize(), x.a, x.b)) +# do_pickle(w, lambda x: (x(), x(), x(), x())) +# do_pickle(w) +# assert 'WeibullDeviate' in repr(w) +# assert 'WeibullDeviate' in str(w) +# assert isinstance(eval(repr(w)), galsim.WeibullDeviate) +# assert isinstance(eval(str(w)), galsim.WeibullDeviate) + +# # Check that we can construct a WeibullDeviate from None, and that it depends on dev/random. +# w1 = galsim.WeibullDeviate(None) +# w2 = galsim.WeibullDeviate(None) +# assert w1 != w2, "Consecutive WeibullDeviate(None) compared equal!" +# # We shouldn't be able to construct a WeibullDeviate from anything but a BaseDeviate, int, str, +# # or None. +# assert_raises(TypeError, galsim.WeibullDeviate, dict()) +# assert_raises(TypeError, galsim.WeibullDeviate, list()) +# assert_raises(TypeError, galsim.WeibullDeviate, set()) + + +# @timer +# def test_gamma(): +# """Test Gamma random number generator +# """ +# g = galsim.GammaDeviate(testseed, k=gammaK, theta=gammaTheta) +# g2 = g.duplicate() +# g3 = galsim.GammaDeviate(g.serialize(), k=gammaK, theta=gammaTheta) +# testResult = (g(), g(), g()) +# np.testing.assert_array_almost_equal( +# np.array(testResult), np.array(gammaResult), precision, +# err_msg='Wrong Gamma random number sequence generated') +# testResult = (g2(), g2(), g2()) +# np.testing.assert_array_almost_equal( +# np.array(testResult), np.array(gammaResult), precision, +# err_msg='Wrong Gamma random number sequence generated with duplicate') +# testResult = (g3(), g3(), g3()) +# np.testing.assert_array_almost_equal( +# np.array(testResult), np.array(gammaResult), precision, +# err_msg='Wrong Gamma random number sequence generated from serialize') + +# # Check that the mean and variance come out right +# g = galsim.GammaDeviate(testseed, k=gammaK, theta=gammaTheta) +# vals = [g() for i in range(nvals)] +# mean = np.mean(vals) +# var = np.var(vals) +# mu = gammaK*gammaTheta +# v = gammaK*gammaTheta**2 +# print('mean = ',mean,' true mean = ',mu) +# print('var = ',var,' true var = ',v) +# np.testing.assert_almost_equal(mean, mu, 1, +# err_msg='Wrong mean from GammaDeviate') +# np.testing.assert_almost_equal(var, v, 0, +# err_msg='Wrong variance from GammaDeviate') + +# # Check discard +# g2 = galsim.GammaDeviate(testseed, k=gammaK, theta=gammaTheta) +# g2.discard(nvals, suppress_warnings=True) +# v1,v2 = g(),g2() +# print('after %d vals, next one is %s, %s'%(nvals,v1,v2)) +# # Gamma uses at least 2 rngs per value, but can use arbitrarily more than this. +# assert v1 != v2 +# assert not g.has_reliable_discard +# assert not g.generates_in_pairs + +# # Discard normally emits a warning for Gamma +# g2 = galsim.GammaDeviate(testseed, k=gammaK, theta=gammaTheta) +# with assert_warns(galsim.GalSimWarning): +# g2.discard(nvals) + +# # Check seed, reset +# g.seed(testseed) +# testResult2 = (g(), g(), g()) +# np.testing.assert_array_equal( +# np.array(testResult), np.array(testResult2), +# err_msg='Wrong gamma random number sequence generated after seed') + +# g.reset(testseed) +# testResult2 = (g(), g(), g()) +# np.testing.assert_array_equal( +# np.array(testResult), np.array(testResult2), +# err_msg='Wrong gamma random number sequence generated after reset(seed)') + +# rng = galsim.BaseDeviate(testseed) +# g.reset(rng) +# testResult2 = (g(), g(), g()) +# np.testing.assert_array_equal( +# np.array(testResult), np.array(testResult2), +# err_msg='Wrong gamma random number sequence generated after reset(rng)') + +# ud = galsim.UniformDeviate(testseed) +# g.reset(ud) +# testResult = (g(), g(), g()) +# np.testing.assert_array_equal( +# np.array(testResult), np.array(testResult2), +# err_msg='Wrong gamma random number sequence generated after reset(ud)') + +# # Check that two connected gamma deviates work correctly together. +# g2 = galsim.GammaDeviate(testseed, k=gammaK, theta=gammaTheta) +# g.reset(g2) +# testResult2 = (g(), g2(), g()) +# np.testing.assert_array_equal( +# np.array(testResult), np.array(testResult2), +# err_msg='Wrong gamma random number sequence generated using two gds') +# g.seed(testseed) +# testResult2 = (g2(), g(), g2()) +# np.testing.assert_array_equal( +# np.array(testResult), np.array(testResult2), +# err_msg='Wrong gamma random number sequence generated using two gds after seed') + +# # Check that seeding with the time works (although we cannot check the output). +# # We're mostly just checking that this doesn't raise an exception. +# # The output could be anything. +# g.seed() +# testResult2 = (g(), g(), g()) +# assert testResult2 != testResult +# g.reset() +# testResult3 = (g(), g(), g()) +# assert testResult3 != testResult +# assert testResult3 != testResult2 +# g.reset() +# testResult4 = (g(), g(), g()) +# assert testResult4 != testResult +# assert testResult4 != testResult2 +# assert testResult4 != testResult3 +# g = galsim.GammaDeviate(k=gammaK, theta=gammaTheta) +# testResult5 = (g(), g(), g()) +# assert testResult5 != testResult +# assert testResult5 != testResult2 +# assert testResult5 != testResult3 +# assert testResult5 != testResult4 + +# # Test generate +# g.seed(testseed) +# test_array = np.empty(3) +# g.generate(test_array) +# np.testing.assert_array_almost_equal( +# test_array, np.array(gammaResult), precision, +# err_msg='Wrong gamma random number sequence from generate.') + +# # Test generate with a float32 array +# g.seed(testseed) +# test_array = np.empty(3, dtype=np.float32) +# g.generate(test_array) +# np.testing.assert_array_almost_equal( +# test_array, np.array(gammaResult), precisionF, +# err_msg='Wrong gamma random number sequence from generate.') + +# # Check picklability +# do_pickle(g, lambda x: (x.serialize(), x.k, x.theta)) +# do_pickle(g, lambda x: (x(), x(), x(), x())) +# do_pickle(g) +# assert 'GammaDeviate' in repr(g) +# assert 'GammaDeviate' in str(g) +# assert isinstance(eval(repr(g)), galsim.GammaDeviate) +# assert isinstance(eval(str(g)), galsim.GammaDeviate) + +# # Check that we can construct a GammaDeviate from None, and that it depends on dev/random. +# g1 = galsim.GammaDeviate(None) +# g2 = galsim.GammaDeviate(None) +# assert g1 != g2, "Consecutive GammaDeviate(None) compared equal!" +# # We shouldn't be able to construct a GammaDeviate from anything but a BaseDeviate, int, str, +# # or None. +# assert_raises(TypeError, galsim.GammaDeviate, dict()) +# assert_raises(TypeError, galsim.GammaDeviate, list()) +# assert_raises(TypeError, galsim.GammaDeviate, set()) + + +# @timer +# def test_chi2(): +# """Test Chi^2 random number generator +# """ +# c = galsim.Chi2Deviate(testseed, n=chi2N) +# c2 = c.duplicate() +# c3 = galsim.Chi2Deviate(c.serialize(), n=chi2N) +# testResult = (c(), c(), c()) +# np.testing.assert_array_almost_equal( +# np.array(testResult), np.array(chi2Result), precision, +# err_msg='Wrong Chi^2 random number sequence generated') +# testResult = (c2(), c2(), c2()) +# np.testing.assert_array_almost_equal( +# np.array(testResult), np.array(chi2Result), precision, +# err_msg='Wrong Chi^2 random number sequence generated with duplicate') +# testResult = (c3(), c3(), c3()) +# np.testing.assert_array_almost_equal( +# np.array(testResult), np.array(chi2Result), precision, +# err_msg='Wrong Chi^2 random number sequence generated from serialize') + +# # Check that the mean and variance come out right +# c = galsim.Chi2Deviate(testseed, n=chi2N) +# vals = [c() for i in range(nvals)] +# mean = np.mean(vals) +# var = np.var(vals) +# mu = chi2N +# v = 2.*chi2N +# print('mean = ',mean,' true mean = ',mu) +# print('var = ',var,' true var = ',v) +# np.testing.assert_almost_equal(mean, mu, 1, +# err_msg='Wrong mean from Chi2Deviate') +# np.testing.assert_almost_equal(var, v, 0, +# err_msg='Wrong variance from Chi2Deviate') + +# # Check discard +# c2 = galsim.Chi2Deviate(testseed, n=chi2N) +# c2.discard(nvals, suppress_warnings=True) +# v1,v2 = c(),c2() +# print('after %d vals, next one is %s, %s'%(nvals,v1,v2)) +# # Chi2 uses at least 2 rngs per value, but can use arbitrarily more than this. +# assert v1 != v2 +# assert not c.has_reliable_discard +# assert not c.generates_in_pairs + +# # Discard normally emits a warning for Chi2 +# c2 = galsim.Chi2Deviate(testseed, n=chi2N) +# with assert_warns(galsim.GalSimWarning): +# c2.discard(nvals) + +# # Check seed, reset +# c.seed(testseed) +# testResult2 = (c(), c(), c()) +# np.testing.assert_array_equal( +# np.array(testResult), np.array(testResult2), +# err_msg='Wrong Chi^2 random number sequence generated after seed') + +# c.reset(testseed) +# testResult2 = (c(), c(), c()) +# np.testing.assert_array_equal( +# np.array(testResult), np.array(testResult2), +# err_msg='Wrong Chi^2 random number sequence generated after reset(seed)') + +# rng = galsim.BaseDeviate(testseed) +# c.reset(rng) +# testResult2 = (c(), c(), c()) +# np.testing.assert_array_equal( +# np.array(testResult), np.array(testResult2), +# err_msg='Wrong Chi^2 random number sequence generated after reset(rng)') + +# ud = galsim.UniformDeviate(testseed) +# c.reset(ud) +# testResult = (c(), c(), c()) +# np.testing.assert_array_equal( +# np.array(testResult), np.array(testResult2), +# err_msg='Wrong Chi^2 random number sequence generated after reset(ud)') + +# # Check that two connected Chi^2 deviates work correctly together. +# c2 = galsim.Chi2Deviate(testseed, n=chi2N) +# c.reset(c2) +# testResult2 = (c(), c2(), c()) +# np.testing.assert_array_equal( +# np.array(testResult), np.array(testResult2), +# err_msg='Wrong Chi^2 random number sequence generated using two cds') +# c.seed(testseed) +# testResult2 = (c2(), c(), c2()) +# np.testing.assert_array_equal( +# np.array(testResult), np.array(testResult2), +# err_msg='Wrong Chi^2 random number sequence generated using two cds after seed') + +# # Check that seeding with the time works (although we cannot check the output). +# # We're mostly just checking that this doesn't raise an exception. +# # The output could be anything. +# c.seed() +# testResult2 = (c(), c(), c()) +# assert testResult2 != testResult +# c.reset() +# testResult3 = (c(), c(), c()) +# assert testResult3 != testResult +# assert testResult3 != testResult2 +# c.reset() +# testResult4 = (c(), c(), c()) +# assert testResult4 != testResult +# assert testResult4 != testResult2 +# assert testResult4 != testResult3 +# c = galsim.Chi2Deviate(n=chi2N) +# testResult5 = (c(), c(), c()) +# assert testResult5 != testResult +# assert testResult5 != testResult2 +# assert testResult5 != testResult3 +# assert testResult5 != testResult4 + +# # Test generate +# c.seed(testseed) +# test_array = np.empty(3) +# c.generate(test_array) +# np.testing.assert_array_almost_equal( +# test_array, np.array(chi2Result), precision, +# err_msg='Wrong Chi^2 random number sequence from generate.') + +# # Test generate with a float32 array +# c.seed(testseed) +# test_array = np.empty(3, dtype=np.float32) +# c.generate(test_array) +# np.testing.assert_array_almost_equal( +# test_array, np.array(chi2Result), precisionF, +# err_msg='Wrong Chi^2 random number sequence from generate.') + +# # Check picklability +# do_pickle(c, lambda x: (x.serialize(), x.n)) +# do_pickle(c, lambda x: (x(), x(), x(), x())) +# do_pickle(c) +# assert 'Chi2Deviate' in repr(c) +# assert 'Chi2Deviate' in str(c) +# assert isinstance(eval(repr(c)), galsim.Chi2Deviate) +# assert isinstance(eval(str(c)), galsim.Chi2Deviate) + +# # Check that we can construct a Chi2Deviate from None, and that it depends on dev/random. +# c1 = galsim.Chi2Deviate(None) +# c2 = galsim.Chi2Deviate(None) +# assert c1 != c2, "Consecutive Chi2Deviate(None) compared equal!" +# # We shouldn't be able to construct a Chi2Deviate from anything but a BaseDeviate, int, str, +# # or None. +# assert_raises(TypeError, galsim.Chi2Deviate, dict()) +# assert_raises(TypeError, galsim.Chi2Deviate, list()) +# assert_raises(TypeError, galsim.Chi2Deviate, set()) + + +# @timer +# def test_distfunction(): +# """Test distribution-defined random number generator with a function +# """ +# # Make sure it requires an input function in order to work. +# assert_raises(TypeError, galsim.DistDeviate) +# # Make sure it does appropriate input sanity checks. +# assert_raises(TypeError, galsim.DistDeviate, +# function='../examples/data/cosmo-fid.zmed1.00_smoothed.out', +# x_min=1.) +# assert_raises(TypeError, galsim.DistDeviate, function=1.0) +# assert_raises(ValueError, galsim.DistDeviate, function='foo.dat') +# assert_raises(TypeError, galsim.DistDeviate, function = lambda x : x*x, interpolant='linear') +# assert_raises(TypeError, galsim.DistDeviate, function = lambda x : x*x) +# assert_raises(TypeError, galsim.DistDeviate, function = lambda x : x*x, x_min=1.) +# test_vals = range(10) +# assert_raises(TypeError, galsim.DistDeviate, +# function=galsim.LookupTable(test_vals, test_vals), +# x_min = 1.) +# foo = galsim.DistDeviate(10, galsim.LookupTable(test_vals, test_vals)) +# assert_raises(ValueError, foo.val, -1.) +# assert_raises(ValueError, galsim.DistDeviate, function = lambda x : -1, x_min=dmin, x_max=dmax) +# assert_raises(ValueError, galsim.DistDeviate, function = lambda x : x**2-1, x_min=dmin, x_max=dmax) + +# d = galsim.DistDeviate(testseed, function=dfunction, x_min=dmin, x_max=dmax) +# d2 = d.duplicate() +# d3 = galsim.DistDeviate(d.serialize(), function=dfunction, x_min=dmin, x_max=dmax) +# testResult = (d(), d(), d()) +# np.testing.assert_array_almost_equal( +# np.array(testResult), np.array(dFunctionResult), precision, +# err_msg='Wrong DistDeviate random number sequence generated') +# testResult = (d2(), d2(), d2()) +# np.testing.assert_array_almost_equal( +# np.array(testResult), np.array(dFunctionResult), precision, +# err_msg='Wrong DistDeviate random number sequence generated with duplicate') +# testResult = (d3(), d3(), d3()) +# np.testing.assert_array_almost_equal( +# np.array(testResult), np.array(dFunctionResult), precision, +# err_msg='Wrong DistDeviate random number sequence generated from serialize') + +# # Check val() method +# # pdf(x) = x^2 +# # cdf(x) = (x/2)^3 +# # val(y) = 2 y^(1/3) +# np.testing.assert_almost_equal(d.val(0), 0, 4) +# np.testing.assert_almost_equal(d.val(1), 2, 4) +# np.testing.assert_almost_equal(d.val(0.125), 1, 4) +# np.testing.assert_almost_equal(d.val(0.027), 0.6, 4) +# np.testing.assert_almost_equal(d.val(0.512), 1.6, 4) +# u = galsim.UniformDeviate(testseed) +# testResult = (d.val(u()), d.val(u()), d.val(u())) +# np.testing.assert_array_almost_equal( +# np.array(testResult), np.array(dFunctionResult), precision, +# err_msg='Wrong DistDeviate sequence using d.val(u())') + +# # Check that the mean and variance come out right +# d = galsim.DistDeviate(testseed, function=dfunction, x_min=dmin, x_max=dmax) +# vals = [d() for i in range(nvals)] +# mean = np.mean(vals) +# var = np.var(vals) +# mu = 3./2. +# v = 3./20. +# print('mean = ',mean,' true mean = ',mu) +# print('var = ',var,' true var = ',v) +# np.testing.assert_almost_equal(mean, mu, 1, +# err_msg='Wrong mean from DistDeviate random numbers using function') +# np.testing.assert_almost_equal(var, v, 1, +# err_msg='Wrong variance from DistDeviate random numbers using function') + +# # Check discard +# d2 = galsim.DistDeviate(testseed, function=dfunction, x_min=dmin, x_max=dmax) +# d2.discard(nvals) +# v1,v2 = d(),d2() +# print('after %d vals, next one is %s, %s'%(nvals,v1,v2)) +# assert v1 == v2 +# assert d.has_reliable_discard +# assert not d.generates_in_pairs + +# # Check seed, reset +# d.seed(testseed) +# testResult2 = (d(), d(), d()) +# np.testing.assert_array_equal( +# np.array(testResult), np.array(testResult2), +# err_msg='Wrong DistDeviate random number sequence generated after seed') + +# d.reset(testseed) +# testResult2 = (d(), d(), d()) +# np.testing.assert_array_equal( +# np.array(testResult), np.array(testResult2), +# err_msg='Wrong DistDeviate random number sequence generated after reset(seed)') + +# rng = galsim.BaseDeviate(testseed) +# d.reset(rng) +# testResult2 = (d(), d(), d()) +# np.testing.assert_array_equal( +# np.array(testResult), np.array(testResult2), +# err_msg='Wrong DistDeviate random number sequence generated after reset(rng)') + +# ud = galsim.UniformDeviate(testseed) +# d.reset(ud) +# testResult = (d(), d(), d()) +# np.testing.assert_array_equal( +# np.array(testResult), np.array(testResult2), +# err_msg='Wrong DistDeviate random number sequence generated after reset(ud)') + +# # Check that two connected DistDeviate deviates work correctly together. +# d2 = galsim.DistDeviate(testseed, function=dfunction, x_min=dmin, x_max=dmax) +# d.reset(d2) +# testResult2 = (d(), d2(), d()) +# np.testing.assert_array_equal( +# np.array(testResult), np.array(testResult2), +# err_msg='Wrong DistDeviate random number sequence generated using two dds') +# d.seed(testseed) +# testResult2 = (d2(), d(), d2()) +# np.testing.assert_array_equal( +# np.array(testResult), np.array(testResult2), +# err_msg='Wrong DistDeviate random number sequence generated using two dds after seed') + +# # Check that seeding with the time works (although we cannot check the output). +# # We're mostly just checking that this doesn't raise an exception. +# # The output could be anything. +# d.seed() +# testResult2 = (d(), d(), d()) +# assert testResult2 != testResult +# d.reset() +# testResult3 = (d(), d(), d()) +# assert testResult3 != testResult +# assert testResult3 != testResult2 +# d.reset() +# testResult4 = (d(), d(), d()) +# assert testResult4 != testResult +# assert testResult4 != testResult2 +# assert testResult4 != testResult3 +# d = galsim.DistDeviate(function=dfunction, x_min=dmin, x_max=dmax) +# testResult5 = (d(), d(), d()) +# assert testResult5 != testResult +# assert testResult5 != testResult2 +# assert testResult5 != testResult3 +# assert testResult5 != testResult4 + +# # Check with lambda function +# d = galsim.DistDeviate(testseed, function=lambda x: x*x, x_min=dmin, x_max=dmax) +# testResult = (d(), d(), d()) +# np.testing.assert_array_almost_equal( +# np.array(testResult), np.array(dFunctionResult), precision, +# err_msg='Wrong DistDeviate random number sequence generated with lambda function') + +# # Check auto-generated lambda function +# d = galsim.DistDeviate(testseed, function='x*x', x_min=dmin, x_max=dmax) +# testResult = (d(), d(), d()) +# np.testing.assert_array_almost_equal( +# np.array(testResult), np.array(dFunctionResult), precision, +# err_msg='Wrong DistDeviate random number sequence generated with auto-lambda function') + +# # Test generate +# d.seed(testseed) +# test_array = np.empty(3) +# d.generate(test_array) +# np.testing.assert_array_almost_equal( +# test_array, np.array(dFunctionResult), precision, +# err_msg='Wrong DistDeviate random number sequence from generate.') + +# # Test add_generate +# d.seed(testseed) +# d.add_generate(test_array) +# np.testing.assert_array_almost_equal( +# test_array, 2*np.array(dFunctionResult), precision, +# err_msg='Wrong DistDeviate random number sequence from add_generate.') + +# # Test generate with a float32 array +# d.seed(testseed) +# test_array = np.empty(3, dtype=np.float32) +# d.generate(test_array) +# np.testing.assert_array_almost_equal( +# test_array, np.array(dFunctionResult), precisionF, +# err_msg='Wrong DistDeviate random number sequence from generate.') + +# # Test add_generate with a float32 array +# d.seed(testseed) +# d.add_generate(test_array) +# np.testing.assert_array_almost_equal( +# test_array, 2*np.array(dFunctionResult), precisionF, +# err_msg='Wrong DistDeviate random number sequence from add_generate.') + +# # Check that generated values are independent of number of threads. +# d1 = galsim.DistDeviate(testseed, function=lambda x: np.exp(-x**3), x_min=0, x_max=2) +# d2 = galsim.DistDeviate(testseed, function=lambda x: np.exp(-x**3), x_min=0, x_max=2) +# v1 = np.empty(555) +# v2 = np.empty(555) +# with single_threaded(): +# d1.generate(v1) +# with single_threaded(num_threads=10): +# d2.generate(v2) +# np.testing.assert_array_equal(v1, v2) +# with single_threaded(): +# d1.add_generate(v1) +# with single_threaded(num_threads=10): +# d2.add_generate(v2) +# np.testing.assert_array_equal(v1, v2) + +# # Check picklability +# do_pickle(d, lambda x: (x(), x(), x(), x())) +# do_pickle(d) +# assert 'DistDeviate' in repr(d) +# assert 'DistDeviate' in str(d) +# assert isinstance(eval(repr(d)), galsim.DistDeviate) +# assert isinstance(eval(str(d)), galsim.DistDeviate) + +# # Check that we can construct a DistDeviate from None, and that it depends on dev/random. +# c1 = galsim.DistDeviate(None, lambda x:1, 0, 1) +# c2 = galsim.DistDeviate(None, lambda x:1, 0, 1) +# assert c1 != c2, "Consecutive DistDeviate(None) compared equal!" +# # We shouldn't be able to construct a DistDeviate from anything but a BaseDeviate, int, str, +# # or None. +# assert_raises(TypeError, galsim.DistDeviate, dict(), lambda x:1, 0, 1) +# assert_raises(TypeError, galsim.DistDeviate, list(), lambda x:1, 0, 1) +# assert_raises(TypeError, galsim.DistDeviate, set(), lambda x:1, 0, 1) + + +# @timer +# def test_distLookupTable(): +# """Test distribution-defined random number generator with a LookupTable +# """ +# precision = 9 +# # Note: 256 used to be the default, so this is a regression test +# # We check below that it works with the default npoints=None +# d = galsim.DistDeviate(testseed, function=dLookupTable, npoints=256) +# d2 = d.duplicate() +# d3 = galsim.DistDeviate(d.serialize(), function=dLookupTable, npoints=256) +# np.testing.assert_equal( +# d.x_min, dLookupTable.x_min, +# err_msg='DistDeviate and the LookupTable passed to it have different lower bounds') +# np.testing.assert_equal( +# d.x_max, dLookupTable.x_max, +# err_msg='DistDeviate and the LookupTable passed to it have different upper bounds') + +# testResult = (d(), d(), d()) +# print('testResult = ',testResult) +# np.testing.assert_array_almost_equal( +# np.array(testResult), np.array(dLookupTableResult), precision, +# err_msg='Wrong DistDeviate random number sequence using LookupTable') +# testResult = (d2(), d2(), d2()) +# np.testing.assert_array_almost_equal( +# np.array(testResult), np.array(dLookupTableResult), precision, +# err_msg='Wrong DistDeviate random number sequence using LookupTable with duplicate') +# testResult = (d3(), d3(), d3()) +# np.testing.assert_array_almost_equal( +# np.array(testResult), np.array(dLookupTableResult), precision, +# err_msg='Wrong DistDeviate random number sequence using LookupTable from serialize') + +# # Check that the mean and variance come out right +# d = galsim.DistDeviate(testseed, function=dLookupTable, npoints=256) +# vals = [d() for i in range(nvals)] +# mean = np.mean(vals) +# var = np.var(vals) +# mu = 2. +# v = 7./3. +# print('mean = ',mean,' true mean = ',mu) +# print('var = ',var,' true var = ',v) +# np.testing.assert_almost_equal(mean, mu, 1, +# err_msg='Wrong mean from DistDeviate random numbers using LookupTable') +# np.testing.assert_almost_equal(var, v, 1, +# err_msg='Wrong variance from DistDeviate random numbers using LookupTable') + +# # Check discard +# d2 = galsim.DistDeviate(testseed, function=dLookupTable, npoints=256) +# d2.discard(nvals) +# v1,v2 = d(),d2() +# print('after %d vals, next one is %s, %s'%(nvals,v1,v2)) +# assert v1 == v2 +# assert d.has_reliable_discard +# assert not d.generates_in_pairs + +# # This should give the same values with only 5 points because of the particular nature +# # of these arrays. +# d = galsim.DistDeviate(testseed, function=dLookupTable, npoints=5) +# testResult = (d(), d(), d()) +# np.testing.assert_array_almost_equal( +# np.array(testResult), np.array(dLookupTableResult), precision, +# err_msg='Wrong DistDeviate random number sequence for LookupTable with 5 points') + +# # And it should also work if npoints is None +# d = galsim.DistDeviate(testseed, function=dLookupTable) +# testResult = (d(), d(), d()) +# assert len(dLookupTable.x) == len(d._inverse_cdf.x) +# np.testing.assert_array_almost_equal( +# np.array(testResult), np.array(dLookupTableResult), precision, +# err_msg='Wrong DistDeviate random number sequence for LookupTable with npoints=None') + +# # Also read these values from a file +# d = galsim.DistDeviate(testseed, function=dLookupTableFile, interpolant='linear') +# testResult = (d(), d(), d()) +# np.testing.assert_array_almost_equal( +# np.array(testResult), np.array(dLookupTableResult), precision, +# err_msg='Wrong DistDeviate random number sequence for LookupTable from file') + +# d = galsim.DistDeviate(testseed, function=dLookupTableFile) +# testResult = (d(), d(), d()) +# np.testing.assert_array_almost_equal( +# np.array(testResult), np.array(dLookupTableResult), precision, +# err_msg='Wrong DistDeviate random number sequence for LookupTable with default ' +# 'interpolant') + +# # Test generate +# d.seed(testseed) +# test_array = np.empty(3) +# d.generate(test_array) +# np.testing.assert_array_almost_equal( +# test_array, np.array(dLookupTableResult), precision, +# err_msg='Wrong DistDeviate random number sequence from generate.') + +# # Test filling an image +# d.seed(testseed) +# testimage = galsim.ImageD(np.zeros((3, 1))) +# testimage.addNoise(galsim.DeviateNoise(d)) +# np.testing.assert_array_almost_equal( +# testimage.array.flatten(), np.array(dLookupTableResult), precision, +# err_msg='Wrong DistDeviate random number sequence generated when applied to image.') + +# # Test generate +# d.seed(testseed) +# test_array = np.empty(3) +# d.generate(test_array) +# np.testing.assert_array_almost_equal( +# test_array, np.array(dLookupTableResult), precision, +# err_msg='Wrong DistDeviate random number sequence from generate.') + +# # Test generate with a float32 array +# d.seed(testseed) +# test_array = np.empty(3, dtype=np.float32) +# d.generate(test_array) +# np.testing.assert_array_almost_equal( +# test_array, np.array(dLookupTableResult), precisionF, +# err_msg='Wrong DistDeviate random number sequence from generate.') + +# # Test a case with nearly flat probabilities +# # x and p arrays with and without a small (epsilon) step +# dx_eps = np.arange(6) +# dp1_eps = np.zeros(dx_eps.shape) +# dp2_eps = np.zeros(dx_eps.shape) +# eps = np.finfo(dp1_eps[0].dtype).eps +# dp1_eps[0] = 0.5 +# dp2_eps[0] = 0.5 +# dp1_eps[-1] = 0.5 +# dp2_eps[-2] = eps +# dp2_eps[-1] = 0.5-eps +# dLookupTableEps1 = galsim.LookupTable(x=dx_eps, f=dp1_eps, interpolant='linear') +# dLookupTableEps2 = galsim.LookupTable(x=dx_eps, f=dp2_eps, interpolant='linear') +# d1 = galsim.DistDeviate(testseed, function=dLookupTableEps1, npoints=len(dx_eps)) +# d2 = galsim.DistDeviate(testseed, function=dLookupTableEps2, npoints=len(dx_eps)) +# # If these were successfully created everything is probably fine, but check they create the same +# # internal LookupTable +# np.testing.assert_array_almost_equal( +# d1._inverse_cdf.getArgs(), d2._inverse_cdf.getArgs(), precision, +# err_msg='DistDeviate with near-flat probabilities incorrectly created ' +# 'a monotonic version of the CDF') +# np.testing.assert_array_almost_equal( +# d1._inverse_cdf.getVals(), d2._inverse_cdf.getVals(), precision, +# err_msg='DistDeviate with near-flat probabilities incorrectly created ' +# 'a monotonic version of the CDF') + +# # And that they generate the same values +# ar1 = np.empty(100); d1.generate(ar1) +# ar2 = np.empty(100); d2.generate(ar2) +# np.testing.assert_array_almost_equal(ar1, ar2, precision, +# err_msg='Two DistDeviates with near-flat probabilities generated different values.') + +# # Check picklability +# do_pickle(d, lambda x: (x(), x(), x(), x())) +# do_pickle(d) +# assert 'DistDeviate' in repr(d) +# assert 'DistDeviate' in str(d) +# assert isinstance(eval(repr(d)), galsim.DistDeviate) +# assert isinstance(eval(str(d)), galsim.DistDeviate) + + +# @timer +# def test_multiprocess(): +# """Test that the same random numbers are generated in single-process and multi-process modes. +# """ +# from multiprocessing import current_process +# from multiprocessing import get_context +# ctx = get_context('fork') +# Process = ctx.Process +# Queue = ctx.Queue + +# def generate_list(seed): +# """Given a particular seed value, generate a list of random numbers. +# Should be deterministic given the input seed value. +# """ +# rng = galsim.UniformDeviate(seed) +# out = [] +# for i in range(20): +# out.append(rng()) +# return out + +# def worker(input, output): +# """input is a queue with seed values +# output is a queue storing the results of the tasks along with the process name, +# and which args the result is for. +# """ +# for args in iter(input.get, 'STOP'): +# result = generate_list(*args) +# output.put( (result, current_process().name, args) ) + +# # Use sequential numbers. +# # On inspection, can see that even the first value in each list is random with +# # respect to the other lists. i.e. "nearby" inputs do not produce nearby outputs. +# # I don't know of an actual assert to do for this, but it is clearly true. +# seeds = [ 1532424 + i for i in range(16) ] + +# nproc = 4 # Each process will do 4 lists (typically) + +# # First make lists in the single process: +# ref_lists = dict() +# for seed in seeds: +# list = generate_list(seed) +# ref_lists[seed] = list + +# # Now do this with multiprocessing +# # Put the seeds in a queue +# task_queue = Queue() +# for seed in seeds: +# task_queue.put( [seed] ) + +# # Run the tasks: +# done_queue = Queue() +# for k in range(nproc): +# Process(target=worker, args=(task_queue, done_queue)).start() + +# # Check the results in the order they finished +# for i in range(len(seeds)): +# list, proc, args = done_queue.get() +# seed = args[0] +# np.testing.assert_array_equal( +# list, ref_lists[seed], +# err_msg="Random numbers are different when using multiprocessing") + +# # Stop the processes: +# for k in range(nproc): +# task_queue.put('STOP') + + +# @timer +# def test_permute(): +# """Simple tests of the permute() function.""" +# # Make a fake list, and another list consisting of indices. +# my_list = [3.7, 4.1, 1.9, 11.1, 378.3, 100.0] +# import copy +# my_list_copy = copy.deepcopy(my_list) +# n_list = len(my_list) +# ind_list = list(range(n_list)) + +# # Permute both at the same time. +# galsim.random.permute(312, my_list, ind_list) + +# # Make sure that everything is sensible +# for ind in range(n_list): +# assert my_list_copy[ind_list[ind]] == my_list[ind] + +# # Repeat with same seed, should do same permutation. +# my_list = copy.deepcopy(my_list_copy) +# galsim.random.permute(312, my_list) +# for ind in range(n_list): +# assert my_list_copy[ind_list[ind]] == my_list[ind] + +# # permute with no lists should raise TypeError +# with assert_raises(TypeError): +# galsim.random.permute(312) + + +# @timer +# def test_ne(): +# """ Check that inequality works as expected for corner cases where the reprs of two +# unequal BaseDeviates may be the same due to truncation. +# """ +# a = galsim.BaseDeviate(seed='1 2 3 4 5 6 7 8 9 10') +# b = galsim.BaseDeviate(seed='1 2 3 7 6 5 4 8 9 10') +# assert repr(a) == repr(b) +# assert a != b + +# # Check DistDeviate separately, since it overrides __repr__ and __eq__ +# d1 = galsim.DistDeviate(seed=a, function=galsim.LookupTable([1, 2, 3], [4, 5, 6])) +# d2 = galsim.DistDeviate(seed=b, function=galsim.LookupTable([1, 2, 3], [4, 5, 6])) +# assert repr(d1) == repr(d2) +# assert d1 != d2 + +# @timer +# def test_int64(): +# # cf. #1009 +# # Check that various possible integer types work as seeds. + +# rng1 = galsim.BaseDeviate(int(123)) +# # cf. https://www.numpy.org/devdocs/user/basics.types.html +# ivalues =[np.int8(123), # Note this one requires i < 128 +# np.int16(123), +# np.int32(123), +# np.int64(123), +# np.uint8(123), +# np.uint16(123), +# np.uint32(123), +# np.uint64(123), +# np.short(123), +# np.ushort(123), +# np.intc(123), +# np.uintc(123), +# np.intp(123), +# np.uintp(123), +# np.int_(123), +# np.longlong(123), +# np.ulonglong(123), +# np.array(123).astype(np.int64)] + +# for i in ivalues: +# rng2 = galsim.BaseDeviate(i) +# assert rng2 == rng1 + +# @timer +# def test_numpy_generator(): +# rng = galsim.BaseDeviate(1234) +# gen = galsim.BaseDeviate(1234).as_numpy_generator() + +# # The regular (and somewhat cumbersome) GalSim way: +# a1 = np.empty(10, dtype=float) +# galsim.UniformDeviate(rng).generate(a1) +# a1 *= 9. +# a1 += 1. + +# # The nicer numpy syntax +# a2 = gen.uniform(1.,10., size=10) +# print('a1 = ',a1) +# print('a2 = ',a2) +# np.testing.assert_array_equal(a1, a2) + +# # Can also use the np property as a quick shorthand +# a1 = rng.np.normal(0, 10, size=20) +# a2 = gen.normal(0, 10, size=20) +# print('a1 = ',a1) +# print('a2 = ',a2) +# np.testing.assert_array_equal(a1, a2) + +# # Check that normal gives statistically the right mean/var. +# # (Numpy's normal uses the next_uint64 function, so this is a non-trivial test of that +# # code, which I originally got wrong.) +# a3 = gen.normal(17, 23, size=1_000_000) +# print('mean = ',np.mean(a3)) +# print('std = ',np.std(a3)) +# assert np.isclose(np.mean(a3), 17, rtol=1.e-3) +# assert np.isclose(np.std(a3), 23, rtol=3.e-3) + +# if __name__ == "__main__": +# testfns = [v for k, v in vars().items() if k[:5] == 'test_' and callable(v)] +# for testfn in testfns: +# testfn() diff --git a/tests/jax/galsim/test_wcs_jax.py b/tests/jax/galsim/test_wcs_jax.py index 49f6d8d4..59cd392c 100644 --- a/tests/jax/galsim/test_wcs_jax.py +++ b/tests/jax/galsim/test_wcs_jax.py @@ -5,9 +5,10 @@ import time import warnings -import galsim import numpy as np -from galsim_test_helpers import * +from galsim_test_helpers import assert_raises, do_pickle, gsobject_compare, timer + +import jax_galsim as galsim # These positions will be used a few times below, so define them here. # One of the tests requires that the last pair are integers, so don't change that. @@ -476,7 +477,9 @@ def do_wcs_image(wcs, name, approx=False): # Use the "blank" image as our test image. It's not blank in the sense of having all # zeros. Rather, there are basically random values that we can use to test that # the shifted values are correct. And it is a conveniently small-ish, non-square image. - dir = "fits_files" + dir = os.path.join( + os.path.dirname(__file__), "..", "..", "GalSim", "tests", "fits_files" + ) file_name = "blankimg.fits" im = galsim.fits.read(file_name, dir=dir) np.testing.assert_equal(im.origin.x, 1, "initial origin is not 1,1 as expected") @@ -910,7 +913,7 @@ def do_jac_decomp(wcs, name): M = scale * S.dot(R).dot(F) J = wcs.getMatrix() - np.testing.assert_almost_equal( + np.testing.assert_array_almost_equal( M, J, 8, "Decomposition was inconsistent with jacobian for " + name ) @@ -3448,7 +3451,7 @@ def test_fittedsipwcs(): "ZTF": (0.1, 0.1), } - dir = "fits_files" + dir = os.path.join(os.path.dirname(__file__), "..", "..", "GalSim/tests/fits_files") if __name__ == "__main__": test_tags = all_tags From da19b771b013bc4942ebe543cfb9aeb2505c3e9f Mon Sep 17 00:00:00 2001 From: beckermr Date: Sun, 8 Oct 2023 16:54:23 -0400 Subject: [PATCH 02/33] STY blacken --- jax_galsim/random.py | 41 +++++++++++++++++++++++++---------------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/jax_galsim/random.py b/jax_galsim/random.py index 5755f333..b61c2862 100644 --- a/jax_galsim/random.py +++ b/jax_galsim/random.py @@ -1,7 +1,6 @@ import secrets import galsim as _galsim - import jax import jax.numpy as jnp import jax.random as jrandom @@ -15,7 +14,6 @@ from jax_galsim.core.utils import ensure_hashable - LAX_FUNCTIONAL_RNG = ( "The JAX version of the this class is purely function and thus cannot " "share state with any other version of this class. Also no type checking is done on the inputs." @@ -35,7 +33,10 @@ class BaseDeviate: def __init__(self, seed=None): self.reset(seed=seed) - @_wraps(_galsim.BaseDeviate.seed, lax_description="The JAX version of this method does no type checking.") + @_wraps( + _galsim.BaseDeviate.seed, + lax_description="The JAX version of this method does no type checking.", + ) def seed(self, seed=0): self._seed(seed=seed) @@ -50,7 +51,7 @@ def _seed(self, seed=0): "The JAX version of this method does no type checking. Also, the JAX version of this " "class cannot be linked to another JAX version of this class so ``reset`` is equivalent " "to ``seed``. If another ``BaseDeviate`` is supplied, that deviates current state is used." - ) + ), ) def reset(self, seed=None): if isinstance(seed, BaseDeviate): @@ -71,11 +72,15 @@ def _reset(self, rng): @property @_wraps(_galsim.BaseDeviate.np) def np(self): - raise NotImplementedError("The JAX galsim.BaseDeviate does not support being used as a numpy PRNG.") + raise NotImplementedError( + "The JAX galsim.BaseDeviate does not support being used as a numpy PRNG." + ) @_wraps(_galsim.BaseDeviate.as_numpy_generator) def as_numpy_generator(self): - raise NotImplementedError("The JAX galsim.BaseDeviate does not support being used as a numpy PRNG.") + raise NotImplementedError( + "The JAX galsim.BaseDeviate does not support being used as a numpy PRNG." + ) @_wraps(_galsim.BaseDeviate.duplicate) def duplicate(self): @@ -86,7 +91,10 @@ def duplicate(self): def __copy__(self): return self.duplicate() - @_wraps(_galsim.BaseDeviate.clearCache, lax_description="This method is a no-op for the JAX version of this class.") + @_wraps( + _galsim.BaseDeviate.clearCache, + lax_description="This method is a no-op for the JAX version of this class.", + ) def clearCache(self): pass @@ -95,7 +103,7 @@ def clearCache(self): lax_description=( "The JAX version of this class has reliable discarding and uses one key per value " "so it never generates in pairs. Thus this method will never raise an error." - ) + ), ) def discard(self, n, suppress_warnings=False): def _discard(i, key): @@ -120,7 +128,7 @@ def raw(self): lax_description=( "JAX arrays cannot be changed in-place, so the JAX version of " "this method returns a new array." - ) + ), ) def generate(self, array): self._key, array = self.__class__._generate(self._key, array) @@ -131,7 +139,7 @@ def generate(self, array): lax_description=( "JAX arrays cannot be changed in-place, so the JAX version of " "this method returns a new array." - ) + ), ) def add_generate(self, array): return self.generate(array) + array @@ -141,11 +149,10 @@ def __call__(self): return val def __eq__(self, other): - return ( - self is other - or ( - isinstance(other, self.__class__) - and jnp.array_equal(jrandom.key_data(self._key), jrandom.key_data(other._key)) + return self is other or ( + isinstance(other, self.__class__) + and jnp.array_equal( + jrandom.key_data(self._key), jrandom.key_data(other._key) ) ) @@ -189,7 +196,9 @@ def tree_unflatten(cls, aux_data, children): class UniformDeviate(BaseDeviate): def _generate(key, array): # we do it this way so that the RNG appears to have a fixed state that is advanced per value drawn - key, res = jax.lax.scan(UniformDeviate._generate_one, key, None, length=array.ravel().shape[0]) + key, res = jax.lax.scan( + UniformDeviate._generate_one, key, None, length=array.ravel().shape[0] + ) return key, res.reshape(array.shape) @jax.jit From 9f7be37740db21fb9abd1fd04b0fdd1ee4bc4531 Mon Sep 17 00:00:00 2001 From: beckermr Date: Sun, 8 Oct 2023 16:59:22 -0400 Subject: [PATCH 03/33] bug in pre-commit --- .pre-commit-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1d952b2e..f91dc766 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,9 +11,9 @@ repos: - id: flake8 entry: pflake8 additional_dependencies: [pyproject-flake8] - exclude: tests/Galsim|tests/Coord/|tests/jax/galsim/ + exclude: tests/Galsim/|tests/Coord/|tests/jax/galsim/ - repo: https://github.com/pycqa/isort rev: 5.12.0 hooks: - id: isort - exclude: tests/Galsim|tests/Coord/|tests/jax/galsim/ + exclude: tests/Galsim/|tests/Coord/|tests/jax/galsim/ From f11fde71b8be2e79d704aca1769b1e1885389690 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 9 Oct 2023 13:26:02 -0500 Subject: [PATCH 04/33] ENH working gaussian random deviates --- jax_galsim/__init__.py | 2 +- jax_galsim/random.py | 177 +++++++------ tests/jax/galsim/test_random_jax.py | 392 ++++++++++++++-------------- 3 files changed, 297 insertions(+), 274 deletions(-) diff --git a/jax_galsim/__init__.py b/jax_galsim/__init__.py index e2d91e4d..b90a11d1 100644 --- a/jax_galsim/__init__.py +++ b/jax_galsim/__init__.py @@ -8,7 +8,7 @@ from .errors import GalSimWarning, GalSimDeprecationWarning # noise -from .random import BaseDeviate, UniformDeviate +from .random import BaseDeviate, UniformDeviate, GaussianDeviate # Basic building blocks from .bounds import Bounds, BoundsD, BoundsI diff --git a/jax_galsim/random.py b/jax_galsim/random.py index b61c2862..1b362097 100644 --- a/jax_galsim/random.py +++ b/jax_galsim/random.py @@ -32,6 +32,7 @@ class BaseDeviate: def __init__(self, seed=None): self.reset(seed=seed) + self._params = {} @_wraps( _galsim.BaseDeviate.seed, @@ -50,7 +51,7 @@ def _seed(self, seed=0): lax_description=( "The JAX version of this method does no type checking. Also, the JAX version of this " "class cannot be linked to another JAX version of this class so ``reset`` is equivalent " - "to ``seed``. If another ``BaseDeviate`` is supplied, that deviates current state is used." + "to ``seed``. If another ``BaseDeviate`` is supplied, that deviate's current state is used." ), ) def reset(self, seed=None): @@ -69,6 +70,9 @@ def reset(self, seed=None): def _reset(self, rng): self._key = rng._key + def serialize(self): + return repr(ensure_hashable(jrandom.key_data(self._key))) + @property @_wraps(_galsim.BaseDeviate.np) def np(self): @@ -82,15 +86,6 @@ def as_numpy_generator(self): "The JAX galsim.BaseDeviate does not support being used as a numpy PRNG." ) - @_wraps(_galsim.BaseDeviate.duplicate) - def duplicate(self): - ret = BaseDeviate.__new__(self.__class__) - ret._key = self._key - return ret - - def __copy__(self): - return self.duplicate() - @_wraps( _galsim.BaseDeviate.clearCache, lax_description="This method is a no-op for the JAX version of this class.", @@ -106,11 +101,15 @@ def clearCache(self): ), ) def discard(self, n, suppress_warnings=False): - def _discard(i, key): + self._key = self.__class__._discard(self._key, n) + + @jax.jit + def _discard(key, n): + def __discard(i, key): key, subkey = jrandom.split(key) return key - self._key = jax.lax.fori_loop(0, n, _discard, self._key) + return jax.lax.fori_loop(0, n, __discard, key) @_wraps( _galsim.BaseDeviate.raw, @@ -148,12 +147,23 @@ def __call__(self): self._key, val = self.__class__._generate_one(self._key, None) return val + @_wraps(_galsim.BaseDeviate.duplicate) + def duplicate(self): + ret = self.__class__.__new__(self.__class__) + ret._key = self._key + ret._params = self._params.copy() + return ret + + def __copy__(self): + return self.duplicate() + def __eq__(self, other): return self is other or ( isinstance(other, self.__class__) and jnp.array_equal( jrandom.key_data(self._key), jrandom.key_data(other._key) ) + and self._params == other._params ) def __ne__(self, other): @@ -161,23 +171,11 @@ def __ne__(self, other): __hash__ = None - def serialize(self): - return repr(ensure_hashable(jrandom.key_data(self._key))) - - def __repr__(self): - return "galsim.%s(%r) " % ( - self.__class__.__name__, - ensure_hashable(jrandom.key_data(self._key)), - ) - - def __str__(self): - return self.__repr__() - def tree_flatten(self): """This function flattens the PRNG into a list of children nodes that will be traced by JAX and auxiliary static data.""" # Define the children nodes of the PyTree that need tracing - children = (jrandom.key_data(self._key),) + children = (jrandom.key_data(self._key), self._params) # Define auxiliary static data that doesn’t need to be traced aux_data = {} return (children, aux_data) @@ -185,7 +183,15 @@ def tree_flatten(self): @classmethod def tree_unflatten(cls, aux_data, children): """Recreates an instance of the class from flatten representation""" - return cls(children[0]) + return cls(children[0], **(children[1])) + + def __repr__(self): + return "galsim.BaseDeviate(seed=%r) " % ( + ensure_hashable(jrandom.key_data(self._key)), + ) + + def __str__(self): + return self.__repr__() @_wraps( @@ -206,75 +212,80 @@ def _generate_one(key, x): _key, subkey = jrandom.split(key) return _key, jrandom.uniform(subkey, dtype=float) + def __repr__(self): + return "galsim.UniformDeviate(seed=%r) " % ( + ensure_hashable(jrandom.key_data(self._key)), + ) -# class GaussianDeviate(BaseDeviate): -# """Pseudo-random number generator with Gaussian distribution. + def __str__(self): + return "galsim.UniformDeviate()" -# See http://en.wikipedia.org/wiki/Gaussian_distribution for further details. -# Successive calls to ``g()`` generate pseudo-random values distributed according to a Gaussian -# distribution with the provided ``mean``, ``sigma``:: +@_wraps( + _galsim.GaussianDeviate, + lax_description=LAX_FUNCTIONAL_RNG, +) +@register_pytree_node_class +class GaussianDeviate(BaseDeviate): + def __init__(self, seed=None, mean=0.0, sigma=1.0): + super().__init__(seed=seed) + self._params["mean"] = mean + self._params["sigma"] = sigma -# >>> g = galsim.GaussianDeviate(31415926) -# >>> g() -# 0.5533754000847082 -# >>> g() -# 1.0218588970190354 + @property + def mean(self): + """The mean of the Gaussian distribution.""" + return self._params["mean"] -# Parameters: -# seed: Something that can seed a `BaseDeviate`: an integer seed or another -# `BaseDeviate`. Using 0 means to generate a seed from the system. -# [default: None] -# mean: Mean of Gaussian distribution. [default: 0.] -# sigma: Sigma of Gaussian distribution. [default: 1.; Must be > 0] -# """ -# def __init__(self, seed=None, mean=0., sigma=1.): -# if sigma < 0.: -# raise GalSimRangeError("GaussianDeviate sigma must be > 0.", sigma, 0.) -# self._rng_type = _galsim.GaussianDeviateImpl -# self._rng_args = (float(mean), float(sigma)) -# self.reset(seed) + @property + def sigma(self): + """The sigma of the Gaussian distribution.""" + return self._params["sigma"] -# @property -# def mean(self): -# """The mean of the Gaussian distribution. -# """ -# return self._rng_args[0] + @_wraps( + _galsim.BaseDeviate.generate, + lax_description=( + "JAX arrays cannot be changed in-place, so the JAX version of " + "this method returns a new array." + ), + ) + def generate(self, array): + self._key, array = self.__class__._generate(self._key, array) + return array * self.sigma + self.mean -# @property -# def sigma(self): -# """The sigma of the Gaussian distribution. -# """ -# return self._rng_args[1] + def _generate(key, array): + # we do it this way so that the RNG appears to have a fixed state that is advanced per value drawn + key, res = jax.lax.scan( + GaussianDeviate._generate_one, key, None, length=array.ravel().shape[0] + ) + return key, res.reshape(array.shape) -# @property -# def generates_in_pairs(self): -# return True + def __call__(self): + self._key, val = self.__class__._generate_one(self._key, None) + return val * self.sigma + self.mean -# def __call__(self): -# """Draw a new random number from the distribution. + @jax.jit + def _generate_one(key, x): + _key, subkey = jrandom.split(key) + return _key, jrandom.normal(subkey, dtype=float) -# Returns a Gaussian deviate with the given mean and sigma. -# """ -# return self._rng.generate1() + @_wraps(_galsim.GaussianDeviate.generate_from_variance) + def generate_from_variance(self, array): + self._key, _array = self.__class__._generate(self._key, array) + return _array * jnp.sqrt(array) -# def generate_from_variance(self, array): -# """Generate many Gaussian deviate values using the existing array values as the -# variance for each. -# """ -# array_1d = np.ascontiguousarray(array.ravel(), dtype=float) -# #assert(array_1d.strides[0] == array_1d.itemsize) -# _a = array_1d.__array_interface__['data'][0] -# self._rng.generate_from_variance(len(array_1d), _a) -# if array_1d.data != array.data: -# # array_1d is not a view into the original array. Need to copy back. -# np.copyto(array, array_1d.reshape(array.shape), casting='unsafe') + def __repr__(self): + return "galsim.GaussianDeviate(seed=%r, mean=%r, sigma=%r)" % ( + ensure_hashable(jrandom.key_data(self._key)), + ensure_hashable(self.mean), + ensure_hashable(self.sigma), + ) -# def __repr__(self): -# return 'galsim.GaussianDeviate(seed=%r, mean=%r, sigma=%r)'%( -# self._seed_repr(), self.mean, self.sigma) -# def __str__(self): -# return 'galsim.GaussianDeviate(mean=%r, sigma=%r)'%(self.mean, self.sigma) + def __str__(self): + return "galsim.GaussianDeviate(mean=%r, sigma=%r)" % ( + ensure_hashable(self.mean), + ensure_hashable(self.sigma), + ) # class BinomialDeviate(BaseDeviate): diff --git a/tests/jax/galsim/test_random_jax.py b/tests/jax/galsim/test_random_jax.py index ed057b51..bcfccda1 100644 --- a/tests/jax/galsim/test_random_jax.py +++ b/tests/jax/galsim/test_random_jax.py @@ -27,7 +27,7 @@ gMean = 4.7 gSigma = 3.2 # the right answer for the first three Gaussian deviates produced from testseed -gResult = (6.3344979808161215, 6.2082355273987861, -0.069894693358302007) +gResult = (-2.1568953985, 2.3232138032, 1.5308165692) # N, p to use for binomial tests bN = 10 @@ -300,208 +300,220 @@ def test_uniform(): # assert_raises(TypeError, u.seed, 12.3) -# @timer -# def test_gaussian(): -# """Test Gaussian random number generator -# """ -# g = galsim.GaussianDeviate(testseed, mean=gMean, sigma=gSigma) -# g2 = g.duplicate() -# g3 = galsim.GaussianDeviate(g.serialize(), mean=gMean, sigma=gSigma) -# testResult = (g(), g(), g()) -# np.testing.assert_array_almost_equal( -# np.array(testResult), np.array(gResult), precision, -# err_msg='Wrong Gaussian random number sequence generated') -# testResult = (g2(), g2(), g2()) -# np.testing.assert_array_almost_equal( -# np.array(testResult), np.array(gResult), precision, -# err_msg='Wrong Gaussian random number sequence generated with duplicate') -# testResult = (g3(), g3(), g3()) -# np.testing.assert_array_almost_equal( -# np.array(testResult), np.array(gResult), precision, -# err_msg='Wrong Gaussian random number sequence generated from serialize') - -# # Check that the mean and variance come out right -# g = galsim.GaussianDeviate(testseed, mean=gMean, sigma=gSigma) -# vals = [g() for i in range(nvals)] -# mean = np.mean(vals) -# var = np.var(vals) -# mu = gMean -# v = gSigma**2 -# print('mean = ',mean,' true mean = ',mu) -# print('var = ',var,' true var = ',v) -# np.testing.assert_almost_equal(mean, mu, 1, -# err_msg='Wrong mean from GaussianDeviate') -# np.testing.assert_almost_equal(var, v, 0, -# err_msg='Wrong variance from GaussianDeviate') +@timer +def test_gaussian(): + """Test Gaussian random number generator + """ + g = galsim.GaussianDeviate(testseed, mean=gMean, sigma=gSigma) + g2 = g.duplicate() + g3 = galsim.GaussianDeviate(g.serialize(), mean=gMean, sigma=gSigma) + testResult = (g(), g(), g()) + np.testing.assert_array_almost_equal( + np.array(testResult), np.array(gResult), precision, + err_msg='Wrong Gaussian random number sequence generated') + testResult = (g2(), g2(), g2()) + np.testing.assert_array_almost_equal( + np.array(testResult), np.array(gResult), precision, + err_msg='Wrong Gaussian random number sequence generated with duplicate') + testResult = (g3(), g3(), g3()) + np.testing.assert_array_almost_equal( + np.array(testResult), np.array(gResult), precision, + err_msg='Wrong Gaussian random number sequence generated from serialize') -# # Check discard -# g2 = galsim.GaussianDeviate(testseed, mean=gMean, sigma=gSigma) -# g2.discard(nvals) -# v1,v2 = g(),g2() -# print('after %d vals, next one is %s, %s'%(nvals,v1,v2)) -# assert v1 == v2 -# # Note: For Gaussian, this only works if nvals is even. -# g2 = galsim.GaussianDeviate(testseed, mean=gMean, sigma=gSigma) -# g2.discard(nvals+1, suppress_warnings=True) -# v1,v2 = g(),g2() -# print('after %d vals, next one is %s, %s'%(nvals+1,v1,v2)) -# assert v1 != v2 -# assert g.has_reliable_discard -# assert g.generates_in_pairs + # Check that the mean and variance come out right + g = galsim.GaussianDeviate(testseed, mean=gMean, sigma=gSigma) + vals = [g() for i in range(nvals)] + mean = np.mean(vals) + var = np.var(vals) + mu = gMean + v = gSigma**2 + print('mean = ', mean, ' true mean = ', mu) + print('var = ', var, ' true var = ', v) + np.testing.assert_almost_equal( + mean, mu, 1, + err_msg='Wrong mean from GaussianDeviate') + np.testing.assert_almost_equal( + var, v, 0, + err_msg='Wrong variance from GaussianDeviate') -# # If don't explicitly suppress the warning, then a warning is emitted when n is odd. -# g2 = galsim.GaussianDeviate(testseed, mean=gMean, sigma=gSigma) -# with assert_warns(galsim.GalSimWarning): -# g2.discard(nvals+1) + # Check discard + g2 = galsim.GaussianDeviate(testseed, mean=gMean, sigma=gSigma) + g2.discard(nvals) + v1, v2 = g(), g2() + print('after %d vals, next one is %s, %s' % (nvals, v1, v2)) + assert v1 == v2 + # NOTE: JAX doesn't appear to have this issue + # Note: For Gaussian, this only works if nvals is even. + # g2 = galsim.GaussianDeviate(testseed, mean=gMean, sigma=gSigma) + # g2.discard(nvals+1, suppress_warnings=True) + # v1,v2 = g(),g2() + # print('after %d vals, next one is %s, %s'%(nvals+1,v1,v2)) + # assert v1 != v2 + assert g.has_reliable_discard + # NOTE changed to NOT here for JAX + assert not g.generates_in_pairs + + # NOTE: JAX doesn't warn for this + # If don't explicitly suppress the warning, then a warning is emitted when n is odd. + # g2 = galsim.GaussianDeviate(testseed, mean=gMean, sigma=gSigma) + # with assert_warns(galsim.GalSimWarning): + # g2.discard(nvals+1) -# # Check seed, reset -# g.seed(testseed) -# testResult2 = (g(), g(), g()) -# np.testing.assert_array_equal( -# np.array(testResult), np.array(testResult2), -# err_msg='Wrong Gaussian random number sequence generated after seed') + # Check seed, reset + g.seed(testseed) + testResult2 = (g(), g(), g()) + np.testing.assert_array_equal( + np.array(testResult), np.array(testResult2), + err_msg='Wrong Gaussian random number sequence generated after seed') -# g.reset(testseed) -# testResult2 = (g(), g(), g()) -# np.testing.assert_array_equal( -# np.array(testResult), np.array(testResult2), -# err_msg='Wrong Gaussian random number sequence generated after reset(seed)') + g.reset(testseed) + testResult2 = (g(), g(), g()) + np.testing.assert_array_equal( + np.array(testResult), np.array(testResult2), + err_msg='Wrong Gaussian random number sequence generated after reset(seed)') -# rng = galsim.BaseDeviate(testseed) -# g.reset(rng) -# testResult2 = (g(), g(), g()) -# np.testing.assert_array_equal( -# np.array(testResult), np.array(testResult2), -# err_msg='Wrong Gaussian random number sequence generated after reset(rng)') + rng = galsim.BaseDeviate(testseed) + g.reset(rng) + testResult2 = (g(), g(), g()) + np.testing.assert_array_equal( + np.array(testResult), np.array(testResult2), + err_msg='Wrong Gaussian random number sequence generated after reset(rng)') -# ud = galsim.UniformDeviate(testseed) -# g.reset(ud) -# testResult = (g(), g(), g()) -# np.testing.assert_array_equal( -# np.array(testResult), np.array(testResult2), -# err_msg='Wrong Gaussian random number sequence generated after reset(ud)') + ud = galsim.UniformDeviate(testseed) + g.reset(ud) + testResult = (g(), g(), g()) + np.testing.assert_array_equal( + np.array(testResult), np.array(testResult2), + err_msg='Wrong Gaussian random number sequence generated after reset(ud)') + + # NOTE jax doesn't allow connected RNGs + # # Check that two connected Gaussian deviates work correctly together. + # g2 = galsim.GaussianDeviate(testseed, mean=gMean, sigma=gSigma) + # g.reset(g2) + # # Note: GaussianDeviate generates two values at a time, so we have to compare them in pairs. + # testResult2 = (g(), g(), g2()) + # np.testing.assert_array_equal( + # np.array(testResult), np.array(testResult2), + # err_msg='Wrong Gaussian random number sequence generated using two gds') + # g.seed(testseed) + # # For the same reason, after seeding one, we need to manually clear the other's cache: + # g2.clearCache() + # testResult2 = (g2(), g2(), g()) + # np.testing.assert_array_equal( + # np.array(testResult), np.array(testResult2), + # err_msg='Wrong Gaussian random number sequence generated using two gds after seed') -# # Check that two connected Gaussian deviates work correctly together. -# g2 = galsim.GaussianDeviate(testseed, mean=gMean, sigma=gSigma) -# g.reset(g2) -# # Note: GaussianDeviate generates two values at a time, so we have to compare them in pairs. -# testResult2 = (g(), g(), g2()) -# np.testing.assert_array_equal( -# np.array(testResult), np.array(testResult2), -# err_msg='Wrong Gaussian random number sequence generated using two gds') -# g.seed(testseed) -# # For the same reason, after seeding one, we need to manually clear the other's cache: -# g2.clearCache() -# testResult2 = (g2(), g2(), g()) -# np.testing.assert_array_equal( -# np.array(testResult), np.array(testResult2), -# err_msg='Wrong Gaussian random number sequence generated using two gds after seed') + # Check that seeding with the time works (although we cannot check the output). + # We're mostly just checking that this doesn't raise an exception. + # The output could be anything. + g.seed() + testResult2 = (g(), g(), g()) + assert testResult2 != testResult + g.reset() + testResult3 = (g(), g(), g()) + assert testResult3 != testResult + assert testResult3 != testResult2 + g.reset() + testResult4 = (g(), g(), g()) + assert testResult4 != testResult + assert testResult4 != testResult2 + assert testResult4 != testResult3 + g = galsim.GaussianDeviate(mean=gMean, sigma=gSigma) + testResult5 = (g(), g(), g()) + assert testResult5 != testResult + assert testResult5 != testResult2 + assert testResult5 != testResult3 + assert testResult5 != testResult4 -# # Check that seeding with the time works (although we cannot check the output). -# # We're mostly just checking that this doesn't raise an exception. -# # The output could be anything. -# g.seed() -# testResult2 = (g(), g(), g()) -# assert testResult2 != testResult -# g.reset() -# testResult3 = (g(), g(), g()) -# assert testResult3 != testResult -# assert testResult3 != testResult2 -# g.reset() -# testResult4 = (g(), g(), g()) -# assert testResult4 != testResult -# assert testResult4 != testResult2 -# assert testResult4 != testResult3 -# g = galsim.GaussianDeviate(mean=gMean, sigma=gSigma) -# testResult5 = (g(), g(), g()) -# assert testResult5 != testResult -# assert testResult5 != testResult2 -# assert testResult5 != testResult3 -# assert testResult5 != testResult4 + # Test generate + g.seed(testseed) + test_array = np.empty(3) + test_array.fill(np.nan) + test_array = g.generate(test_array) + np.testing.assert_array_almost_equal( + test_array, np.array(gResult), precision, + err_msg='Wrong Gaussian random number sequence from generate.') -# # Test generate -# g.seed(testseed) -# test_array = np.empty(3) -# g.generate(test_array) -# np.testing.assert_array_almost_equal( -# test_array, np.array(gResult), precision, -# err_msg='Wrong Gaussian random number sequence from generate.') - -# # Test generate_from_variance. -# g2 = galsim.GaussianDeviate(testseed, mean=5, sigma=0.3) -# g3 = galsim.GaussianDeviate(testseed, mean=5, sigma=0.3) -# test_array.fill(gSigma**2) -# g2.generate_from_variance(test_array) -# np.testing.assert_array_almost_equal( -# test_array, np.array(gResult)-gMean, precision, -# err_msg='Wrong Gaussian random number sequence from generate_from_variance.') -# # After running generate_from_variance, it should be back to using the specified mean, sigma. -# # Note: need to round up to even number for discard, since gd generates 2 at a time. -# g3.discard((len(test_array)+1)//2 * 2) -# print('g2,g3 = ',g2(),g3()) -# assert g2() == g3() - -# # Test generate with a float32 array. -# g.seed(testseed) -# test_array = np.empty(3, dtype=np.float32) -# g.generate(test_array) -# np.testing.assert_array_almost_equal( -# test_array, np.array(gResult), precisionF, -# err_msg='Wrong Gaussian random number sequence from generate.') + # Test generate_from_variance. + g2 = galsim.GaussianDeviate(testseed, mean=5, sigma=0.3) + g3 = galsim.GaussianDeviate(testseed, mean=5, sigma=0.3) + test_array = np.empty(3) + test_array.fill(gSigma**2) + test_array = g2.generate_from_variance(test_array) + np.testing.assert_array_almost_equal( + test_array, np.array(gResult) - gMean, precision, + err_msg='Wrong Gaussian random number sequence from generate_from_variance.') + # NOTE JAX can use the array shape here + # After running generate_from_variance, it should be back to using the specified mean, sigma. + # Note: need to round up to even number for discard, since gd generates 2 at a time. + g3.discard(len(test_array)) + print('g2,g3 = ', g2(), g3()) + assert g2() == g3() + + # Test generate with a float32 array. + g.seed(testseed) + test_array = np.empty(3, dtype=np.float32) + test_array.fill(np.nan) + test_array = g.generate(test_array) + np.testing.assert_array_almost_equal( + test_array, np.array(gResult), precisionF, + err_msg='Wrong Gaussian random number sequence from generate.') -# # Test generate_from_variance. -# g2.seed(testseed) -# test_array.fill(gSigma**2) -# g2.generate_from_variance(test_array) -# np.testing.assert_array_almost_equal( -# test_array, np.array(gResult)-gMean, precisionF, -# err_msg='Wrong Gaussian random number sequence from generate_from_variance.') + # Test generate_from_variance. + g2.seed(testseed) + test_array = np.empty(3, dtype=np.float32) + test_array.fill(gSigma**2) + test_array = g2.generate_from_variance(test_array) + np.testing.assert_array_almost_equal( + test_array, np.array(gResult) - gMean, precisionF, + err_msg='Wrong Gaussian random number sequence from generate_from_variance.') -# # Check that generated values are independent of number of threads. -# g1 = galsim.GaussianDeviate(testseed, mean=53, sigma=1.3) -# g2 = galsim.GaussianDeviate(testseed, mean=53, sigma=1.3) -# v1 = np.empty(555) -# v2 = np.empty(555) -# with single_threaded(): -# g1.generate(v1) -# with single_threaded(num_threads=10): -# g2.generate(v2) -# np.testing.assert_array_equal(v1, v2) -# with single_threaded(): -# g1.add_generate(v1) -# with single_threaded(num_threads=10): -# g2.add_generate(v2) -# np.testing.assert_array_equal(v1, v2) -# ud = galsim.UniformDeviate(testseed + 3) -# ud.generate(v1) -# v1 += 6.7 -# v2[:] = v1 -# with single_threaded(): -# g1.generate_from_variance(v1) -# with single_threaded(num_threads=10): -# g2.generate_from_variance(v2) -# np.testing.assert_array_equal(v1, v2) + # Check that generated values are independent of number of threads. + g1 = galsim.GaussianDeviate(testseed, mean=53, sigma=1.3) + g2 = galsim.GaussianDeviate(testseed, mean=53, sigma=1.3) + v1 = np.empty(555) + v2 = np.empty(555) + with single_threaded(): + v1 = g1.generate(v1) + with single_threaded(num_threads=10): + v2 = g2.generate(v2) + np.testing.assert_array_equal(v1, v2) + with single_threaded(): + v1 = g1.add_generate(v1) + with single_threaded(num_threads=10): + v2 = g2.add_generate(v2) + np.testing.assert_array_equal(v1, v2) + ud = galsim.UniformDeviate(testseed + 3) + ud.generate(v1) + v1 += 6.7 + v2 = v1.copy() + with single_threaded(): + v1 = g1.generate_from_variance(v1) + with single_threaded(num_threads=10): + v2 = g2.generate_from_variance(v2) + np.testing.assert_array_equal(v1, v2) -# # Check picklability -# do_pickle(g, lambda x: (x.serialize(), x.mean, x.sigma)) -# do_pickle(g, lambda x: (x(), x(), x(), x())) -# do_pickle(g) -# assert 'GaussianDeviate' in repr(g) -# assert 'GaussianDeviate' in str(g) -# assert isinstance(eval(repr(g)), galsim.GaussianDeviate) -# assert isinstance(eval(str(g)), galsim.GaussianDeviate) - -# # Check that we can construct a GaussianDeviate from None, and that it depends on dev/random. -# g1 = galsim.GaussianDeviate(None) -# g2 = galsim.GaussianDeviate(None) -# assert g1 != g2, "Consecutive GaussianDeviate(None) compared equal!" -# # We shouldn't be able to construct a GaussianDeviate from anything but a BaseDeviate, int, str, -# # or None. -# assert_raises(TypeError, galsim.GaussianDeviate, dict()) -# assert_raises(TypeError, galsim.GaussianDeviate, list()) -# assert_raises(TypeError, galsim.GaussianDeviate, set()) + # Check picklability + do_pickle(g, lambda x: (x.serialize(), x.mean, x.sigma)) + do_pickle(g, lambda x: (x(), x(), x(), x())) + do_pickle(g) + assert 'GaussianDeviate' in repr(g) + assert 'GaussianDeviate' in str(g) + assert isinstance(eval(repr(g)), galsim.GaussianDeviate) + assert isinstance(eval(str(g)), galsim.GaussianDeviate) + + # Check that we can construct a GaussianDeviate from None, and that it depends on dev/random. + g1 = galsim.GaussianDeviate(None) + g2 = galsim.GaussianDeviate(None) + assert g1 != g2, "Consecutive GaussianDeviate(None) compared equal!" -# assert_raises(ValueError, galsim.GaussianDeviate, testseed, mean=1, sigma=-1) + # NOTE: We do not test for these since we do no type checking in JAX + # We shouldn't be able to construct a GaussianDeviate from anything but a BaseDeviate, int, str, + # or None. + # assert_raises(TypeError, galsim.GaussianDeviate, dict()) + # assert_raises(TypeError, galsim.GaussianDeviate, list()) + # assert_raises(TypeError, galsim.GaussianDeviate, set()) + # assert_raises(ValueError, galsim.GaussianDeviate, testseed, mean=1, sigma=-1) # @timer From 1233f0d9934dd5972aaa12bb65025f9136d6a5bc Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 9 Oct 2023 20:15:23 -0500 Subject: [PATCH 05/33] ENH add poisson --- jax_galsim/__init__.py | 2 +- jax_galsim/random.py | 118 ++++----- tests/jax/galsim/test_random_jax.py | 362 ++++++++++++++-------------- 3 files changed, 251 insertions(+), 231 deletions(-) diff --git a/jax_galsim/__init__.py b/jax_galsim/__init__.py index b90a11d1..e77d507e 100644 --- a/jax_galsim/__init__.py +++ b/jax_galsim/__init__.py @@ -8,7 +8,7 @@ from .errors import GalSimWarning, GalSimDeprecationWarning # noise -from .random import BaseDeviate, UniformDeviate, GaussianDeviate +from .random import BaseDeviate, UniformDeviate, GaussianDeviate, PoissonDeviate # Basic building blocks from .bounds import Bounds, BoundsD, BoundsI diff --git a/jax_galsim/random.py b/jax_galsim/random.py index 1b362097..72855a19 100644 --- a/jax_galsim/random.py +++ b/jax_galsim/random.py @@ -200,6 +200,7 @@ def __str__(self): ) @register_pytree_node_class class UniformDeviate(BaseDeviate): + @jax.jit def _generate(key, array): # we do it this way so that the RNG appears to have a fixed state that is advanced per value drawn key, res = jax.lax.scan( @@ -253,6 +254,7 @@ def generate(self, array): self._key, array = self.__class__._generate(self._key, array) return array * self.sigma + self.mean + @jax.jit def _generate(key, array): # we do it this way so that the RNG appears to have a fixed state that is advanced per value drawn key, res = jax.lax.scan( @@ -341,70 +343,76 @@ def __str__(self): # return 'galsim.BinomialDeviate(N=%r, p=%r)'%(self.n, self.p) -# class PoissonDeviate(BaseDeviate): -# """Pseudo-random Poisson deviate with specified ``mean``. - -# The input ``mean`` sets the mean and variance of the Poisson deviate. An integer deviate with -# this distribution is returned after each call. -# See http://en.wikipedia.org/wiki/Poisson_distribution for more details. +@_wraps( + _galsim.PoissonDeviate, + lax_description=LAX_FUNCTIONAL_RNG, +) +@register_pytree_node_class +class PoissonDeviate(BaseDeviate): + def __init__(self, seed=None, mean=1.0): + super().__init__(seed=seed) + self._params["mean"] = mean -# Successive calls to ``p()`` generate pseudo-random integer values distributed according to a -# Poisson distribution with the specified ``mean``:: + @property + def mean(self): + """The mean of the Gaussian distribution.""" + return self._params["mean"] -# >>> p = galsim.PoissonDeviate(31415926, mean=100) -# >>> p() -# 94 -# >>> p() -# 106 + @_wraps( + _galsim.BaseDeviate.generate, + lax_description=( + "JAX arrays cannot be changed in-place, so the JAX version of " + "this method returns a new array." + ), + ) + def generate(self, array): + self._key, array = self.__class__._generate(self._key, array, self.mean) + return array -# Parameters: -# seed: Something that can seed a `BaseDeviate`: an integer seed or another -# `BaseDeviate`. Using 0 means to generate a seed from the system. -# [default: None] -# mean: Mean of the distribution. [default: 1; Must be > 0] -# """ -# def __init__(self, seed=None, mean=1.): -# if mean < 0: -# raise GalSimValueError("PoissonDeviate is only defined for mean >= 0.", mean) -# self._rng_type = _galsim.PoissonDeviateImpl -# self._rng_args = (float(mean),) -# self.reset(seed) + @jax.jit + def _generate(key, array, mean): + # we do it this way so that the RNG appears to have a fixed state that is advanced per value drawn + key, res = jax.lax.scan( + PoissonDeviate._generate_one, + key, + jnp.broadcast_to(mean, array.ravel().shape), + length=array.ravel().shape[0], + ) + return key, res.reshape(array.shape) -# @property -# def mean(self): -# """The mean of the distribution. -# """ -# return self._rng_args[0] + def __call__(self): + self._key, val = self.__class__._generate_one(self._key, self.mean) + return val -# @property -# def has_reliable_discard(self): -# return False + @jax.jit + def _generate_one(key, mean): + _key, subkey = jrandom.split(key) + return _key, jrandom.poisson(subkey, mean, dtype=int) -# def __call__(self): -# """Draw a new random number from the distribution. + @_wraps(_galsim.PoissonDeviate.generate_from_expectation) + def generate_from_expectation(self, array): + self._key, _array = self.__class__._generate_from_exp(self._key, array) + return _array -# Returns a Poisson deviate with the given mean. -# """ -# return self._rng.generate1() + @jax.jit + def _generate_from_exp(key, array): + # we do it this way so that the RNG appears to have a fixed state that is advanced per value drawn + key, res = jax.lax.scan( + PoissonDeviate._generate_one, + key, + array.ravel(), + length=array.ravel().shape[0], + ) + return key, res.reshape(array.shape) -# def generate_from_expectation(self, array): -# """Generate many Poisson deviate values using the existing array values as the -# expectation value (aka mean) for each. -# """ -# if np.any(array < 0): -# raise GalSimValueError("Expectation array may not have values < 0.", array) -# array_1d = np.ascontiguousarray(array.ravel(), dtype=float) -# #assert(array_1d.strides[0] == array_1d.itemsize) -# _a = array_1d.__array_interface__['data'][0] -# self._rng.generate_from_expectation(len(array_1d), _a) -# if array_1d.data != array.data: -# # array_1d is not a view into the original array. Need to copy back. -# np.copyto(array, array_1d.reshape(array.shape), casting='unsafe') + def __repr__(self): + return "galsim.PoissonDeviate(seed=%r, mean=%r)" % ( + ensure_hashable(jrandom.key_data(self._key)), + ensure_hashable(self.mean), + ) -# def __repr__(self): -# return 'galsim.PoissonDeviate(seed=%r, mean=%r)'%(self._seed_repr(), self.mean) -# def __str__(self): -# return 'galsim.PoissonDeviate(mean=%r)'%(self.mean) + def __str__(self): + return "galsim.PoissonDeviate(mean=%r)" % (ensure_hashable(self.mean),) # class WeibullDeviate(BaseDeviate): diff --git a/tests/jax/galsim/test_random_jax.py b/tests/jax/galsim/test_random_jax.py index bcfccda1..ae98b496 100644 --- a/tests/jax/galsim/test_random_jax.py +++ b/tests/jax/galsim/test_random_jax.py @@ -38,7 +38,7 @@ # mean to use for Poisson tests pMean = 7 # the right answer for the first three Poisson deviates produced from testseed -pResult = (4, 5, 6) +pResult = (6, 11, 4) # a & b to use for Weibull tests wA = 4.0 @@ -675,195 +675,207 @@ def test_gaussian(): # assert_raises(TypeError, galsim.BinomialDeviate, set()) -# @timer -# def test_poisson(): -# """Test Poisson random number generator -# """ -# p = galsim.PoissonDeviate(testseed, mean=pMean) -# p2 = p.duplicate() -# p3 = galsim.PoissonDeviate(p.serialize(), mean=pMean) -# testResult = (p(), p(), p()) -# np.testing.assert_array_almost_equal( -# np.array(testResult), np.array(pResult), precision, -# err_msg='Wrong Poisson random number sequence generated') -# testResult = (p2(), p2(), p2()) -# np.testing.assert_array_almost_equal( -# np.array(testResult), np.array(pResult), precision, -# err_msg='Wrong Poisson random number sequence generated with duplicate') -# testResult = (p3(), p3(), p3()) -# np.testing.assert_array_almost_equal( -# np.array(testResult), np.array(pResult), precision, -# err_msg='Wrong Poisson random number sequence generated from serialize') +@timer +def test_poisson(): + """Test Poisson random number generator + """ + p = galsim.PoissonDeviate(testseed, mean=pMean) + p2 = p.duplicate() + p3 = galsim.PoissonDeviate(p.serialize(), mean=pMean) + testResult = (p(), p(), p()) + np.testing.assert_array_almost_equal( + np.array(testResult), np.array(pResult), precision, + err_msg='Wrong Poisson random number sequence generated') + testResult = (p2(), p2(), p2()) + np.testing.assert_array_almost_equal( + np.array(testResult), np.array(pResult), precision, + err_msg='Wrong Poisson random number sequence generated with duplicate') + testResult = (p3(), p3(), p3()) + np.testing.assert_array_almost_equal( + np.array(testResult), np.array(pResult), precision, + err_msg='Wrong Poisson random number sequence generated from serialize') -# # Check that the mean and variance come out right -# p = galsim.PoissonDeviate(testseed, mean=pMean) -# vals = [p() for i in range(nvals)] -# mean = np.mean(vals) -# var = np.var(vals) -# mu = pMean -# v = pMean -# print('mean = ',mean,' true mean = ',mu) -# print('var = ',var,' true var = ',v) -# np.testing.assert_almost_equal(mean, mu, 1, -# err_msg='Wrong mean from PoissonDeviate') -# np.testing.assert_almost_equal(var, v, 1, -# err_msg='Wrong variance from PoissonDeviate') + # Check that the mean and variance come out right + p = galsim.PoissonDeviate(testseed, mean=pMean) + vals = [p() for i in range(nvals)] + mean = np.mean(vals) + var = np.var(vals) + mu = pMean + v = pMean + print('mean = ', mean, ' true mean = ', mu) + print('var = ', var, ' true var = ', v) + np.testing.assert_almost_equal( + mean, mu, 1, + err_msg='Wrong mean from PoissonDeviate') + np.testing.assert_almost_equal( + var, v, 1, + err_msg='Wrong variance from PoissonDeviate') -# # Check discard -# p2 = galsim.PoissonDeviate(testseed, mean=pMean) -# p2.discard(nvals, suppress_warnings=True) -# v1,v2 = p(),p2() -# print('With mean = %d, after %d vals, next one is %s, %s'%(pMean,nvals,v1,v2)) -# assert v1 == v2 + # Check discard + p2 = galsim.PoissonDeviate(testseed, mean=pMean) + p2.discard(nvals, suppress_warnings=True) + v1, v2 = p(), p2() + print('With mean = %d, after %d vals, next one is %s, %s' % (pMean, nvals, v1, v2)) + assert v1 == v2 -# # With a very small mean value, Poisson reliably only uses 1 rng per value. -# # But at only slightly larger means, it sometimes uses two rngs for a single value. -# # Basically anything >= 10 causes this next test to have v1 != v2 -# high_mean = 10 -# p = galsim.PoissonDeviate(testseed, mean=high_mean) -# p2 = galsim.PoissonDeviate(testseed, mean=high_mean) -# vals = [p() for i in range(nvals)] -# p2.discard(nvals, suppress_warnings=True) -# v1,v2 = p(),p2() -# print('With mean = %d, after %d vals, next one is %s, %s'%(high_mean,nvals,v1,v2)) -# assert v1 != v2 -# assert not p.has_reliable_discard -# assert not p.generates_in_pairs + # NOTE: the JAX RNGs have reliabel discard + # With a very small mean value, Poisson reliably only uses 1 rng per value. + # But at only slightly larger means, it sometimes uses two rngs for a single value. + # Basically anything >= 10 causes this next test to have v1 != v2 + high_mean = 10 + p = galsim.PoissonDeviate(testseed, mean=high_mean) + p2 = galsim.PoissonDeviate(testseed, mean=high_mean) + vals = [p() for i in range(nvals)] + p2.discard(nvals, suppress_warnings=True) + v1, v2 = p(), p2() + print('With mean = %d, after %d vals, next one is %s, %s' % (high_mean, nvals, v1, v2)) + assert v1 == v2 + assert p.has_reliable_discard + assert not p.generates_in_pairs -# # Discard normally emits a warning for Poisson -# p2 = galsim.PoissonDeviate(testseed, mean=pMean) -# with assert_warns(galsim.GalSimWarning): -# p2.discard(nvals) + # NOTE jax-galsim doesn't do this + # Discard normally emits a warning for Poisson + # p2 = galsim.PoissonDeviate(testseed, mean=pMean) + # with assert_warns(galsim.GalSimWarning): + # p2.discard(nvals) -# # Check seed, reset -# p = galsim.PoissonDeviate(testseed, mean=pMean) -# p.seed(testseed) -# testResult2 = (p(), p(), p()) -# np.testing.assert_array_equal( -# np.array(testResult), np.array(testResult2), -# err_msg='Wrong poisson random number sequence generated after seed') + # Check seed, reset + p = galsim.PoissonDeviate(testseed, mean=pMean) + p.seed(testseed) + testResult2 = (p(), p(), p()) + np.testing.assert_array_equal( + np.array(testResult), np.array(testResult2), + err_msg='Wrong poisson random number sequence generated after seed') -# p.reset(testseed) -# testResult2 = (p(), p(), p()) -# np.testing.assert_array_equal( -# np.array(testResult), np.array(testResult2), -# err_msg='Wrong poisson random number sequence generated after reset(seed)') + p.reset(testseed) + testResult2 = (p(), p(), p()) + np.testing.assert_array_equal( + np.array(testResult), np.array(testResult2), + err_msg='Wrong poisson random number sequence generated after reset(seed)') -# rng = galsim.BaseDeviate(testseed) -# p.reset(rng) -# testResult2 = (p(), p(), p()) -# np.testing.assert_array_equal( -# np.array(testResult), np.array(testResult2), -# err_msg='Wrong poisson random number sequence generated after reset(rng)') + rng = galsim.BaseDeviate(testseed) + p.reset(rng) + testResult2 = (p(), p(), p()) + np.testing.assert_array_equal( + np.array(testResult), np.array(testResult2), + err_msg='Wrong poisson random number sequence generated after reset(rng)') -# ud = galsim.UniformDeviate(testseed) -# p.reset(ud) -# testResult = (p(), p(), p()) -# np.testing.assert_array_equal( -# np.array(testResult), np.array(testResult2), -# err_msg='Wrong poisson random number sequence generated after reset(ud)') + ud = galsim.UniformDeviate(testseed) + p.reset(ud) + testResult = (p(), p(), p()) + np.testing.assert_array_equal( + np.array(testResult), np.array(testResult2), + err_msg='Wrong poisson random number sequence generated after reset(ud)') -# # Check that two connected poisson deviates work correctly together. -# p2 = galsim.PoissonDeviate(testseed, mean=pMean) -# p.reset(p2) -# testResult2 = (p(), p2(), p()) -# np.testing.assert_array_equal( -# np.array(testResult), np.array(testResult2), -# err_msg='Wrong poisson random number sequence generated using two pds') -# p.seed(testseed) -# testResult2 = (p2(), p(), p2()) -# np.testing.assert_array_equal( -# np.array(testResult), np.array(testResult2), -# err_msg='Wrong poisson random number sequence generated using two pds after seed') + # NOTE: jax-galsim doesn't have connected RNGs + # # Check that two connected poisson deviates work correctly together. + # p2 = galsim.PoissonDeviate(testseed, mean=pMean) + # p.reset(p2) + # testResult2 = (p(), p2(), p()) + # np.testing.assert_array_equal( + # np.array(testResult), np.array(testResult2), + # err_msg='Wrong poisson random number sequence generated using two pds') + # p.seed(testseed) + # testResult2 = (p2(), p(), p2()) + # np.testing.assert_array_equal( + # np.array(testResult), np.array(testResult2), + # err_msg='Wrong poisson random number sequence generated using two pds after seed') -# # Check that seeding with the time works (although we cannot check the output). -# # We're mostly just checking that this doesn't raise an exception. -# # The output could be anything. However, in this case, there are few enough options -# # for the output that occasionally two of these match. So we don't do the normal -# # testResult2 != testResult, etc. -# p.seed() -# testResult2 = (p(), p(), p()) -# #assert testResult2 != testResult -# p.reset() -# testResult3 = (p(), p(), p()) -# #assert testResult3 != testResult -# #assert testResult3 != testResult2 -# p.reset() -# testResult4 = (p(), p(), p()) -# #assert testResult4 != testResult -# #assert testResult4 != testResult2 -# #assert testResult4 != testResult3 -# p = galsim.PoissonDeviate(mean=pMean) -# testResult5 = (p(), p(), p()) -# #assert testResult5 != testResult -# #assert testResult5 != testResult2 -# #assert testResult5 != testResult3 -# #assert testResult5 != testResult4 + # Check that seeding with the time works (although we cannot check the output). + # We're mostly just checking that this doesn't raise an exception. + # The output could be anything. However, in this case, there are few enough options + # for the output that occasionally two of these match. So we don't do the normal + # testResult2 != testResult, etc. + p.seed() + testResult2 = (p(), p(), p()) + p.reset() + testResult3 = (p(), p(), p()) + p.reset() + testResult4 = (p(), p(), p()) + p = galsim.PoissonDeviate(mean=pMean) + testResult5 = (p(), p(), p()) + assert ( + (testResult2 != testResult) + or (testResult3 != testResult) + or (testResult4 != testResult) + or (testResult5 != testResult) + ) + try: + assert testResult3 != testResult2 + assert testResult4 != testResult2 + assert testResult4 != testResult3 + assert testResult5 != testResult2 + assert testResult5 != testResult3 + assert testResult5 != testResult4 + except AssertionError: + print("one of the poisson results was equal but this can happen occasionally") -# # Test generate -# p.seed(testseed) -# test_array = np.empty(3) -# p.generate(test_array) -# np.testing.assert_array_almost_equal( -# test_array, np.array(pResult), precision, -# err_msg='Wrong poisson random number sequence from generate.') + # Test generate + p.seed(testseed) + test_array = np.empty(3) + test_array.fill(np.nan) + test_array = p.generate(test_array) + np.testing.assert_array_almost_equal( + test_array, np.array(pResult), precision, + err_msg='Wrong poisson random number sequence from generate.') -# # Test generate with an int array -# p.seed(testseed) -# test_array = np.empty(3, dtype=int) -# p.generate(test_array) -# np.testing.assert_array_almost_equal( -# test_array, np.array(pResult), precisionI, -# err_msg='Wrong poisson random number sequence from generate.') + # Test generate with an int array + p.seed(testseed) + test_array = np.empty(3, dtype=int) + test_array = p.generate(test_array) + np.testing.assert_array_almost_equal( + test_array, np.array(pResult), precisionI, + err_msg='Wrong poisson random number sequence from generate.') -# # Test generate_from_expectation -# p2 = galsim.PoissonDeviate(testseed, mean=77) -# test_array = np.array([pMean]*3, dtype=int) -# p2.generate_from_expectation(test_array) -# np.testing.assert_array_almost_equal( -# test_array, np.array(pResult), precisionI, -# err_msg='Wrong poisson random number sequence from generate_from_expectation.') -# # After generating, it should be back to mean=77 -# test_array2 = np.array([p2() for i in range(100)]) -# print('test_array2 = ',test_array2) -# print('mean = ',test_array2.mean()) -# assert np.isclose(test_array2.mean(), 77, atol=2) + # Test generate_from_expectation + p2 = galsim.PoissonDeviate(testseed, mean=77) + test_array = np.array([pMean] * 3, dtype=int) + test_array = p2.generate_from_expectation(test_array) + np.testing.assert_array_almost_equal( + test_array, np.array(pResult), precisionI, + err_msg='Wrong poisson random number sequence from generate_from_expectation.') + # After generating, it should be back to mean=77 + test_array2 = np.array([p2() for i in range(100)]) + print('test_array2 = ', test_array2) + print('mean = ', test_array2.mean()) + assert np.isclose(test_array2.mean(), 77, atol=2) -# # Check that generated values are independent of number of threads. -# # This should be trivial, since Poisson disables multi-threading, but check anyway. -# p1 = galsim.PoissonDeviate(testseed, mean=77) -# p2 = galsim.PoissonDeviate(testseed, mean=77) -# v1 = np.empty(555) -# v2 = np.empty(555) -# with single_threaded(): -# p1.generate(v1) -# with single_threaded(num_threads=10): -# p2.generate(v2) -# np.testing.assert_array_equal(v1, v2) -# with single_threaded(): -# p1.add_generate(v1) -# with single_threaded(num_threads=10): -# p2.add_generate(v2) -# np.testing.assert_array_equal(v1, v2) + # Check that generated values are independent of number of threads. + # This should be trivial, since Poisson disables multi-threading, but check anyway. + p1 = galsim.PoissonDeviate(testseed, mean=77) + p2 = galsim.PoissonDeviate(testseed, mean=77) + v1 = np.empty(555) + v2 = np.empty(555) + with single_threaded(): + v1 = p1.generate(v1) + with single_threaded(num_threads=10): + v2 = p2.generate(v2) + np.testing.assert_array_equal(v1, v2) + with single_threaded(): + v1 = p1.add_generate(v1) + with single_threaded(num_threads=10): + v2 = p2.add_generate(v2) + np.testing.assert_array_equal(v1, v2) -# # Check picklability -# do_pickle(p, lambda x: (x.serialize(), x.mean)) -# do_pickle(p, lambda x: (x(), x(), x(), x())) -# do_pickle(p) -# assert 'PoissonDeviate' in repr(p) -# assert 'PoissonDeviate' in str(p) -# assert isinstance(eval(repr(p)), galsim.PoissonDeviate) -# assert isinstance(eval(str(p)), galsim.PoissonDeviate) - -# # Check that we can construct a PoissonDeviate from None, and that it depends on dev/random. -# p1 = galsim.PoissonDeviate(None) -# p2 = galsim.PoissonDeviate(None) -# assert p1 != p2, "Consecutive PoissonDeviate(None) compared equal!" -# # We shouldn't be able to construct a PoissonDeviate from anything but a BaseDeviate, int, str, -# # or None. -# assert_raises(TypeError, galsim.PoissonDeviate, dict()) -# assert_raises(TypeError, galsim.PoissonDeviate, list()) -# assert_raises(TypeError, galsim.PoissonDeviate, set()) + # Check picklability + do_pickle(p, lambda x: (x.serialize(), x.mean)) + do_pickle(p, lambda x: (x(), x(), x(), x())) + do_pickle(p) + assert 'PoissonDeviate' in repr(p) + assert 'PoissonDeviate' in str(p) + assert isinstance(eval(repr(p)), galsim.PoissonDeviate) + assert isinstance(eval(str(p)), galsim.PoissonDeviate) + + # Check that we can construct a PoissonDeviate from None, and that it depends on dev/random. + p1 = galsim.PoissonDeviate(None) + p2 = galsim.PoissonDeviate(None) + assert p1 != p2, "Consecutive PoissonDeviate(None) compared equal!" + # NOTE we do not use type checking in JAX + # # We shouldn't be able to construct a PoissonDeviate from anything but a BaseDeviate, int, str, + # # or None. + # assert_raises(TypeError, galsim.PoissonDeviate, dict()) + # assert_raises(TypeError, galsim.PoissonDeviate, list()) + # assert_raises(TypeError, galsim.PoissonDeviate, set()) # @timer From f7d4070600bb9bc6e72ed4cafe140b4683ab4b9a Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 9 Oct 2023 20:21:02 -0500 Subject: [PATCH 06/33] PROD make black checks match --- .github/workflows/python_package.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python_package.yaml b/.github/workflows/python_package.yaml index 1e0cacae..2e6c35b9 100644 --- a/.github/workflows/python_package.yaml +++ b/.github/workflows/python_package.yaml @@ -27,7 +27,7 @@ jobs: python -m pip install . - name: Ensure black formatting run: | - black --check jax_galsim/ tests/ --exclude tests/GalSim/ + black --check jax_galsim/ tests/ --exclude "tests/GalSim/|tests/Coord/|tests/jax/galsim/" - name: Lint with flake8 run: | # stop the build if there are Python syntax errors or undefined names From ffad7b9fecc625287528fbbf09efeaabe47548ce Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 9 Oct 2023 20:26:41 -0500 Subject: [PATCH 07/33] STY please the flake8 --- tests/jax/galsim/test_wcs_jax.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/jax/galsim/test_wcs_jax.py b/tests/jax/galsim/test_wcs_jax.py index 59cd392c..aca741e8 100644 --- a/tests/jax/galsim/test_wcs_jax.py +++ b/tests/jax/galsim/test_wcs_jax.py @@ -6,7 +6,7 @@ import warnings import numpy as np -from galsim_test_helpers import assert_raises, do_pickle, gsobject_compare, timer +from galsim_test_helpers import assert_raises, do_pickle, gsobject_compare, timer, assert_warns, profile import jax_galsim as galsim @@ -3133,9 +3133,9 @@ def test_inverseab_convergence(): [0.0003767412741890354, 0.00019733136932198898], ] ), - coord.CelestialCoord( - coord.Angle(2.171481673601117, coord.radians), - coord.Angle(-0.47508762601580773, coord.radians), + galsim.CelestialCoord( + galsim.Angle(2.171481673601117, galsim.radians), + galsim.Angle(-0.47508762601580773, galsim.radians), ), None, np.array( From 7a157e6857fc323ca2bc72039d598f0f54924b80 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 9 Oct 2023 20:29:24 -0500 Subject: [PATCH 08/33] STY please the flake8 --- tests/jax/galsim/test_wcs_jax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/jax/galsim/test_wcs_jax.py b/tests/jax/galsim/test_wcs_jax.py index aca741e8..c758f9c6 100644 --- a/tests/jax/galsim/test_wcs_jax.py +++ b/tests/jax/galsim/test_wcs_jax.py @@ -3422,7 +3422,7 @@ def check_sphere(ra1, dec1, ra2, dec2, atol=1): w = dsq >= 3.99 if np.any(w): cross = np.cross(np.array([x1, y1, z1])[w], np.array([x2, y2, z2])[w]) - crosssq = cross[0] ** 2 + cross[1] ** 2 + cross[2] ** 2 + crossq = cross[0] ** 2 + cross[1] ** 2 + cross[2] ** 2 dist[w] = np.pi - np.arcsin(np.sqrt(crossq)) dist = np.rad2deg(dist) * 3600 np.testing.assert_allclose(dist, 0.0, rtol=0.0, atol=atol) From 8d278bc62b7bfca33a1880368266d8c38731fd1b Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 9 Oct 2023 20:39:57 -0500 Subject: [PATCH 09/33] STY please the flake8 --- tests/jax/galsim/test_wcs_jax.py | 190 +++++++++++++++---------------- 1 file changed, 92 insertions(+), 98 deletions(-) diff --git a/tests/jax/galsim/test_wcs_jax.py b/tests/jax/galsim/test_wcs_jax.py index c758f9c6..9383f5a8 100644 --- a/tests/jax/galsim/test_wcs_jax.py +++ b/tests/jax/galsim/test_wcs_jax.py @@ -967,7 +967,6 @@ def do_nonlocal_wcs(wcs, ufunc, vfunc, name, test_pickle=True, color=None): wcs4 = wcs.local(wcs.origin, color=color) assert wcs != wcs4, name + " is not != wcs.local()" assert wcs4 != wcs, name + " is not != wcs.local() (reverse)" - world_origin = wcs.toWorld(wcs.origin, color=color) if wcs.isUniform(): if wcs.world_origin == galsim.PositionD(0, 0): wcs2 = wcs.local(wcs.origin, color=color).withOrigin(wcs.origin) @@ -1031,8 +1030,8 @@ def do_nonlocal_wcs(wcs, ufunc, vfunc, name, test_pickle=True, color=None): full_im2 = galsim.Image(galsim.BoundsI(-1023, 1024, -1023, 1024), scale=1.0) for x0, y0, u0, v0 in zip(far_x_list, far_y_list, far_u_list, far_v_list): - local_ufunc = lambda x, y: ufunc(x + x0, y + y0) - u0 - local_vfunc = lambda x, y: vfunc(x + x0, y + y0) - v0 + local_ufunc = lambda x, y: ufunc(x + x0, y + y0) - u0 # noqa: E731 + local_vfunc = lambda x, y: vfunc(x + x0, y + y0) - v0 # noqa: E731 image_pos = galsim.PositionD(x0, y0) world_pos = galsim.PositionD(u0, v0) do_wcs_pos( @@ -1207,8 +1206,6 @@ def do_celestial_wcs(wcs, name, test_pickle=True, approx=False): "shiftOrigin(new_origin) returned wrong world position", ) - world_origin = wcs.toWorld(wcs.origin) - full_im1 = galsim.Image(galsim.BoundsI(-1023, 1024, -1023, 1024), wcs=wcs) full_im2 = galsim.Image(galsim.BoundsI(-1023, 1024, -1023, 1024), scale=1.0) @@ -1524,8 +1521,8 @@ def test_pixelscale(): # assert_raises(TypeError, galsim.PixelScale, scale=scale, origin=galsim.PositionD(0, 0)) # assert_raises(TypeError, galsim.PixelScale, scale=scale, world_origin=galsim.PositionD(0, 0)) - ufunc = lambda x, y: x * scale - vfunc = lambda x, y: y * scale + ufunc = lambda x, y: x * scale # noqa: E731 + vfunc = lambda x, y: y * scale # noqa: E731 # Do generic tests that apply to all WCS types do_local_wcs(wcs, ufunc, vfunc, "PixelScale") @@ -1596,8 +1593,8 @@ def test_pixelscale(): assert wcs != wcs3b, "OffsetWCS is not != a different one (origin)" assert wcs != wcs3c, "OffsetWCS is not != a different one (world_origin)" - ufunc = lambda x, y: scale * (x - x0) - vfunc = lambda x, y: scale * (y - y0) + ufunc = lambda x, y: scale * (x - x0) # noqa: E731 + vfunc = lambda x, y: scale * (y - y0) # noqa: E731 do_nonlocal_wcs(wcs, ufunc, vfunc, "OffsetWCS 1") # Add a world origin offset @@ -1605,8 +1602,8 @@ def test_pixelscale(): v0 = -141.9 world_origin = galsim.PositionD(u0, v0) wcs = galsim.OffsetWCS(scale, world_origin=world_origin) - ufunc = lambda x, y: scale * x + u0 - vfunc = lambda x, y: scale * y + v0 + ufunc = lambda x, y: scale * x + u0 # noqa: E731 + vfunc = lambda x, y: scale * y + v0 # noqa: E731 do_nonlocal_wcs(wcs, ufunc, vfunc, "OffsetWCS 2") # Add both kinds of offsets @@ -1617,8 +1614,8 @@ def test_pixelscale(): origin = galsim.PositionD(x0, y0) world_origin = galsim.PositionD(u0, v0) wcs = galsim.OffsetWCS(scale, origin=origin, world_origin=world_origin) - ufunc = lambda x, y: scale * (x - x0) + u0 - vfunc = lambda x, y: scale * (y - y0) + v0 + ufunc = lambda x, y: scale * (x - x0) + u0 # noqa: E731 + vfunc = lambda x, y: scale * (y - y0) + v0 # noqa: E731 do_nonlocal_wcs(wcs, ufunc, vfunc, "OffsetWCS 3") # Check that using a wcs in the context of an image works correctly @@ -1660,8 +1657,8 @@ def test_shearwcs(): assert wcs != wcs3b, "ShearWCS is not != a different one (shear)" factor = 1.0 / np.sqrt(1.0 - g1 * g1 - g2 * g2) - ufunc = lambda x, y: (x - g1 * x - g2 * y) * scale * factor - vfunc = lambda x, y: (y + g1 * y - g2 * x) * scale * factor + ufunc = lambda x, y: (x - g1 * x - g2 * y) * scale * factor # noqa: E731 + vfunc = lambda x, y: (y + g1 * y - g2 * x) * scale * factor # noqa: E731 # Do generic tests that apply to all WCS types do_local_wcs(wcs, ufunc, vfunc, "ShearWCS") @@ -1746,8 +1743,8 @@ def test_shearwcs(): assert wcs != wcs3c, "OffsetShearWCS is not != a different one (origin)" assert wcs != wcs3d, "OffsetShearWCS is not != a different one (world_origin)" - ufunc = lambda x, y: ((1 - g1) * (x - x0) - g2 * (y - y0)) * scale * factor - vfunc = lambda x, y: ((1 + g1) * (y - y0) - g2 * (x - x0)) * scale * factor + ufunc = lambda x, y: ((1 - g1) * (x - x0) - g2 * (y - y0)) * scale * factor # noqa: E731 + vfunc = lambda x, y: ((1 + g1) * (y - y0) - g2 * (x - x0)) * scale * factor # noqa: E731 do_nonlocal_wcs(wcs, ufunc, vfunc, "OffsetShearWCS 1") # Add a world origin offset @@ -1755,8 +1752,8 @@ def test_shearwcs(): v0 = -141.9 world_origin = galsim.PositionD(u0, v0) wcs = galsim.OffsetShearWCS(scale, shear, world_origin=world_origin) - ufunc = lambda x, y: ((1 - g1) * x - g2 * y) * scale * factor + u0 - vfunc = lambda x, y: ((1 + g1) * y - g2 * x) * scale * factor + v0 + ufunc = lambda x, y: ((1 - g1) * x - g2 * y) * scale * factor + u0 # noqa: E731 + vfunc = lambda x, y: ((1 + g1) * y - g2 * x) * scale * factor + v0 # noqa: E731 do_nonlocal_wcs(wcs, ufunc, vfunc, "OffsetShearWCS 2") # Add both kinds of offsets @@ -1767,8 +1764,8 @@ def test_shearwcs(): origin = galsim.PositionD(x0, y0) world_origin = galsim.PositionD(u0, v0) wcs = galsim.OffsetShearWCS(scale, shear, origin=origin, world_origin=world_origin) - ufunc = lambda x, y: ((1 - g1) * (x - x0) - g2 * (y - y0)) * scale * factor + u0 - vfunc = lambda x, y: ((1 + g1) * (y - y0) - g2 * (x - x0)) * scale * factor + v0 + ufunc = lambda x, y: ((1 - g1) * (x - x0) - g2 * (y - y0)) * scale * factor + u0 # noqa: E731 + vfunc = lambda x, y: ((1 + g1) * (y - y0) - g2 * (x - x0)) * scale * factor + v0 # noqa: E731 do_nonlocal_wcs(wcs, ufunc, vfunc, "OffsetShearWCS 3") # Check that using a wcs in the context of an image works correctly @@ -1828,8 +1825,8 @@ def test_affinetransform(): assert wcs != wcs3c, "JacobianWCS is not != a different one (dvdx)" assert wcs != wcs3d, "JacobianWCS is not != a different one (dvdy)" - ufunc = lambda x, y: dudx * x + dudy * y - vfunc = lambda x, y: dvdx * x + dvdy * y + ufunc = lambda x, y: dudx * x + dudy * y # noqa: E731 + vfunc = lambda x, y: dvdx * x + dvdy * y # noqa: E731 do_local_wcs(wcs, ufunc, vfunc, "JacobianWCS 1") # Check the decomposition: @@ -1885,8 +1882,8 @@ def test_affinetransform(): assert wcs != wcs3e, "AffineTransform is not != a different one (origin)" assert wcs != wcs3f, "AffineTransform is not != a different one (world_origin)" - ufunc = lambda x, y: dudx * (x - x0) + dudy * (y - y0) - vfunc = lambda x, y: dvdx * (x - x0) + dvdy * (y - y0) + ufunc = lambda x, y: dudx * (x - x0) + dudy * (y - y0) # noqa: E731 + vfunc = lambda x, y: dvdx * (x - x0) + dvdy * (y - y0) # noqa: E731 do_nonlocal_wcs(wcs, ufunc, vfunc, "AffineTransform 1") # Next one with a flip and significant rotation and a large (u,v) offset @@ -1896,8 +1893,8 @@ def test_affinetransform(): dvdy = 0.1409 wcs = galsim.JacobianWCS(dudx, dudy, dvdx, dvdy) - ufunc = lambda x, y: dudx * x + dudy * y - vfunc = lambda x, y: dvdx * x + dvdy * y + ufunc = lambda x, y: dudx * x + dudy * y # noqa: E731 + vfunc = lambda x, y: dvdx * x + dvdy * y # noqa: E731 do_local_wcs(wcs, ufunc, vfunc, "JacobianWCS 2") # Check the decomposition: @@ -1909,8 +1906,8 @@ def test_affinetransform(): wcs = galsim.AffineTransform( dudx, dudy, dvdx, dvdy, world_origin=galsim.PositionD(u0, v0) ) - ufunc = lambda x, y: dudx * x + dudy * y + u0 - vfunc = lambda x, y: dvdx * x + dvdy * y + v0 + ufunc = lambda x, y: dudx * x + dudy * y + u0 # noqa: E731 + vfunc = lambda x, y: dvdx * x + dvdy * y + v0 # noqa: E731 do_nonlocal_wcs(wcs, ufunc, vfunc, "AffineTransform 2") # Finally a really crazy one that isn't remotely regular @@ -1920,8 +1917,8 @@ def test_affinetransform(): dvdy = -0.3013 wcs = galsim.JacobianWCS(dudx, dudy, dvdx, dvdy) - ufunc = lambda x, y: dudx * x + dudy * y - vfunc = lambda x, y: dvdx * x + dvdy * y + ufunc = lambda x, y: dudx * x + dudy * y # noqa: E731 + vfunc = lambda x, y: dvdx * x + dvdy * y # noqa: E731 do_local_wcs(wcs, ufunc, vfunc, "Jacobian 3") # Check the decomposition: @@ -1940,8 +1937,8 @@ def test_affinetransform(): wcs = galsim.AffineTransform( dudx, dudy, dvdx, dvdy, origin=origin, world_origin=world_origin ) - ufunc = lambda x, y: dudx * (x - x0) + dudy * (y - y0) + u0 - vfunc = lambda x, y: dvdx * (x - x0) + dvdy * (y - y0) + v0 + ufunc = lambda x, y: dudx * (x - x0) + dudy * (y - y0) + u0 # noqa: E731 + vfunc = lambda x, y: dvdx * (x - x0) + dvdy * (y - y0) + v0 # noqa: E731 do_nonlocal_wcs(wcs, ufunc, vfunc, "AffineTransform 3") # Check that using a wcs in the context of an image works correctly @@ -2011,8 +2008,8 @@ def test_uvfunction(): # First make some that are identical to simpler WCS classes: # 1. Like PixelScale scale = 0.17 - ufunc = lambda x, y: x * scale - vfunc = lambda x, y: y * scale + ufunc = lambda x, y: x * scale # noqa: E731 + vfunc = lambda x, y: y * scale # noqa: E731 wcs = galsim.UVFunction(ufunc, vfunc) do_nonlocal_wcs(wcs, ufunc, vfunc, "UVFunction like PixelScale", test_pickle=False) assert wcs.ufunc(2.9, 3.7) == ufunc(2.9, 3.7) @@ -2025,8 +2022,8 @@ def test_uvfunction(): assert not wcs.isCelestial() # Also check with inverse functions. - xfunc = lambda u, v: u / scale - yfunc = lambda u, v: v / scale + xfunc = lambda u, v: u / scale # noqa: E731 + yfunc = lambda u, v: v / scale # noqa: E731 wcs = galsim.UVFunction(ufunc, vfunc, xfunc, yfunc) do_nonlocal_wcs( wcs, ufunc, vfunc, "UVFunction like PixelScale with inverse", test_pickle=False @@ -2060,14 +2057,14 @@ def test_uvfunction(): g1 = 0.14 g2 = -0.37 factor = 1.0 / np.sqrt(1.0 - g1 * g1 - g2 * g2) - ufunc = lambda x, y: (x - g1 * x - g2 * y) * scale * factor - vfunc = lambda x, y: (y + g1 * y - g2 * x) * scale * factor + ufunc = lambda x, y: (x - g1 * x - g2 * y) * scale * factor # noqa: E731 + vfunc = lambda x, y: (y + g1 * y - g2 * x) * scale * factor # noqa: E731 wcs = galsim.UVFunction(ufunc, vfunc) do_nonlocal_wcs(wcs, ufunc, vfunc, "UVFunction like ShearWCS", test_pickle=False) # Also check with inverse functions. - xfunc = lambda u, v: (u + g1 * u + g2 * v) / scale * factor - yfunc = lambda u, v: (v - g1 * v + g2 * u) / scale * factor + xfunc = lambda u, v: (u + g1 * u + g2 * v) / scale * factor # noqa: E731 + yfunc = lambda u, v: (v - g1 * v + g2 * u) / scale * factor # noqa: E731 wcs = galsim.UVFunction(ufunc, vfunc, xfunc, yfunc) do_nonlocal_wcs( wcs, ufunc, vfunc, "UVFunction like ShearWCS with inverse", test_pickle=False @@ -2079,8 +2076,8 @@ def test_uvfunction(): dvdx = 0.1409 dvdy = 0.2391 - ufunc = lambda x, y: dudx * x + dudy * y - vfunc = lambda x, y: dvdx * x + dvdy * y + ufunc = lambda x, y: dudx * x + dudy * y # noqa: E731 + vfunc = lambda x, y: dvdx * x + dvdy * y # noqa: E731 wcs = galsim.UVFunction(ufunc, vfunc) do_nonlocal_wcs( wcs, ufunc, vfunc, "UVFunction like AffineTransform", test_pickle=False @@ -2116,7 +2113,7 @@ def test_uvfunction(): uses_color=True, ) do_nonlocal_wcs( - wcs, ufunc, vfunc, "UVFunction with unused color term", test_pickle=True + wcsc, ufunc, vfunc, "UVFunction with unused color term", test_pickle=True ) # 4. Next some UVFunctions with non-trivial offsets @@ -2126,8 +2123,8 @@ def test_uvfunction(): v0 = -141.9 origin = galsim.PositionD(x0, y0) world_origin = galsim.PositionD(u0, v0) - ufunc2 = lambda x, y: dudx * (x - x0) + dudy * (y - y0) + u0 - vfunc2 = lambda x, y: dvdx * (x - x0) + dvdy * (y - y0) + v0 + ufunc2 = lambda x, y: dudx * (x - x0) + dudy * (y - y0) + u0 # noqa: E731 + vfunc2 = lambda x, y: dvdx * (x - x0) + dvdy * (y - y0) + v0 # noqa: E731 wcs = galsim.UVFunction(ufunc2, vfunc2) do_nonlocal_wcs( wcs, ufunc2, vfunc2, "UVFunction with origins in funcs", test_pickle=False @@ -2200,8 +2197,8 @@ def test_uvfunction(): "UVFunction dvdy does not match expected value.", ) - ufunc = lambda x, y: radial_u(x - x0, y - y0) - vfunc = lambda x, y: radial_v(x - x0, y - y0) + ufunc = lambda x, y: radial_u(x - x0, y - y0) # noqa: E731 + vfunc = lambda x, y: radial_v(x - x0, y - y0) # noqa: E731 do_nonlocal_wcs(wcs, ufunc, vfunc, "Cubic radial UVFunction", test_pickle=False) # Check that using a wcs in the context of an image works correctly @@ -2212,8 +2209,8 @@ def test_uvfunction(): cubic_u = Cubic(2.9e-5, 2000.0, "u") cubic_v = Cubic(-3.7e-5, 2000.0, "v") wcs = galsim.UVFunction(cubic_u, cubic_v, origin=galsim.PositionD(x0, y0)) - ufunc = lambda x, y: cubic_u(x - x0, y - y0) - vfunc = lambda x, y: cubic_v(x - x0, y - y0) + ufunc = lambda x, y: cubic_u(x - x0, y - y0) # noqa: E731 + vfunc = lambda x, y: cubic_v(x - x0, y - y0) # noqa: E731 do_nonlocal_wcs(wcs, ufunc, vfunc, "Cubic object UVFunction", test_pickle=False) # Check that using a wcs in the context of an image works correctly @@ -2221,8 +2218,8 @@ def test_uvfunction(): # 7. Test the UVFunction that is used in demo9 to confirm that I got the # inverse function correct! - ufunc = lambda x, y: 0.05 * x * (1.0 + 2.0e-6 * (x**2 + y**2)) - vfunc = lambda x, y: 0.05 * y * (1.0 + 2.0e-6 * (x**2 + y**2)) + ufunc = lambda x, y: 0.05 * x * (1.0 + 2.0e-6 * (x**2 + y**2)) # noqa: E731 + vfunc = lambda x, y: 0.05 * y * (1.0 + 2.0e-6 * (x**2 + y**2)) # noqa: E731 # w = 0.05 (r + 2.e-6 r^3) # 0 = r^3 + 5e5 r - 1e7 w # @@ -2234,7 +2231,7 @@ def test_uvfunction(): # ( 5 sqrt( w^2 + 5.e3/27 ) - 5 w )^1/3 ) import math - xfunc = lambda u, v: ( + xfunc = lambda u, v: ( # noqa: E731 lambda w: ( 0.0 if w == 0.0 @@ -2247,7 +2244,7 @@ def test_uvfunction(): ) ) )(math.sqrt(u**2 + v**2)) - yfunc = lambda u, v: ( + yfunc = lambda u, v: ( # noqa: E731 lambda w: ( 0.0 if w == 0.0 @@ -2284,19 +2281,19 @@ def test_uvfunction(): # This version doesn't work with numpy arrays because of the math functions. # This provides a test of that branch of the makeSkyImage function. - ufunc = lambda x, y: 0.17 * x * (1.0 + 1.0e-5 * math.sqrt(x**2 + y**2)) - vfunc = lambda x, y: 0.17 * y * (1.0 + 1.0e-5 * math.sqrt(x**2 + y**2)) + ufunc = lambda x, y: 0.17 * x * (1.0 + 1.0e-5 * math.sqrt(x**2 + y**2)) # noqa: E731 + vfunc = lambda x, y: 0.17 * y * (1.0 + 1.0e-5 * math.sqrt(x**2 + y**2)) # noqa: E731 wcs = galsim.UVFunction(ufunc, vfunc) do_nonlocal_wcs(wcs, ufunc, vfunc, "UVFunction with math funcs", test_pickle=False) do_wcs_image(wcs, "UVFunction_math") # 8. A non-trivial color example - ufunc = lambda x, y, c: (dudx + 0.1 * c) * x + dudy * y - vfunc = lambda x, y, c: dvdx * x + (dvdy - 0.2 * c) * y - xfunc = lambda u, v, c: ((dvdy - 0.2 * c) * u - dudy * v) / ( + ufunc = lambda x, y, c: (dudx + 0.1 * c) * x + dudy * y # noqa: E731 + vfunc = lambda x, y, c: dvdx * x + (dvdy - 0.2 * c) * y # noqa: E731 + xfunc = lambda u, v, c: ((dvdy - 0.2 * c) * u - dudy * v) / ( # noqa: E731 (dudx + 0.1 * c) * (dvdy - 0.2 * c) - dudy * dvdx ) - yfunc = lambda u, v, c: (-dvdx * u + (dudx + 0.1 * c) * v) / ( + yfunc = lambda u, v, c: (-dvdx * u + (dudx + 0.1 * c) * v) / ( # noqa: E731 (dudx + 0.1 * c) * (dvdy - 0.2 * c) - dudy * dvdx ) wcs = galsim.UVFunction(ufunc, vfunc, xfunc, yfunc, uses_color=True) @@ -2329,10 +2326,10 @@ def test_uvfunction(): ) # 9. A non-trivial color example that fails for arrays - ufunc = lambda x, y, c: math.exp(c * x) - vfunc = lambda x, y, c: math.exp(c * y / 2) - xfunc = lambda u, v, c: math.log(u) / c - yfunc = lambda u, v, c: math.log(v) * 2 / c + ufunc = lambda x, y, c: math.exp(c * x) # noqa: E731 + vfunc = lambda x, y, c: math.exp(c * y / 2) # noqa: E731 + xfunc = lambda u, v, c: math.log(u) / c # noqa: E731 + yfunc = lambda u, v, c: math.log(v) * 2 / c # noqa: E731 wcs = galsim.UVFunction(ufunc, vfunc, xfunc, yfunc, uses_color=True) do_nonlocal_wcs( wcs, @@ -2344,20 +2341,20 @@ def test_uvfunction(): ) # 10. One with invalid functions, which raise errors. (Just for coverage really.) - ufunc = lambda x, y: 0.05 * x * math.sqrt(x**2 - y**2) - vfunc = lambda x, y: 0.05 * y * math.sqrt(3 * y**2 - x**2) - xfunc = lambda u, v: 0.05 * u * math.sqrt(2 * u**2 - 7 * v**2) - yfunc = lambda u, v: 0.05 * v * math.sqrt(8 * v**2 - u**2) + ufunc = lambda x, y: 0.05 * x * math.sqrt(x**2 - y**2) # noqa: E731 + vfunc = lambda x, y: 0.05 * y * math.sqrt(3 * y**2 - x**2) # noqa: E731 + xfunc = lambda u, v: 0.05 * u * math.sqrt(2 * u**2 - 7 * v**2) # noqa: E731 + yfunc = lambda u, v: 0.05 * v * math.sqrt(8 * v**2 - u**2) # noqa: E731 wcs = galsim.UVFunction(ufunc, vfunc, xfunc, yfunc) assert_raises(ValueError, wcs.toWorld, galsim.PositionD(5, 6)) assert_raises(ValueError, wcs.toWorld, galsim.PositionD(8, 2)) assert_raises(ValueError, wcs.toImage, galsim.PositionD(3, 3)) assert_raises(ValueError, wcs.toImage, galsim.PositionD(6, 0)) # Repeat with color - ufunc = lambda x, y, c: 0.05 * c * math.sqrt(x**2 - y**2) - vfunc = lambda x, y, c: 0.05 * c * math.sqrt(3 * y**2 - x**2) - xfunc = lambda u, v, c: 0.05 * c * math.sqrt(2 * u**2 - 7 * v**2) - yfunc = lambda u, v, c: 0.05 * c * math.sqrt(8 * v**2 - u**2) + ufunc = lambda x, y, c: 0.05 * c * math.sqrt(x**2 - y**2) # noqa: E731 + vfunc = lambda x, y, c: 0.05 * c * math.sqrt(3 * y**2 - x**2) # noqa: E731 + xfunc = lambda u, v, c: 0.05 * c * math.sqrt(2 * u**2 - 7 * v**2) # noqa: E731 + yfunc = lambda u, v, c: 0.05 * c * math.sqrt(8 * v**2 - u**2) # noqa: E731 wcs = galsim.UVFunction(ufunc, vfunc, xfunc, yfunc, uses_color=True) assert_raises(ValueError, wcs.toWorld, galsim.PositionD(5, 6), color=0.2) assert_raises(ValueError, wcs.toWorld, galsim.PositionD(8, 2), color=0.2) @@ -2372,47 +2369,45 @@ def test_radecfunction(): funcs = [] scale = 0.17 - ufunc = lambda x, y: x * scale - vfunc = lambda x, y: y * scale + ufunc = lambda x, y: x * scale # noqa: E731 + vfunc = lambda x, y: y * scale # noqa: E731 funcs.append((ufunc, vfunc, "like PixelScale")) scale = 0.23 g1 = 0.14 g2 = -0.37 factor = 1.0 / np.sqrt(1.0 - g1 * g1 - g2 * g2) - ufunc = lambda x, y: (x - g1 * x - g2 * y) * scale * factor - vfunc = lambda x, y: (y + g1 * y - g2 * x) * scale * factor + ufunc = lambda x, y: (x - g1 * x - g2 * y) * scale * factor # noqa: E731 + vfunc = lambda x, y: (y + g1 * y - g2 * x) * scale * factor # noqa: E731 funcs.append((ufunc, vfunc, "like ShearWCS")) dudx = 0.2342 dudy = 0.1432 dvdx = 0.1409 dvdy = 0.2391 - ufunc = lambda x, y: dudx * x + dudy * y - vfunc = lambda x, y: dvdx * x + dvdy * y + ufunc = lambda x, y: dudx * x + dudy * y # noqa: E731 + vfunc = lambda x, y: dvdx * x + dvdy * y # noqa: E731 funcs.append((ufunc, vfunc, "like JacobianWCS")) x0 = 1.3 y0 = -0.9 u0 = 124.3 v0 = -141.9 - origin = galsim.PositionD(x0, y0) - world_origin = galsim.PositionD(u0, v0) - ufunc = lambda x, y: dudx * (x - x0) + dudy * (y - y0) + u0 - vfunc = lambda x, y: dvdx * (x - x0) + dvdy * (y - y0) + v0 + ufunc = lambda x, y: dudx * (x - x0) + dudy * (y - y0) + u0 # noqa: E731 + vfunc = lambda x, y: dvdx * (x - x0) + dvdy * (y - y0) + v0 # noqa: E731 funcs.append((ufunc, vfunc, "like AffineTransform")) funcs.append((radial_u, radial_v, "Cubic radial")) - ufunc = lambda x, y: radial_u(x - x0, y - y0) - vfunc = lambda x, y: radial_v(x - x0, y - y0) + ufunc = lambda x, y: radial_u(x - x0, y - y0) # noqa: E731 + vfunc = lambda x, y: radial_v(x - x0, y - y0) # noqa: E731 funcs.append((ufunc, vfunc, "offset Cubic radial")) cubic_u = Cubic(2.9e-5, 2000.0, "u") cubic_v = Cubic(-3.7e-5, 2000.0, "v") - ufunc = lambda x, y: cubic_u(x - x0, y - y0) - vfunc = lambda x, y: cubic_v(x - x0, y - y0) + ufunc = lambda x, y: cubic_u(x - x0, y - y0) # noqa: E731 + vfunc = lambda x, y: cubic_v(x - x0, y - y0) # noqa: E731 funcs.append((ufunc, vfunc, "offset Cubic object")) # The last one needs to not have a lambda, since we use it for the image test, which @@ -2437,7 +2432,7 @@ def test_radecfunction(): ) scale = galsim.arcsec / galsim.radians - radec_func = lambda x, y: center.deproject_rad( + radec_func = lambda x, y: center.deproject_rad( # noqa: E731 ufunc(x, y) * scale, vfunc(x, y) * scale, projection="lambert" ) wcs2 = galsim.RaDecFunction(radec_func) @@ -2450,12 +2445,12 @@ def test_radecfunction(): # code does the right thing in that case too, since local and makeSkyImage # try the numpy option first and do something else if it fails. # This also tests the alternate initialization using separate ra_func, dec_fun. - ra_func = lambda x, y: center.deproject( + ra_func = lambda x, y: center.deproject( # noqa: E731 ufunc(x, y) * galsim.arcsec, vfunc(x, y) * galsim.arcsec, projection="lambert", ).ra.rad - dec_func = lambda x, y: center.deproject( + dec_func = lambda x, y: center.deproject( # noqa: E731 ufunc(x, y) * galsim.arcsec, vfunc(x, y) * galsim.arcsec, projection="lambert", @@ -2524,7 +2519,6 @@ def test_radecfunction(): image_pos = galsim.PositionD(x, y) world_pos1 = wcs1.toWorld(image_pos) world_pos2 = test_wcs.toWorld(image_pos) - origin = test_wcs.toWorld(galsim.PositionD(0.0, 0.0)) d3 = np.sqrt(world_pos1.x**2 + world_pos1.y**2) d4 = center.distanceTo(world_pos2) d4 = 2.0 * np.sin(d4 / 2) * galsim.radians / galsim.arcsec @@ -2715,7 +2709,7 @@ def test_radecfunction(): do_wcs_image(wcs3, "RaDecFunction") # One with invalid functions, which raise errors. (Just for coverage really.) - radec_func = lambda x, y: center.deproject_rad( + radec_func = lambda x, y: center.deproject_rad( # noqa: E731 math.sqrt(x), math.sqrt(y), projection="lambert" ) wcs = galsim.RaDecFunction(radec_func) @@ -2783,8 +2777,8 @@ def test_astropywcs(): """Test the AstropyWCS class""" with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=RuntimeWarning) - import astropy.wcs - import scipy # AstropyWCS constructor will do this, so check now. + import astropy.wcs # noqa: F401 + import scipy # noqa: F401 - AstropyWCS constructor will do this, so check now. # These all work, but it is quite slow, so only test a few of them for the regular unit tests. # (1.8 seconds for 4 tags.) @@ -3323,13 +3317,13 @@ def test_fitswcs(): # mostly just tests the basic interface of the FitsWCS function. test_tags = ["TAN", "TPV"] try: - import starlink.Ast + import starlink.Ast # noqa: F401 # Useful also to test one that GSFitsWCS doesn't work on. This works on Travis at # least, and helps to cover some of the FitsWCS functionality where the first try # isn't successful. test_tags.append("HPX") - except: + except Exception: pass dir = "fits_files" @@ -3364,7 +3358,7 @@ def test_fitswcs(): # We don't really have any accuracy checks here. This really just checks that the # read function doesn't raise an exception. hdu, hdu_list, fin = galsim.fits.readFile(file_name, dir) - affine = galsim.AffineTransform._readHeader(hdu.header) + galsim.AffineTransform._readHeader(hdu.header) galsim.fits.closeHDUList(hdu_list, fin) # This does support LINEAR WCS types. @@ -4011,8 +4005,8 @@ def test_razero(): # do this. with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=RuntimeWarning) - import astropy.wcs - import scipy # AstropyWCS constructor will do this, so check now. + import astropy.wcs # noqa: F401 + import scipy # noqa: F401 - AstropyWCS constructor will do this, so check now. dir = "fits_files" # This file is based in sipsample.fits, but with the CRVAL1 changed to 0.002322805429 From 1d53e18dc4dda8852582fe0931abdf8b8b133a13 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 9 Oct 2023 20:42:33 -0500 Subject: [PATCH 10/33] BUG wrong error --- jax_galsim/random.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_galsim/random.py b/jax_galsim/random.py index 72855a19..affd7bf0 100644 --- a/jax_galsim/random.py +++ b/jax_galsim/random.py @@ -9,7 +9,7 @@ try: from jax.extend.random import wrap_key_data -except ImportError: +except ModuleNotFoundError: from jax.random import wrap_key_data from jax_galsim.core.utils import ensure_hashable From 347b86045b72e90a5d53f42a9b348fa93c860c90 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 9 Oct 2023 20:45:05 -0500 Subject: [PATCH 11/33] ENH update python build matrix --- .github/workflows/python_package.yaml | 2 +- jax_galsim/random.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python_package.yaml b/.github/workflows/python_package.yaml index 2e6c35b9..b4b6ca67 100644 --- a/.github/workflows/python_package.yaml +++ b/.github/workflows/python_package.yaml @@ -12,7 +12,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8", "3.9"] + python-version: ["3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v3 diff --git a/jax_galsim/random.py b/jax_galsim/random.py index affd7bf0..15bc4f56 100644 --- a/jax_galsim/random.py +++ b/jax_galsim/random.py @@ -9,7 +9,7 @@ try: from jax.extend.random import wrap_key_data -except ModuleNotFoundError: +except (ModuleNotFoundError, ImportError): from jax.random import wrap_key_data from jax_galsim.core.utils import ensure_hashable From 474e2edeeca9146459743c1cc5dd42a6e36a6082 Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 10 Oct 2023 09:43:53 -0500 Subject: [PATCH 12/33] ENH add chi2 --- jax_galsim/__init__.py | 8 +- jax_galsim/random.py | 89 ++++----- tests/jax/galsim/test_random_jax.py | 273 ++++++++++++++-------------- 3 files changed, 193 insertions(+), 177 deletions(-) diff --git a/jax_galsim/__init__.py b/jax_galsim/__init__.py index e77d507e..2a7ec231 100644 --- a/jax_galsim/__init__.py +++ b/jax_galsim/__init__.py @@ -8,7 +8,13 @@ from .errors import GalSimWarning, GalSimDeprecationWarning # noise -from .random import BaseDeviate, UniformDeviate, GaussianDeviate, PoissonDeviate +from .random import ( + BaseDeviate, + UniformDeviate, + GaussianDeviate, + PoissonDeviate, + Chi2Deviate, +) # Basic building blocks from .bounds import Bounds, BoundsD, BoundsI diff --git a/jax_galsim/random.py b/jax_galsim/random.py index 15bc4f56..ca9dcdab 100644 --- a/jax_galsim/random.py +++ b/jax_galsim/random.py @@ -528,55 +528,60 @@ def __str__(self): # return 'galsim.GammaDeviate(k=%r, theta=%r)'%(self.k, self.theta) -# class Chi2Deviate(BaseDeviate): -# """Pseudo-random Chi^2-distributed deviate for degrees-of-freedom parameter ``n``. - -# See http://en.wikipedia.org/wiki/Chi-squared_distribution (note that k=n in the notation -# adopted in the Boost.Random routine called by this class). The Chi^2 distribution is a -# real-valued distribution producing deviates >= 0. - -# Successive calls to ``chi2()`` generate pseudo-random values distributed according to a -# chi-square distribution with the specified degrees of freedom, ``n``:: +@_wraps( + _galsim.Chi2Deviate, + lax_description=LAX_FUNCTIONAL_RNG, +) +@register_pytree_node_class +class Chi2Deviate(BaseDeviate): + def __init__(self, seed=None, n=1.0): + super().__init__(seed=seed) + self._params["n"] = n -# >>> chi2 = galsim.Chi2Deviate(31415926, n=7) -# >>> chi2() -# 7.9182211987712385 -# >>> chi2() -# 6.644121724269535 + @property + def n(self): + """The number of degrees of freedom.""" + return self._params["n"] -# Parameters: -# seed: Something that can seed a `BaseDeviate`: an integer seed or another -# `BaseDeviate`. Using 0 means to generate a seed from the system. -# [default: None] -# n: Number of degrees of freedom for the output distribution. [default: 1; -# Must be > 0] -# """ -# def __init__(self, seed=None, n=1.): -# self._rng_type = _galsim.Chi2DeviateImpl -# self._rng_args = (float(n),) -# self.reset(seed) + @_wraps( + _galsim.BaseDeviate.generate, + lax_description=( + "JAX arrays cannot be changed in-place, so the JAX version of " + "this method returns a new array." + ), + ) + def generate(self, array): + self._key, array = self.__class__._generate(self._key, array, self.n) + return array -# @property -# def n(self): -# """The number of degrees of freedom. -# """ -# return self._rng_args[0] + @jax.jit + def _generate(key, array, n): + # we do it this way so that the RNG appears to have a fixed state that is advanced per value drawn + key, res = jax.lax.scan( + Chi2Deviate._generate_one, + key, + jnp.broadcast_to(n, array.ravel().shape), + length=array.ravel().shape[0], + ) + return key, res.reshape(array.shape) -# @property -# def has_reliable_discard(self): -# return False + def __call__(self): + self._key, val = self.__class__._generate_one(self._key, self.n) + return val -# def __call__(self): -# """Draw a new random number from the distribution. + @jax.jit + def _generate_one(key, n): + _key, subkey = jrandom.split(key) + return _key, jrandom.chisquare(subkey, n, dtype=float) -# Returns a Chi2-distributed deviate with the given number of degrees of freedom. -# """ -# return self._rng.generate1() + def __repr__(self): + return "galsim.Chi2Deviate(seed=%r, n=%r)" % ( + ensure_hashable(jrandom.key_data(self._key)), + ensure_hashable(self.n), + ) -# def __repr__(self): -# return 'galsim.Chi2Deviate(seed=%r, n=%r)'%(self._seed_repr(), self.n) -# def __str__(self): -# return 'galsim.Chi2Deviate(n=%r)'%(self.n) + def __str__(self): + return "galsim.Chi2Deviate(n=%r)" % (ensure_hashable(self.n),) # class DistDeviate(BaseDeviate): diff --git a/tests/jax/galsim/test_random_jax.py b/tests/jax/galsim/test_random_jax.py index ae98b496..8c2ed7d1 100644 --- a/tests/jax/galsim/test_random_jax.py +++ b/tests/jax/galsim/test_random_jax.py @@ -55,7 +55,7 @@ # n to use for Chi2 tests chi2N = 30 # Tabulated results for Chi2 -chi2Result = (32.209933900954049, 50.040002656028513, 24.301442486313896) +chi2Result = (36.7583415337, 32.7223187231, 23.1555198334) # function and min&max to use for DistDeviate function call tests dmin = 0.0 @@ -1340,151 +1340,156 @@ def test_poisson(): # assert_raises(TypeError, galsim.GammaDeviate, set()) -# @timer -# def test_chi2(): -# """Test Chi^2 random number generator -# """ -# c = galsim.Chi2Deviate(testseed, n=chi2N) -# c2 = c.duplicate() -# c3 = galsim.Chi2Deviate(c.serialize(), n=chi2N) -# testResult = (c(), c(), c()) -# np.testing.assert_array_almost_equal( -# np.array(testResult), np.array(chi2Result), precision, -# err_msg='Wrong Chi^2 random number sequence generated') -# testResult = (c2(), c2(), c2()) -# np.testing.assert_array_almost_equal( -# np.array(testResult), np.array(chi2Result), precision, -# err_msg='Wrong Chi^2 random number sequence generated with duplicate') -# testResult = (c3(), c3(), c3()) -# np.testing.assert_array_almost_equal( -# np.array(testResult), np.array(chi2Result), precision, -# err_msg='Wrong Chi^2 random number sequence generated from serialize') +@timer +def test_chi2(): + """Test Chi^2 random number generator + """ + c = galsim.Chi2Deviate(testseed, n=chi2N) + c2 = c.duplicate() + c3 = galsim.Chi2Deviate(c.serialize(), n=chi2N) + testResult = (c(), c(), c()) + np.testing.assert_array_almost_equal( + np.array(testResult), np.array(chi2Result), precision, + err_msg='Wrong Chi^2 random number sequence generated') + testResult = (c2(), c2(), c2()) + np.testing.assert_array_almost_equal( + np.array(testResult), np.array(chi2Result), precision, + err_msg='Wrong Chi^2 random number sequence generated with duplicate') + testResult = (c3(), c3(), c3()) + np.testing.assert_array_almost_equal( + np.array(testResult), np.array(chi2Result), precision, + err_msg='Wrong Chi^2 random number sequence generated from serialize') -# # Check that the mean and variance come out right -# c = galsim.Chi2Deviate(testseed, n=chi2N) -# vals = [c() for i in range(nvals)] -# mean = np.mean(vals) -# var = np.var(vals) -# mu = chi2N -# v = 2.*chi2N -# print('mean = ',mean,' true mean = ',mu) -# print('var = ',var,' true var = ',v) -# np.testing.assert_almost_equal(mean, mu, 1, -# err_msg='Wrong mean from Chi2Deviate') -# np.testing.assert_almost_equal(var, v, 0, -# err_msg='Wrong variance from Chi2Deviate') + # Check that the mean and variance come out right + c = galsim.Chi2Deviate(testseed, n=chi2N) + vals = [c() for i in range(nvals)] + mean = np.mean(vals) + var = np.var(vals) + mu = chi2N + v = 2. * chi2N + print('mean = ', mean, ' true mean = ', mu) + print('var = ', var, ' true var = ', v) + np.testing.assert_almost_equal( + mean, mu, 1, + err_msg='Wrong mean from Chi2Deviate') + np.testing.assert_almost_equal( + var, v, 0, + err_msg='Wrong variance from Chi2Deviate') -# # Check discard -# c2 = galsim.Chi2Deviate(testseed, n=chi2N) -# c2.discard(nvals, suppress_warnings=True) -# v1,v2 = c(),c2() -# print('after %d vals, next one is %s, %s'%(nvals,v1,v2)) -# # Chi2 uses at least 2 rngs per value, but can use arbitrarily more than this. -# assert v1 != v2 -# assert not c.has_reliable_discard -# assert not c.generates_in_pairs + # NOTE: + # Check discard + c2 = galsim.Chi2Deviate(testseed, n=chi2N) + c2.discard(nvals) + v1, v2 = c(), c2() + print('after %d vals, next one is %s, %s' % (nvals, v1, v2)) + assert v1 == v2 + assert c.has_reliable_discard + assert not c.generates_in_pairs -# # Discard normally emits a warning for Chi2 -# c2 = galsim.Chi2Deviate(testseed, n=chi2N) -# with assert_warns(galsim.GalSimWarning): -# c2.discard(nvals) + # NOTE jax has reliable discard + # # Discard normally emits a warning for Chi2 + # c2 = galsim.Chi2Deviate(testseed, n=chi2N) + # with assert_warns(galsim.GalSimWarning): + # c2.discard(nvals) -# # Check seed, reset -# c.seed(testseed) -# testResult2 = (c(), c(), c()) -# np.testing.assert_array_equal( -# np.array(testResult), np.array(testResult2), -# err_msg='Wrong Chi^2 random number sequence generated after seed') + # Check seed, reset + c.seed(testseed) + testResult2 = (c(), c(), c()) + np.testing.assert_array_equal( + np.array(testResult), np.array(testResult2), + err_msg='Wrong Chi^2 random number sequence generated after seed') -# c.reset(testseed) -# testResult2 = (c(), c(), c()) -# np.testing.assert_array_equal( -# np.array(testResult), np.array(testResult2), -# err_msg='Wrong Chi^2 random number sequence generated after reset(seed)') + c.reset(testseed) + testResult2 = (c(), c(), c()) + np.testing.assert_array_equal( + np.array(testResult), np.array(testResult2), + err_msg='Wrong Chi^2 random number sequence generated after reset(seed)') -# rng = galsim.BaseDeviate(testseed) -# c.reset(rng) -# testResult2 = (c(), c(), c()) -# np.testing.assert_array_equal( -# np.array(testResult), np.array(testResult2), -# err_msg='Wrong Chi^2 random number sequence generated after reset(rng)') + rng = galsim.BaseDeviate(testseed) + c.reset(rng) + testResult2 = (c(), c(), c()) + np.testing.assert_array_equal( + np.array(testResult), np.array(testResult2), + err_msg='Wrong Chi^2 random number sequence generated after reset(rng)') -# ud = galsim.UniformDeviate(testseed) -# c.reset(ud) -# testResult = (c(), c(), c()) -# np.testing.assert_array_equal( -# np.array(testResult), np.array(testResult2), -# err_msg='Wrong Chi^2 random number sequence generated after reset(ud)') + ud = galsim.UniformDeviate(testseed) + c.reset(ud) + testResult = (c(), c(), c()) + np.testing.assert_array_equal( + np.array(testResult), np.array(testResult2), + err_msg='Wrong Chi^2 random number sequence generated after reset(ud)') -# # Check that two connected Chi^2 deviates work correctly together. -# c2 = galsim.Chi2Deviate(testseed, n=chi2N) -# c.reset(c2) -# testResult2 = (c(), c2(), c()) -# np.testing.assert_array_equal( -# np.array(testResult), np.array(testResult2), -# err_msg='Wrong Chi^2 random number sequence generated using two cds') -# c.seed(testseed) -# testResult2 = (c2(), c(), c2()) -# np.testing.assert_array_equal( -# np.array(testResult), np.array(testResult2), -# err_msg='Wrong Chi^2 random number sequence generated using two cds after seed') + # NOTE we cannot connect two deviates in JAX + # # Check that two connected Chi^2 deviates work correctly together. + # c2 = galsim.Chi2Deviate(testseed, n=chi2N) + # c.reset(c2) + # testResult2 = (c(), c2(), c()) + # np.testing.assert_array_equal( + # np.array(testResult), np.array(testResult2), + # err_msg='Wrong Chi^2 random number sequence generated using two cds') + # c.seed(testseed) + # testResult2 = (c2(), c(), c2()) + # np.testing.assert_array_equal( + # np.array(testResult), np.array(testResult2), + # err_msg='Wrong Chi^2 random number sequence generated using two cds after seed') -# # Check that seeding with the time works (although we cannot check the output). -# # We're mostly just checking that this doesn't raise an exception. -# # The output could be anything. -# c.seed() -# testResult2 = (c(), c(), c()) -# assert testResult2 != testResult -# c.reset() -# testResult3 = (c(), c(), c()) -# assert testResult3 != testResult -# assert testResult3 != testResult2 -# c.reset() -# testResult4 = (c(), c(), c()) -# assert testResult4 != testResult -# assert testResult4 != testResult2 -# assert testResult4 != testResult3 -# c = galsim.Chi2Deviate(n=chi2N) -# testResult5 = (c(), c(), c()) -# assert testResult5 != testResult -# assert testResult5 != testResult2 -# assert testResult5 != testResult3 -# assert testResult5 != testResult4 + # Check that seeding with the time works (although we cannot check the output). + # We're mostly just checking that this doesn't raise an exception. + # The output could be anything. + c.seed() + testResult2 = (c(), c(), c()) + assert testResult2 != testResult + c.reset() + testResult3 = (c(), c(), c()) + assert testResult3 != testResult + assert testResult3 != testResult2 + c.reset() + testResult4 = (c(), c(), c()) + assert testResult4 != testResult + assert testResult4 != testResult2 + assert testResult4 != testResult3 + c = galsim.Chi2Deviate(n=chi2N) + testResult5 = (c(), c(), c()) + assert testResult5 != testResult + assert testResult5 != testResult2 + assert testResult5 != testResult3 + assert testResult5 != testResult4 -# # Test generate -# c.seed(testseed) -# test_array = np.empty(3) -# c.generate(test_array) -# np.testing.assert_array_almost_equal( -# test_array, np.array(chi2Result), precision, -# err_msg='Wrong Chi^2 random number sequence from generate.') + # Test generate + c.seed(testseed) + test_array = np.empty(3) + test_array = c.generate(test_array) + np.testing.assert_array_almost_equal( + test_array, np.array(chi2Result), precision, + err_msg='Wrong Chi^2 random number sequence from generate.') -# # Test generate with a float32 array -# c.seed(testseed) -# test_array = np.empty(3, dtype=np.float32) -# c.generate(test_array) -# np.testing.assert_array_almost_equal( -# test_array, np.array(chi2Result), precisionF, -# err_msg='Wrong Chi^2 random number sequence from generate.') + # Test generate with a float32 array + c.seed(testseed) + test_array = np.empty(3, dtype=np.float32) + test_array = c.generate(test_array) + np.testing.assert_array_almost_equal( + test_array, np.array(chi2Result), precisionF, + err_msg='Wrong Chi^2 random number sequence from generate.') -# # Check picklability -# do_pickle(c, lambda x: (x.serialize(), x.n)) -# do_pickle(c, lambda x: (x(), x(), x(), x())) -# do_pickle(c) -# assert 'Chi2Deviate' in repr(c) -# assert 'Chi2Deviate' in str(c) -# assert isinstance(eval(repr(c)), galsim.Chi2Deviate) -# assert isinstance(eval(str(c)), galsim.Chi2Deviate) - -# # Check that we can construct a Chi2Deviate from None, and that it depends on dev/random. -# c1 = galsim.Chi2Deviate(None) -# c2 = galsim.Chi2Deviate(None) -# assert c1 != c2, "Consecutive Chi2Deviate(None) compared equal!" -# # We shouldn't be able to construct a Chi2Deviate from anything but a BaseDeviate, int, str, -# # or None. -# assert_raises(TypeError, galsim.Chi2Deviate, dict()) -# assert_raises(TypeError, galsim.Chi2Deviate, list()) -# assert_raises(TypeError, galsim.Chi2Deviate, set()) + # Check picklability + do_pickle(c, lambda x: (x.serialize(), x.n)) + do_pickle(c, lambda x: (x(), x(), x(), x())) + do_pickle(c) + assert 'Chi2Deviate' in repr(c) + assert 'Chi2Deviate' in str(c) + assert isinstance(eval(repr(c)), galsim.Chi2Deviate) + assert isinstance(eval(str(c)), galsim.Chi2Deviate) + + # Check that we can construct a Chi2Deviate from None, and that it depends on dev/random. + c1 = galsim.Chi2Deviate(None) + c2 = galsim.Chi2Deviate(None) + assert c1 != c2, "Consecutive Chi2Deviate(None) compared equal!" + # NOTE jax does not raise + # # We shouldn't be able to construct a Chi2Deviate from anything but a BaseDeviate, int, str, + # # or None. + # assert_raises(TypeError, galsim.Chi2Deviate, dict()) + # assert_raises(TypeError, galsim.Chi2Deviate, list()) + # assert_raises(TypeError, galsim.Chi2Deviate, set()) # @timer From 74771a6ceb088336a0a82d3b1abe2ae84a1b5c3a Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 10 Oct 2023 10:25:52 -0500 Subject: [PATCH 13/33] ENH add gamma --- jax_galsim/__init__.py | 1 + jax_galsim/random.py | 109 ++++++----- tests/jax/galsim/test_random_jax.py | 274 ++++++++++++++-------------- 3 files changed, 200 insertions(+), 184 deletions(-) diff --git a/jax_galsim/__init__.py b/jax_galsim/__init__.py index 2a7ec231..f7b85f48 100644 --- a/jax_galsim/__init__.py +++ b/jax_galsim/__init__.py @@ -14,6 +14,7 @@ GaussianDeviate, PoissonDeviate, Chi2Deviate, + GammaDeviate, ) # Basic building blocks diff --git a/jax_galsim/random.py b/jax_galsim/random.py index ca9dcdab..3f8dbccc 100644 --- a/jax_galsim/random.py +++ b/jax_galsim/random.py @@ -244,7 +244,7 @@ def sigma(self): return self._params["sigma"] @_wraps( - _galsim.BaseDeviate.generate, + _galsim.GaussianDeviate.generate, lax_description=( "JAX arrays cannot be changed in-place, so the JAX version of " "this method returns a new array." @@ -359,7 +359,7 @@ def mean(self): return self._params["mean"] @_wraps( - _galsim.BaseDeviate.generate, + _galsim.PoissonDeviate.generate, lax_description=( "JAX arrays cannot be changed in-place, so the JAX version of " "this method returns a new array." @@ -471,61 +471,70 @@ def __str__(self): # return 'galsim.WeibullDeviate(a=%r, b=%r)'%(self.a, self.b) -# class GammaDeviate(BaseDeviate): -# """A Gamma-distributed deviate with shape parameter ``k`` and scale parameter ``theta``. -# See http://en.wikipedia.org/wiki/Gamma_distribution. -# (Note: we use the k, theta notation. If you prefer alpha, beta, use k=alpha, theta=1/beta.) -# The Gamma distribution is a real valued distribution producing deviates >= 0. - -# Successive calls to ``g()`` generate pseudo-random values distributed according to a gamma -# distribution with the specified shape and scale parameters ``k`` and ``theta``:: +@_wraps( + _galsim.GammaDeviate, + lax_description=LAX_FUNCTIONAL_RNG, +) +@register_pytree_node_class +class GammaDeviate(BaseDeviate): + def __init__(self, seed=None, k=1.0, theta=1.0): + super().__init__(seed=seed) + self._params["k"] = k + self._params["theta"] = theta -# >>> gam = galsim.GammaDeviate(31415926, k=1, theta=2) -# >>> gam() -# 0.37508882726316 -# >>> gam() -# 1.3504199388358704 + @property + def k(self): + """The shape parameter, k.""" + return self._params["k"] -# Parameters: -# seed: Something that can seed a `BaseDeviate`: an integer seed or another -# `BaseDeviate`. Using 0 means to generate a seed from the system. -# [default: None] -# k: Shape parameter of the distribution. [default: 1; Must be > 0] -# theta: Scale parameter of the distribution. [default: 1; Must be > 0] -# """ -# def __init__(self, seed=None, k=1., theta=1.): -# self._rng_type = _galsim.GammaDeviateImpl -# self._rng_args = (float(k), float(theta)) -# self.reset(seed) + @property + def theta(self): + """The scale parameter, theta.""" + return self._params["theta"] -# @property -# def k(self): -# """The shape parameter, k. -# """ -# return self._rng_args[0] + @_wraps( + _galsim.GammaDeviate.generate, + lax_description=( + "JAX arrays cannot be changed in-place, so the JAX version of " + "this method returns a new array." + ), + ) + def generate(self, array): + self._key, array = self.__class__._generate(self._key, array, self.k) + return array * self.theta -# @property -# def theta(self): -# """The scale parameter, theta. -# """ -# return self._rng_args[1] + @jax.jit + def _generate(key, array, k): + # we do it this way so that the RNG appears to have a fixed state that is advanced per value drawn + key, res = jax.lax.scan( + GammaDeviate._generate_one, + key, + jnp.broadcast_to(k, array.ravel().shape), + length=array.ravel().shape[0], + ) + return key, res.reshape(array.shape) -# @property -# def has_reliable_discard(self): -# return False + def __call__(self): + self._key, val = self.__class__._generate_one(self._key, self.k) + return val * self.theta -# def __call__(self): -# """Draw a new random number from the distribution. + @jax.jit + def _generate_one(key, k): + _key, subkey = jrandom.split(key) + return _key, jrandom.gamma(subkey, k, dtype=float) -# Returns a Gamma-distributed deviate with the given k and theta. -# """ -# return self._rng.generate1() + def __repr__(self): + return "galsim.GammaDeviate(seed=%r, k=%r, theta=%r)" % ( + ensure_hashable(jrandom.key_data(self._key)), + ensure_hashable(self.k), + ensure_hashable(self.theta), + ) -# def __repr__(self): -# return 'galsim.GammaDeviate(seed=%r, k=%r, theta=%r)'%( -# self._seed_repr(), self.k, self.theta) -# def __str__(self): -# return 'galsim.GammaDeviate(k=%r, theta=%r)'%(self.k, self.theta) + def __str__(self): + return "galsim.GammaDeviate(k=%r, theta=%r)" % ( + ensure_hashable(self.k), + ensure_hashable(self.theta), + ) @_wraps( @@ -544,7 +553,7 @@ def n(self): return self._params["n"] @_wraps( - _galsim.BaseDeviate.generate, + _galsim.Chi2Deviate.generate, lax_description=( "JAX arrays cannot be changed in-place, so the JAX version of " "this method returns a new array." diff --git a/tests/jax/galsim/test_random_jax.py b/tests/jax/galsim/test_random_jax.py index 8c2ed7d1..18c06bba 100644 --- a/tests/jax/galsim/test_random_jax.py +++ b/tests/jax/galsim/test_random_jax.py @@ -50,7 +50,7 @@ gammaK = 1.5 gammaTheta = 4.5 # Tabulated results for Gamma -gammaResult = (4.7375613139927157, 15.272973580418618, 21.485016362839747) +gammaResult = (10.9318881415, 7.6074550007, 2.0526795529) # n to use for Chi2 tests chi2N = 30 @@ -1193,151 +1193,157 @@ def test_poisson(): # assert_raises(TypeError, galsim.WeibullDeviate, set()) -# @timer -# def test_gamma(): -# """Test Gamma random number generator -# """ -# g = galsim.GammaDeviate(testseed, k=gammaK, theta=gammaTheta) -# g2 = g.duplicate() -# g3 = galsim.GammaDeviate(g.serialize(), k=gammaK, theta=gammaTheta) -# testResult = (g(), g(), g()) -# np.testing.assert_array_almost_equal( -# np.array(testResult), np.array(gammaResult), precision, -# err_msg='Wrong Gamma random number sequence generated') -# testResult = (g2(), g2(), g2()) -# np.testing.assert_array_almost_equal( -# np.array(testResult), np.array(gammaResult), precision, -# err_msg='Wrong Gamma random number sequence generated with duplicate') -# testResult = (g3(), g3(), g3()) -# np.testing.assert_array_almost_equal( -# np.array(testResult), np.array(gammaResult), precision, -# err_msg='Wrong Gamma random number sequence generated from serialize') +@timer +def test_gamma(): + """Test Gamma random number generator + """ + g = galsim.GammaDeviate(testseed, k=gammaK, theta=gammaTheta) + g2 = g.duplicate() + g3 = galsim.GammaDeviate(g.serialize(), k=gammaK, theta=gammaTheta) + testResult = (g(), g(), g()) + np.testing.assert_array_almost_equal( + np.array(testResult), np.array(gammaResult), precision, + err_msg='Wrong Gamma random number sequence generated') + testResult = (g2(), g2(), g2()) + np.testing.assert_array_almost_equal( + np.array(testResult), np.array(gammaResult), precision, + err_msg='Wrong Gamma random number sequence generated with duplicate') + testResult = (g3(), g3(), g3()) + np.testing.assert_array_almost_equal( + np.array(testResult), np.array(gammaResult), precision, + err_msg='Wrong Gamma random number sequence generated from serialize') -# # Check that the mean and variance come out right -# g = galsim.GammaDeviate(testseed, k=gammaK, theta=gammaTheta) -# vals = [g() for i in range(nvals)] -# mean = np.mean(vals) -# var = np.var(vals) -# mu = gammaK*gammaTheta -# v = gammaK*gammaTheta**2 -# print('mean = ',mean,' true mean = ',mu) -# print('var = ',var,' true var = ',v) -# np.testing.assert_almost_equal(mean, mu, 1, -# err_msg='Wrong mean from GammaDeviate') -# np.testing.assert_almost_equal(var, v, 0, -# err_msg='Wrong variance from GammaDeviate') + # Check that the mean and variance come out right + g = galsim.GammaDeviate(testseed, k=gammaK, theta=gammaTheta) + vals = [g() for i in range(nvals)] + mean = np.mean(vals) + var = np.var(vals) + mu = gammaK * gammaTheta + v = gammaK * gammaTheta**2 + print('mean = ', mean, ' true mean = ', mu) + print('var = ', var, ' true var = ', v) + np.testing.assert_almost_equal( + mean, mu, 1, + err_msg='Wrong mean from GammaDeviate') + np.testing.assert_almost_equal( + var, v, 0, + err_msg='Wrong variance from GammaDeviate') -# # Check discard -# g2 = galsim.GammaDeviate(testseed, k=gammaK, theta=gammaTheta) -# g2.discard(nvals, suppress_warnings=True) -# v1,v2 = g(),g2() -# print('after %d vals, next one is %s, %s'%(nvals,v1,v2)) -# # Gamma uses at least 2 rngs per value, but can use arbitrarily more than this. -# assert v1 != v2 -# assert not g.has_reliable_discard -# assert not g.generates_in_pairs + # NOTE jax has a reliabble discard + # Check discard + g2 = galsim.GammaDeviate(testseed, k=gammaK, theta=gammaTheta) + g2.discard(nvals, suppress_warnings=True) + v1, v2 = g(), g2() + print('after %d vals, next one is %s, %s' % (nvals, v1, v2)) + # Gamma uses at least 2 rngs per value, but can use arbitrarily more than this. + assert v1 == v2 + assert g.has_reliable_discard + assert not g.generates_in_pairs -# # Discard normally emits a warning for Gamma -# g2 = galsim.GammaDeviate(testseed, k=gammaK, theta=gammaTheta) -# with assert_warns(galsim.GalSimWarning): -# g2.discard(nvals) + # NOTE jax has a reliabble discard + # Discard normally emits a warning for Gamma + # g2 = galsim.GammaDeviate(testseed, k=gammaK, theta=gammaTheta) + # with assert_warns(galsim.GalSimWarning): + # g2.discard(nvals) -# # Check seed, reset -# g.seed(testseed) -# testResult2 = (g(), g(), g()) -# np.testing.assert_array_equal( -# np.array(testResult), np.array(testResult2), -# err_msg='Wrong gamma random number sequence generated after seed') + # Check seed, reset + g.seed(testseed) + testResult2 = (g(), g(), g()) + np.testing.assert_array_equal( + np.array(testResult), np.array(testResult2), + err_msg='Wrong gamma random number sequence generated after seed') -# g.reset(testseed) -# testResult2 = (g(), g(), g()) -# np.testing.assert_array_equal( -# np.array(testResult), np.array(testResult2), -# err_msg='Wrong gamma random number sequence generated after reset(seed)') + g.reset(testseed) + testResult2 = (g(), g(), g()) + np.testing.assert_array_equal( + np.array(testResult), np.array(testResult2), + err_msg='Wrong gamma random number sequence generated after reset(seed)') -# rng = galsim.BaseDeviate(testseed) -# g.reset(rng) -# testResult2 = (g(), g(), g()) -# np.testing.assert_array_equal( -# np.array(testResult), np.array(testResult2), -# err_msg='Wrong gamma random number sequence generated after reset(rng)') + rng = galsim.BaseDeviate(testseed) + g.reset(rng) + testResult2 = (g(), g(), g()) + np.testing.assert_array_equal( + np.array(testResult), np.array(testResult2), + err_msg='Wrong gamma random number sequence generated after reset(rng)') -# ud = galsim.UniformDeviate(testseed) -# g.reset(ud) -# testResult = (g(), g(), g()) -# np.testing.assert_array_equal( -# np.array(testResult), np.array(testResult2), -# err_msg='Wrong gamma random number sequence generated after reset(ud)') + ud = galsim.UniformDeviate(testseed) + g.reset(ud) + testResult = (g(), g(), g()) + np.testing.assert_array_equal( + np.array(testResult), np.array(testResult2), + err_msg='Wrong gamma random number sequence generated after reset(ud)') -# # Check that two connected gamma deviates work correctly together. -# g2 = galsim.GammaDeviate(testseed, k=gammaK, theta=gammaTheta) -# g.reset(g2) -# testResult2 = (g(), g2(), g()) -# np.testing.assert_array_equal( -# np.array(testResult), np.array(testResult2), -# err_msg='Wrong gamma random number sequence generated using two gds') -# g.seed(testseed) -# testResult2 = (g2(), g(), g2()) -# np.testing.assert_array_equal( -# np.array(testResult), np.array(testResult2), -# err_msg='Wrong gamma random number sequence generated using two gds after seed') + # NOTE jax cannot connect RNGs + # # Check that two connected gamma deviates work correctly together. + # g2 = galsim.GammaDeviate(testseed, k=gammaK, theta=gammaTheta) + # g.reset(g2) + # testResult2 = (g(), g2(), g()) + # np.testing.assert_array_equal( + # np.array(testResult), np.array(testResult2), + # err_msg='Wrong gamma random number sequence generated using two gds') + # g.seed(testseed) + # testResult2 = (g2(), g(), g2()) + # np.testing.assert_array_equal( + # np.array(testResult), np.array(testResult2), + # err_msg='Wrong gamma random number sequence generated using two gds after seed') -# # Check that seeding with the time works (although we cannot check the output). -# # We're mostly just checking that this doesn't raise an exception. -# # The output could be anything. -# g.seed() -# testResult2 = (g(), g(), g()) -# assert testResult2 != testResult -# g.reset() -# testResult3 = (g(), g(), g()) -# assert testResult3 != testResult -# assert testResult3 != testResult2 -# g.reset() -# testResult4 = (g(), g(), g()) -# assert testResult4 != testResult -# assert testResult4 != testResult2 -# assert testResult4 != testResult3 -# g = galsim.GammaDeviate(k=gammaK, theta=gammaTheta) -# testResult5 = (g(), g(), g()) -# assert testResult5 != testResult -# assert testResult5 != testResult2 -# assert testResult5 != testResult3 -# assert testResult5 != testResult4 + # Check that seeding with the time works (although we cannot check the output). + # We're mostly just checking that this doesn't raise an exception. + # The output could be anything. + g.seed() + testResult2 = (g(), g(), g()) + assert testResult2 != testResult + g.reset() + testResult3 = (g(), g(), g()) + assert testResult3 != testResult + assert testResult3 != testResult2 + g.reset() + testResult4 = (g(), g(), g()) + assert testResult4 != testResult + assert testResult4 != testResult2 + assert testResult4 != testResult3 + g = galsim.GammaDeviate(k=gammaK, theta=gammaTheta) + testResult5 = (g(), g(), g()) + assert testResult5 != testResult + assert testResult5 != testResult2 + assert testResult5 != testResult3 + assert testResult5 != testResult4 -# # Test generate -# g.seed(testseed) -# test_array = np.empty(3) -# g.generate(test_array) -# np.testing.assert_array_almost_equal( -# test_array, np.array(gammaResult), precision, -# err_msg='Wrong gamma random number sequence from generate.') + # Test generate + g.seed(testseed) + test_array = np.empty(3) + test_array = g.generate(test_array) + np.testing.assert_array_almost_equal( + test_array, np.array(gammaResult), precision, + err_msg='Wrong gamma random number sequence from generate.') -# # Test generate with a float32 array -# g.seed(testseed) -# test_array = np.empty(3, dtype=np.float32) -# g.generate(test_array) -# np.testing.assert_array_almost_equal( -# test_array, np.array(gammaResult), precisionF, -# err_msg='Wrong gamma random number sequence from generate.') + # Test generate with a float32 array + g.seed(testseed) + test_array = np.empty(3, dtype=np.float32) + test_array = g.generate(test_array) + np.testing.assert_array_almost_equal( + test_array, np.array(gammaResult), precisionF, + err_msg='Wrong gamma random number sequence from generate.') -# # Check picklability -# do_pickle(g, lambda x: (x.serialize(), x.k, x.theta)) -# do_pickle(g, lambda x: (x(), x(), x(), x())) -# do_pickle(g) -# assert 'GammaDeviate' in repr(g) -# assert 'GammaDeviate' in str(g) -# assert isinstance(eval(repr(g)), galsim.GammaDeviate) -# assert isinstance(eval(str(g)), galsim.GammaDeviate) - -# # Check that we can construct a GammaDeviate from None, and that it depends on dev/random. -# g1 = galsim.GammaDeviate(None) -# g2 = galsim.GammaDeviate(None) -# assert g1 != g2, "Consecutive GammaDeviate(None) compared equal!" -# # We shouldn't be able to construct a GammaDeviate from anything but a BaseDeviate, int, str, -# # or None. -# assert_raises(TypeError, galsim.GammaDeviate, dict()) -# assert_raises(TypeError, galsim.GammaDeviate, list()) -# assert_raises(TypeError, galsim.GammaDeviate, set()) + # Check picklability + do_pickle(g, lambda x: (x.serialize(), x.k, x.theta)) + do_pickle(g, lambda x: (x(), x(), x(), x())) + do_pickle(g) + assert 'GammaDeviate' in repr(g) + assert 'GammaDeviate' in str(g) + assert isinstance(eval(repr(g)), galsim.GammaDeviate) + assert isinstance(eval(str(g)), galsim.GammaDeviate) + + # Check that we can construct a GammaDeviate from None, and that it depends on dev/random. + g1 = galsim.GammaDeviate(None) + g2 = galsim.GammaDeviate(None) + assert g1 != g2, "Consecutive GammaDeviate(None) compared equal!" + # NOTE jax does not raise for type errors + # # We shouldn't be able to construct a GammaDeviate from anything but a BaseDeviate, int, str, + # # or None. + # assert_raises(TypeError, galsim.GammaDeviate, dict()) + # assert_raises(TypeError, galsim.GammaDeviate, list()) + # assert_raises(TypeError, galsim.GammaDeviate, set()) @timer From 034f4fd6ce969f1329ab29f48fd978a9e95954ea Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 10 Oct 2023 15:06:18 -0500 Subject: [PATCH 14/33] ENH support permute --- jax_galsim/random.py | 50 ++++++++------------------- tests/jax/galsim/test_random_jax.py | 53 +++++++++++++++-------------- 2 files changed, 42 insertions(+), 61 deletions(-) diff --git a/jax_galsim/random.py b/jax_galsim/random.py index 3f8dbccc..1593c530 100644 --- a/jax_galsim/random.py +++ b/jax_galsim/random.py @@ -34,6 +34,10 @@ def __init__(self, seed=None): self.reset(seed=seed) self._params = {} + @property + def key(self): + return self._key + @_wraps( _galsim.BaseDeviate.seed, lax_description="The JAX version of this method does no type checking.", @@ -835,38 +839,14 @@ def __str__(self): # self._npoints == other._npoints)) -# class GalSimBitGenerator(np.random.BitGenerator): -# """A numpy.random.BitGenerator that uses the GalSim C++-layer random number generator -# for the random bit generation. - -# Parameters: -# rng: The galsim.BaseDeviate object to use for the underlying bit generation. -# """ -# def __init__(self, rng): -# super().__init__(0) -# self.rng = rng -# self.rng._rng.setup_bitgen(self.capsule) - -# def permute(rng, *args): -# """Randomly permute one or more lists. - -# If more than one list is given, then all lists will have the same random permutation -# applied to it. - -# Parameters: -# rng: The random number generator to use. (This will be converted to a `UniformDeviate`.) -# args: Any number of lists to be permuted. -# """ -# from .random import UniformDeviate -# ud = UniformDeviate(rng) -# if len(args) == 0: -# raise TypeError("permute called with no lists to permute") - -# # We use an algorithm called the Knuth shuffle, which is based on the Fisher-Yates shuffle. -# # See http://en.wikipedia.org/wiki/Fisher-Yates_shuffle for more information. -# n = len(args[0]) -# for i in range(n-1,1,-1): -# j = int((i+1) * ud()) -# if j == i+1: j = i # I'm not sure if this is possible, but just in case... -# for lst in args: -# lst[i], lst[j] = lst[j], lst[i] +@_wraps( + _galsim.random.permute, + lax_description="The JAX implementation of this function cannot operate in-place and so returns a new list of arrays.", +) +def permute(rng, *args): + rng = BaseDeviate(rng) + arrs = [] + for arr in args: + arrs.append(jrandom.permutation(rng.key, arr)) + rng.discard(1) + return arrs diff --git a/tests/jax/galsim/test_random_jax.py b/tests/jax/galsim/test_random_jax.py index 18c06bba..f3043676 100644 --- a/tests/jax/galsim/test_random_jax.py +++ b/tests/jax/galsim/test_random_jax.py @@ -1941,32 +1941,33 @@ def test_chi2(): # task_queue.put('STOP') -# @timer -# def test_permute(): -# """Simple tests of the permute() function.""" -# # Make a fake list, and another list consisting of indices. -# my_list = [3.7, 4.1, 1.9, 11.1, 378.3, 100.0] -# import copy -# my_list_copy = copy.deepcopy(my_list) -# n_list = len(my_list) -# ind_list = list(range(n_list)) - -# # Permute both at the same time. -# galsim.random.permute(312, my_list, ind_list) - -# # Make sure that everything is sensible -# for ind in range(n_list): -# assert my_list_copy[ind_list[ind]] == my_list[ind] - -# # Repeat with same seed, should do same permutation. -# my_list = copy.deepcopy(my_list_copy) -# galsim.random.permute(312, my_list) -# for ind in range(n_list): -# assert my_list_copy[ind_list[ind]] == my_list[ind] - -# # permute with no lists should raise TypeError -# with assert_raises(TypeError): -# galsim.random.permute(312) +@timer +def test_permute(): + """Simple tests of the permute() function.""" + # Make a fake list, and another list consisting of indices. + my_list = [3.7, 4.1, 1.9, 11.1, 378.3, 100.0] + import copy + my_list_copy = copy.deepcopy(my_list) + n_list = len(my_list) + ind_list = list(range(n_list)) + + # Permute both at the same time. + galsim.random.permute(312, np.array(my_list), np.array(ind_list)) + + # Make sure that everything is sensible + for ind in range(n_list): + assert my_list_copy[ind_list[ind]] == my_list[ind] + + # Repeat with same seed, should do same permutation. + my_list = copy.deepcopy(my_list_copy) + galsim.random.permute(312, np.array(my_list)) + for ind in range(n_list): + assert my_list_copy[ind_list[ind]] == my_list[ind] + + # NOTE no errors raised in JAX + # # permute with no lists should raise TypeError + # with assert_raises(TypeError): + # galsim.random.permute(312) # @timer From 1ded9fb68624dbd7c40904b84b19ce09e64f22f6 Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 10 Oct 2023 15:10:35 -0500 Subject: [PATCH 15/33] TST enable more tests --- tests/jax/galsim/test_random_jax.py | 243 +++++++++++----------------- 1 file changed, 95 insertions(+), 148 deletions(-) diff --git a/tests/jax/galsim/test_random_jax.py b/tests/jax/galsim/test_random_jax.py index f3043676..c4741f2e 100644 --- a/tests/jax/galsim/test_random_jax.py +++ b/tests/jax/galsim/test_random_jax.py @@ -1874,71 +1874,71 @@ def test_chi2(): # assert isinstance(eval(str(d)), galsim.DistDeviate) -# @timer -# def test_multiprocess(): -# """Test that the same random numbers are generated in single-process and multi-process modes. -# """ -# from multiprocessing import current_process -# from multiprocessing import get_context -# ctx = get_context('fork') -# Process = ctx.Process -# Queue = ctx.Queue - -# def generate_list(seed): -# """Given a particular seed value, generate a list of random numbers. -# Should be deterministic given the input seed value. -# """ -# rng = galsim.UniformDeviate(seed) -# out = [] -# for i in range(20): -# out.append(rng()) -# return out - -# def worker(input, output): -# """input is a queue with seed values -# output is a queue storing the results of the tasks along with the process name, -# and which args the result is for. -# """ -# for args in iter(input.get, 'STOP'): -# result = generate_list(*args) -# output.put( (result, current_process().name, args) ) - -# # Use sequential numbers. -# # On inspection, can see that even the first value in each list is random with -# # respect to the other lists. i.e. "nearby" inputs do not produce nearby outputs. -# # I don't know of an actual assert to do for this, but it is clearly true. -# seeds = [ 1532424 + i for i in range(16) ] - -# nproc = 4 # Each process will do 4 lists (typically) - -# # First make lists in the single process: -# ref_lists = dict() -# for seed in seeds: -# list = generate_list(seed) -# ref_lists[seed] = list - -# # Now do this with multiprocessing -# # Put the seeds in a queue -# task_queue = Queue() -# for seed in seeds: -# task_queue.put( [seed] ) - -# # Run the tasks: -# done_queue = Queue() -# for k in range(nproc): -# Process(target=worker, args=(task_queue, done_queue)).start() - -# # Check the results in the order they finished -# for i in range(len(seeds)): -# list, proc, args = done_queue.get() -# seed = args[0] -# np.testing.assert_array_equal( -# list, ref_lists[seed], -# err_msg="Random numbers are different when using multiprocessing") - -# # Stop the processes: -# for k in range(nproc): -# task_queue.put('STOP') +@timer +def test_multiprocess(): + """Test that the same random numbers are generated in single-process and multi-process modes. + """ + from multiprocessing import current_process + from multiprocessing import get_context + ctx = get_context('fork') + Process = ctx.Process + Queue = ctx.Queue + + def generate_list(seed): + """Given a particular seed value, generate a list of random numbers. + Should be deterministic given the input seed value. + """ + rng = galsim.UniformDeviate(seed) + out = [] + for i in range(20): + out.append(rng()) + return out + + def worker(input, output): + """input is a queue with seed values + output is a queue storing the results of the tasks along with the process name, + and which args the result is for. + """ + for args in iter(input.get, 'STOP'): + result = generate_list(*args) + output.put((result, current_process().name, args)) + + # Use sequential numbers. + # On inspection, can see that even the first value in each list is random with + # respect to the other lists. i.e. "nearby" inputs do not produce nearby outputs. + # I don't know of an actual assert to do for this, but it is clearly true. + seeds = [1532424 + i for i in range(16)] + + nproc = 4 # Each process will do 4 lists (typically) + + # First make lists in the single process: + ref_lists = dict() + for seed in seeds: + list = generate_list(seed) + ref_lists[seed] = list + + # Now do this with multiprocessing + # Put the seeds in a queue + task_queue = Queue() + for seed in seeds: + task_queue.put([seed]) + + # Run the tasks: + done_queue = Queue() + for k in range(nproc): + Process(target=worker, args=(task_queue, done_queue)).start() + + # Check the results in the order they finished + for i in range(len(seeds)): + _list, proc, args = done_queue.get() + seed = args[0] + np.testing.assert_array_equal( + _list, ref_lists[seed], + err_msg="Random numbers are different when using multiprocessing") + + # Stop the processes: + for k in range(nproc): + task_queue.put('STOP') @timer @@ -1970,86 +1970,33 @@ def test_permute(): # galsim.random.permute(312) -# @timer -# def test_ne(): -# """ Check that inequality works as expected for corner cases where the reprs of two -# unequal BaseDeviates may be the same due to truncation. -# """ -# a = galsim.BaseDeviate(seed='1 2 3 4 5 6 7 8 9 10') -# b = galsim.BaseDeviate(seed='1 2 3 7 6 5 4 8 9 10') -# assert repr(a) == repr(b) -# assert a != b - -# # Check DistDeviate separately, since it overrides __repr__ and __eq__ -# d1 = galsim.DistDeviate(seed=a, function=galsim.LookupTable([1, 2, 3], [4, 5, 6])) -# d2 = galsim.DistDeviate(seed=b, function=galsim.LookupTable([1, 2, 3], [4, 5, 6])) -# assert repr(d1) == repr(d2) -# assert d1 != d2 - -# @timer -# def test_int64(): -# # cf. #1009 -# # Check that various possible integer types work as seeds. - -# rng1 = galsim.BaseDeviate(int(123)) -# # cf. https://www.numpy.org/devdocs/user/basics.types.html -# ivalues =[np.int8(123), # Note this one requires i < 128 -# np.int16(123), -# np.int32(123), -# np.int64(123), -# np.uint8(123), -# np.uint16(123), -# np.uint32(123), -# np.uint64(123), -# np.short(123), -# np.ushort(123), -# np.intc(123), -# np.uintc(123), -# np.intp(123), -# np.uintp(123), -# np.int_(123), -# np.longlong(123), -# np.ulonglong(123), -# np.array(123).astype(np.int64)] - -# for i in ivalues: -# rng2 = galsim.BaseDeviate(i) -# assert rng2 == rng1 - -# @timer -# def test_numpy_generator(): -# rng = galsim.BaseDeviate(1234) -# gen = galsim.BaseDeviate(1234).as_numpy_generator() - -# # The regular (and somewhat cumbersome) GalSim way: -# a1 = np.empty(10, dtype=float) -# galsim.UniformDeviate(rng).generate(a1) -# a1 *= 9. -# a1 += 1. - -# # The nicer numpy syntax -# a2 = gen.uniform(1.,10., size=10) -# print('a1 = ',a1) -# print('a2 = ',a2) -# np.testing.assert_array_equal(a1, a2) - -# # Can also use the np property as a quick shorthand -# a1 = rng.np.normal(0, 10, size=20) -# a2 = gen.normal(0, 10, size=20) -# print('a1 = ',a1) -# print('a2 = ',a2) -# np.testing.assert_array_equal(a1, a2) - -# # Check that normal gives statistically the right mean/var. -# # (Numpy's normal uses the next_uint64 function, so this is a non-trivial test of that -# # code, which I originally got wrong.) -# a3 = gen.normal(17, 23, size=1_000_000) -# print('mean = ',np.mean(a3)) -# print('std = ',np.std(a3)) -# assert np.isclose(np.mean(a3), 17, rtol=1.e-3) -# assert np.isclose(np.std(a3), 23, rtol=3.e-3) - -# if __name__ == "__main__": -# testfns = [v for k, v in vars().items() if k[:5] == 'test_' and callable(v)] -# for testfn in testfns: -# testfn() +@timer +def test_int64(): + # cf. #1009 + # Check that various possible integer types work as seeds. + + rng1 = galsim.BaseDeviate(int(123)) + # cf. https://www.numpy.org/devdocs/user/basics.types.html + ivalues = [ + np.int8(123), # Note this one requires i < 128 + np.int16(123), + np.int32(123), + np.int64(123), + np.uint8(123), + np.uint16(123), + np.uint32(123), + np.uint64(123), + np.short(123), + np.ushort(123), + np.intc(123), + np.uintc(123), + np.intp(123), + np.uintp(123), + np.int_(123), + np.longlong(123), + np.ulonglong(123), + np.array(123).astype(np.int64)] + + for i in ivalues: + rng2 = galsim.BaseDeviate(i) + assert rng2 == rng1 From 0353e0e7b75dffc763409ee0fd9fc4aadc0810df Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 10 Oct 2023 17:12:47 -0500 Subject: [PATCH 16/33] REF use Gaussian at high n for poisson --- jax_galsim/random.py | 11 +- tests/jax/galsim/test_random_jax.py | 298 ++++++++++++++-------------- 2 files changed, 157 insertions(+), 152 deletions(-) diff --git a/jax_galsim/random.py b/jax_galsim/random.py index 1593c530..98396f65 100644 --- a/jax_galsim/random.py +++ b/jax_galsim/random.py @@ -391,7 +391,16 @@ def __call__(self): @jax.jit def _generate_one(key, mean): _key, subkey = jrandom.split(key) - return _key, jrandom.poisson(subkey, mean, dtype=int) + val = jax.lax.cond( + mean < 2**17, + lambda subkey, mean: jrandom.poisson(subkey, mean, dtype=int).astype(float), + lambda subkey, mean: ( + jrandom.normal(subkey, dtype=float) * jnp.sqrt(mean) + mean + ), + subkey, + mean, + ) + return _key, val @_wraps(_galsim.PoissonDeviate.generate_from_expectation) def generate_from_expectation(self, array): diff --git a/tests/jax/galsim/test_random_jax.py b/tests/jax/galsim/test_random_jax.py index c4741f2e..566a52c8 100644 --- a/tests/jax/galsim/test_random_jax.py +++ b/tests/jax/galsim/test_random_jax.py @@ -878,161 +878,157 @@ def test_poisson(): # assert_raises(TypeError, galsim.PoissonDeviate, set()) -# @timer -# def test_poisson_highmean(): -# """Test Poisson random number generator with high (>2^30) mean (cf. Issue #881) +@timer +def test_poisson_highmean(): + # NOTE JAX has the same issues as boost with high mean poisson RVs but + # it happens at lower mean values + mean_vals = [ + 2**17 - 50, # Uses Poisson + 2**17, # Uses Poisson (highest value of mean that does) + 2**17 + 50, # Uses Gaussian + 2**18, # This is where problems happen if not using Gaussian + 5.e20, # Definitely would have problems with normal implementation. + ] + + nvals = 100000 + rtol_var = 1.e-2 + + for mean in mean_vals: + print('Test PoissonDeviate with mean = ', np.log(mean) / np.log(2)) + p = galsim.PoissonDeviate(testseed, mean=mean) + p2 = p.duplicate() + p3 = galsim.PoissonDeviate(p.serialize(), mean=mean) + testResult = (p(), p(), p()) + testResult2 = (p2(), p2(), p2()) + testResult3 = (p3(), p3(), p3()) + np.testing.assert_allclose( + testResult2, testResult, rtol=1.e-8, + err_msg='PoissonDeviate.duplicate not equivalent for mean=%s' % mean) + np.testing.assert_allclose( + testResult3, testResult, rtol=1.e-8, + err_msg='PoissonDeviate from serialize not equivalent for mean=%s' % mean) + + # Check that the mean and variance come out right + p = galsim.PoissonDeviate(testseed, mean=mean) + vals = [p() for i in range(nvals)] + mu = np.mean(vals) + var = np.var(vals) + print("rtol = ", 3 * np.sqrt(mean / nvals) / mean) + print('mean = ', mu, ' true mean = ', mean) + print('var = ', var, ' true var = ', mean) + np.testing.assert_allclose( + mu, mean, rtol=3 * np.sqrt(mean / nvals) / mean, + err_msg='Wrong mean from PoissonDeviate with mean=%s' % mean) + np.testing.assert_allclose( + var, mean, rtol=rtol_var, + err_msg='Wrong variance from PoissonDeviate with mean=%s' % mean) + + # Check discard + p2 = galsim.PoissonDeviate(testseed, mean=mean) + p2.discard(nvals, suppress_warnings=True) + v1, v2 = p(), p2() + print('after %d vals, next one is %s, %s' % (nvals, v1, v2)) + assert v1 == v2 + + # NOTE: JAX cannot connect RNGs + # # Check that two connected poisson deviates work correctly together. + # p2 = galsim.PoissonDeviate(testseed, mean=mean) + # p.reset(p2) + # testResult2 = (p(), p(), p2()) + # np.testing.assert_array_equal( + # testResult2, testResult, + # err_msg='Wrong poisson random number sequence generated using two pds') + # p.seed(testseed) + # p2.clearCache() + # testResult2 = (p2(), p2(), p()) + # np.testing.assert_array_equal( + # testResult2, testResult, + # err_msg='Wrong poisson random number sequence generated using two pds after seed') + + # FIXME no noise methods for images in JAX yet + # Test filling an image + # p.seed(testseed) + # testimage = galsim.ImageD(np.zeros((3, 1))) + # testimage.addNoise(galsim.DeviateNoise(p)) + # np.testing.assert_array_equal( + # testimage.array.flatten(), testResult, + # err_msg='Wrong poisson random number sequence generated when applied to image.') + + # # The PoissonNoise version also subtracts off the mean value + # rng = galsim.BaseDeviate(testseed) + # pn = galsim.PoissonNoise(rng, sky_level=mean) + # testimage.fill(0) + # testimage.addNoise(pn) + # np.testing.assert_array_equal( + # testimage.array.flatten(), np.array(testResult) - mean, + # err_msg='Wrong poisson random number sequence generated using PoissonNoise') + + # # Check PoissonNoise variance: + # np.testing.assert_allclose( + # pn.getVariance(), mean, rtol=1.e-8, + # err_msg="PoissonNoise getVariance returns wrong variance") + # np.testing.assert_allclose( + # pn.sky_level, mean, rtol=1.e-8, + # err_msg="PoissonNoise sky_level returns wrong value") + + # # Check that the noise model really does produce this variance. + # big_im = galsim.Image(2048, 2048, dtype=float) + # big_im.addNoise(pn) + # var = np.var(big_im.array) + # print('variance = ', var) + # print('getVar = ', pn.getVariance()) + # np.testing.assert_allclose( + # var, pn.getVariance(), rtol=rtol_var, + # err_msg='Realized variance for PoissonNoise did not match getVariance()') -# It turns out that the boost poisson deviate class that we use maxes out at 2^31 and wraps -# around to -2^31. We have code to automatically switch over to using a Gaussian deviate -# instead if the mean > 2^30 (factor of 2 from the problem to be safe). Check that this -# works properly. -# """ -# mean_vals =[ 2**30 + 50, # Uses Gaussian -# 2**30 - 50, # Uses Poisson -# 2**30, # Uses Poisson (highest value of mean that does) -# 2**31, # This is where problems happen if not using Gaussian -# 5.e20, # Definitely would have problems with normal implementation. -# ] - -# if __name__ == '__main__': -# nvals = 10000000 -# rtol_var = 1.e-3 -# else: -# nvals = 100000 -# rtol_var = 1.e-2 - -# for mean in mean_vals: -# print('Test PoissonDeviate with mean = ',mean) -# p = galsim.PoissonDeviate(testseed, mean=mean) -# p2 = p.duplicate() -# p3 = galsim.PoissonDeviate(p.serialize(), mean=mean) -# testResult = (p(), p(), p()) -# testResult2 = (p2(), p2(), p2()) -# testResult3 = (p3(), p3(), p3()) -# np.testing.assert_allclose( -# testResult2, testResult, rtol=1.e-8, -# err_msg='PoissonDeviate.duplicate not equivalent for mean=%s'%mean) -# np.testing.assert_allclose( -# testResult3, testResult, rtol=1.e-8, -# err_msg='PoissonDeviate from serialize not equivalent for mean=%s'%mean) - -# # Check that the mean and variance come out right -# p = galsim.PoissonDeviate(testseed, mean=mean) -# vals = [p() for i in range(nvals)] -# mu = np.mean(vals) -# var = np.var(vals) -# print('mean = ',mu,' true mean = ',mean) -# print('var = ',var,' true var = ',mean) -# np.testing.assert_allclose(mu, mean, rtol=1.e-5, -# err_msg='Wrong mean from PoissonDeviate with mean=%s'%mean) -# np.testing.assert_allclose(var, mean, rtol=rtol_var, -# err_msg='Wrong variance from PoissonDeviate with mean=%s'%mean) - -# # Check discard -# p2 = galsim.PoissonDeviate(testseed, mean=mean) -# p2.discard(nvals, suppress_warnings=True) -# v1,v2 = p(),p2() -# print('after %d vals, next one is %s, %s'%(nvals,v1,v2)) -# if mean > 2**30: -# # Poisson doesn't have a reliable rng count (unless the mean is vv small). -# # But above 2**30 we're back to Gaussian, which is reliable. -# assert v1 == v2 - -# # Check that two connected poisson deviates work correctly together. -# p2 = galsim.PoissonDeviate(testseed, mean=mean) -# p.reset(p2) -# testResult2 = (p(), p(), p2()) -# np.testing.assert_array_equal( -# testResult2, testResult, -# err_msg='Wrong poisson random number sequence generated using two pds') -# p.seed(testseed) -# p2.clearCache() -# testResult2 = (p2(), p2(), p()) -# np.testing.assert_array_equal( -# testResult2, testResult, -# err_msg='Wrong poisson random number sequence generated using two pds after seed') - -# # Test filling an image -# p.seed(testseed) -# testimage = galsim.ImageD(np.zeros((3, 1))) -# testimage.addNoise(galsim.DeviateNoise(p)) -# np.testing.assert_array_equal( -# testimage.array.flatten(), testResult, -# err_msg='Wrong poisson random number sequence generated when applied to image.') - -# # The PoissonNoise version also subtracts off the mean value -# rng = galsim.BaseDeviate(testseed) -# pn = galsim.PoissonNoise(rng, sky_level=mean) -# testimage.fill(0) -# testimage.addNoise(pn) -# np.testing.assert_array_equal( -# testimage.array.flatten(), np.array(testResult)-mean, -# err_msg='Wrong poisson random number sequence generated using PoissonNoise') - -# # Check PoissonNoise variance: -# np.testing.assert_allclose( -# pn.getVariance(), mean, rtol=1.e-8, -# err_msg="PoissonNoise getVariance returns wrong variance") -# np.testing.assert_allclose( -# pn.sky_level, mean, rtol=1.e-8, -# err_msg="PoissonNoise sky_level returns wrong value") - -# # Check that the noise model really does produce this variance. -# big_im = galsim.Image(2048,2048,dtype=float) -# big_im.addNoise(pn) -# var = np.var(big_im.array) -# print('variance = ',var) -# print('getVar = ',pn.getVariance()) -# np.testing.assert_allclose( -# var, pn.getVariance(), rtol=rtol_var, -# err_msg='Realized variance for PoissonNoise did not match getVariance()') +@timer +def test_poisson_zeromean(): + """Make sure Poisson Deviate behaves sensibly when mean=0. + """ + p = galsim.PoissonDeviate(testseed, mean=0) + p2 = p.duplicate() + p3 = galsim.PoissonDeviate(p.serialize(), mean=0) + do_pickle(p) -# @timer -# def test_poisson_zeromean(): -# """Make sure Poisson Deviate behaves sensibly when mean=0. -# """ -# p = galsim.PoissonDeviate(testseed, mean=0) -# p2 = p.duplicate() -# p3 = galsim.PoissonDeviate(p.serialize(), mean=0) -# do_pickle(p) - -# # Test direct draws -# testResult = (p(), p(), p()) -# testResult2 = (p2(), p2(), p2()) -# testResult3 = (p3(), p3(), p3()) -# np.testing.assert_array_equal(testResult, 0) -# np.testing.assert_array_equal(testResult2, 0) -# np.testing.assert_array_equal(testResult3, 0) + # Test direct draws + testResult = (p(), p(), p()) + testResult2 = (p2(), p2(), p2()) + testResult3 = (p3(), p3(), p3()) + np.testing.assert_array_equal(testResult, 0) + np.testing.assert_array_equal(testResult2, 0) + np.testing.assert_array_equal(testResult3, 0) + + # Test generate + test_array = np.empty(3, dtype=int) + p.generate(test_array) + np.testing.assert_array_equal(test_array, 0) + p2.generate(test_array) + np.testing.assert_array_equal(test_array, 0) + p3.generate(test_array) + np.testing.assert_array_equal(test_array, 0) + + # Test generate_from_expectation + test_array = np.array([0, 0, 0]) + np.testing.assert_allclose(test_array, 0) + test_array = np.array([1, 0, 4]) + assert test_array[0] != 0 + assert test_array[1] == 0 + assert test_array[2] != 0 + + # NOTE JAX does not raise an error for this + # # Error raised if mean<0 + # with assert_raises(ValueError): + # p = galsim.PoissonDeviate(testseed, mean=-0.1) + # with assert_raises(ValueError): + # p = galsim.PoissonDeviate(testseed, mean=-10) + # test_array = np.array([-1,1,4]) + # with assert_raises(ValueError): + # p.generate_from_expectation(test_array) + # test_array = np.array([1,-1,-4]) + # with assert_raises(ValueError): + # p.generate_from_expectation(test_array) -# # Test generate -# test_array = np.empty(3, dtype=int) -# p.generate(test_array) -# np.testing.assert_array_equal(test_array, 0) -# p2.generate(test_array) -# np.testing.assert_array_equal(test_array, 0) -# p3.generate(test_array) -# np.testing.assert_array_equal(test_array, 0) - -# # Test generate_from_expectation -# test_array = np.array([0,0,0]) -# np.testing.assert_allclose(test_array, 0) -# test_array = np.array([1,0,4]) -# assert test_array[0] != 0 -# assert test_array[1] == 0 -# assert test_array[2] != 0 - -# # Error raised if mean<0 -# with assert_raises(ValueError): -# p = galsim.PoissonDeviate(testseed, mean=-0.1) -# with assert_raises(ValueError): -# p = galsim.PoissonDeviate(testseed, mean=-10) -# test_array = np.array([-1,1,4]) -# with assert_raises(ValueError): -# p.generate_from_expectation(test_array) -# test_array = np.array([1,-1,-4]) -# with assert_raises(ValueError): -# p.generate_from_expectation(test_array) # @timer # def test_weibull(): From f1881e9194feccb2d81c3e53871545779dbf5cc1 Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 10 Oct 2023 21:34:22 -0500 Subject: [PATCH 17/33] ENH add Wdist --- jax_galsim/__init__.py | 1 + jax_galsim/random.py | 106 +++++----- tests/jax/galsim/test_random_jax.py | 297 ++++++++++++++-------------- 3 files changed, 212 insertions(+), 192 deletions(-) diff --git a/jax_galsim/__init__.py b/jax_galsim/__init__.py index f7b85f48..856a0cc0 100644 --- a/jax_galsim/__init__.py +++ b/jax_galsim/__init__.py @@ -15,6 +15,7 @@ PoissonDeviate, Chi2Deviate, GammaDeviate, + WeibullDeviate, ) # Basic building blocks diff --git a/jax_galsim/random.py b/jax_galsim/random.py index 98396f65..3eadc91d 100644 --- a/jax_galsim/random.py +++ b/jax_galsim/random.py @@ -428,60 +428,74 @@ def __str__(self): return "galsim.PoissonDeviate(mean=%r)" % (ensure_hashable(self.mean),) -# class WeibullDeviate(BaseDeviate): -# """Pseudo-random Weibull-distributed deviate for shape parameter ``a`` and scale parameter ``b``. - -# The Weibull distribution is related to a number of other probability distributions; in -# particular, it interpolates between the exponential distribution (a=1) and the Rayleigh -# distribution (a=2). -# See http://en.wikipedia.org/wiki/Weibull_distribution (a=k and b=lambda in the notation adopted -# in the Wikipedia article) for more details. The Weibull distribution is real valued and -# produces deviates >= 0. +@_wraps( + _galsim.WeibullDeviate, + lax_description=LAX_FUNCTIONAL_RNG, +) +@register_pytree_node_class +class WeibullDeviate(BaseDeviate): + def __init__(self, seed=None, a=1.0, b=1.0): + super().__init__(seed=seed) + self._params["a"] = a + self._params["b"] = b -# Successive calls to ``w()`` generate pseudo-random values distributed according to a Weibull -# distribution with the specified shape and scale parameters ``a`` and ``b``:: + @property + def a(self): + """The shape parameter, a.""" + return self._params["a"] -# >>> w = galsim.WeibullDeviate(31415926, a=1.3, b=4) -# >>> w() -# 1.1038481241018219 -# >>> w() -# 2.957052966368049 + @property + def b(self): + """The scale parameter, b.""" + return self._params["b"] -# Parameters: -# seed: Something that can seed a `BaseDeviate`: an integer seed or another -# `BaseDeviate`. Using 0 means to generate a seed from the system. -# [default: None] -# a: Shape parameter of the distribution. [default: 1; Must be > 0] -# b: Scale parameter of the distribution. [default: 1; Must be > 0] -# """ -# def __init__(self, seed=None, a=1., b=1.): -# self._rng_type = _galsim.WeibullDeviateImpl -# self._rng_args = (float(a), float(b)) -# self.reset(seed) + @_wraps( + _galsim.WeibullDeviate.generate, + lax_description=( + "JAX arrays cannot be changed in-place, so the JAX version of " + "this method returns a new array." + ), + ) + def generate(self, array): + self._key, array = self.__class__._generate(self._key, array, self.a, self.b) + return array -# @property -# def a(self): -# """The shape parameter, a. -# """ -# return self._rng_args[0] + @jax.jit + def _generate(key, array, a, b): + # we do it this way so that the RNG appears to have a fixed state that is advanced per value drawn + carry, res = jax.lax.scan( + WeibullDeviate._generate_one, + (key, a, b), + None, + length=array.ravel().shape[0], + ) + key, _, _ = carry + return key, res.reshape(array.shape) -# @property -# def b(self): -# """The scale parameter, b. -# """ -# return self._rng_args[1] + def __call__(self): + carry, val = self.__class__._generate_one((self._key, self.a, self.b), None) + self._key, _, _ = carry + return val -# def __call__(self): -# """Draw a new random number from the distribution. + @jax.jit + def _generate_one(args, x): + key, a, b = args + _key, subkey = jrandom.split(key) + # argument order is scale, concentration + return (_key, a, b), jrandom.weibull_min(subkey, b, a, dtype=float) -# Returns a Weibull-distributed deviate with the given shape parameters a and b. -# """ -# return self._rng.generate1() + def __repr__(self): + return "galsim.WeibullDeviate(seed=%r, a=%r, b=%r)" % ( + ensure_hashable(jrandom.key_data(self._key)), + ensure_hashable(self.a), + ensure_hashable(self.b), + ) -# def __repr__(self): -# return 'galsim.WeibullDeviate(seed=%r, a=%r, b=%r)'%(self._seed_repr(), self.a, self.b) -# def __str__(self): -# return 'galsim.WeibullDeviate(a=%r, b=%r)'%(self.a, self.b) + def __str__(self): + return "galsim.WeibullDeviate(a=%r, b=%r)" % ( + ensure_hashable(self.a), + ensure_hashable(self.b), + ) @_wraps( diff --git a/tests/jax/galsim/test_random_jax.py b/tests/jax/galsim/test_random_jax.py index 566a52c8..6b01e651 100644 --- a/tests/jax/galsim/test_random_jax.py +++ b/tests/jax/galsim/test_random_jax.py @@ -1,3 +1,4 @@ +import math import numpy as np import os import galsim @@ -44,7 +45,7 @@ wA = 4.0 wB = 9.0 # Tabulated results for Weibull -wResult = (5.3648053017485591, 6.3093033550873878, 7.7982696798921074) +wResult = (3.2106530102, 6.4256210259, 5.8255498741) # k & theta to use for Gamma tests gammaK = 1.5 @@ -1030,163 +1031,167 @@ def test_poisson_zeromean(): # p.generate_from_expectation(test_array) -# @timer -# def test_weibull(): -# """Test Weibull random number generator -# """ -# w = galsim.WeibullDeviate(testseed, a=wA, b=wB) -# w2 = w.duplicate() -# w3 = galsim.WeibullDeviate(w.serialize(), a=wA, b=wB) -# testResult = (w(), w(), w()) -# np.testing.assert_array_almost_equal( -# np.array(testResult), np.array(wResult), precision, -# err_msg='Wrong Weibull random number sequence generated') -# testResult = (w2(), w2(), w2()) -# np.testing.assert_array_almost_equal( -# np.array(testResult), np.array(wResult), precision, -# err_msg='Wrong Weibull random number sequence generated with duplicate') -# testResult = (w3(), w3(), w3()) -# np.testing.assert_array_almost_equal( -# np.array(testResult), np.array(wResult), precision, -# err_msg='Wrong Weibull random number sequence generated from serialize') +@timer +def test_weibull(): + """Test Weibull random number generator + """ + w = galsim.WeibullDeviate(testseed, a=wA, b=wB) + w2 = w.duplicate() + w3 = galsim.WeibullDeviate(w.serialize(), a=wA, b=wB) + testResult = (w(), w(), w()) + np.testing.assert_array_almost_equal( + np.array(testResult), np.array(wResult), precision, + err_msg='Wrong Weibull random number sequence generated') + testResult = (w2(), w2(), w2()) + np.testing.assert_array_almost_equal( + np.array(testResult), np.array(wResult), precision, + err_msg='Wrong Weibull random number sequence generated with duplicate') + testResult = (w3(), w3(), w3()) + np.testing.assert_array_almost_equal( + np.array(testResult), np.array(wResult), precision, + err_msg='Wrong Weibull random number sequence generated from serialize') -# # Check that the mean and variance come out right -# w = galsim.WeibullDeviate(testseed, a=wA, b=wB) -# vals = [w() for i in range(nvals)] -# mean = np.mean(vals) -# var = np.var(vals) -# gammaFactor1 = math.gamma(1.+1./wA) -# gammaFactor2 = math.gamma(1.+2./wA) -# mu = wB * gammaFactor1 -# v = wB**2 * gammaFactor2 - mu**2 -# print('mean = ',mean,' true mean = ',mu) -# print('var = ',var,' true var = ',v) -# np.testing.assert_almost_equal(mean, mu, 1, -# err_msg='Wrong mean from WeibullDeviate') -# np.testing.assert_almost_equal(var, v, 1, -# err_msg='Wrong variance from WeibullDeviate') + # Check that the mean and variance come out right + w = galsim.WeibullDeviate(testseed, a=wA, b=wB) + vals = [w() for i in range(nvals)] + mean = np.mean(vals) + var = np.var(vals) + gammaFactor1 = math.gamma(1. + 1. / wA) + gammaFactor2 = math.gamma(1. + 2. / wA) + mu = wB * gammaFactor1 + v = wB**2 * gammaFactor2 - mu**2 + print('mean = ', mean, ' true mean = ', mu) + print('var = ', var, ' true var = ', v) + np.testing.assert_almost_equal( + mean, mu, 1, + err_msg='Wrong mean from WeibullDeviate') + np.testing.assert_almost_equal( + var, v, 1, + err_msg='Wrong variance from WeibullDeviate') -# # Check discard -# w2 = galsim.WeibullDeviate(testseed, a=wA, b=wB) -# w2.discard(nvals) -# v1,v2 = w(),w2() -# print('after %d vals, next one is %s, %s'%(nvals,v1,v2)) -# assert v1 == v2 -# assert w.has_reliable_discard -# assert not w.generates_in_pairs + # Check discard + w2 = galsim.WeibullDeviate(testseed, a=wA, b=wB) + w2.discard(nvals) + v1, v2 = w(), w2() + print('after %d vals, next one is %s, %s' % (nvals, v1, v2)) + assert v1 == v2 + assert w.has_reliable_discard + assert not w.generates_in_pairs -# # Check seed, reset -# w.seed(testseed) -# testResult2 = (w(), w(), w()) -# np.testing.assert_array_equal( -# np.array(testResult), np.array(testResult2), -# err_msg='Wrong weibull random number sequence generated after seed') + # Check seed, reset + w.seed(testseed) + testResult2 = (w(), w(), w()) + np.testing.assert_array_equal( + np.array(testResult), np.array(testResult2), + err_msg='Wrong weibull random number sequence generated after seed') -# w.reset(testseed) -# testResult2 = (w(), w(), w()) -# np.testing.assert_array_equal( -# np.array(testResult), np.array(testResult2), -# err_msg='Wrong weibull random number sequence generated after reset(seed)') + w.reset(testseed) + testResult2 = (w(), w(), w()) + np.testing.assert_array_equal( + np.array(testResult), np.array(testResult2), + err_msg='Wrong weibull random number sequence generated after reset(seed)') -# rng = galsim.BaseDeviate(testseed) -# w.reset(rng) -# testResult2 = (w(), w(), w()) -# np.testing.assert_array_equal( -# np.array(testResult), np.array(testResult2), -# err_msg='Wrong weibull random number sequence generated after reset(rng)') + rng = galsim.BaseDeviate(testseed) + w.reset(rng) + testResult2 = (w(), w(), w()) + np.testing.assert_array_equal( + np.array(testResult), np.array(testResult2), + err_msg='Wrong weibull random number sequence generated after reset(rng)') -# ud = galsim.UniformDeviate(testseed) -# w.reset(ud) -# testResult = (w(), w(), w()) -# np.testing.assert_array_equal( -# np.array(testResult), np.array(testResult2), -# err_msg='Wrong weibull random number sequence generated after reset(ud)') + ud = galsim.UniformDeviate(testseed) + w.reset(ud) + testResult = (w(), w(), w()) + np.testing.assert_array_equal( + np.array(testResult), np.array(testResult2), + err_msg='Wrong weibull random number sequence generated after reset(ud)') -# # Check that two connected weibull deviates work correctly together. -# w2 = galsim.WeibullDeviate(testseed, a=wA, b=wB) -# w.reset(w2) -# testResult2 = (w(), w2(), w()) -# np.testing.assert_array_equal( -# np.array(testResult), np.array(testResult2), -# err_msg='Wrong weibull random number sequence generated using two wds') -# w.seed(testseed) -# testResult2 = (w2(), w(), w2()) -# np.testing.assert_array_equal( -# np.array(testResult), np.array(testResult2), -# err_msg='Wrong weibull random number sequence generated using two wds after seed') + # NOTE JAX does not allow connected deviates + # # Check that two connected weibull deviates work correctly together. + # w2 = galsim.WeibullDeviate(testseed, a=wA, b=wB) + # w.reset(w2) + # testResult2 = (w(), w2(), w()) + # np.testing.assert_array_equal( + # np.array(testResult), np.array(testResult2), + # err_msg='Wrong weibull random number sequence generated using two wds') + # w.seed(testseed) + # testResult2 = (w2(), w(), w2()) + # np.testing.assert_array_equal( + # np.array(testResult), np.array(testResult2), + # err_msg='Wrong weibull random number sequence generated using two wds after seed') -# # Check that seeding with the time works (although we cannot check the output). -# # We're mostly just checking that this doesn't raise an exception. -# # The output could be anything. -# w.seed() -# testResult2 = (w(), w(), w()) -# assert testResult2 != testResult -# w.reset() -# testResult3 = (w(), w(), w()) -# assert testResult3 != testResult -# assert testResult3 != testResult2 -# w.reset() -# testResult4 = (w(), w(), w()) -# assert testResult4 != testResult -# assert testResult4 != testResult2 -# assert testResult4 != testResult3 -# w = galsim.WeibullDeviate(a=wA, b=wB) -# testResult5 = (w(), w(), w()) -# assert testResult5 != testResult -# assert testResult5 != testResult2 -# assert testResult5 != testResult3 -# assert testResult5 != testResult4 + # Check that seeding with the time works (although we cannot check the output). + # We're mostly just checking that this doesn't raise an exception. + # The output could be anything. + w.seed() + testResult2 = (w(), w(), w()) + assert testResult2 != testResult + w.reset() + testResult3 = (w(), w(), w()) + assert testResult3 != testResult + assert testResult3 != testResult2 + w.reset() + testResult4 = (w(), w(), w()) + assert testResult4 != testResult + assert testResult4 != testResult2 + assert testResult4 != testResult3 + w = galsim.WeibullDeviate(a=wA, b=wB) + testResult5 = (w(), w(), w()) + assert testResult5 != testResult + assert testResult5 != testResult2 + assert testResult5 != testResult3 + assert testResult5 != testResult4 -# # Test generate -# w.seed(testseed) -# test_array = np.empty(3) -# w.generate(test_array) -# np.testing.assert_array_almost_equal( -# test_array, np.array(wResult), precision, -# err_msg='Wrong weibull random number sequence from generate.') + # Test generate + w.seed(testseed) + test_array = np.empty(3) + test_array = w.generate(test_array) + np.testing.assert_array_almost_equal( + test_array, np.array(wResult), precision, + err_msg='Wrong weibull random number sequence from generate.') -# # Test generate with a float32 array -# w.seed(testseed) -# test_array = np.empty(3, dtype=np.float32) -# w.generate(test_array) -# np.testing.assert_array_almost_equal( -# test_array, np.array(wResult), precisionF, -# err_msg='Wrong weibull random number sequence from generate.') + # Test generate with a float32 array + w.seed(testseed) + test_array = np.empty(3, dtype=np.float32) + test_array = w.generate(test_array) + np.testing.assert_array_almost_equal( + test_array, np.array(wResult), precisionF, + err_msg='Wrong weibull random number sequence from generate.') -# # Check that generated values are independent of number of threads. -# w1 = galsim.WeibullDeviate(testseed, a=3.1, b=7.3) -# w2 = galsim.WeibullDeviate(testseed, a=3.1, b=7.3) -# v1 = np.empty(555) -# v2 = np.empty(555) -# with single_threaded(): -# w1.generate(v1) -# with single_threaded(num_threads=10): -# w2.generate(v2) -# np.testing.assert_array_equal(v1, v2) -# with single_threaded(): -# w1.add_generate(v1) -# with single_threaded(num_threads=10): -# w2.add_generate(v2) -# np.testing.assert_array_equal(v1, v2) + # Check that generated values are independent of number of threads. + w1 = galsim.WeibullDeviate(testseed, a=3.1, b=7.3) + w2 = galsim.WeibullDeviate(testseed, a=3.1, b=7.3) + v1 = np.empty(555) + v2 = np.empty(555) + with single_threaded(): + v1 = w1.generate(v1) + with single_threaded(num_threads=10): + v2 = w2.generate(v2) + np.testing.assert_array_equal(v1, v2) + with single_threaded(): + v1 = w1.add_generate(v1) + with single_threaded(num_threads=10): + v2 = w2.add_generate(v2) + np.testing.assert_array_equal(v1, v2) -# # Check picklability -# do_pickle(w, lambda x: (x.serialize(), x.a, x.b)) -# do_pickle(w, lambda x: (x(), x(), x(), x())) -# do_pickle(w) -# assert 'WeibullDeviate' in repr(w) -# assert 'WeibullDeviate' in str(w) -# assert isinstance(eval(repr(w)), galsim.WeibullDeviate) -# assert isinstance(eval(str(w)), galsim.WeibullDeviate) - -# # Check that we can construct a WeibullDeviate from None, and that it depends on dev/random. -# w1 = galsim.WeibullDeviate(None) -# w2 = galsim.WeibullDeviate(None) -# assert w1 != w2, "Consecutive WeibullDeviate(None) compared equal!" -# # We shouldn't be able to construct a WeibullDeviate from anything but a BaseDeviate, int, str, -# # or None. -# assert_raises(TypeError, galsim.WeibullDeviate, dict()) -# assert_raises(TypeError, galsim.WeibullDeviate, list()) -# assert_raises(TypeError, galsim.WeibullDeviate, set()) + # Check picklability + do_pickle(w, lambda x: (x.serialize(), x.a, x.b)) + do_pickle(w, lambda x: (x(), x(), x(), x())) + do_pickle(w) + assert 'WeibullDeviate' in repr(w) + assert 'WeibullDeviate' in str(w) + assert isinstance(eval(repr(w)), galsim.WeibullDeviate) + assert isinstance(eval(str(w)), galsim.WeibullDeviate) + + # Check that we can construct a WeibullDeviate from None, and that it depends on dev/random. + w1 = galsim.WeibullDeviate(None) + w2 = galsim.WeibullDeviate(None) + assert w1 != w2, "Consecutive WeibullDeviate(None) compared equal!" + # NOTE JAX does not do type checking + # # We shouldn't be able to construct a WeibullDeviate from anything but a BaseDeviate, int, str, + # # or None. + # assert_raises(TypeError, galsim.WeibullDeviate, dict()) + # assert_raises(TypeError, galsim.WeibullDeviate, list()) + # assert_raises(TypeError, galsim.WeibullDeviate, set()) @timer From ecc4b48048670279068a7580f9ca0e1b7b0456bc Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 10 Oct 2023 23:00:49 -0500 Subject: [PATCH 18/33] ENH add bionomial --- jax_galsim/__init__.py | 1 + jax_galsim/random.py | 106 ++++++---- tests/jax/galsim/test_random_jax.py | 302 ++++++++++++++-------------- 3 files changed, 220 insertions(+), 189 deletions(-) diff --git a/jax_galsim/__init__.py b/jax_galsim/__init__.py index 856a0cc0..6682197a 100644 --- a/jax_galsim/__init__.py +++ b/jax_galsim/__init__.py @@ -16,6 +16,7 @@ Chi2Deviate, GammaDeviate, WeibullDeviate, + BinomialDeviate, ) # Basic building blocks diff --git a/jax_galsim/random.py b/jax_galsim/random.py index 3eadc91d..e5bdbc38 100644 --- a/jax_galsim/random.py +++ b/jax_galsim/random.py @@ -1,4 +1,5 @@ import secrets +from functools import partial import galsim as _galsim import jax @@ -294,57 +295,76 @@ def __str__(self): ) -# class BinomialDeviate(BaseDeviate): -# """Pseudo-random Binomial deviate for ``N`` trials each of probability ``p``. - -# ``N`` is number of 'coin flips,' ``p`` is probability of 'heads,' and each call returns an -# integer value where 0 <= value <= N gives the number of heads. See -# http://en.wikipedia.org/wiki/Binomial_distribution for more information. +@_wraps( + _galsim.BinomialDeviate, + lax_description=LAX_FUNCTIONAL_RNG, +) +@register_pytree_node_class +class BinomialDeviate(BaseDeviate): + def __init__(self, seed=None, N=1, p=0.5): + super().__init__(seed=seed) + self._params["N"] = N + self._params["p"] = p -# Successive calls to ``b()`` generate pseudo-random integer values distributed according to a -# binomial distribution with the provided ``N``, ``p``:: + @property + def n(self): + """The shape parameter, a.""" + return self._params["N"] -# >>> b = galsim.BinomialDeviate(31415926, N=10, p=0.3) -# >>> b() -# 2 -# >>> b() -# 3 + @property + def p(self): + """The scale parameter, b.""" + return self._params["p"] -# Parameters: -# seed: Something that can seed a `BaseDeviate`: an integer seed or another -# `BaseDeviate`. Using 0 means to generate a seed from the system. -# [default: None] -# N: The number of 'coin flips' per trial. [default: 1; Must be > 0] -# p: The probability of success per coin flip. [default: 0.5; Must be > 0] -# """ -# def __init__(self, seed=None, N=1, p=0.5): -# self._rng_type = _galsim.BinomialDeviateImpl -# self._rng_args = (int(N), float(p)) -# self.reset(seed) + @_wraps( + _galsim.BinomialDeviate.generate, + lax_description=( + "JAX arrays cannot be changed in-place, so the JAX version of " + "this method returns a new array." + ), + ) + def generate(self, array): + self._key, array = BinomialDeviate._generate(self._key, array, self.n, self.p) + return array -# @property -# def n(self): -# """The number of 'coin flips'. -# """ -# return self._rng_args[0] + @partial(jax.jit, static_argnums=(2,)) + def _generate(key, array, n, p): + # we do it this way so that the RNG appears to have a fixed state that is advanced per value drawn + carry, res = jax.lax.scan( + BinomialDeviate._generate_one, + (key, jnp.broadcast_to(p, (n,))), + None, + length=array.ravel().shape[0], + ) + key = carry[0] + return key, res.reshape(array.shape) -# @property -# def p(self): -# """The probability of success per 'coin flip'. -# """ -# return self._rng_args[1] + def __call__(self): + carry, val = BinomialDeviate._generate_one( + (self._key, jnp.broadcast_to(self.p, (self.n,))), None + ) + self._key = carry[0] + return val -# def __call__(self): -# """Draw a new random number from the distribution. + @jax.jit + def _generate_one(args, x): + key, p = args + _key, subkey = jrandom.split(key) + # argument order is scale, concentration + return (_key, p), jnp.sum(jrandom.bernoulli(subkey, p)) -# Returns a Binomial deviate with the given n and p. -# """ -# return self._rng.generate1() + def __repr__(self): + return "galsim.BinomialDeviate(seed=%r, N=%r, p=%r)" % ( + ensure_hashable(jrandom.key_data(self._key)), + ensure_hashable(self.n), + ensure_hashable(self.p), + ) -# def __repr__(self): -# return 'galsim.BinomialDeviate(seed=%r, N=%r, p=%r)'%(self._seed_repr(), self.n, self.p) -# def __str__(self): -# return 'galsim.BinomialDeviate(N=%r, p=%r)'%(self.n, self.p) + def __str__(self): + return "galsim.BinomialDeviate(N=%r, p=%r)" % ( + ensure_hashable(self.n), + ensure_hashable(self.p), + ) @_wraps( diff --git a/tests/jax/galsim/test_random_jax.py b/tests/jax/galsim/test_random_jax.py index 6b01e651..68e0b409 100644 --- a/tests/jax/galsim/test_random_jax.py +++ b/tests/jax/galsim/test_random_jax.py @@ -34,7 +34,7 @@ bN = 10 bp = 0.7 # the right answer for the first three binomial deviates produced from testseed -bResult = (9, 8, 7) +bResult = (5, 8, 7) # mean to use for Poisson tests pMean = 7 @@ -517,163 +517,173 @@ def test_gaussian(): # assert_raises(ValueError, galsim.GaussianDeviate, testseed, mean=1, sigma=-1) -# @timer -# def test_binomial(): -# """Test binomial random number generator -# """ -# b = galsim.BinomialDeviate(testseed, N=bN, p=bp) -# b2 = b.duplicate() -# b3 = galsim.BinomialDeviate(b.serialize(), N=bN, p=bp) -# testResult = (b(), b(), b()) -# np.testing.assert_array_almost_equal( -# np.array(testResult), np.array(bResult), precision, -# err_msg='Wrong binomial random number sequence generated') -# testResult = (b2(), b2(), b2()) -# np.testing.assert_array_almost_equal( -# np.array(testResult), np.array(bResult), precision, -# err_msg='Wrong binomial random number sequence generated with duplicate') -# testResult = (b3(), b3(), b3()) -# np.testing.assert_array_almost_equal( -# np.array(testResult), np.array(bResult), precision, -# err_msg='Wrong binomial random number sequence generated from serialize') +@timer +def test_binomial(): + """Test binomial random number generator + """ -# # Check that the mean and variance come out right -# b = galsim.BinomialDeviate(testseed, N=bN, p=bp) -# vals = [b() for i in range(nvals)] -# mean = np.mean(vals) -# var = np.var(vals) -# mu = bN*bp -# v = bN*bp*(1.-bp) -# print('mean = ',mean,' true mean = ',mu) -# print('var = ',var,' true var = ',v) -# np.testing.assert_almost_equal(mean, mu, 1, -# err_msg='Wrong mean from BinomialDeviate') -# np.testing.assert_almost_equal(var, v, 1, -# err_msg='Wrong variance from BinomialDeviate') + b = galsim.BinomialDeviate(testseed, N=bN, p=bp) + b2 = b.duplicate() + b3 = galsim.BinomialDeviate(b.serialize(), N=bN, p=bp) + testResult = (b(), b(), b()) + np.testing.assert_array_almost_equal( + np.array(testResult), np.array(bResult), precision, + err_msg='Wrong binomial random number sequence generated') + testResult = (b2(), b2(), b2()) + np.testing.assert_array_almost_equal( + np.array(testResult), np.array(bResult), precision, + err_msg='Wrong binomial random number sequence generated with duplicate') + testResult = (b3(), b3(), b3()) + np.testing.assert_array_almost_equal( + np.array(testResult), np.array(bResult), precision, + err_msg='Wrong binomial random number sequence generated from serialize') -# # Check discard -# b2 = galsim.BinomialDeviate(testseed, N=bN, p=bp) -# b2.discard(nvals) -# v1,v2 = b(),b2() -# print('after %d vals, next one is %s, %s'%(nvals,v1,v2)) -# assert v1 == v2 -# assert b.has_reliable_discard -# assert not b.generates_in_pairs + # Check that the mean and variance come out right + b = galsim.BinomialDeviate(testseed, N=bN, p=bp) + vals = [b() for i in range(nvals)] + mean = np.mean(vals) + var = np.var(vals) + mu = bN * bp + v = bN * bp * (1. - bp) + print('mean = ', mean, ' true mean = ', mu) + print('var = ', var, ' true var = ', v) + np.testing.assert_almost_equal( + mean, mu, 1, + err_msg='Wrong mean from BinomialDeviate') + np.testing.assert_almost_equal( + var, v, 1, + err_msg='Wrong variance from BinomialDeviate') -# # Check seed, reset -# b.seed(testseed) -# testResult2 = (b(), b(), b()) -# np.testing.assert_array_equal( -# np.array(testResult), np.array(testResult2), -# err_msg='Wrong binomial random number sequence generated after seed') + # Check discard + b2 = galsim.BinomialDeviate(testseed, N=bN, p=bp) + b2.discard(nvals) + v1, v2 = b(), b2() + print('after %d vals, next one is %s, %s' % (nvals, v1, v2)) + assert v1 == v2 + assert b.has_reliable_discard + assert not b.generates_in_pairs -# b.reset(testseed) -# testResult2 = (b(), b(), b()) -# np.testing.assert_array_equal( -# np.array(testResult), np.array(testResult2), -# err_msg='Wrong binomial random number sequence generated after reset(seed)') + # Check seed, reset + b.seed(testseed) + testResult2 = (b(), b(), b()) + np.testing.assert_array_equal( + np.array(testResult), np.array(testResult2), + err_msg='Wrong binomial random number sequence generated after seed') -# rng = galsim.BaseDeviate(testseed) -# b.reset(rng) -# testResult2 = (b(), b(), b()) -# np.testing.assert_array_equal( -# np.array(testResult), np.array(testResult2), -# err_msg='Wrong binomial random number sequence generated after reset(rng)') + b.reset(testseed) + testResult2 = (b(), b(), b()) + np.testing.assert_array_equal( + np.array(testResult), np.array(testResult2), + err_msg='Wrong binomial random number sequence generated after reset(seed)') -# ud = galsim.UniformDeviate(testseed) -# b.reset(ud) -# testResult = (b(), b(), b()) -# np.testing.assert_array_equal( -# np.array(testResult), np.array(testResult2), -# err_msg='Wrong binomial random number sequence generated after reset(ud)') + rng = galsim.BaseDeviate(testseed) + b.reset(rng) + testResult2 = (b(), b(), b()) + np.testing.assert_array_equal( + np.array(testResult), np.array(testResult2), + err_msg='Wrong binomial random number sequence generated after reset(rng)') -# # Check that two connected binomial deviates work correctly together. -# b2 = galsim.BinomialDeviate(testseed, N=bN, p=bp) -# b.reset(b2) -# testResult2 = (b(), b2(), b()) -# np.testing.assert_array_equal( -# np.array(testResult), np.array(testResult2), -# err_msg='Wrong binomial random number sequence generated using two bds') -# b.seed(testseed) -# testResult2 = (b2(), b(), b2()) -# np.testing.assert_array_equal( -# np.array(testResult), np.array(testResult2), -# err_msg='Wrong binomial random number sequence generated using two bds after seed') + ud = galsim.UniformDeviate(testseed) + b.reset(ud) + testResult = (b(), b(), b()) + np.testing.assert_array_equal( + np.array(testResult), np.array(testResult2), + err_msg='Wrong binomial random number sequence generated after reset(ud)') -# # Check that seeding with the time works (although we cannot check the output). -# # We're mostly just checking that this doesn't raise an exception. -# # The output could be anything. However, in this case, there are few enough options -# # for the output that occasionally two of these match. So we don't do the normal -# # testResult2 != testResult, etc. -# b.seed() -# testResult2 = (b(), b(), b()) -# #assert testResult2 != testResult -# b.reset() -# testResult3 = (b(), b(), b()) -# #assert testResult3 != testResult -# #assert testResult3 != testResult2 -# b.reset() -# testResult4 = (b(), b(), b()) -# #assert testResult4 != testResult -# #assert testResult4 != testResult2 -# #assert testResult4 != testResult3 -# b = galsim.BinomialDeviate(N=bN, p=bp) -# testResult5 = (b(), b(), b()) -# #assert testResult5 != testResult -# #assert testResult5 != testResult2 -# #assert testResult5 != testResult3 -# #assert testResult5 != testResult4 + # NOTE JAX does not support connected RNGs + # # Check that two connected binomial deviates work correctly together. + # b2 = galsim.BinomialDeviate(testseed, N=bN, p=bp) + # b.reset(b2) + # testResult2 = (b(), b2(), b()) + # np.testing.assert_array_equal( + # np.array(testResult), np.array(testResult2), + # err_msg='Wrong binomial random number sequence generated using two bds') + # b.seed(testseed) + # testResult2 = (b2(), b(), b2()) + # np.testing.assert_array_equal( + # np.array(testResult), np.array(testResult2), + # err_msg='Wrong binomial random number sequence generated using two bds after seed') -# # Test generate -# b.seed(testseed) -# test_array = np.empty(3) -# b.generate(test_array) -# np.testing.assert_array_almost_equal( -# test_array, np.array(bResult), precision, -# err_msg='Wrong binomial random number sequence from generate.') + # Check that seeding with the time works (although we cannot check the output). + # We're mostly just checking that this doesn't raise an exception. + # The output could be anything. However, in this case, there are few enough options + # for the output that occasionally two of these match. So we don't do the normal + # testResult2 != testResult, etc. + b.seed() + testResult2 = (b(), b(), b()) + b.reset() + testResult3 = (b(), b(), b()) + b.reset() + testResult4 = (b(), b(), b()) + b = galsim.BinomialDeviate(testseed, N=bN, p=bp) + testResult5 = (b(), b(), b()) + assert ( + (testResult2 != testResult) + or (testResult3 != testResult) + or (testResult4 != testResult) + or (testResult5 != testResult) + ) + try: + assert testResult3 != testResult2 + assert testResult4 != testResult2 + assert testResult4 != testResult3 + assert testResult5 != testResult2 + assert testResult5 != testResult3 + assert testResult5 != testResult4 + except AssertionError: + print("one of the poisson results was equal but this can happen occasionally") -# # Test generate with an int array -# b.seed(testseed) -# test_array = np.empty(3, dtype=int) -# b.generate(test_array) -# np.testing.assert_array_almost_equal( -# test_array, np.array(bResult), precisionI, -# err_msg='Wrong binomial random number sequence from generate.') + # Test generate + b.seed(testseed) + test_array = np.empty(3) + test_array = b.generate(test_array) + np.testing.assert_array_almost_equal( + test_array, np.array(bResult), precision, + err_msg='Wrong binomial random number sequence from generate.') -# # Check that generated values are independent of number of threads. -# b1 = galsim.BinomialDeviate(testseed, N=17, p=0.7) -# b2 = galsim.BinomialDeviate(testseed, N=17, p=0.7) -# v1 = np.empty(555) -# v2 = np.empty(555) -# with single_threaded(): -# b1.generate(v1) -# with single_threaded(num_threads=10): -# b2.generate(v2) -# np.testing.assert_array_equal(v1, v2) -# with single_threaded(): -# b1.add_generate(v1) -# with single_threaded(num_threads=10): -# b2.add_generate(v2) -# np.testing.assert_array_equal(v1, v2) + # Test generate with an int array + b.seed(testseed) + test_array = np.empty(3, dtype=int) + test_array = b.generate(test_array) + np.testing.assert_array_almost_equal( + test_array, np.array(bResult), precisionI, + err_msg='Wrong binomial random number sequence from generate.') -# # Check picklability -# do_pickle(b, lambda x: (x.serialize(), x.n, x.p)) -# do_pickle(b, lambda x: (x(), x(), x(), x())) -# do_pickle(b) -# assert 'BinomialDeviate' in repr(b) -# assert 'BinomialDeviate' in str(b) -# assert isinstance(eval(repr(b)), galsim.BinomialDeviate) -# assert isinstance(eval(str(b)), galsim.BinomialDeviate) - -# # Check that we can construct a BinomialDeviate from None, and that it depends on dev/random. -# b1 = galsim.BinomialDeviate(None) -# b2 = galsim.BinomialDeviate(None) -# assert b1 != b2, "Consecutive BinomialDeviate(None) compared equal!" -# # We shouldn't be able to construct a BinomialDeviate from anything but a BaseDeviate, int, str, -# # or None. -# assert_raises(TypeError, galsim.BinomialDeviate, dict()) -# assert_raises(TypeError, galsim.BinomialDeviate, list()) -# assert_raises(TypeError, galsim.BinomialDeviate, set()) + # Check that generated values are independent of number of threads. + b1 = galsim.BinomialDeviate(testseed, N=17, p=0.7) + b2 = galsim.BinomialDeviate(testseed, N=17, p=0.7) + v1 = np.empty(555) + v2 = np.empty(555) + with single_threaded(): + v1 = b1.generate(v1) + with single_threaded(num_threads=10): + v2 = b2.generate(v2) + np.testing.assert_array_equal(v1, v2) + with single_threaded(): + v1 = b1.add_generate(v1) + with single_threaded(num_threads=10): + v2 = b2.add_generate(v2) + np.testing.assert_array_equal(v1, v2) + + # Check picklability + do_pickle(b, lambda x: (x.serialize(), x.n, x.p)) + do_pickle(b, lambda x: (x(), x(), x(), x())) + do_pickle(b) + assert 'BinomialDeviate' in repr(b) + assert 'BinomialDeviate' in str(b) + assert isinstance(eval(repr(b)), galsim.BinomialDeviate) + assert isinstance(eval(str(b)), galsim.BinomialDeviate) + + # Check that we can construct a BinomialDeviate from None, and that it depends on dev/random. + b1 = galsim.BinomialDeviate(None) + b2 = galsim.BinomialDeviate(None) + assert b1 != b2, "Consecutive BinomialDeviate(None) compared equal!" + # NOTE JAX does not do type checking + # # We shouldn't be able to construct a BinomialDeviate from anything but a BaseDeviate, int, str, + # # or None. + # assert_raises(TypeError, galsim.BinomialDeviate, dict()) + # assert_raises(TypeError, galsim.BinomialDeviate, list()) + # assert_raises(TypeError, galsim.BinomialDeviate, set()) @timer From f4a3f41129842ae152314de63f27b5697380d1c3 Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 11 Oct 2023 07:50:13 -0500 Subject: [PATCH 19/33] TST add tests for api --- jax_galsim/random.py | 37 +++++++++++----------- tests/jax/test_api.py | 71 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+), 17 deletions(-) diff --git a/jax_galsim/random.py b/jax_galsim/random.py index e5bdbc38..c5d1307a 100644 --- a/jax_galsim/random.py +++ b/jax_galsim/random.py @@ -28,16 +28,19 @@ @register_pytree_node_class class BaseDeviate: # always the case for JAX - has_reliable_discard = True - generates_in_pairs = False - def __init__(self, seed=None): self.reset(seed=seed) self._params = {} @property - def key(self): - return self._key + @_wraps(_galsim.BaseDeviate.has_reliable_discard) + def has_reliable_discard(self): + return True + + @property + @_wraps(_galsim.BaseDeviate.generates_in_pairs) + def generates_in_pairs(self): + return False @_wraps( _galsim.BaseDeviate.seed, @@ -139,7 +142,7 @@ def generate(self, array): return array @_wraps( - _galsim.BaseDeviate.generate, + _galsim.BaseDeviate.add_generate, lax_description=( "JAX arrays cannot be changed in-place, so the JAX version of " "this method returns a new array." @@ -239,13 +242,13 @@ def __init__(self, seed=None, mean=0.0, sigma=1.0): self._params["sigma"] = sigma @property + @_wraps(_galsim.GaussianDeviate.mean) def mean(self): - """The mean of the Gaussian distribution.""" return self._params["mean"] @property + @_wraps(_galsim.GaussianDeviate.sigma) def sigma(self): - """The sigma of the Gaussian distribution.""" return self._params["sigma"] @_wraps( @@ -307,13 +310,13 @@ def __init__(self, seed=None, N=1, p=0.5): self._params["p"] = p @property + @_wraps(_galsim.BinomialDeviate.n) def n(self): - """The shape parameter, a.""" return self._params["N"] @property + @_wraps(_galsim.BinomialDeviate.p) def p(self): - """The scale parameter, b.""" return self._params["p"] @_wraps( @@ -378,8 +381,8 @@ def __init__(self, seed=None, mean=1.0): self._params["mean"] = mean @property + @_wraps(_galsim.PoissonDeviate.mean) def mean(self): - """The mean of the Gaussian distribution.""" return self._params["mean"] @_wraps( @@ -460,13 +463,13 @@ def __init__(self, seed=None, a=1.0, b=1.0): self._params["b"] = b @property + @_wraps(_galsim.WeibullDeviate.a) def a(self): - """The shape parameter, a.""" return self._params["a"] @property + @_wraps(_galsim.WeibullDeviate.b) def b(self): - """The scale parameter, b.""" return self._params["b"] @_wraps( @@ -530,13 +533,13 @@ def __init__(self, seed=None, k=1.0, theta=1.0): self._params["theta"] = theta @property + @_wraps(_galsim.GammaDeviate.k) def k(self): - """The shape parameter, k.""" return self._params["k"] @property + @_wraps(_galsim.GammaDeviate.theta) def theta(self): - """The scale parameter, theta.""" return self._params["theta"] @_wraps( @@ -595,8 +598,8 @@ def __init__(self, seed=None, n=1.0): self._params["n"] = n @property + @_wraps(_galsim.Chi2Deviate.n) def n(self): - """The number of degrees of freedom.""" return self._params["n"] @_wraps( @@ -890,6 +893,6 @@ def permute(rng, *args): rng = BaseDeviate(rng) arrs = [] for arr in args: - arrs.append(jrandom.permutation(rng.key, arr)) + arrs.append(jrandom.permutation(rng._key, arr)) rng.discard(1) return arrs diff --git a/tests/jax/test_api.py b/tests/jax/test_api.py index 0aedebb0..76f42a76 100644 --- a/tests/jax/test_api.py +++ b/tests/jax/test_api.py @@ -194,6 +194,50 @@ def _reg_fun(x): # check vmap grad np.testing.assert_allclose(_gradfun_vmap(x), [_gradfun(_x) for _x in x]) + elif kind == "vmap-jit-grad-random": + assert obj.__class__.tree_unflatten(*((obj.tree_flatten())[::-1])) == obj + + for key in obj._params: + if key in ["N", "n"]: + continue + + if key == "p": + cen = 0.6 + x = jnp.linspace(0.1, 0.9, 10) + else: + cen = 2.0 + x = jnp.arange(10) + 2.0 + + if key == "k": + rtol = 2e-2 + else: + rtol = 1e-7 + + def _reg_fun(p): + kwargs = {key: p} + return cls(seed=10, **kwargs)().astype(float) + + _fun = jax.jit(_reg_fun) + _gradfun = jax.jit(jax.grad(_fun)) + _fun_vmap = jax.jit(jax.vmap(_fun)) + _gradfun_vmap = jax.jit(jax.vmap(_gradfun)) + + # we can jit the object + np.testing.assert_allclose(_fun(cen), _reg_fun(cen)) + + # check derivs + eps = 1e-6 + grad = _gradfun(cen) + finite_diff = (_reg_fun(cen + eps) - _reg_fun(cen - eps)) / (2 * eps) + np.testing.assert_allclose(grad, finite_diff, rtol=rtol) + + # check vmap + np.testing.assert_allclose(_fun_vmap(x), [_reg_fun(_x) for _x in x]) + + # check vmap grad + np.testing.assert_allclose( + _gradfun_vmap(x), [_gradfun(_x) for _x in x], rtol=rtol + ) elif kind == "docs-methods": # always has gsparams if isinstance(obj, jax_galsim.GSObject): @@ -676,3 +720,30 @@ def _reg_sfun(g1): # check vmap grad np.testing.assert_allclose(_sgradfun_vmap(x), [_sgradfun(_x) for _x in x]) + + +def test_api_random(): + classes = [] + for item in sorted(dir(jax_galsim.random)): + cls = getattr(jax_galsim.random, item) + if inspect.isclass(cls) and issubclass(cls, jax_galsim.random.BaseDeviate): + classes.append(getattr(jax_galsim.random, item)) + + tested = set() + for cls in classes: + obj = cls(seed=42) + print(obj) + tested.add(cls.__name__) + _run_object_checks(obj, cls, "docs-methods") + _run_object_checks(obj, cls, "pickle-eval-repr-img") + _run_object_checks(obj, cls, "vmap-jit-grad-random") + + assert { + "UniformDeviate", + "GaussianDeviate", + "BinomialDeviate", + "PoissonDeviate", + "WeibullDeviate", + "GammaDeviate", + "Chi2Deviate", + } <= tested From e7bb102a3b9946233b52dea976126768f74061b2 Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 11 Oct 2023 08:13:37 -0500 Subject: [PATCH 20/33] TST fix tests --- jax_galsim/convolve.py | 10 +++++-- jax_galsim/transform.py | 16 +++++++----- tests/jax/test_api.py | 58 ++++++++++++++++++++++------------------- 3 files changed, 48 insertions(+), 36 deletions(-) diff --git a/jax_galsim/convolve.py b/jax_galsim/convolve.py index 5e3abd38..0a349039 100644 --- a/jax_galsim/convolve.py +++ b/jax_galsim/convolve.py @@ -367,14 +367,20 @@ def __init__(self, obj, gsparams=None, propagate_gsparams=True): # Save the original object as an attribute, so it can be inspected later if necessary. self._gsparams = GSParams.check(gsparams, obj.gsparams) - self._min_acc_kvalue = obj.flux * self.gsparams.kvalue_accuracy - self._inv_min_acc_kvalue = 1.0 / self._min_acc_kvalue self._propagate_gsparams = propagate_gsparams if self._propagate_gsparams: self._orig_obj = obj.withGSParams(self._gsparams) else: self._orig_obj = obj + @property + def _min_acc_kvalue(self): + return self._orig_obj.flux * self.gsparams.kvalue_accuracy + + @property + def _inv_min_acc_kvalue(self): + return 1.0 / self._min_acc_kvalue + @property def orig_obj(self): """The original object that is being deconvolved.""" diff --git a/jax_galsim/transform.py b/jax_galsim/transform.py index 34bd65c3..a0144018 100644 --- a/jax_galsim/transform.py +++ b/jax_galsim/transform.py @@ -54,12 +54,7 @@ def __init__( obj = obj.withGSParams(self._gsparams) self._params = { - "jac": jax.lax.cond( - jac is not None, - lambda jac: jnp.broadcast_to(jnp.array(jac, dtype=float).ravel(), (4,)), - lambda jax: jnp.array([1.0, 0.0, 0.0, 1.0]), - jac, - ), + "jac": jac, "offset": PositionD(offset), "flux_ratio": flux_ratio, } @@ -77,7 +72,14 @@ def __init__( @property def _jac(self): - return jnp.asarray(self._params["jac"], dtype=float).reshape(2, 2) + jac = self._params["jac"] + jac = jax.lax.cond( + jac is not None, + lambda jac: jnp.broadcast_to(jnp.array(jac, dtype=float).ravel(), (4,)), + lambda jax: jnp.array([1.0, 0.0, 0.0, 1.0]), + jac, + ) + return jnp.asarray(jac, dtype=float).reshape(2, 2) @property def original(self): diff --git a/tests/jax/test_api.py b/tests/jax/test_api.py index 76f42a76..35719624 100644 --- a/tests/jax/test_api.py +++ b/tests/jax/test_api.py @@ -66,13 +66,14 @@ def _attempt_init(cls, kwargs): else: raise e - try: - return cls(jax_galsim.Gaussian(**kwargs)) - except Exception as e: - if any(estr in repr(e) for estr in OK_ERRORS): - pass - else: - raise e + if cls in [jax_galsim.Convolution, jax_galsim.Deconvolution]: + try: + return cls(jax_galsim.Gaussian(**kwargs)) + except Exception as e: + if any(estr in repr(e) for estr in OK_ERRORS): + pass + else: + raise e return None @@ -133,17 +134,30 @@ def _run_object_checks(obj, cls, kind): # JAX tracing should be an identity assert cls.tree_unflatten(*((obj.tree_flatten())[::-1])) == obj - # we can jit the object - np.testing.assert_allclose(_xfun(0.3, obj), obj.xValue(x=0.3, y=-0.3)) - np.testing.assert_allclose(_kfun(0.3, obj), obj.kValue(kx=0.3, ky=-0.3).real) - - # check derivs eps = 1e-6 - grad = _xgradfun(0.3, obj) - finite_diff = ( - obj.xValue(x=0.3 + eps, y=-0.3) - obj.xValue(x=0.3 - eps, y=-0.3) - ) / (2 * eps) - np.testing.assert_allclose(grad, finite_diff) + x = jnp.linspace(-1, 1, 10) + + if cls not in [jax_galsim.Convolution, jax_galsim.Deconvolution]: + # we can jit the object + np.testing.assert_allclose(_xfun(0.3, obj), obj.xValue(x=0.3, y=-0.3)) + + # check derivs + grad = _xgradfun(0.3, obj) + finite_diff = ( + obj.xValue(x=0.3 + eps, y=-0.3) - obj.xValue(x=0.3 - eps, y=-0.3) + ) / (2 * eps) + np.testing.assert_allclose(grad, finite_diff) + + # check vmap + np.testing.assert_allclose( + _xfun_vmap(x, obj), [obj.xValue(x=_x, y=-0.3) for _x in x] + ) + # check vmap grad + np.testing.assert_allclose( + _xgradfun_vmap(x, obj), [_xgradfun(_x, obj) for _x in x] + ) + + np.testing.assert_allclose(_kfun(0.3, obj), obj.kValue(kx=0.3, ky=-0.3).real) grad = _kgradfun(0.3, obj) finite_diff = ( @@ -152,19 +166,9 @@ def _run_object_checks(obj, cls, kind): ) / (2 * eps) np.testing.assert_allclose(grad, finite_diff) - # check vmap - x = jnp.linspace(-1, 1, 10) - np.testing.assert_allclose( - _xfun_vmap(x, obj), [obj.xValue(x=_x, y=-0.3) for _x in x] - ) np.testing.assert_allclose( _kfun_vmap(x, obj), [obj.kValue(kx=_x, ky=-0.3).real for _x in x] ) - - # check vmap grad - np.testing.assert_allclose( - _xgradfun_vmap(x, obj), [_xgradfun(_x, obj) for _x in x] - ) np.testing.assert_allclose( _kgradfun_vmap(x, obj), [_kgradfun(_x, obj) for _x in x] ) From e529ef7a4cd9061db501e55a1e4a01dd8ad90145 Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 17 Oct 2023 07:50:53 -0500 Subject: [PATCH 21/33] latest changes --- tests/GalSim | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/GalSim b/tests/GalSim index 66092bdf..ca90d938 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 66092bdf7215983bab4d2d953a700eb8a0ddcbe4 +Subproject commit ca90d938a3b16450b84452720068e0b558842bbb From 62d146e36e422bef9824925b5cc04bc378d5a8a9 Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 17 Oct 2023 14:43:46 -0500 Subject: [PATCH 22/33] TST fix more tests --- jax_galsim/core/utils.py | 55 ++++++- jax_galsim/moffat.py | 78 ++------- jax_galsim/transform.py | 67 +++++--- jax_galsim/wcs.py | 96 ++++++++++- pyproject.toml | 12 ++ pytest.ini | 10 -- tests/GalSim | 2 +- tests/conftest.py | 31 +++- tests/galsim_tests_config.yaml | 2 + tests/jax/galsim/test_image_jax.py | 80 ++++----- tests/jax/galsim/test_shear_jax.py | 2 +- tests/jax/galsim/test_wcs_jax.py | 238 ++++++++++++++------------- tests/jax/test_moffat_comp_galsim.py | 40 +++++ 13 files changed, 445 insertions(+), 268 deletions(-) delete mode 100644 pytest.ini create mode 100644 tests/jax/test_moffat_comp_galsim.py diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index f2e82df4..872912ec 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -1,10 +1,22 @@ +from functools import partial + import jax +def convert_to_float(x): + if isinstance(x, jax.Array): + if x.shape == (): + return x.item() + else: + return x[0].astype(float).item() + else: + return float(x) + + def cast_scalar_to_float(x): """Cast the input to a float. Works on python floats and jax arrays.""" - if isinstance(x, float): - return float(x) + if isinstance(x, jax.Array): + return x.astype(float) elif hasattr(x, "astype"): return x.astype(float) else: @@ -51,3 +63,42 @@ def ensure_hashable(v): return v else: return v + + +@partial(jax.jit, static_argnames=("niter",)) +def bisect_for_root(func, low, high, niter=75): + def _func(i, args): + func, low, flow, high, fhigh = args + mid = (low + high) / 2.0 + fmid = func(mid) + return jax.lax.cond( + fmid * fhigh < 0, + lambda func, low, flow, mid, fmid, high, fhigh: ( + func, + mid, + fmid, + high, + fhigh, + ), + lambda func, low, flow, mid, fmid, high, fhigh: ( + func, + low, + flow, + mid, + fmid, + ), + func, + low, + flow, + mid, + fmid, + high, + fhigh, + ) + + low = 0.0 + high = 1e5 + flow = func(low) + fhigh = func(high) + args = (func, low, flow, high, fhigh) + return jax.lax.fori_loop(0, niter, _func, args)[-2] diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index 09da2579..f9ec4839 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -1,7 +1,6 @@ import galsim as _galsim import jax import jax.numpy as jnp -import jax.scipy as jsc import tensorflow_probability as tfp from jax._src.numpy.util import _wraps from jax.tree_util import Partial as partial @@ -10,8 +9,9 @@ from jax_galsim.core.bessel import j0 from jax_galsim.core.draw import draw_by_kValue, draw_by_xValue from jax_galsim.core.integrate import ClenshawCurtisQuad, quad_integral -from jax_galsim.core.utils import ensure_hashable +from jax_galsim.core.utils import bisect_for_root, ensure_hashable from jax_galsim.gsobject import GSObject +from jax_galsim.position import PositionD @jax.jit @@ -269,75 +269,21 @@ def __str__(self): s += ")" return s - @property - def _maxk_untrunc(self): - """untruncated Moffat maxK - - The 2D Fourier Transform of f(r)=C (1+(r/rd)^2)^(-beta) leads - C rd^2 = Flux (beta-1)/pi (no truc) - and - f(k) = C rd^2 int_0^infty (1+x^2)^(-beta) J_0(krd x) x dx - = 2 F (k rd /2)^(\beta-1) K[beta-1, k rd]/Gamma[beta-1] - with k->infty asymptotic behavior - f(k)/F \approx sqrt(pi)/Gamma(beta-1) e^(-k') (k'/2)^(beta -3/2) with k' = k rd - So we solve f(maxk)/F = thr (aka maxk_threshold in gsparams.py) - leading to the iterative search of - let alpha = -log(thr Gamma(beta-1)/sqrt(pi)) - k = (\beta -3/2)log(k/2) + alpha - starting with k = alpha - - note : in the code "alternative code" is related to issue #1208 in GalSim github - """ - - def body(i, val): - kcur, alpha = val - knew = (self.beta - 0.5) * jnp.log(kcur) + alpha - # knew = (self.beta -1.5)* jnp.log(kcur/2) + alpha # alternative code - return knew, alpha - - # alpha = -jnp.log(self.gsparams.maxk_threshold - # * jnp.exp(jsc.special.gammaln(self._beta-1))/jnp.sqrt(jnp.pi) ) # alternative code - - alpha = -jnp.log( - self.gsparams.maxk_threshold - * jnp.power(2.0, self.beta - 0.5) - * jnp.exp(jsc.special.gammaln(self.beta - 1)) - / (2 * jnp.sqrt(jnp.pi)) - ) - - val_init = ( - alpha, - alpha, - ) - val = jax.lax.fori_loop(0, 5, body, val_init) - maxk, alpha = val - return maxk / self._r0 - @property def _prefactor(self): return 2.0 * (self.beta - 1.0) / (self._fluxFactor) - @property - def _maxk_trunc(self): - """truncated Moffat maxK""" - # a for gaussian profile... this is f(k_max)/Flux = maxk_threshold - maxk_val = self.gsparams.maxk_threshold - dk = self.gsparams.table_spacing * jnp.sqrt( - jnp.sqrt(self.gsparams.kvalue_accuracy / 10.0) + @jax.jit + def _maxk_func(self, k): + return ( + jnp.abs(self._kValue(PositionD(x=k, y=0)).real / self.flux) + - self.gsparams.maxk_threshold ) - # 50 is a max (GalSim) but it may be lowered if necessary - ki = jnp.arange(0.0, 50.0, dk) - quad = ClenshawCurtisQuad.init(150) - g = partial(_xMoffatIntegrant, beta=self.beta, rmax=self._maxRrD, quad=quad) - fki_1 = jax.jit(jax.vmap(g))(ki) - fki = fki_1 * self._prefactor - cond = jnp.abs(fki) > maxk_val - maxk = ki[cond][-1] - return maxk / self._r0 @property + @jax.jit def _maxk(self): - return jax.lax.select(self.trunc > 0, self._maxk_trunc, self._maxk_untrunc) + return bisect_for_root(partial(self._maxk_func), 0.0, 1e5, niter=75) @property def _stepk_lowbeta(self): @@ -353,12 +299,10 @@ def _stepk_highbeta(self): jnp.power(self.gsparams.folding_threshold, 0.5 / (1.0 - self.beta)) * self._r0 ) - if R > self._maxR: - R = self._maxR + R = jnp.minimum(R, self._maxR) # at least R should be 5 HLR R5hlr = self.gsparams.stepk_minimum_hlr * self.half_light_radius - if R < R5hlr: - R = R5hlr + R = jnp.maximum(R, R5hlr) return jnp.pi / R @property diff --git a/jax_galsim/transform.py b/jax_galsim/transform.py index 4dc2eac8..a0144018 100644 --- a/jax_galsim/transform.py +++ b/jax_galsim/transform.py @@ -1,4 +1,5 @@ import galsim as _galsim +import jax import jax.numpy as jnp from jax._src.numpy.util import _wraps from jax.tree_util import register_pytree_node_class @@ -47,34 +48,38 @@ def __init__( gsparams=None, propagate_gsparams=True, ): - self._offset = PositionD(offset) - self._flux_ratio = flux_ratio self._gsparams = GSParams.check(gsparams, obj.gsparams) self._propagate_gsparams = propagate_gsparams if self._propagate_gsparams: obj = obj.withGSParams(self._gsparams) self._params = { - "obj": obj, "jac": jac, - "offset": self._offset, - "flux_ratio": self._flux_ratio, + "offset": PositionD(offset), + "flux_ratio": flux_ratio, } if isinstance(obj, Transformation): # Combine the two affine transformations into one. dx, dy = self._fwd(obj.offset.x, obj.offset.y) - self._offset.x += dx - self._offset.y += dy + self._params["offset"].x += dx + self._params["offset"].y += dy self._params["jac"] = self._jac.dot(obj.jac) - self._flux_ratio *= obj._flux_ratio + self._params["flux_ratio"] *= obj._params["flux_ratio"] self._original = obj.original else: self._original = obj @property def _jac(self): - return jnp.asarray(self._params["jac"], dtype=float).reshape(2, 2) + jac = self._params["jac"] + jac = jax.lax.cond( + jac is not None, + lambda jac: jnp.broadcast_to(jnp.array(jac, dtype=float).ravel(), (4,)), + lambda jax: jnp.array([1.0, 0.0, 0.0, 1.0]), + jac, + ) + return jnp.asarray(jac, dtype=float).reshape(2, 2) @property def original(self): @@ -89,17 +94,21 @@ def jac(self): @property def offset(self): """The offset of the transformation.""" - return self._offset + return self._params["offset"] @property def flux_ratio(self): """The flux ratio of the transformation.""" - return self._flux_ratio + return self._params["flux_ratio"] @property def _flux(self): return self._flux_scaling * self._original.flux + @property + def _offset(self): + return self._params["offset"] + def withGSParams(self, gsparams=None, **kwargs): """Create a version of the current object with the given gsparams @@ -110,13 +119,14 @@ def withGSParams(self, gsparams=None, **kwargs): """ if gsparams == self.gsparams: return self - from copy import copy - ret = copy(self) - ret._gsparams = GSParams.check(gsparams, self.gsparams, **kwargs) + chld, aux = self.tree_flatten() + aux["gsparams"] = GSParams.check(gsparams, self.gsparams, **kwargs) if self._propagate_gsparams: - ret._original = self._original.withGSParams(ret._gsparams) - return ret + new_obj = chld[0].withGSParams(aux["gsparams"]) + chld = (new_obj,) + chld[1:] + + return self.tree_unflatten(aux, chld) def __eq__(self, other): return self is other or ( @@ -149,7 +159,7 @@ def __repr__(self): "propagate_gsparams=%r)" ) % ( self.original, - ensure_hashable(self._jac), + ensure_hashable(self._jac.ravel()), self.offset, ensure_hashable(self.flux_ratio), self.gsparams, @@ -221,11 +231,11 @@ def _invjac(self): # than flux_ratio, which is really an amplitude scaling. @property def _amp_scaling(self): - return self._flux_ratio + return self._params["flux_ratio"] @property def _flux_scaling(self): - return jnp.abs(self._det) * self._flux_ratio + return jnp.abs(self._det) * self._params["flux_ratio"] def _fwd(self, x, y): res = jnp.dot(self._jac, jnp.array([x, y])) @@ -240,8 +250,8 @@ def _inv(self, x, y): return res[0], res[1] def _kfactor(self, kx, ky): - kx *= -1j * self._offset.x - ky *= -1j * self._offset.y + kx *= -1j * self.offset.x + ky *= -1j * self.offset.y kx += ky return self._flux_scaling * jnp.exp(kx) @@ -269,7 +279,7 @@ def _stepk(self): # stepk = Pi/R # R <- R + |shift| # stepk <- Pi/(Pi/stepk + |shift|) - dr = jnp.hypot(self._offset.x, self._offset.y) + dr = jnp.hypot(self.offset.x, self.offset.y) stepk = jnp.pi / (jnp.pi / stepk + dr) return stepk @@ -283,7 +293,7 @@ def _is_axisymmetric(self): self._original.is_axisymmetric and self._jac[0, 0] == self._jac[1, 1] and self._jac[0, 1] == -self._jac[1, 0] - and self._offset == PositionD(0.0, 0.0) + and self.offset == PositionD(0.0, 0.0) ) @property @@ -314,7 +324,7 @@ def _max_sb(self): return self._amp_scaling * self._original.max_sb def _xValue(self, pos): - pos -= self._offset + pos -= self.offset inv_pos = PositionD(self._inv(pos.x, pos.y)) return self._original._xValue(inv_pos) * self._amp_scaling @@ -360,10 +370,10 @@ def _drawKImage(self, image, jac=None): return image def tree_flatten(self): - """This function flattens the GSObject into a list of children + """This function flattens the Transform into a list of children nodes that will be traced by JAX and auxiliary static data.""" # Define the children nodes of the PyTree that need tracing - children = (self.params,) + children = (self._original, self._params) # Define auxiliary static data that doesn’t need to be traced aux_data = { "gsparams": self.gsparams, @@ -371,6 +381,11 @@ def tree_flatten(self): } return (children, aux_data) + @classmethod + def tree_unflatten(cls, aux_data, children): + """Recreates an instance of the class from flatten representation""" + return cls(children[0], **(children[1]), **aux_data) + def _Transform( obj, diff --git a/jax_galsim/wcs.py b/jax_galsim/wcs.py index cb7320fa..3cebeeb0 100644 --- a/jax_galsim/wcs.py +++ b/jax_galsim/wcs.py @@ -3,7 +3,7 @@ from jax._src.numpy.util import _wraps from jax.tree_util import register_pytree_node_class -from jax_galsim.core.utils import ensure_hashable +from jax_galsim.core.utils import convert_to_float, ensure_hashable from jax_galsim.gsobject import GSObject from jax_galsim.position import Position, PositionD, PositionI from jax_galsim.shear import Shear @@ -18,6 +18,8 @@ def toWorld(self, *args, **kwargs): if len(args) == 1: if isinstance(args[0], GSObject): return self.profileToWorld(*args, **kwargs) + elif isinstance(args[0], Shear): + return self.shearToWorld(*args, **kwargs) else: return self.posToWorld(*args, **kwargs) elif len(args) == 2: @@ -52,11 +54,19 @@ def profileToWorld( image_profile, flux_ratio, PositionD(offset) ) + @_wraps(_galsim.BaseWCS.shearToWorld) + def shearToWorld(self, image_shear, image_pos=None, world_pos=None, color=None): + if color is None: + color = self._color + return self.local(image_pos, world_pos, color=color)._shearToWorld(image_shear) + @_wraps(_galsim.BaseWCS.toImage) def toImage(self, *args, **kwargs): if len(args) == 1: if isinstance(args[0], GSObject): return self.profileToImage(*args, **kwargs) + elif isinstance(args[0], Shear): + return self.shearToImage(*args, **kwargs) else: return self.posToImage(*args, **kwargs) elif len(args) == 2: @@ -94,6 +104,12 @@ def profileToImage( world_profile, flux_ratio, PositionD(offset) ) + @_wraps(_galsim.BaseWCS.shearToImage) + def shearToImage(self, world_shear, image_pos=None, world_pos=None, color=None): + if color is None: + color = self._color + return self.local(image_pos, world_pos, color=color)._shearToImage(world_shear) + @_wraps(_galsim.BaseWCS.local) def local(self, image_pos=None, world_pos=None, color=None): if color is None: @@ -622,6 +638,13 @@ def _profileToImage(self, world_profile, flux_ratio, offset): offset=offset, ) + def _shearToWorld(self, image_shear): + # These are trivial for PixelScale. + return image_shear + + def _shearToImage(self, world_shear): + return world_shear + def _pixelArea(self): return self._scale**2 @@ -728,6 +751,13 @@ def _profileToImage(self, world_profile, flux_ratio, offset): * flux_ratio ) + def _shearToWorld(self, image_shear): + # This isn't worth customizing. Just use the jacobian. + return self._toJacobian()._shearToWorld(image_shear) + + def _shearToImage(self, world_shear): + return self._toJacobian()._shearToImage(world_shear) + def _pixelArea(self): return self._scale**2 @@ -752,6 +782,13 @@ def _toJacobian(self): def _newOrigin(self, origin, world_origin): return OffsetShearWCS(self._scale, self._shear, origin, world_origin) + def _writeHeader(self, header, bounds): + header["GS_WCS"] = ("ShearWCS", "GalSim WCS name") + header["GS_SCALE"] = (self.scale, "GalSim image scale") + header["GS_G1"] = (self.shear.g1, "GalSim image shear g1") + header["GS_G2"] = (self.shear.g2, "GalSim image shear g2") + return self.affine()._writeLinearWCS(header, bounds) + def copy(self): return ShearWCS(self._scale, self._shear) @@ -846,6 +883,24 @@ def _profileToImage(self, world_profile, flux_ratio, offset): offset=offset, ) + def _shearToWorld(self, image_shear): + # Code from https://github.com/rmjarvis/DESWL/blob/y3a1-v23/psf/run_piff.py#L691 + e1 = image_shear.e1 + e2 = image_shear.e2 + + M = jnp.array([[1 + e1, e2], [e2, 1 - e1]]) + J = self.getMatrix() + M = J.dot(M).dot(J.T) + + e1 = (M[0, 0] - M[1, 1]) / (M[0, 0] + M[1, 1]) + e2 = (2.0 * M[0, 1]) / (M[0, 0] + M[1, 1]) + + return Shear(e1=e1, e2=e2) + + def _shearToImage(self, world_shear): + # Same as above but inverse J matrix. + return self._inverse()._shearToWorld(world_shear) + def _pixelArea(self): return abs(self._det) @@ -1096,6 +1151,17 @@ def world_origin(self): def _newOrigin(self, origin, world_origin): return OffsetShearWCS(self.scale, self.shear, origin, world_origin) + def _writeHeader(self, header, bounds): + header["GS_WCS"] = ("OffsetShearWCS", "GalSim WCS name") + header["GS_SCALE"] = (self.scale, "GalSim image scale") + header["GS_G1"] = (self.shear.g1, "GalSim image shear g1") + header["GS_G2"] = (self.shear.g2, "GalSim image shear g2") + header["GS_X0"] = (self.origin.x, "GalSim image origin x coordinate") + header["GS_Y0"] = (self.origin.y, "GalSim image origin y coordinate") + header["GS_U0"] = (self.world_origin.x, "GalSim world origin u coordinate") + header["GS_V0"] = (self.world_origin.y, "GalSim world origin v coordinate") + return self.affine()._writeLinearWCS(header, bounds) + def copy(self): return OffsetShearWCS(self.scale, self.shear, self.origin, self.world_origin) @@ -1173,14 +1239,26 @@ def _writeHeader(self, header, bounds): def _writeLinearWCS(self, header, bounds): header["CTYPE1"] = ("LINEAR", "name of the world coordinate axis") header["CTYPE2"] = ("LINEAR", "name of the world coordinate axis") - header["CRVAL1"] = (self.u0, "world coordinate at reference pixel = u0") - header["CRVAL2"] = (self.v0, "world coordinate at reference pixel = v0") - header["CRPIX1"] = (self.x0, "image coordinate of reference pixel = x0") - header["CRPIX2"] = (self.y0, "image coordinate of reference pixel = y0") - header["CD1_1"] = (self.dudx, "CD1_1 = dudx") - header["CD1_2"] = (self.dudy, "CD1_2 = dudy") - header["CD2_1"] = (self.dvdx, "CD2_1 = dvdx") - header["CD2_2"] = (self.dvdy, "CD2_2 = dvdy") + header["CRVAL1"] = ( + convert_to_float(self.u0), + "world coordinate at reference pixel = u0", + ) + header["CRVAL2"] = ( + convert_to_float(self.v0), + "world coordinate at reference pixel = v0", + ) + header["CRPIX1"] = ( + convert_to_float(self.x0), + "image coordinate of reference pixel = x0", + ) + header["CRPIX2"] = ( + convert_to_float(self.y0), + "image coordinate of reference pixel = y0", + ) + header["CD1_1"] = (convert_to_float(self.dudx), "CD1_1 = dudx") + header["CD1_2"] = (convert_to_float(self.dudy), "CD1_2 = dudy") + header["CD2_1"] = (convert_to_float(self.dvdx), "CD2_1 = dvdx") + header["CD2_2"] = (convert_to_float(self.dvdy), "CD2_2 = dvdy") return header @staticmethod diff --git a/pyproject.toml b/pyproject.toml index c6fe79ab..3f456117 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,3 +21,15 @@ skip = [ "tests/Galsim/", "tests/Coord/", ] + +[tool.pytest.ini_options] +minversion = "6.0" +addopts = "-ra -q" +testpaths = [ + "tests/GalSim/tests/", + "tests/jax", + "tests/Coord/tests/", +] +filterwarnings = [ + "ignore::DeprecationWarning", +] diff --git a/pytest.ini b/pytest.ini deleted file mode 100644 index d85bef3a..00000000 --- a/pytest.ini +++ /dev/null @@ -1,10 +0,0 @@ -# pytest.ini -[pytest] -minversion = 6.0 -addopts = -ra -q -testpaths = - tests/GalSim/tests/ - tests/jax - tests/Coord/tests/ -filterwarnings = - ignore::DeprecationWarning diff --git a/tests/GalSim b/tests/GalSim index ca90d938..81509041 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit ca90d938a3b16450b84452720068e0b558842bbb +Subproject commit 815090419d343d0e840bbc53e79c7bc4469ec79d diff --git a/tests/conftest.py b/tests/conftest.py index 8095105e..dfdc96cc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -68,6 +68,19 @@ def _infile(val, fname): return False +def _convert_galsim_to_jax_galsim(obj): + import galsim as _galsim # noqa: F401 + from numpy import array # noqa: F401 + + import jax_galsim as galsim # noqa: F401 + + if isinstance(obj, _galsim.GSObject): + ret_obj = eval(repr(obj)) + return ret_obj + else: + return obj + + def pytest_pycollect_makemodule(module_path, path, parent): """This hook is tasked with overriding the galsim import at the top of each test file. Replaces it by jax-galsim. @@ -91,7 +104,10 @@ def pytest_pycollect_makemodule(module_path, path, parent): if ( callable(v) and hasattr(v, "__globals__") - and inspect.getsourcefile(v).endswith("galsim_test_helpers.py") + and ( + inspect.getsourcefile(v).endswith("galsim_test_helpers.py") + or inspect.getsourcefile(v).endswith("galsim/utilities.py") + ) and _infile("def " + k, inspect.getsourcefile(v)) and "galsim" in v.__globals__ ): @@ -111,6 +127,19 @@ def pytest_pycollect_makemodule(module_path, path, parent): v.__globals__["coord"] = __import__("jax_galsim") v.__globals__["galsim"] = __import__("jax_galsim") + # the galsim WCS tests have some items that are galsim objects that need conversions + # to jax_galsim objects + if module.name.endswith("tests/GalSim/tests/test_wcs.py"): + for k, v in module.obj.__dict__.items(): + if isinstance(v, __import__("galsim").GSObject): + module.obj.__dict__[k] = _convert_galsim_to_jax_galsim(v) + elif isinstance(v, list): + module.obj.__dict__[k] = [ + _convert_galsim_to_jax_galsim(obj) for obj in v + ] + + module.obj._convert_galsim_to_jax_galsim = _convert_galsim_to_jax_galsim + return module diff --git a/tests/galsim_tests_config.yaml b/tests/galsim_tests_config.yaml index dd761691..25ab8225 100644 --- a/tests/galsim_tests_config.yaml +++ b/tests/galsim_tests_config.yaml @@ -58,3 +58,5 @@ allowed_failures: - "TypeError not raised by __mul__" - "ValueError not raised by CelestialCoord" - "module 'jax_galsim' has no attribute 'TanWCS'" + - "'Image' object has no attribute 'FindAdaptiveMom'" + - " module 'jax_galsim' has no attribute 'BaseCorrelatedNoise'" diff --git a/tests/jax/galsim/test_image_jax.py b/tests/jax/galsim/test_image_jax.py index 945d2da2..c14e87ca 100644 --- a/tests/jax/galsim/test_image_jax.py +++ b/tests/jax/galsim/test_image_jax.py @@ -509,22 +509,22 @@ def test_Image_basic(): # ------------------------- # We will not be doing pickles # Check picklability - # do_pickle(im1) - # do_pickle(im1_view) - # do_pickle(im2) - # do_pickle(im2_view) - # do_pickle(im2_cview) - # do_pickle(im3_view) - # do_pickle(im4_view) + # check_pickle(im1) + # check_pickle(im1_view) + # check_pickle(im2) + # check_pickle(im2_view) + # check_pickle(im2_cview) + # check_pickle(im3_view) + # check_pickle(im4_view) # JAX specific modification # ------------------------- # We will not be doing pickles # Also check picklability of Bounds, Position here. - # do_pickle(galsim.PositionI(2,3)) - # do_pickle(galsim.PositionD(2.2,3.3)) - # do_pickle(galsim.BoundsI(2,3,7,8)) - # do_pickle(galsim.BoundsD(2.1, 4.3, 6.5, 9.1)) + # check_pickle(galsim.PositionI(2,3)) + # check_pickle(galsim.PositionD(2.2,3.3)) + # check_pickle(galsim.BoundsI(2,3,7,8)) + # check_pickle(galsim.BoundsD(2.1, 4.3, 6.5, 9.1)) @timer @@ -632,10 +632,10 @@ def test_undefined_image(): # JAX specific modification # ------------------------- # We will not be doing pickles - # do_pickle(im1.bounds) - # do_pickle(im1) - # do_pickle(im1.view()) - # do_pickle(im1.view(make_const=True)) + # check_pickle(im1.bounds) + # check_pickle(im1) + # check_pickle(im1.view()) + # check_pickle(im1.view(make_const=True)) @timer @@ -2908,7 +2908,7 @@ def test_Image_subImage(): # JAX specific modification # ------------------------- # We won't do any pickling - # do_pickle(image) + # check_pickle(image) assert_raises(TypeError, image.subImage, bounds=None) assert_raises(TypeError, image.subImage, bounds=galsim.BoundsD(0, 4, 0, 4)) @@ -3035,9 +3035,9 @@ def test_Image_resize(): im3_full.array, 23, err_msg="im3_full changed" ) - do_pickle(im1) - do_pickle(im2) - do_pickle(im3) + check_pickle(im1) + check_pickle(im2) + check_pickle(im3) assert_raises(TypeError, im1.resize, bounds=None) assert_raises(TypeError, im1.resize, bounds=galsim.BoundsD(0, 5, 0, 5)) @@ -3083,7 +3083,7 @@ def test_Image_resize(): # assert_raises(galsim.GalSimImmutableError, image.setZero) # assert_raises(galsim.GalSimImmutableError, image.invertSelf) -# do_pickle(image) +# check_pickle(image) @timer @@ -3164,7 +3164,7 @@ def test_Image_constructor(): # JAX specific modification # ------------------------- # No picklibng for JAX images - # do_pickle(test_im) + # check_pickle(test_im) # Check that some invalid sets of construction args raise the appropriate errors # Invalid args @@ -3246,7 +3246,7 @@ def test_Image_view(): # JAX specific modification # ------------------------- # No picklibng for JAX images - # do_pickle(im) + # check_pickle(im) # Test view with no arguments imv = im.view() @@ -3262,8 +3262,8 @@ def test_Image_view(): # JAX specific modification # ------------------------- # No picklibng for JAX images - # do_pickle(im) - # do_pickle(imv) + # check_pickle(im) + # check_pickle(imv) # Test view with new origin imv = im.view(origin=(0, 0)) @@ -3288,8 +3288,8 @@ def test_Image_view(): # JAX specific modification # ------------------------- # No picklibng for JAX images - # do_pickle(imv) - # do_pickle(imv2) + # check_pickle(imv) + # check_pickle(imv2) # Test view with new center imv = im.view(center=(0, 0)) @@ -3316,8 +3316,8 @@ def test_Image_view(): # JAX specific modification # ------------------------- # No picklibng for JAX images - # do_pickle(imv) - # do_pickle(imv2) + # check_pickle(imv) + # check_pickle(imv2) # Test view with new scale imv = im.view(scale=0.17) @@ -3342,8 +3342,8 @@ def test_Image_view(): # JAX specific modification # ------------------------- # No picklibng for JAX images - # do_pickle(imv) - # do_pickle(imv2) + # check_pickle(imv) + # check_pickle(imv2) # Test view with new wcs imv = im.view(wcs=galsim.JacobianWCS(0.0, 0.23, -0.23, 0.0)) @@ -3365,8 +3365,8 @@ def test_Image_view(): # JAX specific modification # ------------------------- # No picklibng for JAX images - # do_pickle(imv) - # do_pickle(imv2) + # check_pickle(imv) + # check_pickle(imv2) # Go back to original value for that pixel and make sure all are still equal to 17 im.setValue(11, 19, 17) @@ -3427,7 +3427,7 @@ def test_ne(): galsim.ImageD(array1, wcs=galsim.PixelScale(0.2)), galsim.ImageD(array1, xmin=2), ] - all_obj_diff(objs) + check_all_diff(objs) @timer @@ -3653,13 +3653,13 @@ def test_complex_image(): # ------------------------- # No picklibng for JAX images # Check picklability - # do_pickle(im1) - # do_pickle(im1_view) - # do_pickle(im1_cview) - # do_pickle(im2) - # do_pickle(im2_view) - # do_pickle(im3_view) - # do_pickle(im4_view) + # check_pickle(im1) + # check_pickle(im1_view) + # check_pickle(im1_cview) + # check_pickle(im2) + # check_pickle(im2_view) + # check_pickle(im3_view) + # check_pickle(im4_view) @timer diff --git a/tests/jax/galsim/test_shear_jax.py b/tests/jax/galsim/test_shear_jax.py index 55a4fde1..890baea7 100644 --- a/tests/jax/galsim/test_shear_jax.py +++ b/tests/jax/galsim/test_shear_jax.py @@ -242,7 +242,7 @@ def test_shear_initialization(): # JAX specific modification # ------------------------- # We don't allow jax objects to be pickled. - # do_pickle(s) + # check_pickle(s) # finally check some examples of invalid initializations for Shear assert_raises(TypeError, galsim.Shear, 0.3) diff --git a/tests/jax/galsim/test_wcs_jax.py b/tests/jax/galsim/test_wcs_jax.py index 49f6d8d4..f2b7e791 100644 --- a/tests/jax/galsim/test_wcs_jax.py +++ b/tests/jax/galsim/test_wcs_jax.py @@ -5,9 +5,17 @@ import time import warnings -import galsim import numpy as np -from galsim_test_helpers import * +from galsim_test_helpers import ( + Profile, + assert_raises, + assert_warns, + check_pickle, + gsobject_compare, + timer, +) + +import jax_galsim as galsim # These positions will be used a few times below, so define them here. # One of the tests requires that the last pair are integers, so don't change that. @@ -476,7 +484,9 @@ def do_wcs_image(wcs, name, approx=False): # Use the "blank" image as our test image. It's not blank in the sense of having all # zeros. Rather, there are basically random values that we can use to test that # the shifted values are correct. And it is a conveniently small-ish, non-square image. - dir = "fits_files" + dir = os.path.join( + os.path.dirname(__file__), "..", "..", "GalSim", "tests", "fits_files" + ) file_name = "blankimg.fits" im = galsim.fits.read(file_name, dir=dir) np.testing.assert_equal(im.origin.x, 1, "initial origin is not 1,1 as expected") @@ -801,7 +811,7 @@ def do_local_wcs(wcs, ufunc, vfunc, name): do_wcs_pos(wcs, ufunc, vfunc, name) # Check picklability - do_pickle(wcs) + check_pickle(wcs) # Test the transformation of a GSObject # These only work for local WCS projections! @@ -910,7 +920,7 @@ def do_jac_decomp(wcs, name): M = scale * S.dot(R).dot(F) J = wcs.getMatrix() - np.testing.assert_almost_equal( + np.testing.assert_array_almost_equal( M, J, 8, "Decomposition was inconsistent with jacobian for " + name ) @@ -964,7 +974,6 @@ def do_nonlocal_wcs(wcs, ufunc, vfunc, name, test_pickle=True, color=None): wcs4 = wcs.local(wcs.origin, color=color) assert wcs != wcs4, name + " is not != wcs.local()" assert wcs4 != wcs, name + " is not != wcs.local() (reverse)" - world_origin = wcs.toWorld(wcs.origin, color=color) if wcs.isUniform(): if wcs.world_origin == galsim.PositionD(0, 0): wcs2 = wcs.local(wcs.origin, color=color).withOrigin(wcs.origin) @@ -1014,7 +1023,7 @@ def do_nonlocal_wcs(wcs, ufunc, vfunc, name, test_pickle=True, color=None): # Check picklability if test_pickle: - do_pickle(wcs) + check_pickle(wcs) # The GSObject transformation tests are only valid for a local WCS. # But it should work for wcs.local() @@ -1028,8 +1037,8 @@ def do_nonlocal_wcs(wcs, ufunc, vfunc, name, test_pickle=True, color=None): full_im2 = galsim.Image(galsim.BoundsI(-1023, 1024, -1023, 1024), scale=1.0) for x0, y0, u0, v0 in zip(far_x_list, far_y_list, far_u_list, far_v_list): - local_ufunc = lambda x, y: ufunc(x + x0, y + y0) - u0 - local_vfunc = lambda x, y: vfunc(x + x0, y + y0) - v0 + local_ufunc = lambda x, y: ufunc(x + x0, y + y0) - u0 # noqa: E731 + local_vfunc = lambda x, y: vfunc(x + x0, y + y0) - v0 # noqa: E731 image_pos = galsim.PositionD(x0, y0) world_pos = galsim.PositionD(u0, v0) do_wcs_pos( @@ -1204,8 +1213,6 @@ def do_celestial_wcs(wcs, name, test_pickle=True, approx=False): "shiftOrigin(new_origin) returned wrong world position", ) - world_origin = wcs.toWorld(wcs.origin) - full_im1 = galsim.Image(galsim.BoundsI(-1023, 1024, -1023, 1024), wcs=wcs) full_im2 = galsim.Image(galsim.BoundsI(-1023, 1024, -1023, 1024), scale=1.0) @@ -1223,7 +1230,7 @@ def do_celestial_wcs(wcs, name, test_pickle=True, approx=False): # Check picklability if test_pickle: - do_pickle(wcs) + check_pickle(wcs) near_ra_list = [] near_dec_list = [] @@ -1521,8 +1528,8 @@ def test_pixelscale(): # assert_raises(TypeError, galsim.PixelScale, scale=scale, origin=galsim.PositionD(0, 0)) # assert_raises(TypeError, galsim.PixelScale, scale=scale, world_origin=galsim.PositionD(0, 0)) - ufunc = lambda x, y: x * scale - vfunc = lambda x, y: y * scale + ufunc = lambda x, y: x * scale # noqa: E731 + vfunc = lambda x, y: y * scale # noqa: E731 # Do generic tests that apply to all WCS types do_local_wcs(wcs, ufunc, vfunc, "PixelScale") @@ -1593,8 +1600,8 @@ def test_pixelscale(): assert wcs != wcs3b, "OffsetWCS is not != a different one (origin)" assert wcs != wcs3c, "OffsetWCS is not != a different one (world_origin)" - ufunc = lambda x, y: scale * (x - x0) - vfunc = lambda x, y: scale * (y - y0) + ufunc = lambda x, y: scale * (x - x0) # noqa: E731 + vfunc = lambda x, y: scale * (y - y0) # noqa: E731 do_nonlocal_wcs(wcs, ufunc, vfunc, "OffsetWCS 1") # Add a world origin offset @@ -1602,8 +1609,8 @@ def test_pixelscale(): v0 = -141.9 world_origin = galsim.PositionD(u0, v0) wcs = galsim.OffsetWCS(scale, world_origin=world_origin) - ufunc = lambda x, y: scale * x + u0 - vfunc = lambda x, y: scale * y + v0 + ufunc = lambda x, y: scale * x + u0 # noqa: E731 + vfunc = lambda x, y: scale * y + v0 # noqa: E731 do_nonlocal_wcs(wcs, ufunc, vfunc, "OffsetWCS 2") # Add both kinds of offsets @@ -1614,8 +1621,8 @@ def test_pixelscale(): origin = galsim.PositionD(x0, y0) world_origin = galsim.PositionD(u0, v0) wcs = galsim.OffsetWCS(scale, origin=origin, world_origin=world_origin) - ufunc = lambda x, y: scale * (x - x0) + u0 - vfunc = lambda x, y: scale * (y - y0) + v0 + ufunc = lambda x, y: scale * (x - x0) + u0 # noqa: E731 + vfunc = lambda x, y: scale * (y - y0) + v0 # noqa: E731 do_nonlocal_wcs(wcs, ufunc, vfunc, "OffsetWCS 3") # Check that using a wcs in the context of an image works correctly @@ -1657,8 +1664,8 @@ def test_shearwcs(): assert wcs != wcs3b, "ShearWCS is not != a different one (shear)" factor = 1.0 / np.sqrt(1.0 - g1 * g1 - g2 * g2) - ufunc = lambda x, y: (x - g1 * x - g2 * y) * scale * factor - vfunc = lambda x, y: (y + g1 * y - g2 * x) * scale * factor + ufunc = lambda x, y: (x - g1 * x - g2 * y) * scale * factor # noqa: E731 + vfunc = lambda x, y: (y + g1 * y - g2 * x) * scale * factor # noqa: E731 # Do generic tests that apply to all WCS types do_local_wcs(wcs, ufunc, vfunc, "ShearWCS") @@ -1743,8 +1750,12 @@ def test_shearwcs(): assert wcs != wcs3c, "OffsetShearWCS is not != a different one (origin)" assert wcs != wcs3d, "OffsetShearWCS is not != a different one (world_origin)" - ufunc = lambda x, y: ((1 - g1) * (x - x0) - g2 * (y - y0)) * scale * factor - vfunc = lambda x, y: ((1 + g1) * (y - y0) - g2 * (x - x0)) * scale * factor + ufunc = ( + lambda x, y: ((1 - g1) * (x - x0) - g2 * (y - y0)) * scale * factor + ) # noqa: E731 + vfunc = ( + lambda x, y: ((1 + g1) * (y - y0) - g2 * (x - x0)) * scale * factor + ) # noqa: E731 do_nonlocal_wcs(wcs, ufunc, vfunc, "OffsetShearWCS 1") # Add a world origin offset @@ -1752,8 +1763,8 @@ def test_shearwcs(): v0 = -141.9 world_origin = galsim.PositionD(u0, v0) wcs = galsim.OffsetShearWCS(scale, shear, world_origin=world_origin) - ufunc = lambda x, y: ((1 - g1) * x - g2 * y) * scale * factor + u0 - vfunc = lambda x, y: ((1 + g1) * y - g2 * x) * scale * factor + v0 + ufunc = lambda x, y: ((1 - g1) * x - g2 * y) * scale * factor + u0 # noqa: E731 + vfunc = lambda x, y: ((1 + g1) * y - g2 * x) * scale * factor + v0 # noqa: E731 do_nonlocal_wcs(wcs, ufunc, vfunc, "OffsetShearWCS 2") # Add both kinds of offsets @@ -1764,8 +1775,12 @@ def test_shearwcs(): origin = galsim.PositionD(x0, y0) world_origin = galsim.PositionD(u0, v0) wcs = galsim.OffsetShearWCS(scale, shear, origin=origin, world_origin=world_origin) - ufunc = lambda x, y: ((1 - g1) * (x - x0) - g2 * (y - y0)) * scale * factor + u0 - vfunc = lambda x, y: ((1 + g1) * (y - y0) - g2 * (x - x0)) * scale * factor + v0 + ufunc = ( + lambda x, y: ((1 - g1) * (x - x0) - g2 * (y - y0)) * scale * factor + u0 + ) # noqa: E731 + vfunc = ( + lambda x, y: ((1 + g1) * (y - y0) - g2 * (x - x0)) * scale * factor + v0 + ) # noqa: E731 do_nonlocal_wcs(wcs, ufunc, vfunc, "OffsetShearWCS 3") # Check that using a wcs in the context of an image works correctly @@ -1825,8 +1840,8 @@ def test_affinetransform(): assert wcs != wcs3c, "JacobianWCS is not != a different one (dvdx)" assert wcs != wcs3d, "JacobianWCS is not != a different one (dvdy)" - ufunc = lambda x, y: dudx * x + dudy * y - vfunc = lambda x, y: dvdx * x + dvdy * y + ufunc = lambda x, y: dudx * x + dudy * y # noqa: E731 + vfunc = lambda x, y: dvdx * x + dvdy * y # noqa: E731 do_local_wcs(wcs, ufunc, vfunc, "JacobianWCS 1") # Check the decomposition: @@ -1882,8 +1897,8 @@ def test_affinetransform(): assert wcs != wcs3e, "AffineTransform is not != a different one (origin)" assert wcs != wcs3f, "AffineTransform is not != a different one (world_origin)" - ufunc = lambda x, y: dudx * (x - x0) + dudy * (y - y0) - vfunc = lambda x, y: dvdx * (x - x0) + dvdy * (y - y0) + ufunc = lambda x, y: dudx * (x - x0) + dudy * (y - y0) # noqa: E731 + vfunc = lambda x, y: dvdx * (x - x0) + dvdy * (y - y0) # noqa: E731 do_nonlocal_wcs(wcs, ufunc, vfunc, "AffineTransform 1") # Next one with a flip and significant rotation and a large (u,v) offset @@ -1893,8 +1908,8 @@ def test_affinetransform(): dvdy = 0.1409 wcs = galsim.JacobianWCS(dudx, dudy, dvdx, dvdy) - ufunc = lambda x, y: dudx * x + dudy * y - vfunc = lambda x, y: dvdx * x + dvdy * y + ufunc = lambda x, y: dudx * x + dudy * y # noqa: E731 + vfunc = lambda x, y: dvdx * x + dvdy * y # noqa: E731 do_local_wcs(wcs, ufunc, vfunc, "JacobianWCS 2") # Check the decomposition: @@ -1906,8 +1921,8 @@ def test_affinetransform(): wcs = galsim.AffineTransform( dudx, dudy, dvdx, dvdy, world_origin=galsim.PositionD(u0, v0) ) - ufunc = lambda x, y: dudx * x + dudy * y + u0 - vfunc = lambda x, y: dvdx * x + dvdy * y + v0 + ufunc = lambda x, y: dudx * x + dudy * y + u0 # noqa: E731 + vfunc = lambda x, y: dvdx * x + dvdy * y + v0 # noqa: E731 do_nonlocal_wcs(wcs, ufunc, vfunc, "AffineTransform 2") # Finally a really crazy one that isn't remotely regular @@ -1917,8 +1932,8 @@ def test_affinetransform(): dvdy = -0.3013 wcs = galsim.JacobianWCS(dudx, dudy, dvdx, dvdy) - ufunc = lambda x, y: dudx * x + dudy * y - vfunc = lambda x, y: dvdx * x + dvdy * y + ufunc = lambda x, y: dudx * x + dudy * y # noqa: E731 + vfunc = lambda x, y: dvdx * x + dvdy * y # noqa: E731 do_local_wcs(wcs, ufunc, vfunc, "Jacobian 3") # Check the decomposition: @@ -1937,8 +1952,8 @@ def test_affinetransform(): wcs = galsim.AffineTransform( dudx, dudy, dvdx, dvdy, origin=origin, world_origin=world_origin ) - ufunc = lambda x, y: dudx * (x - x0) + dudy * (y - y0) + u0 - vfunc = lambda x, y: dvdx * (x - x0) + dvdy * (y - y0) + v0 + ufunc = lambda x, y: dudx * (x - x0) + dudy * (y - y0) + u0 # noqa: E731 + vfunc = lambda x, y: dvdx * (x - x0) + dvdy * (y - y0) + v0 # noqa: E731 do_nonlocal_wcs(wcs, ufunc, vfunc, "AffineTransform 3") # Check that using a wcs in the context of an image works correctly @@ -2008,8 +2023,8 @@ def test_uvfunction(): # First make some that are identical to simpler WCS classes: # 1. Like PixelScale scale = 0.17 - ufunc = lambda x, y: x * scale - vfunc = lambda x, y: y * scale + ufunc = lambda x, y: x * scale # noqa: E731 + vfunc = lambda x, y: y * scale # noqa: E731 wcs = galsim.UVFunction(ufunc, vfunc) do_nonlocal_wcs(wcs, ufunc, vfunc, "UVFunction like PixelScale", test_pickle=False) assert wcs.ufunc(2.9, 3.7) == ufunc(2.9, 3.7) @@ -2022,8 +2037,8 @@ def test_uvfunction(): assert not wcs.isCelestial() # Also check with inverse functions. - xfunc = lambda u, v: u / scale - yfunc = lambda u, v: v / scale + xfunc = lambda u, v: u / scale # noqa: E731 + yfunc = lambda u, v: v / scale # noqa: E731 wcs = galsim.UVFunction(ufunc, vfunc, xfunc, yfunc) do_nonlocal_wcs( wcs, ufunc, vfunc, "UVFunction like PixelScale with inverse", test_pickle=False @@ -2057,14 +2072,14 @@ def test_uvfunction(): g1 = 0.14 g2 = -0.37 factor = 1.0 / np.sqrt(1.0 - g1 * g1 - g2 * g2) - ufunc = lambda x, y: (x - g1 * x - g2 * y) * scale * factor - vfunc = lambda x, y: (y + g1 * y - g2 * x) * scale * factor + ufunc = lambda x, y: (x - g1 * x - g2 * y) * scale * factor # noqa: E731 + vfunc = lambda x, y: (y + g1 * y - g2 * x) * scale * factor # noqa: E731 wcs = galsim.UVFunction(ufunc, vfunc) do_nonlocal_wcs(wcs, ufunc, vfunc, "UVFunction like ShearWCS", test_pickle=False) # Also check with inverse functions. - xfunc = lambda u, v: (u + g1 * u + g2 * v) / scale * factor - yfunc = lambda u, v: (v - g1 * v + g2 * u) / scale * factor + xfunc = lambda u, v: (u + g1 * u + g2 * v) / scale * factor # noqa: E731 + yfunc = lambda u, v: (v - g1 * v + g2 * u) / scale * factor # noqa: E731 wcs = galsim.UVFunction(ufunc, vfunc, xfunc, yfunc) do_nonlocal_wcs( wcs, ufunc, vfunc, "UVFunction like ShearWCS with inverse", test_pickle=False @@ -2076,8 +2091,8 @@ def test_uvfunction(): dvdx = 0.1409 dvdy = 0.2391 - ufunc = lambda x, y: dudx * x + dudy * y - vfunc = lambda x, y: dvdx * x + dvdy * y + ufunc = lambda x, y: dudx * x + dudy * y # noqa: E731 + vfunc = lambda x, y: dvdx * x + dvdy * y # noqa: E731 wcs = galsim.UVFunction(ufunc, vfunc) do_nonlocal_wcs( wcs, ufunc, vfunc, "UVFunction like AffineTransform", test_pickle=False @@ -2113,7 +2128,7 @@ def test_uvfunction(): uses_color=True, ) do_nonlocal_wcs( - wcs, ufunc, vfunc, "UVFunction with unused color term", test_pickle=True + wcsc, ufunc, vfunc, "UVFunction with unused color term", test_pickle=True ) # 4. Next some UVFunctions with non-trivial offsets @@ -2123,8 +2138,8 @@ def test_uvfunction(): v0 = -141.9 origin = galsim.PositionD(x0, y0) world_origin = galsim.PositionD(u0, v0) - ufunc2 = lambda x, y: dudx * (x - x0) + dudy * (y - y0) + u0 - vfunc2 = lambda x, y: dvdx * (x - x0) + dvdy * (y - y0) + v0 + ufunc2 = lambda x, y: dudx * (x - x0) + dudy * (y - y0) + u0 # noqa: E731 + vfunc2 = lambda x, y: dvdx * (x - x0) + dvdy * (y - y0) + v0 # noqa: E731 wcs = galsim.UVFunction(ufunc2, vfunc2) do_nonlocal_wcs( wcs, ufunc2, vfunc2, "UVFunction with origins in funcs", test_pickle=False @@ -2197,8 +2212,8 @@ def test_uvfunction(): "UVFunction dvdy does not match expected value.", ) - ufunc = lambda x, y: radial_u(x - x0, y - y0) - vfunc = lambda x, y: radial_v(x - x0, y - y0) + ufunc = lambda x, y: radial_u(x - x0, y - y0) # noqa: E731 + vfunc = lambda x, y: radial_v(x - x0, y - y0) # noqa: E731 do_nonlocal_wcs(wcs, ufunc, vfunc, "Cubic radial UVFunction", test_pickle=False) # Check that using a wcs in the context of an image works correctly @@ -2209,8 +2224,8 @@ def test_uvfunction(): cubic_u = Cubic(2.9e-5, 2000.0, "u") cubic_v = Cubic(-3.7e-5, 2000.0, "v") wcs = galsim.UVFunction(cubic_u, cubic_v, origin=galsim.PositionD(x0, y0)) - ufunc = lambda x, y: cubic_u(x - x0, y - y0) - vfunc = lambda x, y: cubic_v(x - x0, y - y0) + ufunc = lambda x, y: cubic_u(x - x0, y - y0) # noqa: E731 + vfunc = lambda x, y: cubic_v(x - x0, y - y0) # noqa: E731 do_nonlocal_wcs(wcs, ufunc, vfunc, "Cubic object UVFunction", test_pickle=False) # Check that using a wcs in the context of an image works correctly @@ -2218,8 +2233,8 @@ def test_uvfunction(): # 7. Test the UVFunction that is used in demo9 to confirm that I got the # inverse function correct! - ufunc = lambda x, y: 0.05 * x * (1.0 + 2.0e-6 * (x**2 + y**2)) - vfunc = lambda x, y: 0.05 * y * (1.0 + 2.0e-6 * (x**2 + y**2)) + ufunc = lambda x, y: 0.05 * x * (1.0 + 2.0e-6 * (x**2 + y**2)) # noqa: E731 + vfunc = lambda x, y: 0.05 * y * (1.0 + 2.0e-6 * (x**2 + y**2)) # noqa: E731 # w = 0.05 (r + 2.e-6 r^3) # 0 = r^3 + 5e5 r - 1e7 w # @@ -2231,7 +2246,7 @@ def test_uvfunction(): # ( 5 sqrt( w^2 + 5.e3/27 ) - 5 w )^1/3 ) import math - xfunc = lambda u, v: ( + xfunc = lambda u, v: ( # noqa: E731 lambda w: ( 0.0 if w == 0.0 @@ -2244,7 +2259,7 @@ def test_uvfunction(): ) ) )(math.sqrt(u**2 + v**2)) - yfunc = lambda u, v: ( + yfunc = lambda u, v: ( # noqa: E731 lambda w: ( 0.0 if w == 0.0 @@ -2281,19 +2296,23 @@ def test_uvfunction(): # This version doesn't work with numpy arrays because of the math functions. # This provides a test of that branch of the makeSkyImage function. - ufunc = lambda x, y: 0.17 * x * (1.0 + 1.0e-5 * math.sqrt(x**2 + y**2)) - vfunc = lambda x, y: 0.17 * y * (1.0 + 1.0e-5 * math.sqrt(x**2 + y**2)) + ufunc = ( + lambda x, y: 0.17 * x * (1.0 + 1.0e-5 * math.sqrt(x**2 + y**2)) + ) # noqa: E731 + vfunc = ( + lambda x, y: 0.17 * y * (1.0 + 1.0e-5 * math.sqrt(x**2 + y**2)) + ) # noqa: E731 wcs = galsim.UVFunction(ufunc, vfunc) do_nonlocal_wcs(wcs, ufunc, vfunc, "UVFunction with math funcs", test_pickle=False) do_wcs_image(wcs, "UVFunction_math") # 8. A non-trivial color example - ufunc = lambda x, y, c: (dudx + 0.1 * c) * x + dudy * y - vfunc = lambda x, y, c: dvdx * x + (dvdy - 0.2 * c) * y - xfunc = lambda u, v, c: ((dvdy - 0.2 * c) * u - dudy * v) / ( + ufunc = lambda x, y, c: (dudx + 0.1 * c) * x + dudy * y # noqa: E731 + vfunc = lambda x, y, c: dvdx * x + (dvdy - 0.2 * c) * y # noqa: E731 + xfunc = lambda u, v, c: ((dvdy - 0.2 * c) * u - dudy * v) / ( # noqa: E731 (dudx + 0.1 * c) * (dvdy - 0.2 * c) - dudy * dvdx ) - yfunc = lambda u, v, c: (-dvdx * u + (dudx + 0.1 * c) * v) / ( + yfunc = lambda u, v, c: (-dvdx * u + (dudx + 0.1 * c) * v) / ( # noqa: E731 (dudx + 0.1 * c) * (dvdy - 0.2 * c) - dudy * dvdx ) wcs = galsim.UVFunction(ufunc, vfunc, xfunc, yfunc, uses_color=True) @@ -2326,10 +2345,10 @@ def test_uvfunction(): ) # 9. A non-trivial color example that fails for arrays - ufunc = lambda x, y, c: math.exp(c * x) - vfunc = lambda x, y, c: math.exp(c * y / 2) - xfunc = lambda u, v, c: math.log(u) / c - yfunc = lambda u, v, c: math.log(v) * 2 / c + ufunc = lambda x, y, c: math.exp(c * x) # noqa: E731 + vfunc = lambda x, y, c: math.exp(c * y / 2) # noqa: E731 + xfunc = lambda u, v, c: math.log(u) / c # noqa: E731 + yfunc = lambda u, v, c: math.log(v) * 2 / c # noqa: E731 wcs = galsim.UVFunction(ufunc, vfunc, xfunc, yfunc, uses_color=True) do_nonlocal_wcs( wcs, @@ -2341,20 +2360,20 @@ def test_uvfunction(): ) # 10. One with invalid functions, which raise errors. (Just for coverage really.) - ufunc = lambda x, y: 0.05 * x * math.sqrt(x**2 - y**2) - vfunc = lambda x, y: 0.05 * y * math.sqrt(3 * y**2 - x**2) - xfunc = lambda u, v: 0.05 * u * math.sqrt(2 * u**2 - 7 * v**2) - yfunc = lambda u, v: 0.05 * v * math.sqrt(8 * v**2 - u**2) + ufunc = lambda x, y: 0.05 * x * math.sqrt(x**2 - y**2) # noqa: E731 + vfunc = lambda x, y: 0.05 * y * math.sqrt(3 * y**2 - x**2) # noqa: E731 + xfunc = lambda u, v: 0.05 * u * math.sqrt(2 * u**2 - 7 * v**2) # noqa: E731 + yfunc = lambda u, v: 0.05 * v * math.sqrt(8 * v**2 - u**2) # noqa: E731 wcs = galsim.UVFunction(ufunc, vfunc, xfunc, yfunc) assert_raises(ValueError, wcs.toWorld, galsim.PositionD(5, 6)) assert_raises(ValueError, wcs.toWorld, galsim.PositionD(8, 2)) assert_raises(ValueError, wcs.toImage, galsim.PositionD(3, 3)) assert_raises(ValueError, wcs.toImage, galsim.PositionD(6, 0)) # Repeat with color - ufunc = lambda x, y, c: 0.05 * c * math.sqrt(x**2 - y**2) - vfunc = lambda x, y, c: 0.05 * c * math.sqrt(3 * y**2 - x**2) - xfunc = lambda u, v, c: 0.05 * c * math.sqrt(2 * u**2 - 7 * v**2) - yfunc = lambda u, v, c: 0.05 * c * math.sqrt(8 * v**2 - u**2) + ufunc = lambda x, y, c: 0.05 * c * math.sqrt(x**2 - y**2) # noqa: E731 + vfunc = lambda x, y, c: 0.05 * c * math.sqrt(3 * y**2 - x**2) # noqa: E731 + xfunc = lambda u, v, c: 0.05 * c * math.sqrt(2 * u**2 - 7 * v**2) # noqa: E731 + yfunc = lambda u, v, c: 0.05 * c * math.sqrt(8 * v**2 - u**2) # noqa: E731 wcs = galsim.UVFunction(ufunc, vfunc, xfunc, yfunc, uses_color=True) assert_raises(ValueError, wcs.toWorld, galsim.PositionD(5, 6), color=0.2) assert_raises(ValueError, wcs.toWorld, galsim.PositionD(8, 2), color=0.2) @@ -2369,47 +2388,45 @@ def test_radecfunction(): funcs = [] scale = 0.17 - ufunc = lambda x, y: x * scale - vfunc = lambda x, y: y * scale + ufunc = lambda x, y: x * scale # noqa: E731 + vfunc = lambda x, y: y * scale # noqa: E731 funcs.append((ufunc, vfunc, "like PixelScale")) scale = 0.23 g1 = 0.14 g2 = -0.37 factor = 1.0 / np.sqrt(1.0 - g1 * g1 - g2 * g2) - ufunc = lambda x, y: (x - g1 * x - g2 * y) * scale * factor - vfunc = lambda x, y: (y + g1 * y - g2 * x) * scale * factor + ufunc = lambda x, y: (x - g1 * x - g2 * y) * scale * factor # noqa: E731 + vfunc = lambda x, y: (y + g1 * y - g2 * x) * scale * factor # noqa: E731 funcs.append((ufunc, vfunc, "like ShearWCS")) dudx = 0.2342 dudy = 0.1432 dvdx = 0.1409 dvdy = 0.2391 - ufunc = lambda x, y: dudx * x + dudy * y - vfunc = lambda x, y: dvdx * x + dvdy * y + ufunc = lambda x, y: dudx * x + dudy * y # noqa: E731 + vfunc = lambda x, y: dvdx * x + dvdy * y # noqa: E731 funcs.append((ufunc, vfunc, "like JacobianWCS")) x0 = 1.3 y0 = -0.9 u0 = 124.3 v0 = -141.9 - origin = galsim.PositionD(x0, y0) - world_origin = galsim.PositionD(u0, v0) - ufunc = lambda x, y: dudx * (x - x0) + dudy * (y - y0) + u0 - vfunc = lambda x, y: dvdx * (x - x0) + dvdy * (y - y0) + v0 + ufunc = lambda x, y: dudx * (x - x0) + dudy * (y - y0) + u0 # noqa: E731 + vfunc = lambda x, y: dvdx * (x - x0) + dvdy * (y - y0) + v0 # noqa: E731 funcs.append((ufunc, vfunc, "like AffineTransform")) funcs.append((radial_u, radial_v, "Cubic radial")) - ufunc = lambda x, y: radial_u(x - x0, y - y0) - vfunc = lambda x, y: radial_v(x - x0, y - y0) + ufunc = lambda x, y: radial_u(x - x0, y - y0) # noqa: E731 + vfunc = lambda x, y: radial_v(x - x0, y - y0) # noqa: E731 funcs.append((ufunc, vfunc, "offset Cubic radial")) cubic_u = Cubic(2.9e-5, 2000.0, "u") cubic_v = Cubic(-3.7e-5, 2000.0, "v") - ufunc = lambda x, y: cubic_u(x - x0, y - y0) - vfunc = lambda x, y: cubic_v(x - x0, y - y0) + ufunc = lambda x, y: cubic_u(x - x0, y - y0) # noqa: E731 + vfunc = lambda x, y: cubic_v(x - x0, y - y0) # noqa: E731 funcs.append((ufunc, vfunc, "offset Cubic object")) # The last one needs to not have a lambda, since we use it for the image test, which @@ -2434,7 +2451,7 @@ def test_radecfunction(): ) scale = galsim.arcsec / galsim.radians - radec_func = lambda x, y: center.deproject_rad( + radec_func = lambda x, y: center.deproject_rad( # noqa: E731 ufunc(x, y) * scale, vfunc(x, y) * scale, projection="lambert" ) wcs2 = galsim.RaDecFunction(radec_func) @@ -2447,12 +2464,12 @@ def test_radecfunction(): # code does the right thing in that case too, since local and makeSkyImage # try the numpy option first and do something else if it fails. # This also tests the alternate initialization using separate ra_func, dec_fun. - ra_func = lambda x, y: center.deproject( + ra_func = lambda x, y: center.deproject( # noqa: E731 ufunc(x, y) * galsim.arcsec, vfunc(x, y) * galsim.arcsec, projection="lambert", ).ra.rad - dec_func = lambda x, y: center.deproject( + dec_func = lambda x, y: center.deproject( # noqa: E731 ufunc(x, y) * galsim.arcsec, vfunc(x, y) * galsim.arcsec, projection="lambert", @@ -2521,7 +2538,6 @@ def test_radecfunction(): image_pos = galsim.PositionD(x, y) world_pos1 = wcs1.toWorld(image_pos) world_pos2 = test_wcs.toWorld(image_pos) - origin = test_wcs.toWorld(galsim.PositionD(0.0, 0.0)) d3 = np.sqrt(world_pos1.x**2 + world_pos1.y**2) d4 = center.distanceTo(world_pos2) d4 = 2.0 * np.sin(d4 / 2) * galsim.radians / galsim.arcsec @@ -2712,7 +2728,7 @@ def test_radecfunction(): do_wcs_image(wcs3, "RaDecFunction") # One with invalid functions, which raise errors. (Just for coverage really.) - radec_func = lambda x, y: center.deproject_rad( + radec_func = lambda x, y: center.deproject_rad( # noqa: E731 math.sqrt(x), math.sqrt(y), projection="lambert" ) wcs = galsim.RaDecFunction(radec_func) @@ -2780,8 +2796,8 @@ def test_astropywcs(): """Test the AstropyWCS class""" with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=RuntimeWarning) - import astropy.wcs - import scipy # AstropyWCS constructor will do this, so check now. + import astropy.wcs # noqa: F401 + import scipy # noqa: F401 - AstropyWCS constructor will do this, so check now. # These all work, but it is quite slow, so only test a few of them for the regular unit tests. # (1.8 seconds for 4 tags.) @@ -3130,9 +3146,9 @@ def test_inverseab_convergence(): [0.0003767412741890354, 0.00019733136932198898], ] ), - coord.CelestialCoord( - coord.Angle(2.171481673601117, coord.radians), - coord.Angle(-0.47508762601580773, coord.radians), + galsim.CelestialCoord( + galsim.Angle(2.171481673601117, galsim.radians), + galsim.Angle(-0.47508762601580773, galsim.radians), ), None, np.array( @@ -3320,13 +3336,13 @@ def test_fitswcs(): # mostly just tests the basic interface of the FitsWCS function. test_tags = ["TAN", "TPV"] try: - import starlink.Ast + import starlink.Ast # noqa: F401 # Useful also to test one that GSFitsWCS doesn't work on. This works on Travis at # least, and helps to cover some of the FitsWCS functionality where the first try # isn't successful. test_tags.append("HPX") - except: + except Exception: pass dir = "fits_files" @@ -3361,7 +3377,7 @@ def test_fitswcs(): # We don't really have any accuracy checks here. This really just checks that the # read function doesn't raise an exception. hdu, hdu_list, fin = galsim.fits.readFile(file_name, dir) - affine = galsim.AffineTransform._readHeader(hdu.header) + galsim.AffineTransform._readHeader(hdu.header) galsim.fits.closeHDUList(hdu_list, fin) # This does support LINEAR WCS types. @@ -3419,7 +3435,7 @@ def check_sphere(ra1, dec1, ra2, dec2, atol=1): w = dsq >= 3.99 if np.any(w): cross = np.cross(np.array([x1, y1, z1])[w], np.array([x2, y2, z2])[w]) - crosssq = cross[0] ** 2 + cross[1] ** 2 + cross[2] ** 2 + crossq = cross[0] ** 2 + cross[1] ** 2 + cross[2] ** 2 dist[w] = np.pi - np.arcsin(np.sqrt(crossq)) dist = np.rad2deg(dist) * 3600 np.testing.assert_allclose(dist, 0.0, rtol=0.0, atol=atol) @@ -3448,7 +3464,7 @@ def test_fittedsipwcs(): "ZTF": (0.1, 0.1), } - dir = "fits_files" + dir = os.path.join(os.path.dirname(__file__), "..", "..", "GalSim/tests/fits_files") if __name__ == "__main__": test_tags = all_tags @@ -3964,7 +3980,7 @@ def test_int_args(): # is unnecessary. dir = "des_data" file_name = "DECam_00158414_01.fits.fz" - with profile(): + with Profile(): t0 = time.time() wcs = galsim.FitsWCS(file_name, dir=dir) t1 = time.time() @@ -4008,8 +4024,8 @@ def test_razero(): # do this. with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=RuntimeWarning) - import astropy.wcs - import scipy # AstropyWCS constructor will do this, so check now. + import astropy.wcs # noqa: F401 + import scipy # noqa: F401 - AstropyWCS constructor will do this, so check now. dir = "fits_files" # This file is based in sipsample.fits, but with the CRVAL1 changed to 0.002322805429 diff --git a/tests/jax/test_moffat_comp_galsim.py b/tests/jax/test_moffat_comp_galsim.py new file mode 100644 index 00000000..60f8c831 --- /dev/null +++ b/tests/jax/test_moffat_comp_galsim.py @@ -0,0 +1,40 @@ +import galsim as _galsim +import jax_galsim as galsim +import numpy as np + + +def test_moffat_comp_galsim_maxk(): + psfs = [ + # Make sure to include all the specialized betas we have in C++ layer. + # The scale_radius and flux don't matter, but vary themm too. + # Note: We also specialize beta=1, but that seems to be impossible to realize, + # even when it is trunctatd. + galsim.Moffat(beta=1.5, scale_radius=1, flux=1), + galsim.Moffat(beta=1.5001, scale_radius=1, flux=1), + galsim.Moffat(beta=2, scale_radius=0.8, flux=23), + galsim.Moffat(beta=2.5, scale_radius=1.8e-3, flux=2), + galsim.Moffat(beta=3, scale_radius=1.8e3, flux=35), + galsim.Moffat(beta=3.5, scale_radius=1.3, flux=123), + galsim.Moffat(beta=4, scale_radius=4.9, flux=23), + galsim.Moffat(beta=1.22, scale_radius=23, flux=23), + galsim.Moffat(beta=3.6, scale_radius=2, flux=23), + galsim.Moffat(beta=12.9, scale_radius=5, flux=23), + galsim.Moffat(beta=1.22, scale_radius=7, flux=23, trunc=30), + galsim.Moffat(beta=3.6, scale_radius=9, flux=23, trunc=50), + galsim.Moffat(beta=12.9, scale_radius=11, flux=23, trunc=1000), + ] + threshs = [1.e-3, 1.e-4, 0.03] + print('\nbeta \t trunc \t thresh \t kValue(maxk) \t jgs-maxk \t gs-maxk') + for psf in psfs: + for thresh in threshs: + psf = psf.withGSParams(maxk_threshold=thresh) + gpsf = _galsim.Moffat(beta=psf.beta, scale_radius=psf.scale_radius, flux=psf.flux, trunc=psf.trunc) + gpsf = gpsf.withGSParams(maxk_threshold=thresh) + fk = psf.kValue(psf.maxk, 0).real / psf.flux + + print(f'{psf.beta} \t {int(psf.trunc)} \t {thresh:.1e} \t {fk:.3e} \t {psf.maxk:.3e} \t {gpsf.maxk:.3e}') + np.testing.assert_allclose(psf.kValue(0.0, 0.0), gpsf.kValue(0.0, 0.0), rtol=1e-5) + np.testing.assert_allclose(psf.kValue(0.0, 0.1), gpsf.kValue(0.0, 0.1), rtol=1e-5) + np.testing.assert_allclose(psf.kValue(-1.0, 0.0), gpsf.kValue(-1.0, 0.0), rtol=1e-5) + np.testing.assert_allclose(psf.kValue(1.0, 0.0), gpsf.kValue(1.0, 0.0), rtol=1e-5) + np.testing.assert_allclose(gpsf.maxk, psf.maxk, rtol=0.25, atol=0) From 425f3c3c33796b6e3747fb2394243a18efa31910 Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 17 Oct 2023 14:46:37 -0500 Subject: [PATCH 23/33] STY please the flake8 --- tests/jax/galsim/test_wcs_jax.py | 24 ++++++++++---------- tests/jax/test_moffat_comp_galsim.py | 34 ++++++++++++++++++++-------- 2 files changed, 37 insertions(+), 21 deletions(-) diff --git a/tests/jax/galsim/test_wcs_jax.py b/tests/jax/galsim/test_wcs_jax.py index f2b7e791..d5ab3628 100644 --- a/tests/jax/galsim/test_wcs_jax.py +++ b/tests/jax/galsim/test_wcs_jax.py @@ -1750,12 +1750,12 @@ def test_shearwcs(): assert wcs != wcs3c, "OffsetShearWCS is not != a different one (origin)" assert wcs != wcs3d, "OffsetShearWCS is not != a different one (world_origin)" - ufunc = ( + ufunc = ( # noqa: E731 lambda x, y: ((1 - g1) * (x - x0) - g2 * (y - y0)) * scale * factor - ) # noqa: E731 - vfunc = ( + ) + vfunc = ( # noqa: E731 lambda x, y: ((1 + g1) * (y - y0) - g2 * (x - x0)) * scale * factor - ) # noqa: E731 + ) do_nonlocal_wcs(wcs, ufunc, vfunc, "OffsetShearWCS 1") # Add a world origin offset @@ -1775,12 +1775,12 @@ def test_shearwcs(): origin = galsim.PositionD(x0, y0) world_origin = galsim.PositionD(u0, v0) wcs = galsim.OffsetShearWCS(scale, shear, origin=origin, world_origin=world_origin) - ufunc = ( + ufunc = ( # noqa: E731 lambda x, y: ((1 - g1) * (x - x0) - g2 * (y - y0)) * scale * factor + u0 - ) # noqa: E731 - vfunc = ( + ) + vfunc = ( # noqa: E731 lambda x, y: ((1 + g1) * (y - y0) - g2 * (x - x0)) * scale * factor + v0 - ) # noqa: E731 + ) do_nonlocal_wcs(wcs, ufunc, vfunc, "OffsetShearWCS 3") # Check that using a wcs in the context of an image works correctly @@ -2296,12 +2296,12 @@ def test_uvfunction(): # This version doesn't work with numpy arrays because of the math functions. # This provides a test of that branch of the makeSkyImage function. - ufunc = ( + ufunc = ( # noqa: E731 lambda x, y: 0.17 * x * (1.0 + 1.0e-5 * math.sqrt(x**2 + y**2)) - ) # noqa: E731 - vfunc = ( + ) + vfunc = ( # noqa: E731 lambda x, y: 0.17 * y * (1.0 + 1.0e-5 * math.sqrt(x**2 + y**2)) - ) # noqa: E731 + ) wcs = galsim.UVFunction(ufunc, vfunc) do_nonlocal_wcs(wcs, ufunc, vfunc, "UVFunction with math funcs", test_pickle=False) do_wcs_image(wcs, "UVFunction_math") diff --git a/tests/jax/test_moffat_comp_galsim.py b/tests/jax/test_moffat_comp_galsim.py index 60f8c831..d4549420 100644 --- a/tests/jax/test_moffat_comp_galsim.py +++ b/tests/jax/test_moffat_comp_galsim.py @@ -1,7 +1,8 @@ import galsim as _galsim -import jax_galsim as galsim import numpy as np +import jax_galsim as galsim + def test_moffat_comp_galsim_maxk(): psfs = [ @@ -23,18 +24,33 @@ def test_moffat_comp_galsim_maxk(): galsim.Moffat(beta=3.6, scale_radius=9, flux=23, trunc=50), galsim.Moffat(beta=12.9, scale_radius=11, flux=23, trunc=1000), ] - threshs = [1.e-3, 1.e-4, 0.03] - print('\nbeta \t trunc \t thresh \t kValue(maxk) \t jgs-maxk \t gs-maxk') + threshs = [1.0e-3, 1.0e-4, 0.03] + print("\nbeta \t trunc \t thresh \t kValue(maxk) \t jgs-maxk \t gs-maxk") for psf in psfs: for thresh in threshs: psf = psf.withGSParams(maxk_threshold=thresh) - gpsf = _galsim.Moffat(beta=psf.beta, scale_radius=psf.scale_radius, flux=psf.flux, trunc=psf.trunc) + gpsf = _galsim.Moffat( + beta=psf.beta, + scale_radius=psf.scale_radius, + flux=psf.flux, + trunc=psf.trunc, + ) gpsf = gpsf.withGSParams(maxk_threshold=thresh) fk = psf.kValue(psf.maxk, 0).real / psf.flux - print(f'{psf.beta} \t {int(psf.trunc)} \t {thresh:.1e} \t {fk:.3e} \t {psf.maxk:.3e} \t {gpsf.maxk:.3e}') - np.testing.assert_allclose(psf.kValue(0.0, 0.0), gpsf.kValue(0.0, 0.0), rtol=1e-5) - np.testing.assert_allclose(psf.kValue(0.0, 0.1), gpsf.kValue(0.0, 0.1), rtol=1e-5) - np.testing.assert_allclose(psf.kValue(-1.0, 0.0), gpsf.kValue(-1.0, 0.0), rtol=1e-5) - np.testing.assert_allclose(psf.kValue(1.0, 0.0), gpsf.kValue(1.0, 0.0), rtol=1e-5) + print( + f"{psf.beta} \t {int(psf.trunc)} \t {thresh:.1e} \t {fk:.3e} \t {psf.maxk:.3e} \t {gpsf.maxk:.3e}" + ) + np.testing.assert_allclose( + psf.kValue(0.0, 0.0), gpsf.kValue(0.0, 0.0), rtol=1e-5 + ) + np.testing.assert_allclose( + psf.kValue(0.0, 0.1), gpsf.kValue(0.0, 0.1), rtol=1e-5 + ) + np.testing.assert_allclose( + psf.kValue(-1.0, 0.0), gpsf.kValue(-1.0, 0.0), rtol=1e-5 + ) + np.testing.assert_allclose( + psf.kValue(1.0, 0.0), gpsf.kValue(1.0, 0.0), rtol=1e-5 + ) np.testing.assert_allclose(gpsf.maxk, psf.maxk, rtol=0.25, atol=0) From bc7f3550d8ae942fcad0649d14de358261505e7e Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Tue, 17 Oct 2023 14:47:18 -0500 Subject: [PATCH 24/33] Update jax_galsim/core/utils.py --- jax_galsim/core/utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index 872912ec..c41a8a40 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -96,8 +96,6 @@ def _func(i, args): fhigh, ) - low = 0.0 - high = 1e5 flow = func(low) fhigh = func(high) args = (func, low, flow, high, fhigh) From 851b4f74022e72586974968bf5cdfa408302c619 Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 17 Oct 2023 23:00:15 -0500 Subject: [PATCH 25/33] TST patch galsim in check_pickle --- tests/GalSim | 2 +- tests/conftest.py | 20 +++++++++++++++++++- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/tests/GalSim b/tests/GalSim index 81509041..b018d57f 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 815090419d343d0e840bbc53e79c7bc4469ec79d +Subproject commit b018d57fba88eabbaacf40d34d3029a77e7071f2 diff --git a/tests/conftest.py b/tests/conftest.py index dfdc96cc..4d8bab9a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,23 +1,41 @@ import inspect import os +import sys from functools import lru_cache +from unittest.mock import patch +import galsim import pytest import yaml # Define the accuracy for running the tests from jax.config import config +import jax_galsim + config.update("jax_enable_x64", True) # Identify the path to this current file test_directory = os.path.dirname(os.path.abspath(__file__)) - # Loading which tests to run with open(os.path.join(test_directory, "galsim_tests_config.yaml"), "r") as f: test_config = yaml.safe_load(f) +# we need to patch the galsim utilities check_pickle function +# to use jax_galsim. it has an import inside a function so +# we patch sys.modules. +# see https://stackoverflow.com/questions/34213088/mocking-a-module-imported-inside-of-a-function +orig_check_pickle = galsim.utilities.check_pickle + + +def _check_pickle(*args, **kwargs): + with patch.dict(sys.modules, {"galsim": jax_galsim}): + return orig_check_pickle(*args, **kwargs) + + +galsim.utilities.check_pickle = _check_pickle + def pytest_ignore_collect(collection_path, path, config): """This hook will skip collecting tests that are not From cd61feac35b52d15b68fc5fd5eef9bee0c1961a0 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Tue, 17 Oct 2023 23:28:07 -0500 Subject: [PATCH 26/33] Update tests/jax/galsim/test_random_jax.py Co-authored-by: Ismael Mendoza <11745764+ismael-mendoza@users.noreply.github.com> --- tests/jax/galsim/test_random_jax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/jax/galsim/test_random_jax.py b/tests/jax/galsim/test_random_jax.py index 68e0b409..9b70b945 100644 --- a/tests/jax/galsim/test_random_jax.py +++ b/tests/jax/galsim/test_random_jax.py @@ -1251,7 +1251,7 @@ def test_gamma(): assert g.has_reliable_discard assert not g.generates_in_pairs - # NOTE jax has a reliabble discard + # NOTE jax has a reliable discard # Discard normally emits a warning for Gamma # g2 = galsim.GammaDeviate(testseed, k=gammaK, theta=gammaTheta) # with assert_warns(galsim.GalSimWarning): From d9d57dc241063bc101fc1fe644f8f7f13acca0fd Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Tue, 17 Oct 2023 23:28:14 -0500 Subject: [PATCH 27/33] Update tests/jax/galsim/test_random_jax.py Co-authored-by: Ismael Mendoza <11745764+ismael-mendoza@users.noreply.github.com> --- tests/jax/galsim/test_random_jax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/jax/galsim/test_random_jax.py b/tests/jax/galsim/test_random_jax.py index 9b70b945..151c11d4 100644 --- a/tests/jax/galsim/test_random_jax.py +++ b/tests/jax/galsim/test_random_jax.py @@ -1240,7 +1240,7 @@ def test_gamma(): var, v, 0, err_msg='Wrong variance from GammaDeviate') - # NOTE jax has a reliabble discard + # NOTE jax has a reliable discard # Check discard g2 = galsim.GammaDeviate(testseed, k=gammaK, theta=gammaTheta) g2.discard(nvals, suppress_warnings=True) From 93088dc0f53aa2496eb4b69eb2d597e178ffb34c Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 18 Oct 2023 06:45:50 -0500 Subject: [PATCH 28/33] REF remove unused attribute --- jax_galsim/transform.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/jax_galsim/transform.py b/jax_galsim/transform.py index a0144018..f86e7b3e 100644 --- a/jax_galsim/transform.py +++ b/jax_galsim/transform.py @@ -105,10 +105,6 @@ def flux_ratio(self): def _flux(self): return self._flux_scaling * self._original.flux - @property - def _offset(self): - return self._params["offset"] - def withGSParams(self, gsparams=None, **kwargs): """Create a version of the current object with the given gsparams @@ -308,7 +304,7 @@ def _is_analytic_k(self): def _centroid(self): cen = self._original.centroid cen = PositionD(self._fwd(cen.x, cen.y)) - cen += self._offset + cen += self.offset return cen @property From a1d5d9baf70627271b2485e5c437987020b1a389 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Wed, 18 Oct 2023 11:02:17 -0500 Subject: [PATCH 29/33] Update jax_galsim/transform.py --- jax_galsim/transform.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/jax_galsim/transform.py b/jax_galsim/transform.py index 56f3e1ee..f86e7b3e 100644 --- a/jax_galsim/transform.py +++ b/jax_galsim/transform.py @@ -105,10 +105,6 @@ def flux_ratio(self): def _flux(self): return self._flux_scaling * self._original.flux - @property - def _offset(self): - return self._params["offset"] - def withGSParams(self, gsparams=None, **kwargs): """Create a version of the current object with the given gsparams From 0e9a4b84892eeed982da448d88687ab859bfd984 Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 18 Oct 2023 11:04:52 -0500 Subject: [PATCH 30/33] TST fix tests --- tests/jax/galsim/test_random_jax.py | 56 ++++++++++++++--------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/tests/jax/galsim/test_random_jax.py b/tests/jax/galsim/test_random_jax.py index 151c11d4..879cdad2 100644 --- a/tests/jax/galsim/test_random_jax.py +++ b/tests/jax/galsim/test_random_jax.py @@ -3,7 +3,7 @@ import os import galsim from galsim.utilities import single_threaded -from galsim_test_helpers import timer, do_pickle # noqa: E402 +from galsim_test_helpers import timer, check_pickle # noqa: E402 precision = 10 # decimal point at which agreement is required for all double precision tests @@ -274,10 +274,10 @@ def test_uniform(): np.testing.assert_array_equal(v1, v2) # Check picklability - do_pickle(u, lambda x: x.serialize()) - do_pickle(u, lambda x: (x(), x(), x(), x())) - do_pickle(u) - do_pickle(rng) + check_pickle(u, lambda x: x.serialize()) + check_pickle(u, lambda x: (x(), x(), x(), x())) + check_pickle(u) + check_pickle(rng) assert "UniformDeviate" in repr(u) assert "UniformDeviate" in str(u) assert isinstance(eval(repr(u)), galsim.UniformDeviate) @@ -495,9 +495,9 @@ def test_gaussian(): np.testing.assert_array_equal(v1, v2) # Check picklability - do_pickle(g, lambda x: (x.serialize(), x.mean, x.sigma)) - do_pickle(g, lambda x: (x(), x(), x(), x())) - do_pickle(g) + check_pickle(g, lambda x: (x.serialize(), x.mean, x.sigma)) + check_pickle(g, lambda x: (x(), x(), x(), x())) + check_pickle(g) assert 'GaussianDeviate' in repr(g) assert 'GaussianDeviate' in str(g) assert isinstance(eval(repr(g)), galsim.GaussianDeviate) @@ -666,9 +666,9 @@ def test_binomial(): np.testing.assert_array_equal(v1, v2) # Check picklability - do_pickle(b, lambda x: (x.serialize(), x.n, x.p)) - do_pickle(b, lambda x: (x(), x(), x(), x())) - do_pickle(b) + check_pickle(b, lambda x: (x.serialize(), x.n, x.p)) + check_pickle(b, lambda x: (x(), x(), x(), x())) + check_pickle(b) assert 'BinomialDeviate' in repr(b) assert 'BinomialDeviate' in str(b) assert isinstance(eval(repr(b)), galsim.BinomialDeviate) @@ -869,9 +869,9 @@ def test_poisson(): np.testing.assert_array_equal(v1, v2) # Check picklability - do_pickle(p, lambda x: (x.serialize(), x.mean)) - do_pickle(p, lambda x: (x(), x(), x(), x())) - do_pickle(p) + check_pickle(p, lambda x: (x.serialize(), x.mean)) + check_pickle(p, lambda x: (x(), x(), x(), x())) + check_pickle(p) assert 'PoissonDeviate' in repr(p) assert 'PoissonDeviate' in str(p) assert isinstance(eval(repr(p)), galsim.PoissonDeviate) @@ -1000,7 +1000,7 @@ def test_poisson_zeromean(): p = galsim.PoissonDeviate(testseed, mean=0) p2 = p.duplicate() p3 = galsim.PoissonDeviate(p.serialize(), mean=0) - do_pickle(p) + check_pickle(p) # Test direct draws testResult = (p(), p(), p()) @@ -1184,9 +1184,9 @@ def test_weibull(): np.testing.assert_array_equal(v1, v2) # Check picklability - do_pickle(w, lambda x: (x.serialize(), x.a, x.b)) - do_pickle(w, lambda x: (x(), x(), x(), x())) - do_pickle(w) + check_pickle(w, lambda x: (x.serialize(), x.a, x.b)) + check_pickle(w, lambda x: (x(), x(), x(), x())) + check_pickle(w) assert 'WeibullDeviate' in repr(w) assert 'WeibullDeviate' in str(w) assert isinstance(eval(repr(w)), galsim.WeibullDeviate) @@ -1337,9 +1337,9 @@ def test_gamma(): err_msg='Wrong gamma random number sequence from generate.') # Check picklability - do_pickle(g, lambda x: (x.serialize(), x.k, x.theta)) - do_pickle(g, lambda x: (x(), x(), x(), x())) - do_pickle(g) + check_pickle(g, lambda x: (x.serialize(), x.k, x.theta)) + check_pickle(g, lambda x: (x(), x(), x(), x())) + check_pickle(g) assert 'GammaDeviate' in repr(g) assert 'GammaDeviate' in str(g) assert isinstance(eval(repr(g)), galsim.GammaDeviate) @@ -1489,9 +1489,9 @@ def test_chi2(): err_msg='Wrong Chi^2 random number sequence from generate.') # Check picklability - do_pickle(c, lambda x: (x.serialize(), x.n)) - do_pickle(c, lambda x: (x(), x(), x(), x())) - do_pickle(c) + check_pickle(c, lambda x: (x.serialize(), x.n)) + check_pickle(c, lambda x: (x(), x(), x(), x())) + check_pickle(c) assert 'Chi2Deviate' in repr(c) assert 'Chi2Deviate' in str(c) assert isinstance(eval(repr(c)), galsim.Chi2Deviate) @@ -1710,8 +1710,8 @@ def test_chi2(): # np.testing.assert_array_equal(v1, v2) # # Check picklability -# do_pickle(d, lambda x: (x(), x(), x(), x())) -# do_pickle(d) +# check_pickle(d, lambda x: (x(), x(), x(), x())) +# check_pickle(d) # assert 'DistDeviate' in repr(d) # assert 'DistDeviate' in str(d) # assert isinstance(eval(repr(d)), galsim.DistDeviate) @@ -1877,8 +1877,8 @@ def test_chi2(): # err_msg='Two DistDeviates with near-flat probabilities generated different values.') # # Check picklability -# do_pickle(d, lambda x: (x(), x(), x(), x())) -# do_pickle(d) +# check_pickle(d, lambda x: (x(), x(), x(), x())) +# check_pickle(d) # assert 'DistDeviate' in repr(d) # assert 'DistDeviate' in str(d) # assert isinstance(eval(repr(d)), galsim.DistDeviate) From e18bb4734989b5cf2a937435ac674942d3a084bb Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 18 Oct 2023 11:26:05 -0500 Subject: [PATCH 31/33] ENH change api to get tests to pass without using internal attributes --- jax_galsim/core/draw.py | 4 ++-- jax_galsim/transform.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/jax_galsim/core/draw.py b/jax_galsim/core/draw.py index 4ccf2789..a8edfe51 100644 --- a/jax_galsim/core/draw.py +++ b/jax_galsim/core/draw.py @@ -51,12 +51,12 @@ def draw_by_kValue(gsobject, image, jacobian=jnp.eye(2)): return Image(array=im, bounds=image.bounds, wcs=image.wcs, check_bounds=False) -def apply_kImage_phases(gsobject, image, jacobian=jnp.eye(2)): +def apply_kImage_phases(offset, image, jacobian=jnp.eye(2)): # Create an array of coordinates kcoords = jnp.stack(image.get_pixel_centers(), axis=-1) kcoords = kcoords * image.scale # Scale by the image pixel scale kcoords = jnp.dot(kcoords, jacobian) - cenx, ceny = gsobject._offset.x, gsobject._offset.y + cenx, ceny = offset.x, offset.y # # flux Exp(-i (kx cx + kxy cx + kyx cy + ky cy ) ) diff --git a/jax_galsim/transform.py b/jax_galsim/transform.py index f86e7b3e..c64863f8 100644 --- a/jax_galsim/transform.py +++ b/jax_galsim/transform.py @@ -360,7 +360,7 @@ def _drawKImage(self, image, jac=None): image = self._original._drawKImage(image, jac1) _jac = jnp.eye(2) if jac is None else jac - image = apply_kImage_phases(self, image, _jac) + image = apply_kImage_phases(self.offset, image, _jac) image = image * self._flux_scaling return image From 20cd1439983fbce08c8713a3d5fd104657eb21c2 Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 18 Oct 2023 12:34:26 -0500 Subject: [PATCH 32/33] TST make sure to patch BaseDeviate --- tests/conftest.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/conftest.py b/tests/conftest.py index 4d8bab9a..5db13ef4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,6 +27,7 @@ # we patch sys.modules. # see https://stackoverflow.com/questions/34213088/mocking-a-module-imported-inside-of-a-function orig_check_pickle = galsim.utilities.check_pickle +orig_check_pickle.__globals__["BaseDeviate"] = jax_galsim.BaseDeviate def _check_pickle(*args, **kwargs): From ab8589e440112c3b3fbbbb929cea7019950321e9 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Thu, 19 Oct 2023 06:43:45 -0500 Subject: [PATCH 33/33] Update python_package.yaml --- .github/workflows/python_package.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/python_package.yaml b/.github/workflows/python_package.yaml index b4b6ca67..dac1d792 100644 --- a/.github/workflows/python_package.yaml +++ b/.github/workflows/python_package.yaml @@ -2,9 +2,9 @@ name: Python package on: push: - branches: ["main"] + branches: + - main pull_request: - branches: ["main"] jobs: build: @@ -42,4 +42,4 @@ jobs: - name: Test with pytest run: | git submodule update --init --recursive - pytest + pytest --durations=0