Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JIT Cleanup #1007

Merged
merged 4 commits into from
Oct 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added test/grid_geoflow.exo
Binary file not shown.
3 changes: 0 additions & 3 deletions uxarray/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@
# error tolerance, mainly in the intersection calculations.
MACHINE_EPSILON = np.float64(np.finfo(float).eps)

ENABLE_JIT_CACHE = True
ENABLE_JIT = True

ENABLE_FMA = False

GRID_DIMS = ["n_node", "n_edge", "n_face"]
Expand Down
6 changes: 3 additions & 3 deletions uxarray/grid/arcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def _to_list(obj):
return obj


@njit
@njit(cache=True)
def _point_within_gca_body(
angle, gca_cart, pt, GCRv0_lonlat, GCRv1_lonlat, pt_lonlat, is_directed
):
Expand Down Expand Up @@ -244,7 +244,7 @@ def point_within_gca(pt, gca_cart, is_directed=False):
return out


@njit
@njit(cache=True)
def in_between(p, q, r) -> bool:
"""Determines whether the number q is between p and r.

Expand All @@ -266,7 +266,7 @@ def in_between(p, q, r) -> bool:
return p <= q <= r or r <= q <= p


@njit
@njit(cache=True)
def _decide_pole_latitude(lat1, lat2):
"""Determine the pole latitude based on the latitudes of two points on a
Great Circle Arc (GCA).
Expand Down
17 changes: 7 additions & 10 deletions uxarray/grid/area.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@

from uxarray.grid.coordinates import _lonlat_rad_to_xyz

from numba import njit, config
from uxarray.constants import ENABLE_JIT_CACHE, ENABLE_JIT
from numba import njit

config.DISABLE_JIT = not ENABLE_JIT


@njit(cache=ENABLE_JIT_CACHE)
@njit(cache=True)
def calculate_face_area(
x, y, z, quadrature_rule="gaussian", order=4, coords_type="spherical"
):
Expand Down Expand Up @@ -98,7 +95,7 @@ def calculate_face_area(
return area, jacobian


@njit(cache=ENABLE_JIT_CACHE)
@njit(cache=True)
def get_all_face_area_from_coords(
x,
y,
Expand Down Expand Up @@ -173,7 +170,7 @@ def get_all_face_area_from_coords(
return area, jacobian


@njit(cache=ENABLE_JIT_CACHE)
@njit(cache=True)
def calculate_spherical_triangle_jacobian(node1, node2, node3, dA, dB):
"""Calculate Jacobian of a spherical triangle. This is a helper function
for calculating face area.
Expand Down Expand Up @@ -263,7 +260,7 @@ def calculate_spherical_triangle_jacobian(node1, node2, node3, dA, dB):
return dJacobian


@njit(cache=ENABLE_JIT_CACHE)
@njit(cache=True)
def calculate_spherical_triangle_jacobian_barycentric(node1, node2, node3, dA, dB):
"""Calculate Jacobian of a spherical triangle. This is a helper function
for calculating face area.
Expand Down Expand Up @@ -342,7 +339,7 @@ def calculate_spherical_triangle_jacobian_barycentric(node1, node2, node3, dA, d
return 0.5 * dJacobian


@njit(cache=ENABLE_JIT_CACHE)
@njit(cache=True)
def get_gauss_quadratureDG(nCount):
"""Gauss Quadrature Points for integration.

Expand Down Expand Up @@ -587,7 +584,7 @@ def get_gauss_quadratureDG(nCount):
return dG, dW


@njit(cache=ENABLE_JIT_CACHE)
@njit(cache=True)
def get_tri_quadratureDG(nOrder):
"""Triangular Quadrature Points for integration.

Expand Down
4 changes: 2 additions & 2 deletions uxarray/grid/connectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def _populate_n_nodes_per_face(grid):
)


@njit()
@njit(cache=True)
def _build_n_nodes_per_face(face_nodes, n_face, n_max_face_nodes):
"""Constructs ``n_nodes_per_face``, which contains the number of non-fill-
value nodes for each face in ``face_node_connectivity``"""
Expand Down Expand Up @@ -251,7 +251,7 @@ def _populate_edge_face_connectivity(grid):
)


@njit
@njit(cache=True)
def _build_edge_face_connectivity(face_edges, n_nodes_per_face, n_edge):
"""Helper for (``edge_face_connectivity``) construction."""
edge_faces = np.ones(shape=(n_edge, 2), dtype=face_edges.dtype) * INT_FILL_VALUE
Expand Down
16 changes: 8 additions & 8 deletions uxarray/grid/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def _lonlat_rad_to_xyz(
return x, y, z


@njit
@njit(cache=True)
def _xyz_to_lonlat_rad_no_norm(
x: Union[np.ndarray, float],
y: Union[np.ndarray, float],
Expand Down Expand Up @@ -67,7 +67,7 @@ def _xyz_to_lonlat_rad_no_norm(
return lon, lat


@njit
@njit(cache=True)
def _xyz_to_lonlat_rad_scalar(
x: Union[np.ndarray, float],
y: Union[np.ndarray, float],
Expand Down Expand Up @@ -217,7 +217,7 @@ def _normalize_xyz(
return x_norm, y_norm, z_norm


@njit
@njit(cache=True)
def _normalize_xyz_scalar(x: float, y: float, z: float):
denom = np.linalg.norm(np.asarray(np.array([x, y, z]), dtype=np.float64), ord=2)
x_norm = x / denom
Expand Down Expand Up @@ -430,7 +430,7 @@ def _smallest_enclosing_circle(points):
return _welzl_recursive(points, np.empty((0, 2)), None)


@njit
@njit(cache=True)
def _circle_from_two_points(p1, p2):
"""Calculate the smallest circle that encloses two points on a unit sphere.

Expand Down Expand Up @@ -459,7 +459,7 @@ def _circle_from_two_points(p1, p2):
return center, radius


@njit
@njit(cache=True)
def _circle_from_three_points(p1, p2, p3):
"""Calculate the smallest circle that encloses three points on a unit
sphere. This is a placeholder implementation.
Expand Down Expand Up @@ -499,7 +499,7 @@ def _circle_from_three_points(p1, p2, p3):
return center, radius


@njit
@njit(cache=True)
def _is_inside_circle(circle, point):
"""Check if a point is inside a given circle on a unit sphere.

Expand Down Expand Up @@ -763,7 +763,7 @@ def _xyz_to_lonlat_rad(
return lon, lat


@njit
@njit(cache=True)
def _xyz_to_lonlat_rad_no_norm(
x: Union[np.ndarray, float],
y: Union[np.ndarray, float],
Expand Down Expand Up @@ -870,7 +870,7 @@ def _xyz_to_lonlat_deg(
return lon, lat


@njit
@njit(cache=True)
def _normalize_xyz_scalar(x: float, y: float, z: float):
denom = np.linalg.norm(np.asarray(np.array([x, y, z]), dtype=np.float64), ord=2)
x_norm = x / denom
Expand Down
5 changes: 2 additions & 3 deletions uxarray/grid/dual.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from uxarray.constants import INT_FILL_VALUE, INT_DTYPE

from numba import njit
from uxarray.constants import ENABLE_JIT_CACHE


def construct_dual(grid):
Expand Down Expand Up @@ -53,7 +52,7 @@ def construct_dual(grid):
return new_node_face_connectivity


@njit(cache=ENABLE_JIT_CACHE)
@njit(cache=True)
def construct_faces(
n_node,
n_edges,
Expand Down Expand Up @@ -146,7 +145,7 @@ def construct_faces(
return construct_node_face_connectivity


@njit(cache=ENABLE_JIT_CACHE)
@njit(cache=True)
def _order_nodes(
temp_face,
node_0,
Expand Down
8 changes: 6 additions & 2 deletions uxarray/grid/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
from spatialpandas.geometry import MultiPolygonArray, PolygonArray
import xarray as xr

from uxarray.constants import ERROR_TOLERANCE, INT_DTYPE, INT_FILL_VALUE
from uxarray.constants import (
ERROR_TOLERANCE,
INT_DTYPE,
INT_FILL_VALUE,
)
from uxarray.grid.arcs import extreme_gca_latitude, point_within_gca
from uxarray.grid.intersections import gca_gca_intersection
from uxarray.grid.utils import (
Expand Down Expand Up @@ -80,7 +84,7 @@ def error_radius(p1, p2):
return unique_points


@njit
@njit(cache=True)
def _pad_closed_face_nodes(
face_node_connectivity, n_face, n_max_face_nodes, n_nodes_per_face
):
Expand Down
4 changes: 2 additions & 2 deletions uxarray/grid/neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,7 +855,7 @@ def _populate_edge_node_distances(grid):
)


@njit
@njit(cache=True)
def _construct_edge_node_distances(node_lon, node_lat, edge_nodes):
"""Helper for computing the arc-distance between nodes compose each
edge."""
Expand Down Expand Up @@ -890,7 +890,7 @@ def _populate_edge_face_distances(grid):
)


@njit
@njit(cache=True)
def _construct_edge_face_distances(node_lon, node_lat, edge_faces):
"""Helper for computing the arc-distance between faces that saddle a given
edge."""
Expand Down
2 changes: 1 addition & 1 deletion uxarray/grid/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from numba import njit


@njit
@njit(cache=True)
def _angle_of_2_vectors(u, v):
"""Calculate the angle between two 3D vectors u and v in radians. Can be
used to calcualte the span of a GCR.
Expand Down
12 changes: 6 additions & 6 deletions uxarray/utils/computing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from numba import njit


@njit
@njit(cache=True)
def all(a):
"""Numba decorated implementation of ``np.all()``

Expand All @@ -16,7 +16,7 @@ def all(a):
return np.all(a)


@njit
@njit(cache=True)
def isclose(a, b, rtol=1e-05, atol=1e-08):
"""Numba decorated implementation of ``np.isclose()``

Expand All @@ -28,7 +28,7 @@ def isclose(a, b, rtol=1e-05, atol=1e-08):
return np.isclose(a, b, rtol=rtol, atol=atol)


@njit
@njit(cache=True)
def allclose(a, b, rtol=1e-05, atol=1e-08):
"""Numba decorated implementation of ``np.allclose()``

Expand All @@ -39,7 +39,7 @@ def allclose(a, b, rtol=1e-05, atol=1e-08):
return np.allclose(a, b, rtol=rtol, atol=atol)


@njit
@njit(cache=True)
def cross(a, b):
"""Numba decorated implementation of ``np.cross()``

Expand All @@ -50,7 +50,7 @@ def cross(a, b):
return np.cross(a, b)


@njit
@njit(cache=True)
def dot(a, b):
"""Numba decorated implementation of ``np.dot()``

Expand All @@ -61,7 +61,7 @@ def dot(a, b):
return np.dot(a, b)


@njit
@njit(cache=True)
def norm(x):
"""Numba decorated implementation of ``np.linalg.norm()``

Expand Down
49 changes: 0 additions & 49 deletions uxarray/utils/numba_settings.py

This file was deleted.

Loading