Skip to content

Commit

Permalink
updated plotter for plotting resolution ellipses
Browse files Browse the repository at this point in the history
  • Loading branch information
Bing Li committed Nov 7, 2024
1 parent 453c29a commit 89b732c
Show file tree
Hide file tree
Showing 12 changed files with 78 additions and 91 deletions.
2 changes: 1 addition & 1 deletion src/tavi/instrument/resolution/cooper_nathans.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def cooper_nathans(
else:
rez.STATUS = True

rez.set_labels()
rez._set_labels()

rez_list.append(rez)

Expand Down
39 changes: 0 additions & 39 deletions src/tavi/instrument/resolution/ellipse.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
import matplotlib.pyplot as plt
import numpy as np
import numpy.linalg as la
from mpl_toolkits.axisartist import Subplot

from tavi.utilities import sig2fwhm

np.set_printoptions(floatmode="fixed", precision=4)


class ResoEllipse(object):
"""2D ellipses
Expand Down Expand Up @@ -49,38 +45,3 @@ def get_points(self, num_points=128):
pts[0] += self.centers[0]
pts[1] += self.centers[1]
return pts

def generate_plot(self, ax, c="black", linestyle="solid"):
"""Gnerate the ellipse for plotting"""

pts = self.get_points()

if self.grid_helper is None:

s = ax.plot(
pts[0],
pts[1],
c=c,
linestyle=linestyle,
)
else: # askew axes
s = ax.plot(
*self._tr(pts[0], pts[1]),
c=c,
linestyle=linestyle,
)

ax.set_xlabel(self.axes_labels[0])
ax.set_ylabel(self.axes_labels[1])
ax.grid(alpha=0.6)

return None

def plot(self):
"""Plot the ellipses."""

fig = plt.figure()
ax = Subplot(fig, 1, 1, 1, grid_helper=self.grid_helper)
fig.add_subplot(ax)
self.generate_plot(ax)
fig.show()
76 changes: 31 additions & 45 deletions src/tavi/instrument/resolution/ellipsoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@

from tavi.instrument.resolution.curve import ResoCurve
from tavi.instrument.resolution.ellipse import ResoEllipse
from tavi.plotter import Plot2D
from tavi.sample.xtal import Xtal
from tavi.utilities import get_angle_vec, sig2fwhm

np.set_printoptions(floatmode="fixed", precision=4)


class ResoEllipsoid(object):
"""Manage the 4D resolution ellipoid
Expand Down Expand Up @@ -51,7 +50,7 @@ def __init__(

self.projection: Optional[tuple] = projection
self.angles: tuple[float, float, float] = (90.0, 90.0, 90.0)
self.axes_labels = None
self.axes_labels: tuple[str]

self.mat: np.ndarray
self.r0: Optional[float] = None
Expand Down Expand Up @@ -87,7 +86,7 @@ def __init__(
self.q = hkl_prime
self.angles = (get_angle_vec(v1, v2), get_angle_vec(v2, v3), get_angle_vec(v3, v1))

self.set_labels()
self._set_labels()

def _project_to_frame(self, mat_reso, phi, conv_mat):
"""determinate the frame from the projection vectors"""
Expand Down Expand Up @@ -190,7 +189,12 @@ def quadric_proj(quadric, idx):
# print("\nProjected row/column %d:\n%s\n->\n%s.\n" % (idx, str(quadric), str(ortho_proj)))
return np.delete(np.delete(ortho_proj, idx, axis=0), idx, axis=1)

def get_ellipse(self, axes=(0, 1), PROJECTION=False, ORIGIN=True) -> ResoEllipse:
def get_ellipse(
self,
axes: tuple[int, int] = (0, 1),
PROJECTION: bool = False,
ORIGIN: bool = True,
) -> ResoEllipse:
"""Gnerate a 2D ellipse by either making a cut or projection
Arguments:
Expand All @@ -199,7 +203,7 @@ def get_ellipse(self, axes=(0, 1), PROJECTION=False, ORIGIN=True) -> ResoEllipse
ORIGIN: shift the center if True
"""

x_axis, y_axis = axes
qe_list = np.concatenate((self.q, self.en), axis=None)
# axes = np.sort(axes)
# match tuple(np.sort(axes)):
Expand All @@ -212,10 +216,10 @@ def get_ellipse(self, axes=(0, 1), PROJECTION=False, ORIGIN=True) -> ResoEllipse
angle = np.round(self.angles[2], 2)
case _:
angle = 90.0
axes_labels = (self.axes_labels[axes[0]], self.axes_labels[axes[1]])
axes_labels = (self.axes_labels[x_axis], self.axes_labels[y_axis])

if ORIGIN:
centers = (qe_list[axes[0]], qe_list[axes[1]])
centers = (qe_list[x_axis], qe_list[y_axis])
else:
centers = (0.0, 0.0)

Expand All @@ -235,7 +239,7 @@ def get_ellipse(self, axes=(0, 1), PROJECTION=False, ORIGIN=True) -> ResoEllipse

return ResoEllipse(mat, centers, angle, axes_labels)

def set_labels(self):
def _set_labels(self):
"""Set axes labels based on the frame"""
match self.frame:
case "q":
Expand All @@ -250,45 +254,27 @@ def set_labels(self):
"E (meV)",
)

def plot(self):
def plot_ellipses(self):
"""Plot all 2D ellipses"""

# fig = plt.figure()
fig = plt.figure(figsize=(10, 6))
elps_qx_en = self.get_ellipse(axes=(0, 3), PROJECTION=False)
ax = fig.add_subplot(231, axes_class=Axes, grid_helper=elps_qx_en.grid_helper)
elps_qx_en.generate_plot(ax, c="black", linestyle="solid")
elps_proj_qx_en = self.get_ellipse(axes=(0, 3), PROJECTION=True)
elps_proj_qx_en.generate_plot(ax, c="black", linestyle="dashed")

elps_qy_en = self.get_ellipse(axes=(1, 3), PROJECTION=False)
ax = fig.add_subplot(232, axes_class=Axes, grid_helper=elps_qy_en.grid_helper)
elps_qy_en.generate_plot(ax, c="black", linestyle="solid")
elps_proj_qy_en = self.get_ellipse(axes=(1, 3), PROJECTION=True)
elps_proj_qy_en.generate_plot(ax, c="black", linestyle="dashed")

elps_qz_en = self.get_ellipse(axes=(2, 3), PROJECTION=False)
ax = fig.add_subplot(233, axes_class=Axes, grid_helper=elps_qz_en.grid_helper)
elps_qz_en.generate_plot(ax, c="black", linestyle="solid")
elps_proj_qz_en = self.get_ellipse(axes=(2, 3), PROJECTION=True)
elps_proj_qz_en.generate_plot(ax, c="black", linestyle="dashed")

elps_qx_qy = self.get_ellipse(axes=(0, 1), PROJECTION=False)
ax = fig.add_subplot(234, axes_class=Axes, grid_helper=elps_qx_qy.grid_helper)
elps_qx_qy.generate_plot(ax, c="black", linestyle="solid")
elps_proj_qx_qy = self.get_ellipse(axes=(0, 1), PROJECTION=True)
elps_proj_qx_qy.generate_plot(ax, c="black", linestyle="dashed")

elps_qy_qz = self.get_ellipse(axes=(1, 2), PROJECTION=False)
ax = fig.add_subplot(235, axes_class=Axes, grid_helper=elps_qy_qz.grid_helper)
elps_qy_qz.generate_plot(ax, c="black", linestyle="solid")
elps_proj_qy_qz = self.get_ellipse(axes=(1, 2), PROJECTION=True)
elps_proj_qy_qz.generate_plot(ax, c="black", linestyle="dashed")

elps_qx_qz = self.get_ellipse(axes=(0, 2), PROJECTION=False)
ax = fig.add_subplot(236, axes_class=Axes, grid_helper=elps_qx_qz.grid_helper)
elps_qx_qz.generate_plot(ax, c="black", linestyle="solid")
elps_proj_qx_qz = self.get_ellipse(axes=(0, 2), PROJECTION=True)
elps_proj_qx_qz.generate_plot(ax, c="black", linestyle="dashed")

for i, indices in enumerate([(0, 3), (1, 3), (2, 3), (0, 1), (1, 2), (0, 2)]):

ellipse_co = self.get_ellipse(axes=indices, PROJECTION=False)
ellipse_inco = self.get_ellipse(axes=indices, PROJECTION=True)

p = Plot2D()
if indices == (2, 3):
p.add_reso(ellipse_co, c="k", linestyle="solid", label="Coherent")
p.add_reso(ellipse_inco, c="k", linestyle="dashed", label="Incoherent")

else:
p.add_reso(ellipse_co, c="k", linestyle="solid")
p.add_reso(ellipse_inco, c="k", linestyle="dashed")

ax = fig.add_subplot(int(f"23{i+1}"), axes_class=Axes, grid_helper=p.grid_helper)
p.plot(ax)

fig.tight_layout(pad=2)
2 changes: 1 addition & 1 deletion src/tavi/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def plot(self, ax):
ax.grid(alpha=0.6)
for data in self.contour_data + self.reso_data + self.curve_data:
if "label" in data.fmt.keys():
ax.legend()
ax.legend(loc=1)
break

if not self.contour_data:
Expand Down
Binary file modified test_data/scan_to_nexus_test.h5
Binary file not shown.
Binary file modified test_data/spice_to_nxdict_test_all.h5
Binary file not shown.
Binary file modified test_data/spice_to_nxdict_test_empty.h5
Binary file not shown.
Binary file modified test_data/spice_to_nxdict_test_scan0034.h5
Binary file not shown.
Binary file modified test_data/tavi_test_exp424.h5
Binary file not shown.
16 changes: 14 additions & 2 deletions tests/test_ellipse.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,24 @@ def test_local_q(tas_params):
def test_hkl(tas_params):
tas, ei, ef, hkl, _, R0 = tas_params
rez = tas.cooper_nathans(hkl_list=hkl, ei=ei, ef=ef, R0=R0)

e01_co = rez.get_ellipse(axes=(0, 1), PROJECTION=False)

assert np.allclose(e01_co.angle, 60)
assert e01_co.xlabel == "H (r.l.u.)"
assert e01_co.ylabel == "K (r.l.u.)"


def test_plotting(tas_params):
tas, ei, ef, hkl, _, R0 = tas_params
rez = tas.cooper_nathans(hkl_list=hkl, ei=ei, ef=ef, R0=R0)

e01_co = rez.get_ellipse(axes=(0, 1), PROJECTION=False)
e01_inco = rez.get_ellipse(axes=(0, 1), PROJECTION=True)

e03_co = rez.get_ellipse(axes=(0, 3), PROJECTION=False)
e03_inco = rez.get_ellipse(axes=(0, 3), PROJECTION=True)

assert np.allclose(e01_co.angle, 60)

p1 = Plot2D()
p1.add_reso(e01_co, c="k", linestyle="solid")
p1.add_reso(e01_inco, c="k", linestyle="dashed")
Expand All @@ -45,6 +56,7 @@ def test_hkl(tas_params):
ax2 = fig.add_subplot(122, axes_class=Axes)
p2.plot(ax2)

fig.tight_layout(pad=2)
plt.show()


Expand Down
5 changes: 2 additions & 3 deletions tests/test_ellipsoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,10 @@ def test_projection(tas_params):
assert rez.axes_labels == ("(1, 1, 0)", "(0, 0, 1)", "(1, -1, 0)", "E (meV)")


def test_plot(tas_params):
def test_plotting(tas_params):
tas, ei, ef, hkl, _, R0 = tas_params
rez = tas.cooper_nathans(hkl_list=hkl, ei=ei, ef=ef, R0=R0)
rez.plot()

rez.plot_ellipses()
plt.show()


Expand Down
29 changes: 29 additions & 0 deletions tests/test_plotter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# -*- coding: utf-8 -*

import matplotlib.pyplot as plt
from mpl_toolkits.axisartist import Axes

from tavi.data.tavi import TAVI
from tavi.plotter import Plot2D


# TODO overplot resolution
def test_plot2d():
tavi = TAVI("./test_data/tavi_exp424.h5")
scan_list = list(range(42, 49, 1)) + list(range(70, 76, 1))

sg = tavi.combine_scans(scan_list, name="dispH")
scan_data_2d = sg.get_data(
axes=("qh", "en", "detector"),
norm_to=(1, "mcu"),
grid=(0.025, (-0.5, 4.5, 0.1)),
)

p = Plot2D()
p.add_contour(scan_data_2d, cmap="turbo", vmax=1)

fig = plt.figure()
ax = fig.add_subplot(111, axes_class=Axes, grid_helper=p.grid_helper)

p.plot(ax)
plt.show()

0 comments on commit 89b732c

Please sign in to comment.