Skip to content

Commit

Permalink
Merge pull request #72 from beckermr/base-deviate
Browse files Browse the repository at this point in the history
ENH add random deviates
  • Loading branch information
beckermr authored Oct 18, 2023
2 parents 3b830bb + a2bebbe commit e1c6c73
Show file tree
Hide file tree
Showing 16 changed files with 3,339 additions and 201 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/python_package.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9"]
python-version: ["3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v3
Expand All @@ -27,7 +27,7 @@ jobs:
python -m pip install .
- name: Ensure black formatting
run: |
black --check jax_galsim/ tests/ --exclude tests/GalSim/
black --check jax_galsim/ tests/ --exclude "tests/GalSim/|tests/Coord/|tests/jax/galsim/"
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
Expand Down
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@ repos:
hooks:
- id: black
language: python
exclude: tests/GalSim/|tests/Coord/
exclude: tests/GalSim/|tests/Coord/|tests/jax/galsim/
- repo: https://github.com/pycqa/flake8
rev: 6.1.0
hooks:
- id: flake8
entry: pflake8
additional_dependencies: [pyproject-flake8]
exclude: tests/
exclude: tests/Galsim/|tests/Coord/|tests/jax/galsim/
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
exclude: tests/Galsim|tests/Coord/
exclude: tests/Galsim/|tests/Coord/|tests/jax/galsim/
38 changes: 19 additions & 19 deletions jax_galsim/__init__.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
# Inherit all Exception and Warning classes from galsim
from galsim.errors import (
GalSimBoundsError,
GalSimConfigError,
GalSimConfigValueError,
GalSimDeprecationWarning,
GalSimError,
GalSimFFTSizeError,
GalSimHSMError,
GalSimImmutableError,
GalSimIncompatibleValuesError,
GalSimIndexError,
GalSimKeyError,
GalSimNotImplementedError,
GalSimRangeError,
GalSimSEDError,
GalSimUndefinedBoundsError,
GalSimValueError,
GalSimWarning,
# Exception and Warning classes
from .errors import GalSimError, GalSimRangeError, GalSimValueError
from .errors import GalSimKeyError, GalSimIndexError, GalSimNotImplementedError
from .errors import GalSimBoundsError, GalSimUndefinedBoundsError, GalSimImmutableError
from .errors import GalSimIncompatibleValuesError, GalSimSEDError, GalSimHSMError
from .errors import GalSimFFTSizeError
from .errors import GalSimConfigError, GalSimConfigValueError
from .errors import GalSimWarning, GalSimDeprecationWarning

# noise
from .random import (
BaseDeviate,
UniformDeviate,
GaussianDeviate,
PoissonDeviate,
Chi2Deviate,
GammaDeviate,
WeibullDeviate,
BinomialDeviate,
)

# Basic building blocks
Expand Down
10 changes: 8 additions & 2 deletions jax_galsim/convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,14 +367,20 @@ def __init__(self, obj, gsparams=None, propagate_gsparams=True):

# Save the original object as an attribute, so it can be inspected later if necessary.
self._gsparams = GSParams.check(gsparams, obj.gsparams)
self._min_acc_kvalue = obj.flux * self.gsparams.kvalue_accuracy
self._inv_min_acc_kvalue = 1.0 / self._min_acc_kvalue
self._propagate_gsparams = propagate_gsparams
if self._propagate_gsparams:
self._orig_obj = obj.withGSParams(self._gsparams)
else:
self._orig_obj = obj

@property
def _min_acc_kvalue(self):
return self._orig_obj.flux * self.gsparams.kvalue_accuracy

@property
def _inv_min_acc_kvalue(self):
return 1.0 / self._min_acc_kvalue

@property
def orig_obj(self):
"""The original object that is being deconvolved."""
Expand Down
14 changes: 12 additions & 2 deletions jax_galsim/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
import jax


def convert_to_float(x):
if isinstance(x, jax.Array):
if x.shape == ():
return x.item()
else:
return x[0].astype(float).item()
else:
return float(x)


def cast_scalar_to_float(x):
"""Cast the input to a float. Works on python floats and jax arrays."""
if isinstance(x, float):
return float(x)
if isinstance(x, jax.Array):
return x.astype(float)
elif hasattr(x, "astype"):
return x.astype(float)
else:
Expand Down
19 changes: 19 additions & 0 deletions jax_galsim/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from galsim.errors import ( # noqa: F401
GalSimBoundsError,
GalSimConfigError,
GalSimConfigValueError,
GalSimDeprecationWarning,
GalSimError,
GalSimFFTSizeError,
GalSimHSMError,
GalSimImmutableError,
GalSimIncompatibleValuesError,
GalSimIndexError,
GalSimKeyError,
GalSimNotImplementedError,
GalSimRangeError,
GalSimSEDError,
GalSimUndefinedBoundsError,
GalSimValueError,
GalSimWarning,
)
Loading

0 comments on commit e1c6c73

Please sign in to comment.