diff --git a/discorpy/losa/loadersaver.py b/discorpy/losa/loadersaver.py index 6fbaeb5..966d057 100644 --- a/discorpy/losa/loadersaver.py +++ b/discorpy/losa/loadersaver.py @@ -31,12 +31,14 @@ """ +import json import platform from pathlib import Path import h5py import numpy as np from PIL import Image import matplotlib.pyplot as plt +from matplotlib import font_manager from collections import OrderedDict @@ -422,8 +424,7 @@ def save_image(file_path, mat, overwrite=True): str Updated file path. """ - file_path = __get_path(file_path, check_exist=False) - file_path = file_path.resolve() + file_path = __get_path(file_path, check_exist=False).resolve() file_ext = file_path.suffix if not ((file_ext == ".tif") or (file_ext == ".tiff")): if mat.dtype != np.uint8: @@ -494,8 +495,29 @@ def save_plot_image(file_path, list_lines, height, width, overwrite=True, return file_path +def __check_font(font_family): + """ + Check if a specific font is available in Matplotlib. + + Parameters + ---------- + font_family : str + Name of the font to check. + + Returns + ------- + bool + True if font is available, False otherwise. + """ + try: + font_manager.findfont(font_family, fallback_to_default=False) + return True + except: + return False + + def save_residual_plot(file_path, list_data, height, width, overwrite=True, - dpi=100): + dpi=100, font_family='Times New Roman'): """ Save the plot of residual against radius to an image. Useful to check the accuracy of unwarping results. @@ -514,6 +536,8 @@ def save_residual_plot(file_path, list_data, height, width, overwrite=True, Overwrite the existing file if True. dpi : int, optional The resolution in dots per inch. + font_family : str, optional + To set the font family Returns ------- @@ -528,7 +552,8 @@ def save_residual_plot(file_path, list_data, height, width, overwrite=True, fig.set_size_inches(width / dpi, height / dpi) m_size = 0.5 * min(height / dpi, width / dpi) plt.rc('font', size=np.int16(m_size * 4)) - plt.rcParams['font.family'] = 'Times New Roman' + if __check_font(font_family): + plt.rcParams['font.family'] = font_family plt.rcParams['font.weight'] = 'bold' plt.xlabel('Radius', fontweight='bold') plt.ylabel('Residual', fontweight='bold') @@ -562,8 +587,7 @@ def save_hdf_file(file_path, idata, key_path='entry', overwrite=True): str Updated file path. """ - file_path = __get_path(file_path, check_exist=False) - file_path = file_path.resolve() + file_path = __get_path(file_path, check_exist=False).resolve() if file_path.suffix.lower() not in {'.hdf', '.h5', '.nxs', '.hdf5'}: file_path = file_path.with_suffix('.hdf') _create_folder(str(file_path)) @@ -605,8 +629,7 @@ def open_hdf_stream(file_path, data_shape, key_path='entry/data', object hdf object. """ - file_path = __get_path(file_path, check_exist=False) - file_path = file_path.resolve() + file_path = __get_path(file_path, check_exist=False).resolve() if file_path.suffix.lower() not in {'.hdf', '.h5', '.nxs', '.hdf5'}: file_path = file_path.with_suffix('.hdf') _create_folder(str(file_path)) @@ -631,6 +654,60 @@ def open_hdf_stream(file_path, data_shape, key_path='entry/data', return data_out +def save_plot_points(file_path, list_points, height, width, overwrite=True, + dpi=100, marker="o", color="blue"): + """ + Save the plot of dot-centroids to an image. Useful to check if the dots + are arranged properly where dots on the same line having the same color. + + Parameters + ---------- + file_path : str + Output file path. + list_points : list of 1D-array + List of the (y-x)-coordinates of points. + height : int + Height of the image. + width : int + Width of the image. + overwrite : bool, optional + Overwrite the existing file if True. + dpi : int, optional + The resolution in dots per inch. + marker : str + Plot marker. Full list is at: + https://matplotlib.org/stable/api/markers_api.html + color : str + Marker color. Full list is at: + https://matplotlib.org/stable/tutorials/colors/colors.html + + Returns + ------- + str + Updated file path. + """ + file_path = __get_path(file_path, check_exist=False).resolve() + _create_folder(str(file_path)) + if not overwrite: + file_path = _create_file_name(str(file_path)) + fig = plt.figure(frameon=False) + fig.set_size_inches(width / dpi, height / dpi) + ax = plt.Axes(fig, [0., 0., 1.0, 1.0]) + ax.set_axis_off() + fig.add_axes(ax) + plt.axis((0, width, 0, height)) + m_size = 0.5 * min(height / dpi, width / dpi) + for point in list_points: + plt.plot(point[1], height - point[0], marker, color=color, + markersize=m_size) + try: + plt.savefig(file_path, dpi=dpi) + except IOError: + raise ValueError("Couldn't write to file {}".format(file_path)) + plt.close() + return file_path + + def save_metadata_txt(file_path, xcenter, ycenter, list_fact, overwrite=True): """ Write metadata to a text file. @@ -653,8 +730,7 @@ def save_metadata_txt(file_path, xcenter, ycenter, list_fact, overwrite=True): str Updated file path. """ - file_path = __get_path(file_path, check_exist=False) - file_path = file_path.resolve() + file_path = __get_path(file_path, check_exist=False).resolve() if file_path.suffix.lower() not in {'.txt', '.dat'}: file_path = file_path.with_suffix('.txt') _create_folder(str(file_path)) @@ -698,56 +774,78 @@ def load_metadata_txt(file_path): return xcenter, ycenter, list_fact -def save_plot_points(file_path, list_points, height, width, overwrite=True, - dpi=100, marker="o", color="blue"): +def __numpy_encoder(obj): + if isinstance(obj, (np.int_, np.intc, np.intp, np.int8, + np.int16, np.int32, np.int64, np.uint8, + np.uint16, np.uint32, np.uint64)): + return int(obj) + elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)): + return float(obj) + elif isinstance(obj, (np.ndarray,)): + return obj.tolist() + raise TypeError(f"Object of type '{type(obj).__name__}' " + f"is not JSON serializable") + + +def save_metadata_json(file_path, xcenter, ycenter, list_fact, overwrite=True): """ - Save the plot of dot-centroids to an image. Useful to check if the dots - are arranged properly where dots on the same line having the same color. + Write metadata to a JSON file. Parameters ---------- file_path : str Output file path. - list_points : list of 1D-array - List of the (y-x)-coordinates of points. - height : int - Height of the image. - width : int - Width of the image. + xcenter : float + Center of distortion in x-direction. + ycenter : float + Center of distortion in y-direction. + list_fact : list of float + Coefficients of a polynomial. overwrite : bool, optional - Overwrite the existing file if True. - dpi : int, optional - The resolution in dots per inch. - marker : str - Plot marker. Full list is at: - https://matplotlib.org/stable/api/markers_api.html - color : str - Marker color. Full list is at: - https://matplotlib.org/stable/tutorials/colors/colors.html + Overwrite an existing file if True. Returns ------- str Updated file path. """ - file_path = __get_path(file_path, check_exist=False) - file_path = file_path.resolve() + # Get resolved file path and set to JSON suffix + file_path = __get_path(file_path, check_exist=False).resolve() + if file_path.suffix.lower() != '.json': + file_path = file_path.with_suffix('.json') _create_folder(str(file_path)) + if not overwrite: file_path = _create_file_name(str(file_path)) - fig = plt.figure(frameon=False) - fig.set_size_inches(width / dpi, height / dpi) - ax = plt.Axes(fig, [0., 0., 1.0, 1.0]) - ax.set_axis_off() - fig.add_axes(ax) - plt.axis((0, width, 0, height)) - m_size = 0.5 * min(height / dpi, width / dpi) - for point in list_points: - plt.plot(point[1], height - point[0], marker, color=color, - markersize=m_size) - try: - plt.savefig(file_path, dpi=dpi) - except IOError: - raise ValueError("Couldn't write to file {}".format(file_path)) - plt.close() + + # Create metadata dictionary + metadata = { + 'xcenter': float(xcenter), + 'ycenter': float(ycenter), + 'list_fact': list_fact + } + with open(file_path, "w") as f: + json.dump(metadata, f, indent=4, default=__numpy_encoder) return file_path + + +def load_metadata_json(file_path): + """ + Load distortion coefficients from a JSON file. + + Parameters + ---------- + file_path : str + Path to a JSON file. + + Returns + ------- + tuple of floats and list + Tuple of (xcenter, ycenter, list_fact). + """ + with open(__get_path(file_path), 'r') as f: + metadata = json.load(f) + xcenter = metadata['xcenter'] + ycenter = metadata['ycenter'] + list_fact = metadata['list_fact'] + return xcenter, ycenter, list_fact diff --git a/examples/example_01.py b/examples/example_01.py index 2c101ae..5076bf8 100644 --- a/examples/example_01.py +++ b/examples/example_01.py @@ -167,6 +167,8 @@ np.abs(corrected_mat - mat0)) io.save_metadata_txt(output_base + "/coefficients_bw.txt", xcenter, ycenter, list_fact) +# io.save_metadata_json(output_base + "/coefficients_bw.json", xcenter, ycenter, +# list_fact) # Check the correction results list_uhor_lines = post.unwarp_line_backward(list_hor_lines, xcenter, ycenter, @@ -205,6 +207,8 @@ np.abs(corrected_mat - mat0)) io.save_metadata_txt(output_base + "coefficients_fw.txt", xcenter, ycenter, list_fact) +# io.save_metadata_json(output_base + "coefficients_fw.json", xcenter, ycenter, +# list_fact) # Check the correction results list_uhor_lines = post.unwarp_line_forward(list_hor_lines, xcenter, ycenter, @@ -247,6 +251,8 @@ np.abs(corrected_mat - mat0)) io.save_metadata_txt( output_base + "/coefficients_bwfw.txt", xcenter, ycenter, list_bfact) +# io.save_metadata_json( +# output_base + "/coefficients_bwfw.json", xcenter, ycenter, list_bfact) # Check the correction results list_uhor_lines = post.unwarp_line_backward( list_hor_lines, xcenter, ycenter, list_bfact) diff --git a/tests/test_loadersaver.py b/tests/test_loadersaver.py index 35b83b3..4d7cbc5 100644 --- a/tests/test_loadersaver.py +++ b/tests/test_loadersaver.py @@ -239,6 +239,19 @@ def test_open_hdf_stream(self): self.assertRaises(ValueError, f_alias, "./tmp/data/data4.hdf", (64, 64), options={"energy/entry/data": 25.0}) + def test_save_plot_points(self): + f_alias = losa.save_plot_points + list_data = np.ones((64, 2), dtype=np.float32) + file_path = "./tmp/data/plot1.png" + list_data[:, 0] = 0.5 * np.random.rand(64) + list_data[:, 1] = np.arange(64) + f_alias(file_path, list_data, 64, 64, dpi=100) + self.assertTrue(os.path.isfile(file_path)) + + path = f_alias(file_path, list_data, 64, 64, + dpi=100, overwrite=False) + self.assertTrue(os.path.isfile(path)) + def test_save_metadata_txt(self): f_alias = losa.save_metadata_txt file_path = "./tmp/data/coef.txt" @@ -253,21 +266,26 @@ def test_save_metadata_txt(self): self.assertTrue(os.path.isfile(file_path + ".txt")) def test_load_metadata_txt(self): - f_alias = losa.load_metadata_txt file_path = "./tmp/data/coef1.txt" losa.save_metadata_txt(file_path, 31.0, 32.0, [1.0, 0.0]) - (x, y, facts) = f_alias(file_path) + (x, y, facts) = losa.load_metadata_txt(file_path) self.assertTrue(((x == 31.0) and (y == 32.0)) and facts == [1.0, 0.0]) - def test_save_plot_points(self): - f_alias = losa.save_plot_points - list_data = np.ones((64, 2), dtype=np.float32) - file_path = "./tmp/data/plot1.png" - list_data[:, 0] = 0.5 * np.random.rand(64) - list_data[:, 1] = np.arange(64) - f_alias(file_path, list_data, 64, 64, dpi=100) + def test_save_metadata_json(self): + f_alias = losa.save_metadata_json + file_path = "./tmp/data/coef.json" + f_alias(file_path, 31, 32, [1.0, 0.0]) self.assertTrue(os.path.isfile(file_path)) - path = f_alias(file_path, list_data, 64, 64, - dpi=100, overwrite=False) - self.assertTrue(os.path.isfile(path)) + path = f_alias(file_path, 31, 32, [1.0, 0.0], overwrite=False) + self.assertTrue(path != file_path) + + file_path_no_ext = "./tmp/data/coef1" + f_alias(file_path_no_ext, 31, 32, [1.0, 0.0]) + self.assertTrue(os.path.isfile(file_path_no_ext + ".json")) + + def test_load_metadata_json(self): + file_path = "./tmp/data/coef1.json" + losa.save_metadata_json(file_path, 31.0, 32.0, [1.0, 0.0]) + x, y, facts = losa.load_metadata_json(file_path) + self.assertTrue((x == 31.0) and (y == 32.0) and facts == [1.0, 0.0])