From 2c0bbf781505001ad00ccf3b720521768d8e20a5 Mon Sep 17 00:00:00 2001 From: Bing Li Date: Thu, 7 Nov 2024 17:25:22 -0500 Subject: [PATCH] modified generation of hkle list in CN class --- .../instrument/resolution/cooper_nathans.py | 38 +++++++------------ tests/test_cooper_nathans.py | 4 +- tests/test_plotter.py | 36 ++++++++++++++++-- tests/test_scan_group.py | 6 ++- 4 files changed, 52 insertions(+), 32 deletions(-) diff --git a/src/tavi/instrument/resolution/cooper_nathans.py b/src/tavi/instrument/resolution/cooper_nathans.py index f6b6803..b1e0f86 100755 --- a/src/tavi/instrument/resolution/cooper_nathans.py +++ b/src/tavi/instrument/resolution/cooper_nathans.py @@ -130,32 +130,22 @@ def _generate_hkle_list( ef: Union[float, list[float]], ) -> list[tuple[tuple[float, float, float], float, float]]: """Generate a list containing tuple ((h, k, l), ei, ef)""" + hkle_list = [] + if not isinstance(ei, list): + ei = [ei] + if not isinstance(ef, list): + ef = [ef] + if not isinstance(hkl_list, list): + hkl_list = [hkl_list] if isinstance(hkl_list, list): - num = len(hkl_list) - if not isinstance(ei, list): - ei_list = [ei] * num - elif len(ei) == num: - ei_list = ei - else: - raise ValueError("length of ei and hkl_list do not match.") - - if not isinstance(ef, list): - ef_list = [ef] * num - elif len(ef) == num: - ef_list = ef - else: - raise ValueError("length of ef and hkl_list do not match.") - - hkle_list = list(zip(hkl_list, ei_list, ef_list)) - - elif isinstance(hkl_list, tuple) and len(hkl_list) == 3: - if (not isinstance(ei, list)) and (not isinstance(ef, list)): - hkle_list = [(hkl_list, ei, ef)] - else: - raise ValueError("length of ei and hkl_list do not match.") - else: - raise ValueError(f"hkl_list={hkl_list} should be either a list or a tuple.") + for one_ei in ei: + for one_ef in ef: + for hkl in hkl_list: + if isinstance(hkl, tuple) and len(hkl) == 3: + hkle_list.append((hkl, one_ei, one_ef)) + else: + raise ValueError(f"hkl={hkl} is not a tuple of length 3.") return hkle_list def validate_instrument_parameters(self): diff --git a/tests/test_cooper_nathans.py b/tests/test_cooper_nathans.py index 3fec281..5bcbf25 100755 --- a/tests/test_cooper_nathans.py +++ b/tests/test_cooper_nathans.py @@ -55,7 +55,7 @@ def test_copper_nathans_projection(tas_params): def test_copper_nathans_list(tas_params): tas, ei, ef, _, _, r0 = tas_params - rez_list = tas.cooper_nathans(hkl_list=[(0, 0, 3), (0, 0, -3)], ei=ei, ef=ef, projection=None, R0=r0) + rez_list = tas.cooper_nathans(hkl_list=[(0, 0, 3), (0, 0, -3)], ei=[ei, ei + 1], ef=ef, projection=None, R0=r0) mat = np.array( [ [9583.2881, -4671.0614, -0.0000, 986.5610], @@ -64,7 +64,7 @@ def test_copper_nathans_list(tas_params): [986.5610, -4129.1553, -0.0000, 864.3494], ] ) - assert len(rez_list) == 2 + assert len(rez_list) == 4 assert np.allclose(rez_list[0].mat, mat, atol=1e-1) assert np.allclose(rez_list[1].mat, mat, atol=1e-1) diff --git a/tests/test_plotter.py b/tests/test_plotter.py index 6face5d..291996e 100644 --- a/tests/test_plotter.py +++ b/tests/test_plotter.py @@ -1,14 +1,18 @@ # -*- coding: utf-8 -* import matplotlib.pyplot as plt +import numpy as np from mpl_toolkits.axisartist import Axes from tavi.data.tavi import TAVI +from tavi.instrument.resolution.cooper_nathans import CN from tavi.plotter import Plot2D +from tavi.sample.xtal import Xtal -# TODO overplot resolution def test_plot2d(): + + # load data tavi = TAVI("./test_data/tavi_exp424.h5") scan_list = list(range(42, 49, 1)) + list(range(70, 76, 1)) @@ -18,12 +22,36 @@ def test_plot2d(): norm_to=(1, "mcu"), grid=(0.025, (-0.5, 4.5, 0.1)), ) - + # load experimental parameters + instrument_config_json_path = "./src/tavi/instrument/instrument_params/cg4c.json" + tas = CN(SPICE_CONVENTION=False) + tas.load_instrument_params_from_json(instrument_config_json_path) + + sample_json_path = "./test_data/test_samples/nitio3.json" + sample = Xtal.from_json(sample_json_path) + tas.mount_sample(sample) + + # calculate resolution ellipses + R0 = False + hkl_list = [(qh, qh, 3) for qh in np.arange(-0.5, 0.1, 0.05)] + ef = 4.8 + ei_list = [e + ef for e in np.arange(0, 4.1, 0.4)] + projection = ((1, 1, 0), (0, 0, 1), (1, -1, 0)) + rez_list = tas.cooper_nathans(hkl_list=hkl_list, ei=ei_list, ef=ef, projection=projection, R0=R0) + + # genreate plot p = Plot2D() - p.add_contour(scan_data_2d, cmap="turbo", vmax=1) + im = p.add_contour(scan_data_2d, cmap="turbo", vmax=1) + + for rez in rez_list: + e_co = rez.get_ellipse(axes=(0, 3), PROJECTION=False) + e_inco = rez.get_ellipse(axes=(0, 3), PROJECTION=True) + p.add_reso(e_co, c="k", linestyle="solid") + p.add_reso(e_inco, c="k", linestyle="dashed") fig = plt.figure() ax = fig.add_subplot(111, axes_class=Axes, grid_helper=p.grid_helper) - p.plot(ax) + im = p.plot(ax) + fig.colorbar(im, ax=ax) plt.show() diff --git a/tests/test_scan_group.py b/tests/test_scan_group.py index f308510..7f03456 100644 --- a/tests/test_scan_group.py +++ b/tests/test_scan_group.py @@ -50,7 +50,8 @@ def test_scan_group_2d(): plot2d = Plot2D() plot2d.add_contour(scan_data_2d, cmap="turbo", vmax=80) fig, ax = plt.subplots() - plot2d.plot(ax) + im = plot2d.plot(ax) + fig.colorbar(im, ax=ax) plt.show() @@ -68,5 +69,6 @@ def test_scan_group_2d_rebin(): plot2d = Plot2D() plot2d.add_contour(scan_data_2d, cmap="turbo", vmax=1) fig, ax = plt.subplots() - plot2d.plot(ax) + im = plot2d.plot(ax) + fig.colorbar(im, ax=ax) plt.show()