-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b6a633f
commit 948fec3
Showing
3 changed files
with
382 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,3 +36,4 @@ | |
from .vignetting import * | ||
from .sag import * | ||
from .process_info import * | ||
from .table_row import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
import astropy.units as u | ||
import galsim | ||
from astropy.table import QTable | ||
from galsim.config import ( | ||
GetAllParams, | ||
GetInputObj, | ||
InputLoader, | ||
RegisterInputType, | ||
RegisterValueType, | ||
) | ||
|
||
|
||
class TableRow: | ||
"""Class to extract one row from an astropy QTable and make it available to the | ||
galsim config layer. | ||
Parameters | ||
---------- | ||
file_name: str | ||
keys: list | ||
Column names to use as keys. | ||
values: list | ||
Values to match in the key columns. | ||
""" | ||
|
||
_req_params = { | ||
"file_name": str, | ||
"keys": list, | ||
"values": list, | ||
} | ||
|
||
def __init__(self, file_name, keys, values): | ||
self.file_name = file_name | ||
self.keys = keys | ||
self.values = values | ||
self.data = QTable.read(file_name) | ||
|
||
for key, value in zip(keys, values): | ||
self.data = self.data[self.data[key] == value] | ||
|
||
if len(self.data) == 0: | ||
raise KeyError("No rows found with keys = %s, values = %s" % (keys, values)) | ||
if len(self.data) > 1: | ||
raise KeyError( | ||
"Multiple rows found with keys = %s, values = %s" % (keys, values) | ||
) | ||
|
||
def get(self, field, value_type, from_unit=None, to_unit=None, subfield=None): | ||
"""Get a value from the table row. | ||
Parameters | ||
---------- | ||
field: str | ||
The name of the column to extract. | ||
value_type: [float, int, bool, str, galsim.Angle, list] | ||
The type to convert the value to. | ||
from_unit: str, optional | ||
The units of the value in the table. If the table column already has a | ||
unit, then this must match that unit or be omitted. | ||
to_unit: str, optional | ||
The units to convert the value to. Only allowed if value_type is one of | ||
float, int, or list. Use of this parameter requires that the original units | ||
of the column are inferrable from the table itself or from from_unit. | ||
subfield: str, optional | ||
The name of a subfield to extract from a structured array column. If | ||
ommitted, and field refers to a structured array, the entire array is still | ||
readable as a list value type. | ||
Returns | ||
------- | ||
value: value_type | ||
The value from the table. | ||
""" | ||
data = self.data[field] | ||
if subfield is not None: | ||
data = data[subfield] | ||
|
||
# See if data already has a unit, if not, add it. | ||
if data.unit is None: | ||
if from_unit is not None: | ||
data = data * getattr(u, from_unit) | ||
else: | ||
if from_unit is not None: | ||
if data.unit != getattr(u, from_unit): | ||
raise ValueError( | ||
f"from_unit = {from_unit} specified, but field {field} already " | ||
f"has units of {data.unit}." | ||
) | ||
|
||
# Angles are special | ||
if value_type == galsim.Angle: | ||
if to_unit is not None: | ||
raise ValueError("to_unit is not allowed for Angle types.") | ||
|
||
return float(data.to_value(u.rad)[0]) * galsim.radians | ||
|
||
# For non-angles, we leverage astropy units. | ||
if to_unit is not None: | ||
to_unit = getattr(u, to_unit) | ||
if value_type == list: | ||
if to_unit is None: | ||
out = data.value[0].tolist() | ||
else: | ||
out = data.to_value(to_unit)[0].tolist() | ||
# If we had a structured array, `out`` is still a tuple here, so | ||
# use an extra list() here to finish the cast. | ||
return list(out) | ||
|
||
# We have to be careful with strings, as using .value on the table datum will | ||
# convert to bytes, which is not what we want. | ||
if value_type == str: | ||
if to_unit is not None: | ||
raise ValueError("to_unit is not allowed for str types.") | ||
return str(data[0]) | ||
|
||
# No units allowed for bool | ||
if value_type == bool and to_unit is not None: | ||
raise ValueError("to_unit is not allowed for bool types.") | ||
|
||
if to_unit is None: | ||
return value_type(data.value[0]) | ||
else: | ||
return value_type(data.to_value(to_unit)[0]) | ||
|
||
|
||
def RowData(config, base, value_type): | ||
row = GetInputObj("table_row", config, base, "table_row") | ||
req = {"field": str} | ||
opt = {"from_unit": str, "to_unit": str, "subfield": str} | ||
kwargs, safe = GetAllParams(config, base, req=req, opt=opt) | ||
field = kwargs["field"] | ||
from_unit = kwargs.get("from_unit", None) | ||
to_unit = kwargs.get("to_unit", None) | ||
subfield = kwargs.get("subfield", None) | ||
val = row.get(field, value_type, from_unit, to_unit, subfield) | ||
return val, safe | ||
|
||
|
||
RegisterInputType( | ||
input_type="table_row", loader=InputLoader(TableRow, file_scope=True) | ||
) | ||
RegisterValueType( | ||
type_name="RowData", | ||
gen_func=RowData, | ||
valid_types=[float, int, bool, str, galsim.Angle, list], | ||
input_type="table_row", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,234 @@ | ||
from tempfile import TemporaryDirectory | ||
|
||
import astropy.units as u | ||
import galsim | ||
import imsim | ||
import numpy as np | ||
from astropy.table import QTable, Table | ||
from imsim.table_row import RowData | ||
|
||
|
||
def create_table(): | ||
table = QTable() | ||
# For testing single and multi indexing | ||
table["idx"] = [0, 1, 2, 3] | ||
table["idx0"] = [0, 0, 1, 1] | ||
table["idx1"] = [0, 1, 0, 1] | ||
table["unitless"] = np.array(table["idx"], dtype=float) | ||
table["angle1"] = table["idx"] * u.arcsec | ||
table["angle2"] = table["angle1"].to(u.deg) | ||
# structured array column | ||
tilt_dtype = np.dtype([("rx", "<f8"), ("ry", "<f8")]) | ||
table["tilt"] = np.array([(0, 1), (1, 2), (2, 3), (3, 4)], dtype=tilt_dtype) | ||
table["tilt"].unit = u.arcsec | ||
# Add columns with other types | ||
table["int"] = np.array(table["idx"], dtype=int) | ||
table["bool"] = np.array([True, False, True, False], dtype=bool) | ||
table["str"] = ["a", "b", "c", "d"] | ||
# And some other units | ||
table["length"] = table["idx"] * u.m | ||
# Structured array with length units | ||
shift_dtype = np.dtype([("dx", "<f8"), ("dy", "<f8"), ("dz", "<f8")]) | ||
table["shift"] = np.array( | ||
[(0, 1, 2), (1, 2, 3), (2, 3, 4), (3, 4, 5)], dtype=shift_dtype | ||
) | ||
table["shift"].unit = u.m | ||
|
||
table.pprint_all() | ||
return table | ||
|
||
|
||
def check_row_data(config, idx): | ||
# Read unitful column as Angle | ||
a1, safe1 = RowData({"field": "angle1"}, config, galsim.Angle) | ||
a2, safe2 = RowData({"field": "angle2"}, config, galsim.Angle) | ||
assert safe1 == safe2 == True | ||
np.testing.assert_allclose(a1.rad, a2.rad, rtol=0, atol=1e-15) | ||
np.testing.assert_allclose(a1.rad, idx * u.arcsec.to(u.rad), rtol=0, atol=1e-15) | ||
|
||
# Reading the unitful columns directly as floats is permitted. They'll | ||
# come back unequal this time due to unit differences. | ||
a1, safe1 = RowData({"field": "angle1"}, config, float) | ||
a2, safe2 = RowData({"field": "angle2"}, config, float) | ||
assert safe1 == safe2 == True | ||
assert isinstance(a1, float) | ||
assert isinstance(a2, float) | ||
np.testing.assert_allclose( | ||
(a1 * galsim.arcsec).rad, | ||
(a2 * galsim.degrees).rad, | ||
rtol=0, | ||
atol=1e-15, | ||
) | ||
|
||
# We can specify a to_unit though to get them to both come back in, | ||
# e.g., radians. | ||
a1, safe1 = RowData({"field": "angle1", "to_unit": "rad"}, config, float) | ||
a2, safe2 = RowData({"field": "angle2", "to_unit": "rad"}, config, float) | ||
assert safe1 == safe2 == True | ||
assert isinstance(a1, float) | ||
assert isinstance(a2, float) | ||
np.testing.assert_allclose(a1, a2, rtol=0, atol=1e-15) | ||
np.testing.assert_allclose(a1, idx * u.arcsec.to(u.rad), rtol=0, atol=1e-15) | ||
|
||
# Reading the unitless column as an Angle is permitted with from_unit. | ||
a, safe = RowData( | ||
{"field": "unitless", "from_unit": "arcsec"}, config, galsim.Angle | ||
) | ||
assert safe == True | ||
np.testing.assert_allclose(a.rad, idx * u.arcsec.to(u.rad), rtol=0, atol=1e-15) | ||
|
||
# Read the unitless column as initially arcsec then convert to rad | ||
a, safe = RowData( | ||
{"field": "unitless", "from_unit": "arcsec", "to_unit": "rad"}, | ||
config, | ||
float, | ||
) | ||
assert safe == True | ||
np.testing.assert_allclose(a, idx * u.arcsec.to(u.rad), rtol=0, atol=1e-15) | ||
|
||
# Using from_unit with a unitful column will raise if the units don't | ||
# match. | ||
with np.testing.assert_raises(ValueError): | ||
a, safe = RowData( | ||
{"field": "angle1", "from_unit": "deg"}, config, galsim.Angle | ||
) | ||
a, safe = RowData( | ||
{"field": "angle1", "from_unit": "arcsec"}, config, galsim.Angle | ||
) | ||
|
||
# It's an error to try to convert an Angle to a different unit. | ||
with np.testing.assert_raises(ValueError): | ||
a, safe = RowData( | ||
{"field": "angle1", "to_unit": "deg"}, config, galsim.Angle | ||
) | ||
|
||
# Read a structured array subfield | ||
rx, safe1 = RowData({"field": "tilt", "subfield": "rx"}, config, galsim.Angle) | ||
ry, safe2 = RowData({"field": "tilt", "subfield": "ry"}, config, galsim.Angle) | ||
assert safe1 == safe2 == True | ||
np.testing.assert_allclose(rx.rad, idx * u.arcsec.to(u.rad), rtol=0, atol=1e-15) | ||
np.testing.assert_allclose( | ||
ry.rad, (idx + 1) * u.arcsec.to(u.rad), rtol=0, atol=1e-15 | ||
) | ||
|
||
# Read the full structured array as a list | ||
# The config layer is not set up to handle lists of Angles, though, so | ||
# we have to interpret as floats directly. | ||
tilt, safe = RowData({"field": "tilt"}, config, list) | ||
assert safe == True | ||
assert len(tilt) == 2 | ||
np.testing.assert_allclose(tilt[0], idx, rtol=0, atol=1e-15) | ||
np.testing.assert_allclose(tilt[1], (idx + 1), rtol=0, atol=1e-15) | ||
|
||
# Read the full structured array and convert to rad | ||
tilt, safe = RowData({"field": "tilt", "to_unit": "rad"}, config, list) | ||
assert safe == True | ||
assert len(tilt) == 2 | ||
np.testing.assert_allclose(tilt[0], idx * u.arcsec.to(u.rad), rtol=0, atol=1e-15) | ||
np.testing.assert_allclose( | ||
tilt[1], (idx + 1) * u.arcsec.to(u.rad), rtol=0, atol=1e-15 | ||
) | ||
|
||
# Read the int column | ||
i, safe = RowData({"field": "int"}, config, int) | ||
assert safe == True | ||
assert i == idx | ||
|
||
# Read the bool column | ||
b, safe = RowData({"field": "bool"}, config, bool) | ||
assert safe == True | ||
assert b == (idx % 2 == 0) | ||
|
||
# Can't specify to_unit for bool | ||
with np.testing.assert_raises(ValueError): | ||
b, safe = RowData({"field": "bool", "to_unit": "rad"}, config, bool) | ||
|
||
# Read the str column | ||
s, safe = RowData({"field": "str"}, config, str) | ||
assert safe == True | ||
assert s == chr(ord("a") + idx) | ||
|
||
# Can't specify to_unit for str | ||
with np.testing.assert_raises(ValueError): | ||
s, safe = RowData({"field": "str", "to_unit": "rad"}, config, str) | ||
|
||
# Read the length column | ||
l, safe = RowData({"field": "length"}, config, float) | ||
assert safe == True | ||
assert l == idx | ||
|
||
# Read the length column and convert its units | ||
l, safe = RowData({"field": "length", "to_unit": "cm"}, config, float) | ||
assert safe == True | ||
assert l == idx * 100 | ||
|
||
# Read the structured array subfield with units | ||
dx, safe1 = RowData({"field": "shift", "subfield": "dx"}, config, float) | ||
dy, safe2 = RowData({"field": "shift", "subfield": "dy"}, config, float) | ||
dz, safe3 = RowData({"field": "shift", "subfield": "dz"}, config, float) | ||
assert safe1 == safe2 == safe3 == True | ||
np.testing.assert_allclose(dx, idx, rtol=0, atol=1e-15) | ||
np.testing.assert_allclose(dy, (idx + 1), rtol=0, atol=1e-15) | ||
np.testing.assert_allclose(dz, (idx + 2), rtol=0, atol=1e-15) | ||
|
||
# Read the full structured array with units | ||
shift, safe = RowData({"field": "shift"}, config, list) | ||
assert safe == True | ||
assert len(shift) == 3 | ||
np.testing.assert_allclose(shift[0], idx, rtol=0, atol=1e-15) | ||
np.testing.assert_allclose(shift[1], (idx + 1), rtol=0, atol=1e-15) | ||
np.testing.assert_allclose(shift[2], (idx + 2), rtol=0, atol=1e-15) | ||
|
||
# Read the full structured array with units and convert to cm | ||
shift, safe = RowData({"field": "shift", "to_unit": "cm"}, config, list) | ||
assert safe == True | ||
assert len(shift) == 3 | ||
np.testing.assert_allclose(shift[0], idx * 100, rtol=0, atol=1e-15) | ||
np.testing.assert_allclose(shift[1], (idx + 1) * 100, rtol=0, atol=1e-15) | ||
np.testing.assert_allclose(shift[2], (idx + 2) * 100, rtol=0, atol=1e-15) | ||
|
||
|
||
def test_table_row(): | ||
qtable = create_table() | ||
regular_table = Table(qtable) # Regular Table (not QTable) | ||
assert not isinstance(regular_table, QTable) | ||
|
||
# Works for both QTable and Table | ||
for table in [regular_table, qtable]: | ||
with TemporaryDirectory() as tmpdir: | ||
for ext in [".fits", ".ecsv", ".parq"]: | ||
file_name = tmpdir + "/table_row" + ext | ||
table.write(file_name, overwrite=True) | ||
|
||
config = { | ||
"input": { | ||
"table_row": { | ||
"file_name": file_name, | ||
"keys": ["idx"], | ||
"values": [0], | ||
}, | ||
}, | ||
} | ||
|
||
# Check single indexing during table load | ||
for idx in range(4): | ||
config["input"]["table_row"]["values"] = [idx] | ||
galsim.config.RemoveCurrent(config["input"]["table_row"]) | ||
galsim.config.ProcessInput(config) | ||
check_row_data(config, idx) | ||
|
||
# Check multi indexing during table load | ||
for idx0 in range(2): | ||
for idx1 in range(2): | ||
config["input"]["table_row"]["keys"] = ["idx0", "idx1"] | ||
config["input"]["table_row"]["values"] = [idx0, idx1] | ||
galsim.config.RemoveCurrent(config["input"]["table_row"]) | ||
galsim.config.ProcessInput(config) | ||
idx = idx0 * 2 + idx1 | ||
check_row_data(config, idx) | ||
|
||
|
||
if __name__ == "__main__": | ||
testfns = [v for k, v in vars().items() if k[:5] == 'test_' and callable(v)] | ||
for testfn in testfns: | ||
testfn() |