Skip to content

Commit

Permalink
Add TableRow
Browse files Browse the repository at this point in the history
  • Loading branch information
jmeyers314 committed May 17, 2024
1 parent b6a633f commit 948fec3
Show file tree
Hide file tree
Showing 3 changed files with 382 additions and 0 deletions.
1 change: 1 addition & 0 deletions imsim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,4 @@
from .vignetting import *
from .sag import *
from .process_info import *
from .table_row import *
147 changes: 147 additions & 0 deletions imsim/table_row.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import astropy.units as u
import galsim
from astropy.table import QTable
from galsim.config import (
GetAllParams,
GetInputObj,
InputLoader,
RegisterInputType,
RegisterValueType,
)


class TableRow:
"""Class to extract one row from an astropy QTable and make it available to the
galsim config layer.
Parameters
----------
file_name: str
keys: list
Column names to use as keys.
values: list
Values to match in the key columns.
"""

_req_params = {
"file_name": str,
"keys": list,
"values": list,
}

def __init__(self, file_name, keys, values):
self.file_name = file_name
self.keys = keys
self.values = values
self.data = QTable.read(file_name)

for key, value in zip(keys, values):
self.data = self.data[self.data[key] == value]

if len(self.data) == 0:
raise KeyError("No rows found with keys = %s, values = %s" % (keys, values))
if len(self.data) > 1:
raise KeyError(
"Multiple rows found with keys = %s, values = %s" % (keys, values)
)

def get(self, field, value_type, from_unit=None, to_unit=None, subfield=None):
"""Get a value from the table row.
Parameters
----------
field: str
The name of the column to extract.
value_type: [float, int, bool, str, galsim.Angle, list]
The type to convert the value to.
from_unit: str, optional
The units of the value in the table. If the table column already has a
unit, then this must match that unit or be omitted.
to_unit: str, optional
The units to convert the value to. Only allowed if value_type is one of
float, int, or list. Use of this parameter requires that the original units
of the column are inferrable from the table itself or from from_unit.
subfield: str, optional
The name of a subfield to extract from a structured array column. If
ommitted, and field refers to a structured array, the entire array is still
readable as a list value type.
Returns
-------
value: value_type
The value from the table.
"""
data = self.data[field]
if subfield is not None:
data = data[subfield]

# See if data already has a unit, if not, add it.
if data.unit is None:
if from_unit is not None:
data = data * getattr(u, from_unit)
else:
if from_unit is not None:
if data.unit != getattr(u, from_unit):
raise ValueError(
f"from_unit = {from_unit} specified, but field {field} already "
f"has units of {data.unit}."
)

# Angles are special
if value_type == galsim.Angle:
if to_unit is not None:
raise ValueError("to_unit is not allowed for Angle types.")

return float(data.to_value(u.rad)[0]) * galsim.radians

# For non-angles, we leverage astropy units.
if to_unit is not None:
to_unit = getattr(u, to_unit)
if value_type == list:
if to_unit is None:
out = data.value[0].tolist()
else:
out = data.to_value(to_unit)[0].tolist()
# If we had a structured array, `out`` is still a tuple here, so
# use an extra list() here to finish the cast.
return list(out)

# We have to be careful with strings, as using .value on the table datum will
# convert to bytes, which is not what we want.
if value_type == str:
if to_unit is not None:
raise ValueError("to_unit is not allowed for str types.")
return str(data[0])

# No units allowed for bool
if value_type == bool and to_unit is not None:
raise ValueError("to_unit is not allowed for bool types.")

if to_unit is None:
return value_type(data.value[0])
else:
return value_type(data.to_value(to_unit)[0])


def RowData(config, base, value_type):
row = GetInputObj("table_row", config, base, "table_row")
req = {"field": str}
opt = {"from_unit": str, "to_unit": str, "subfield": str}
kwargs, safe = GetAllParams(config, base, req=req, opt=opt)
field = kwargs["field"]
from_unit = kwargs.get("from_unit", None)
to_unit = kwargs.get("to_unit", None)
subfield = kwargs.get("subfield", None)
val = row.get(field, value_type, from_unit, to_unit, subfield)
return val, safe


RegisterInputType(
input_type="table_row", loader=InputLoader(TableRow, file_scope=True)
)
RegisterValueType(
type_name="RowData",
gen_func=RowData,
valid_types=[float, int, bool, str, galsim.Angle, list],
input_type="table_row",
)
234 changes: 234 additions & 0 deletions tests/test_table_row.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
from tempfile import TemporaryDirectory

import astropy.units as u
import galsim
import imsim
import numpy as np
from astropy.table import QTable, Table
from imsim.table_row import RowData


def create_table():
table = QTable()
# For testing single and multi indexing
table["idx"] = [0, 1, 2, 3]
table["idx0"] = [0, 0, 1, 1]
table["idx1"] = [0, 1, 0, 1]
table["unitless"] = np.array(table["idx"], dtype=float)
table["angle1"] = table["idx"] * u.arcsec
table["angle2"] = table["angle1"].to(u.deg)
# structured array column
tilt_dtype = np.dtype([("rx", "<f8"), ("ry", "<f8")])
table["tilt"] = np.array([(0, 1), (1, 2), (2, 3), (3, 4)], dtype=tilt_dtype)
table["tilt"].unit = u.arcsec
# Add columns with other types
table["int"] = np.array(table["idx"], dtype=int)
table["bool"] = np.array([True, False, True, False], dtype=bool)
table["str"] = ["a", "b", "c", "d"]
# And some other units
table["length"] = table["idx"] * u.m
# Structured array with length units
shift_dtype = np.dtype([("dx", "<f8"), ("dy", "<f8"), ("dz", "<f8")])
table["shift"] = np.array(
[(0, 1, 2), (1, 2, 3), (2, 3, 4), (3, 4, 5)], dtype=shift_dtype
)
table["shift"].unit = u.m

table.pprint_all()
return table


def check_row_data(config, idx):
# Read unitful column as Angle
a1, safe1 = RowData({"field": "angle1"}, config, galsim.Angle)
a2, safe2 = RowData({"field": "angle2"}, config, galsim.Angle)
assert safe1 == safe2 == True
np.testing.assert_allclose(a1.rad, a2.rad, rtol=0, atol=1e-15)
np.testing.assert_allclose(a1.rad, idx * u.arcsec.to(u.rad), rtol=0, atol=1e-15)

# Reading the unitful columns directly as floats is permitted. They'll
# come back unequal this time due to unit differences.
a1, safe1 = RowData({"field": "angle1"}, config, float)
a2, safe2 = RowData({"field": "angle2"}, config, float)
assert safe1 == safe2 == True
assert isinstance(a1, float)
assert isinstance(a2, float)
np.testing.assert_allclose(
(a1 * galsim.arcsec).rad,
(a2 * galsim.degrees).rad,
rtol=0,
atol=1e-15,
)

# We can specify a to_unit though to get them to both come back in,
# e.g., radians.
a1, safe1 = RowData({"field": "angle1", "to_unit": "rad"}, config, float)
a2, safe2 = RowData({"field": "angle2", "to_unit": "rad"}, config, float)
assert safe1 == safe2 == True
assert isinstance(a1, float)
assert isinstance(a2, float)
np.testing.assert_allclose(a1, a2, rtol=0, atol=1e-15)
np.testing.assert_allclose(a1, idx * u.arcsec.to(u.rad), rtol=0, atol=1e-15)

# Reading the unitless column as an Angle is permitted with from_unit.
a, safe = RowData(
{"field": "unitless", "from_unit": "arcsec"}, config, galsim.Angle
)
assert safe == True
np.testing.assert_allclose(a.rad, idx * u.arcsec.to(u.rad), rtol=0, atol=1e-15)

# Read the unitless column as initially arcsec then convert to rad
a, safe = RowData(
{"field": "unitless", "from_unit": "arcsec", "to_unit": "rad"},
config,
float,
)
assert safe == True
np.testing.assert_allclose(a, idx * u.arcsec.to(u.rad), rtol=0, atol=1e-15)

# Using from_unit with a unitful column will raise if the units don't
# match.
with np.testing.assert_raises(ValueError):
a, safe = RowData(
{"field": "angle1", "from_unit": "deg"}, config, galsim.Angle
)
a, safe = RowData(
{"field": "angle1", "from_unit": "arcsec"}, config, galsim.Angle
)

# It's an error to try to convert an Angle to a different unit.
with np.testing.assert_raises(ValueError):
a, safe = RowData(
{"field": "angle1", "to_unit": "deg"}, config, galsim.Angle
)

# Read a structured array subfield
rx, safe1 = RowData({"field": "tilt", "subfield": "rx"}, config, galsim.Angle)
ry, safe2 = RowData({"field": "tilt", "subfield": "ry"}, config, galsim.Angle)
assert safe1 == safe2 == True
np.testing.assert_allclose(rx.rad, idx * u.arcsec.to(u.rad), rtol=0, atol=1e-15)
np.testing.assert_allclose(
ry.rad, (idx + 1) * u.arcsec.to(u.rad), rtol=0, atol=1e-15
)

# Read the full structured array as a list
# The config layer is not set up to handle lists of Angles, though, so
# we have to interpret as floats directly.
tilt, safe = RowData({"field": "tilt"}, config, list)
assert safe == True
assert len(tilt) == 2
np.testing.assert_allclose(tilt[0], idx, rtol=0, atol=1e-15)
np.testing.assert_allclose(tilt[1], (idx + 1), rtol=0, atol=1e-15)

# Read the full structured array and convert to rad
tilt, safe = RowData({"field": "tilt", "to_unit": "rad"}, config, list)
assert safe == True
assert len(tilt) == 2
np.testing.assert_allclose(tilt[0], idx * u.arcsec.to(u.rad), rtol=0, atol=1e-15)
np.testing.assert_allclose(
tilt[1], (idx + 1) * u.arcsec.to(u.rad), rtol=0, atol=1e-15
)

# Read the int column
i, safe = RowData({"field": "int"}, config, int)
assert safe == True
assert i == idx

# Read the bool column
b, safe = RowData({"field": "bool"}, config, bool)
assert safe == True
assert b == (idx % 2 == 0)

# Can't specify to_unit for bool
with np.testing.assert_raises(ValueError):
b, safe = RowData({"field": "bool", "to_unit": "rad"}, config, bool)

# Read the str column
s, safe = RowData({"field": "str"}, config, str)
assert safe == True
assert s == chr(ord("a") + idx)

# Can't specify to_unit for str
with np.testing.assert_raises(ValueError):
s, safe = RowData({"field": "str", "to_unit": "rad"}, config, str)

# Read the length column
l, safe = RowData({"field": "length"}, config, float)
assert safe == True
assert l == idx

# Read the length column and convert its units
l, safe = RowData({"field": "length", "to_unit": "cm"}, config, float)
assert safe == True
assert l == idx * 100

# Read the structured array subfield with units
dx, safe1 = RowData({"field": "shift", "subfield": "dx"}, config, float)
dy, safe2 = RowData({"field": "shift", "subfield": "dy"}, config, float)
dz, safe3 = RowData({"field": "shift", "subfield": "dz"}, config, float)
assert safe1 == safe2 == safe3 == True
np.testing.assert_allclose(dx, idx, rtol=0, atol=1e-15)
np.testing.assert_allclose(dy, (idx + 1), rtol=0, atol=1e-15)
np.testing.assert_allclose(dz, (idx + 2), rtol=0, atol=1e-15)

# Read the full structured array with units
shift, safe = RowData({"field": "shift"}, config, list)
assert safe == True
assert len(shift) == 3
np.testing.assert_allclose(shift[0], idx, rtol=0, atol=1e-15)
np.testing.assert_allclose(shift[1], (idx + 1), rtol=0, atol=1e-15)
np.testing.assert_allclose(shift[2], (idx + 2), rtol=0, atol=1e-15)

# Read the full structured array with units and convert to cm
shift, safe = RowData({"field": "shift", "to_unit": "cm"}, config, list)
assert safe == True
assert len(shift) == 3
np.testing.assert_allclose(shift[0], idx * 100, rtol=0, atol=1e-15)
np.testing.assert_allclose(shift[1], (idx + 1) * 100, rtol=0, atol=1e-15)
np.testing.assert_allclose(shift[2], (idx + 2) * 100, rtol=0, atol=1e-15)


def test_table_row():
qtable = create_table()
regular_table = Table(qtable) # Regular Table (not QTable)
assert not isinstance(regular_table, QTable)

# Works for both QTable and Table
for table in [regular_table, qtable]:
with TemporaryDirectory() as tmpdir:
for ext in [".fits", ".ecsv", ".parq"]:
file_name = tmpdir + "/table_row" + ext
table.write(file_name, overwrite=True)

config = {
"input": {
"table_row": {
"file_name": file_name,
"keys": ["idx"],
"values": [0],
},
},
}

# Check single indexing during table load
for idx in range(4):
config["input"]["table_row"]["values"] = [idx]
galsim.config.RemoveCurrent(config["input"]["table_row"])
galsim.config.ProcessInput(config)
check_row_data(config, idx)

# Check multi indexing during table load
for idx0 in range(2):
for idx1 in range(2):
config["input"]["table_row"]["keys"] = ["idx0", "idx1"]
config["input"]["table_row"]["values"] = [idx0, idx1]
galsim.config.RemoveCurrent(config["input"]["table_row"])
galsim.config.ProcessInput(config)
idx = idx0 * 2 + idx1
check_row_data(config, idx)


if __name__ == "__main__":
testfns = [v for k, v in vars().items() if k[:5] == 'test_' and callable(v)]
for testfn in testfns:
testfn()

0 comments on commit 948fec3

Please sign in to comment.