Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Normalization of Parsed Cartesian Coordinates #878

Merged
merged 17 commits into from
Sep 18, 2024
Merged
15 changes: 14 additions & 1 deletion benchmarks/mpas_ocean.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,20 @@ def time_nearest_neighbor_remapping(self):
def time_inverse_distance_weighted_remapping(self):
self.uxds_480["bottomDepth"].remap.inverse_distance_weighted(self.uxds_120.uxgrid)


class HoleEdgeIndices(DatasetBenchmark):
def time_construct_hole_edge_indices(self, resolution):
ux.grid.geometry._construct_hole_edge_indices(self.uxds.uxgrid.edge_face_connectivity)

class CheckNorm:
param_names = ['resolution']
params = ['480km', '120km']

def setup(self, resolution):
self.uxgrid = ux.open_grid(file_path_dict[resolution][0])

def teardown(self, resolution):
del self.uxgrid

def time_check_norm(self, resolution):
from uxarray.grid.validation import _check_normalization
_check_normalization(self.uxgrid)
38 changes: 32 additions & 6 deletions test/test_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,6 @@ def test_read_scrip(self):

# Test read from scrip and from ugrid for grid class
grid_CSne8 = ux.open_grid(gridfile_CSne8) # tests from scrip
pass


class TestOperators(TestCase):
Expand Down Expand Up @@ -926,7 +925,6 @@ def test_from_dataset(self):
xrds = xr.open_dataset(self.gridfile_scrip)
uxgrid = ux.Grid.from_dataset(xrds)

pass

def test_from_face_vertices(self):
single_face_latlon = [(0.0, 90.0), (-180, 0.0), (0.0, -90)]
Expand Down Expand Up @@ -961,7 +959,35 @@ def test_populate_bounds_GCA_mix(self):
face_bounds = bounds_xarray.values
nt.assert_allclose(grid.bounds.values, expected_bounds, atol=ERROR_TOLERANCE)

def test_opti_bounds(self):
import uxarray
uxgrid = ux.open_grid(gridfile_CSne8)
bounds = uxgrid.bounds
def test_populate_bounds_MPAS(self):
uxgrid = ux.open_grid(self.gridfile_mpas)
bounds_xarray = uxgrid.bounds


class TestNormalizeExistingCoordinates(TestCase):
gridfile_mpas = current_path / "meshfiles" / "mpas" / "QU" / "mesh.QU.1920km.151026.nc"
gridfile_CSne30 = current_path / "meshfiles" / "ugrid" / "outCSne30" / "outCSne30.ug"

def test_non_norm_initial(self):
"""Check the normalization of coordinates that were initially parsed as
non-normalized."""
from uxarray.grid.validation import _check_normalization
uxgrid = ux.open_grid(self.gridfile_mpas)

# Make the coordinates not normalized
uxgrid.node_x.data = 5 * uxgrid.node_x.data
uxgrid.node_y.data = 5 * uxgrid.node_y.data
uxgrid.node_z.data = 5 * uxgrid.node_z.data
assert not _check_normalization(uxgrid)

uxgrid.normalize_cartesian_coordinates()

assert _check_normalization(uxgrid)

def test_norm_initial(self):
"""Coordinates should be normalized for grids that we construct
them."""
from uxarray.grid.validation import _check_normalization
uxgrid = ux.open_grid(self.gridfile_CSne30)

assert _check_normalization(uxgrid)
37 changes: 37 additions & 0 deletions uxarray/grid/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
_set_desired_longitude_range,
_populate_node_latlon,
_populate_node_xyz,
_normalize_xyz,
)
from uxarray.grid.connectivity import (
_populate_edge_node_connectivity,
Expand Down Expand Up @@ -72,6 +73,7 @@
_check_connectivity,
_check_duplicate_nodes,
_check_area,
_check_normalization,
)

from xarray.core.utils import UncachedAccessor
Expand Down Expand Up @@ -175,6 +177,9 @@ def __init__(
self._ball_tree = None
self._kd_tree = None

# flag to track if coordinates are normalized
self._normalized = None

# set desired longitude range to [-180, 180]
_set_desired_longitude_range(self._ds)

Expand Down Expand Up @@ -1420,6 +1425,38 @@ def compute_face_areas(

return self._face_areas, self._face_jacobian

def normalize_cartesian_coordinates(self):
"""Normalizes Cartesian coordinates."""

if _check_normalization(self):
# check if coordinates are already normalized
return

if "node_x" in self._ds:
# normalize node coordinates
node_x, node_y, node_z = _normalize_xyz(
self.node_x.values, self.node_y.values, self.node_z.values
)
self.node_x.data = node_x
self.node_y.data = node_y
self.node_z.data = node_z
if "edge_x" in self._ds:
# normalize edge coordinates
edge_x, edge_y, edge_z = _normalize_xyz(
self.edge_x.values, self.edge_y.values, self.edge_z.values
)
self.edge_x.data = edge_x
self.edge_y.data = edge_y
self.edge_z.data = edge_z
if "face_x" in self._ds:
# normalize face coordinates
face_x, face_y, face_z = _normalize_xyz(
self.face_x.values, self.face_y.values, self.face_z.values
)
self.face_x.data = face_x
self.face_y.data = face_y
self.face_z.data = face_z

def to_xarray(self, grid_format: Optional[str] = "ugrid"):
"""Returns a xarray Dataset representation in a specific grid format
from the Grid object.
Expand Down
61 changes: 52 additions & 9 deletions uxarray/grid/validation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import numpy as np
from warnings import warn


from uxarray.constants import ERROR_TOLERANCE


# validation helper functions
def _check_connectivity(self):
def _check_connectivity(grid):
"""Check if all nodes are referenced by at least one element.

If not, the mesh may have hanging nodes and may not a valid UGRID
Expand All @@ -15,28 +14,28 @@ def _check_connectivity(self):

# Check if all nodes are referenced by at least one element
# get unique nodes in connectivity
nodes_in_conn = np.unique(self.face_node_connectivity.values.flatten())
nodes_in_conn = np.unique(grid.face_node_connectivity.values.flatten())
# remove negative indices/fill values from the list
nodes_in_conn = nodes_in_conn[nodes_in_conn >= 0]

# check if the size of unique nodes in connectivity is equal to the number of nodes
if nodes_in_conn.size == self.n_node:
if nodes_in_conn.size == grid.n_node:
print("-All nodes are referenced by at least one element.")
return True
else:
warn(
"Some nodes may not be referenced by any element. {0} and {1}".format(
nodes_in_conn.size, self.n_node
nodes_in_conn.size, grid.n_node
),
RuntimeWarning,
)
return False


def _check_duplicate_nodes(self):
def _check_duplicate_nodes(grid):
"""Check if there are duplicate nodes in the mesh."""

coords1 = np.column_stack((np.vstack(self.node_lon), np.vstack(self.node_lat)))
coords1 = np.column_stack((np.vstack(grid.node_lon), np.vstack(grid.node_lat)))
unique_nodes, indices = np.unique(coords1, axis=0, return_index=True)
duplicate_indices = np.setdiff1d(np.arange(len(coords1)), indices)

Expand All @@ -53,9 +52,9 @@ def _check_duplicate_nodes(self):
return True


def _check_area(self):
def _check_area(grid):
"""Check if each face area is greater than our constant ERROR_TOLERANCE."""
areas = self.face_areas
areas = grid.face_areas
# Check if area of any face is close to zero
if np.any(np.isclose(areas, 0, atol=ERROR_TOLERANCE)):
warn(
Expand All @@ -66,3 +65,47 @@ def _check_area(self):
else:
print("-No face area is close to zero.")
return True


def _check_normalization(grid):
"""Checks whether all the cartesiain coordinates are normalized."""

if grid._normalized is True:
# grid is already normalized, no need to run extra checks
return grid._normalized

if "node_x" in grid._ds:
if not (
np.isclose(
(grid.node_x**2 + grid.node_y**2 + grid.node_z**2),
1.0,
atol=ERROR_TOLERANCE,
)
).all():
grid._normalized = False
return False
if "edge_x" in grid._ds:
if not (
np.isclose(
(grid.node_x**2 + grid.node_y**2 + grid.node_z**2),
1.0,
atol=ERROR_TOLERANCE,
)
).all():
grid._normalized = False
return False
if "face_x" in grid._ds:
if not (
np.isclose(
(grid.node_x**2 + grid.node_y**2 + grid.node_z**2),
1.0,
atol=ERROR_TOLERANCE,
)
).all():
grid._normalized = False
return False

# set the grid as normalized
grid._normalized = True

return True
Loading