Skip to content

Commit

Permalink
Merge pull request #6380 from nilsvu/transform_vol_stride
Browse files Browse the repository at this point in the history
TransformVol.py: add start, stop, stride
  • Loading branch information
knelli2 authored Nov 19, 2024
2 parents ad0e5e8 + d49b6d2 commit 5efe9e7
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/IO/H5/Python/IterElements.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def iter_elements(
"""
if isinstance(volfiles, spectre_h5.H5Vol):
volfiles = [volfiles]
if isinstance(obs_ids, int):
if isinstance(obs_ids, (int, np.integer)):
obs_ids = [obs_ids]
# Assuming the domain is the same in all volfiles at all observations to
# speed up the script
Expand Down
53 changes: 49 additions & 4 deletions src/Visualization/Python/TransformVolumeData.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,9 @@ def __call__(
def transform_volume_data(
volfiles: Union[spectre_h5.H5Vol, Iterable[spectre_h5.H5Vol]],
kernels: Sequence[Kernel],
start_time: Optional[float] = None,
stop_time: Optional[float] = None,
stride: int = 1,
integrate: bool = False,
force: bool = False,
) -> Union[None, Dict[str, Sequence[float]]]:
Expand All @@ -497,6 +500,11 @@ def transform_volume_data(
files will be transformed.
kernels: List of transformations to apply to the volume data in the form
of 'Kernel' objects.
start_time: The earliest time at which to start processing data. The
start-time value is included.
stop_time: The time at which to stop processing data. The stop-time value
is included.
stride: Process only every 'stride'th time step.
integrate: Compute the volume integral over the kernels instead of
writing them back into the volume files. The integral is computed in
inertial coordinates for every tensor component of all kernels and over
Expand Down Expand Up @@ -539,13 +547,28 @@ def transform_volume_data(
if isinstance(volfiles, spectre_h5.H5Vol):
volfiles = [volfiles]
for volfile in volfiles:
all_observation_ids = volfile.list_observation_ids()
num_obs = len(all_observation_ids)
if integrate and "Time" not in integrals:
integrals["Time"] = [
all_observation_ids = np.array(
volfile.list_observation_ids(),
dtype=np.uint64,
)
all_times = np.array(
[
volfile.get_observation_value(obs_id)
for obs_id in all_observation_ids
]
)
# Filter observations
observation_filter = np.ones_like(all_times, dtype=bool)
if start_time is not None:
observation_filter &= all_times >= start_time
if stop_time is not None:
observation_filter &= all_times <= stop_time
all_observation_ids = all_observation_ids[observation_filter][::stride]
all_times = all_times[observation_filter][::stride]
num_obs = len(all_observation_ids)

if integrate and "Time" not in integrals:
integrals["Time"] = all_times

for i_obs, obs_id in enumerate(all_observation_ids):
# Load tensor data for all kernels
Expand Down Expand Up @@ -755,6 +778,28 @@ def parse_kernels(kernels, exec_files, map_input_names, interactive=False):
" transformed to CamelCase."
),
)
@click.option(
"--start-time",
type=float,
help=(
"The earliest time at which to start processing data. The start-time "
"value is included."
),
)
@click.option(
"--stop-time",
type=float,
help=(
"The time at which to stop processing data. The stop-time value is "
"included."
),
)
@click.option(
"--stride",
default=1,
type=int,
help="Process only every stride'th time step",
)
@click.option(
"--integrate",
is_flag=True,
Expand Down
13 changes: 12 additions & 1 deletion tests/Unit/Visualization/Python/Test_TransformVolumeData.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,12 @@ def test_transform_volume_data(self):
with self.assertRaisesRegex(RuntimeError, "already exists"):
transform_volume_data(volfiles=open_volfiles, kernels=kernels)
transform_volume_data(
volfiles=open_volfiles, kernels=kernels, force=True
volfiles=open_volfiles,
kernels=kernels,
force=True,
start_time=0,
stop_time=1,
stride=2,
)

obs_id = open_volfiles[0].list_observation_ids()[0]
Expand Down Expand Up @@ -246,6 +251,12 @@ def test_cli(self):
"element_data.vol",
"-e",
__file__,
"--start-time",
"0",
"--stop-time",
"1",
"--stride",
"2",
]
result = runner.invoke(
transform_volume_data_command,
Expand Down

0 comments on commit 5efe9e7

Please sign in to comment.