Skip to content

Commit

Permalink
Add json loader, saver
Browse files Browse the repository at this point in the history
  • Loading branch information
nghia-vo committed Nov 13, 2024
1 parent 08ec8dc commit d62478e
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 58 deletions.
190 changes: 144 additions & 46 deletions discorpy/losa/loadersaver.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,14 @@
"""

import json
import platform
from pathlib import Path
import h5py
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib import font_manager
from collections import OrderedDict


Expand Down Expand Up @@ -422,8 +424,7 @@ def save_image(file_path, mat, overwrite=True):
str
Updated file path.
"""
file_path = __get_path(file_path, check_exist=False)
file_path = file_path.resolve()
file_path = __get_path(file_path, check_exist=False).resolve()
file_ext = file_path.suffix
if not ((file_ext == ".tif") or (file_ext == ".tiff")):
if mat.dtype != np.uint8:
Expand Down Expand Up @@ -494,8 +495,29 @@ def save_plot_image(file_path, list_lines, height, width, overwrite=True,
return file_path


def __check_font(font_family):
"""
Check if a specific font is available in Matplotlib.
Parameters
----------
font_family : str
Name of the font to check.
Returns
-------
bool
True if font is available, False otherwise.
"""
try:
font_manager.findfont(font_family, fallback_to_default=False)
return True
except:
return False


def save_residual_plot(file_path, list_data, height, width, overwrite=True,
dpi=100):
dpi=100, font_family='Times New Roman'):
"""
Save the plot of residual against radius to an image. Useful to check the
accuracy of unwarping results.
Expand All @@ -514,6 +536,8 @@ def save_residual_plot(file_path, list_data, height, width, overwrite=True,
Overwrite the existing file if True.
dpi : int, optional
The resolution in dots per inch.
font_family : str, optional
To set the font family
Returns
-------
Expand All @@ -528,7 +552,8 @@ def save_residual_plot(file_path, list_data, height, width, overwrite=True,
fig.set_size_inches(width / dpi, height / dpi)
m_size = 0.5 * min(height / dpi, width / dpi)
plt.rc('font', size=np.int16(m_size * 4))
plt.rcParams['font.family'] = 'Times New Roman'
if __check_font(font_family):
plt.rcParams['font.family'] = font_family
plt.rcParams['font.weight'] = 'bold'
plt.xlabel('Radius', fontweight='bold')
plt.ylabel('Residual', fontweight='bold')
Expand Down Expand Up @@ -562,8 +587,7 @@ def save_hdf_file(file_path, idata, key_path='entry', overwrite=True):
str
Updated file path.
"""
file_path = __get_path(file_path, check_exist=False)
file_path = file_path.resolve()
file_path = __get_path(file_path, check_exist=False).resolve()
if file_path.suffix.lower() not in {'.hdf', '.h5', '.nxs', '.hdf5'}:
file_path = file_path.with_suffix('.hdf')
_create_folder(str(file_path))
Expand Down Expand Up @@ -605,8 +629,7 @@ def open_hdf_stream(file_path, data_shape, key_path='entry/data',
object
hdf object.
"""
file_path = __get_path(file_path, check_exist=False)
file_path = file_path.resolve()
file_path = __get_path(file_path, check_exist=False).resolve()
if file_path.suffix.lower() not in {'.hdf', '.h5', '.nxs', '.hdf5'}:
file_path = file_path.with_suffix('.hdf')
_create_folder(str(file_path))
Expand All @@ -631,6 +654,60 @@ def open_hdf_stream(file_path, data_shape, key_path='entry/data',
return data_out


def save_plot_points(file_path, list_points, height, width, overwrite=True,
dpi=100, marker="o", color="blue"):
"""
Save the plot of dot-centroids to an image. Useful to check if the dots
are arranged properly where dots on the same line having the same color.
Parameters
----------
file_path : str
Output file path.
list_points : list of 1D-array
List of the (y-x)-coordinates of points.
height : int
Height of the image.
width : int
Width of the image.
overwrite : bool, optional
Overwrite the existing file if True.
dpi : int, optional
The resolution in dots per inch.
marker : str
Plot marker. Full list is at:
https://matplotlib.org/stable/api/markers_api.html
color : str
Marker color. Full list is at:
https://matplotlib.org/stable/tutorials/colors/colors.html
Returns
-------
str
Updated file path.
"""
file_path = __get_path(file_path, check_exist=False).resolve()
_create_folder(str(file_path))
if not overwrite:
file_path = _create_file_name(str(file_path))
fig = plt.figure(frameon=False)
fig.set_size_inches(width / dpi, height / dpi)
ax = plt.Axes(fig, [0., 0., 1.0, 1.0])
ax.set_axis_off()
fig.add_axes(ax)
plt.axis((0, width, 0, height))
m_size = 0.5 * min(height / dpi, width / dpi)
for point in list_points:
plt.plot(point[1], height - point[0], marker, color=color,
markersize=m_size)
try:
plt.savefig(file_path, dpi=dpi)
except IOError:
raise ValueError("Couldn't write to file {}".format(file_path))
plt.close()
return file_path


def save_metadata_txt(file_path, xcenter, ycenter, list_fact, overwrite=True):
"""
Write metadata to a text file.
Expand All @@ -653,8 +730,7 @@ def save_metadata_txt(file_path, xcenter, ycenter, list_fact, overwrite=True):
str
Updated file path.
"""
file_path = __get_path(file_path, check_exist=False)
file_path = file_path.resolve()
file_path = __get_path(file_path, check_exist=False).resolve()
if file_path.suffix.lower() not in {'.txt', '.dat'}:
file_path = file_path.with_suffix('.txt')
_create_folder(str(file_path))
Expand Down Expand Up @@ -698,56 +774,78 @@ def load_metadata_txt(file_path):
return xcenter, ycenter, list_fact


def save_plot_points(file_path, list_points, height, width, overwrite=True,
dpi=100, marker="o", color="blue"):
def __numpy_encoder(obj):
if isinstance(obj, (np.int_, np.intc, np.intp, np.int8,
np.int16, np.int32, np.int64, np.uint8,
np.uint16, np.uint32, np.uint64)):
return int(obj)
elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)):
return float(obj)
elif isinstance(obj, (np.ndarray,)):
return obj.tolist()
raise TypeError(f"Object of type '{type(obj).__name__}' "
f"is not JSON serializable")


def save_metadata_json(file_path, xcenter, ycenter, list_fact, overwrite=True):
"""
Save the plot of dot-centroids to an image. Useful to check if the dots
are arranged properly where dots on the same line having the same color.
Write metadata to a JSON file.
Parameters
----------
file_path : str
Output file path.
list_points : list of 1D-array
List of the (y-x)-coordinates of points.
height : int
Height of the image.
width : int
Width of the image.
xcenter : float
Center of distortion in x-direction.
ycenter : float
Center of distortion in y-direction.
list_fact : list of float
Coefficients of a polynomial.
overwrite : bool, optional
Overwrite the existing file if True.
dpi : int, optional
The resolution in dots per inch.
marker : str
Plot marker. Full list is at:
https://matplotlib.org/stable/api/markers_api.html
color : str
Marker color. Full list is at:
https://matplotlib.org/stable/tutorials/colors/colors.html
Overwrite an existing file if True.
Returns
-------
str
Updated file path.
"""
file_path = __get_path(file_path, check_exist=False)
file_path = file_path.resolve()
# Get resolved file path and set to JSON suffix
file_path = __get_path(file_path, check_exist=False).resolve()
if file_path.suffix.lower() != '.json':
file_path = file_path.with_suffix('.json')
_create_folder(str(file_path))

if not overwrite:
file_path = _create_file_name(str(file_path))
fig = plt.figure(frameon=False)
fig.set_size_inches(width / dpi, height / dpi)
ax = plt.Axes(fig, [0., 0., 1.0, 1.0])
ax.set_axis_off()
fig.add_axes(ax)
plt.axis((0, width, 0, height))
m_size = 0.5 * min(height / dpi, width / dpi)
for point in list_points:
plt.plot(point[1], height - point[0], marker, color=color,
markersize=m_size)
try:
plt.savefig(file_path, dpi=dpi)
except IOError:
raise ValueError("Couldn't write to file {}".format(file_path))
plt.close()

# Create metadata dictionary
metadata = {
'xcenter': float(xcenter),
'ycenter': float(ycenter),
'list_fact': list_fact
}
with open(file_path, "w") as f:
json.dump(metadata, f, indent=4, default=__numpy_encoder)
return file_path


def load_metadata_json(file_path):
"""
Load distortion coefficients from a JSON file.
Parameters
----------
file_path : str
Path to a JSON file.
Returns
-------
tuple of floats and list
Tuple of (xcenter, ycenter, list_fact).
"""
with open(__get_path(file_path), 'r') as f:
metadata = json.load(f)
xcenter = metadata['xcenter']
ycenter = metadata['ycenter']
list_fact = metadata['list_fact']
return xcenter, ycenter, list_fact
6 changes: 6 additions & 0 deletions examples/example_01.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@
np.abs(corrected_mat - mat0))
io.save_metadata_txt(output_base + "/coefficients_bw.txt", xcenter, ycenter,
list_fact)
# io.save_metadata_json(output_base + "/coefficients_bw.json", xcenter, ycenter,
# list_fact)

# Check the correction results
list_uhor_lines = post.unwarp_line_backward(list_hor_lines, xcenter, ycenter,
Expand Down Expand Up @@ -205,6 +207,8 @@
np.abs(corrected_mat - mat0))
io.save_metadata_txt(output_base + "coefficients_fw.txt", xcenter, ycenter,
list_fact)
# io.save_metadata_json(output_base + "coefficients_fw.json", xcenter, ycenter,
# list_fact)

# Check the correction results
list_uhor_lines = post.unwarp_line_forward(list_hor_lines, xcenter, ycenter,
Expand Down Expand Up @@ -247,6 +251,8 @@
np.abs(corrected_mat - mat0))
io.save_metadata_txt(
output_base + "/coefficients_bwfw.txt", xcenter, ycenter, list_bfact)
# io.save_metadata_json(
# output_base + "/coefficients_bwfw.json", xcenter, ycenter, list_bfact)
# Check the correction results
list_uhor_lines = post.unwarp_line_backward(
list_hor_lines, xcenter, ycenter, list_bfact)
Expand Down
42 changes: 30 additions & 12 deletions tests/test_loadersaver.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,19 @@ def test_open_hdf_stream(self):
self.assertRaises(ValueError, f_alias, "./tmp/data/data4.hdf",
(64, 64), options={"energy/entry/data": 25.0})

def test_save_plot_points(self):
f_alias = losa.save_plot_points
list_data = np.ones((64, 2), dtype=np.float32)
file_path = "./tmp/data/plot1.png"
list_data[:, 0] = 0.5 * np.random.rand(64)
list_data[:, 1] = np.arange(64)
f_alias(file_path, list_data, 64, 64, dpi=100)
self.assertTrue(os.path.isfile(file_path))

path = f_alias(file_path, list_data, 64, 64,
dpi=100, overwrite=False)
self.assertTrue(os.path.isfile(path))

def test_save_metadata_txt(self):
f_alias = losa.save_metadata_txt
file_path = "./tmp/data/coef.txt"
Expand All @@ -253,21 +266,26 @@ def test_save_metadata_txt(self):
self.assertTrue(os.path.isfile(file_path + ".txt"))

def test_load_metadata_txt(self):
f_alias = losa.load_metadata_txt
file_path = "./tmp/data/coef1.txt"
losa.save_metadata_txt(file_path, 31.0, 32.0, [1.0, 0.0])
(x, y, facts) = f_alias(file_path)
(x, y, facts) = losa.load_metadata_txt(file_path)
self.assertTrue(((x == 31.0) and (y == 32.0)) and facts == [1.0, 0.0])

def test_save_plot_points(self):
f_alias = losa.save_plot_points
list_data = np.ones((64, 2), dtype=np.float32)
file_path = "./tmp/data/plot1.png"
list_data[:, 0] = 0.5 * np.random.rand(64)
list_data[:, 1] = np.arange(64)
f_alias(file_path, list_data, 64, 64, dpi=100)
def test_save_metadata_json(self):
f_alias = losa.save_metadata_json
file_path = "./tmp/data/coef.json"
f_alias(file_path, 31, 32, [1.0, 0.0])
self.assertTrue(os.path.isfile(file_path))

path = f_alias(file_path, list_data, 64, 64,
dpi=100, overwrite=False)
self.assertTrue(os.path.isfile(path))
path = f_alias(file_path, 31, 32, [1.0, 0.0], overwrite=False)
self.assertTrue(path != file_path)

file_path_no_ext = "./tmp/data/coef1"
f_alias(file_path_no_ext, 31, 32, [1.0, 0.0])
self.assertTrue(os.path.isfile(file_path_no_ext + ".json"))

def test_load_metadata_json(self):
file_path = "./tmp/data/coef1.json"
losa.save_metadata_json(file_path, 31.0, 32.0, [1.0, 0.0])
x, y, facts = losa.load_metadata_json(file_path)
self.assertTrue((x == 31.0) and (y == 32.0) and facts == [1.0, 0.0])

0 comments on commit d62478e

Please sign in to comment.