Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Analytical overhaul #58

Merged
merged 60 commits into from
Jun 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
45f8c92
wrap up issue#49, see if dask[array] works for pip
MaceKuailv Jun 18, 2023
c00e16f
make it possible to use string as input for oceinterp
MaceKuailv Jun 18, 2023
2671915
break-up analytical
MaceKuailv Jun 18, 2023
347e8bf
fix notebooks as of break apart analytical
MaceKuailv Jun 19, 2023
af7cb6b
dict is all ordered, so remove message from oceinterp
MaceKuailv Jun 19, 2023
3c6bc8c
move wall cross to update_after_cell_change,and rename()
MaceKuailv Jun 19, 2023
0950bd6
rename things a little more
MaceKuailv Jun 19, 2023
ffcd4f5
super weird thing about smart_read causing unreproduceable face conne…
MaceKuailv Jun 19, 2023
894bb20
have to roll back smart_read after so much struggle.
MaceKuailv Jun 19, 2023
cbfb503
fully renamed and restructured analytical-step and cross-wall
MaceKuailv Jun 19, 2023
85dd910
implement subset operations
MaceKuailv Jun 19, 2023
638d32c
allow analytical step done by subset
MaceKuailv Jun 19, 2023
829bbdf
the part that break cell wall cross narrowed
MaceKuailv Jun 20, 2023
cc17e05
further narrow done
MaceKuailv Jun 20, 2023
9b47293
move beyond cross_cell_wall
MaceKuailv Jun 20, 2023
ea5baaa
every function is already using subset
MaceKuailv Jun 20, 2023
504ad83
remove which from cross_cell, get_vol and get_u_du()
MaceKuailv Jun 20, 2023
9ed970e
finally remove all the masking
MaceKuailv Jun 20, 2023
7fa9b96
make sure we are not behind the previous branch
MaceKuailv Jun 20, 2023
45abb47
replace bool_masking with faster integer one
MaceKuailv Jun 20, 2023
41caca7
properly handle tf in to_next_stop
MaceKuailv Jun 20, 2023
04991a1
make the cross_wall_read more concise
MaceKuailv Jun 22, 2023
af3e444
simplify the rel2latlon step
MaceKuailv Jun 22, 2023
79f1546
some renaming
MaceKuailv Jun 22, 2023
fc83507
update docstring within cross_wall and analytical_step
MaceKuailv Jun 22, 2023
773f353
change some variable names
MaceKuailv Jun 22, 2023
e3826f5
move astype int upstream
MaceKuailv Jun 22, 2023
a9b10f8
be more careful with int conversion, try again
MaceKuailv Jun 22, 2023
4740938
fix index in numba mode
MaceKuailv Jun 22, 2023
64895b7
remove astype int in ocedata
MaceKuailv Jun 22, 2023
99aab44
remove more astype, which is slightly slower
MaceKuailv Jun 22, 2023
a8be2a4
refactor update_uvw_array
MaceKuailv Jun 22, 2023
6ea699c
walk through lagrangian, did some style change
MaceKuailv Jun 22, 2023
58aa7aa
fix save_raw, at least I think it is fixed
MaceKuailv Jun 22, 2023
b78f984
tinker with __new__ to help a few functions make more sense
MaceKuailv Jun 22, 2023
a77c5ac
assert accurate coordinate transform for curvilinear local cartesian …
MaceKuailv Jun 22, 2023
c7ae5be
add test for accurate rectilinear convertion
MaceKuailv Jun 22, 2023
46ceb7e
forgot to save the file before commit
MaceKuailv Jun 22, 2023
74b0153
add coord reproduce tests in vertical and temporal direction
MaceKuailv Jun 22, 2023
de1d5ea
add coordineate reproduce test for oceanparcel scheme.
MaceKuailv Jun 24, 2023
f1f7b62
add vertical interpolation quantitative check
MaceKuailv Jun 25, 2023
f86f114
add randomized test for vertical velocity
MaceKuailv Jun 25, 2023
da94685
apply snake case to oceinterp and add a test
MaceKuailv Jun 25, 2023
2a6d670
cover all of oceinterp
MaceKuailv Jun 25, 2023
6d8f957
add the stream function coservation as test cases
MaceKuailv Jun 25, 2023
33581b7
didnt add the file though
MaceKuailv Jun 25, 2023
621f9e5
try fix type check
MaceKuailv Jun 27, 2023
3e7da79
some more coverage
MaceKuailv Jun 27, 2023
4326eec
rename function get_masks.get_masks to get_masks.get_mask_arrays
MaceKuailv Jun 27, 2023
e392334
add repeated get_mask test
MaceKuailv Jun 27, 2023
d7d5394
improve coverage for masks. but the tests is 2 seconds slower.
MaceKuailv Jun 27, 2023
d45ec50
address all warnings in tests. should be a little more careful with s…
MaceKuailv Jun 27, 2023
df5fc4b
quantitative tests for reading velocity
MaceKuailv Jun 27, 2023
22d8d94
assert in eulerian
MaceKuailv Jun 27, 2023
064600a
assert a lot in topology
MaceKuailv Jun 27, 2023
2e1778e
assert for test_weight
MaceKuailv Jun 27, 2023
1142cf6
minor fixes in test_lagrangian
MaceKuailv Jun 27, 2023
1e82974
add a test for the case in issue 55, but cannot reproduce.
MaceKuailv Jun 28, 2023
1d8098e
move that test to a separate module
MaceKuailv Jun 28, 2023
7bb3283
some more simple test on not out of bound
MaceKuailv Jun 28, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ qa:
pre-commit run --all-files

unit-tests:
python -m pytest -vv --cov=seaduck --cov-report=$(COV_REPORT) --doctest-glob="*.md" --doctest-glob="*.rst"
python -m pytest -vv --cov=seaduck --cov-report=$(COV_REPORT) --doctest-glob="*.md" --doctest-glob="*.rst" -W ignore::RuntimeWarning

type-check:
python -m mypy .
Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ dependencies = [
"numpy",
"pandas",
"scipy",
"dask",
"dask[array]",
"xarray"
]
description = "A python package that interpolates data from ocean dataset from both Eulerian and Lagrangian perspective. "
Expand All @@ -39,7 +39,8 @@ module = [
"matplotlib.*",
"numba",
"pooch",
"scipy"
"scipy",
"scipy.interpolate"
]

[tool.ruff]
Expand Down
97 changes: 62 additions & 35 deletions seaduck/eulerian.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import logging

import numpy as np

Expand Down Expand Up @@ -112,8 +113,10 @@ class Position:
To actually do interpolation, use from_latlon method to tell the ducks where they are.
"""

def __init__(self):
self.rel = RelCoord()
def __new__(cls, *arg, **kwarg):
new_position = object.__new__(cls)
new_position.rel = RelCoord()
return new_position

def __getattr__(self, attr):
if attr == "rel":
Expand Down Expand Up @@ -161,30 +164,30 @@ def from_latlon(self, x=None, y=None, z=None, t=None, data=None):
if any(i != self.N for i in length if i > 1):
raise ValueError("Shapes of input coordinates are not compatible")

if isinstance(x, float):
x = np.array([1.0]) * x
if isinstance(y, float):
y = np.array([1.0]) * y
if isinstance(z, float):
z = np.array([1.0]) * z
if isinstance(t, float):
t = np.array([1.0]) * t
if isinstance(x, (int, float, np.floating)):
x = np.ones(self.N, float) * x
if isinstance(y, (int, float, np.floating)):
y = np.ones(self.N, float) * y
if isinstance(z, (int, float, np.floating)):
z = np.ones(self.N, float) * z
if isinstance(t, (int, float, np.floating)):
t = np.ones(self.N, float) * t

for thing in [x, y, z, t]:
if thing is None:
continue
if len(thing.shape) > 1:
raise ValueError("Input need to be 1D numpy arrays")
if (x is not None) and (y is not None):
self.lon = x
self.lat = y
self.rel.update(self.ocedata.find_rel_h(x, y))
self.lon = copy.deepcopy(x)
self.lat = copy.deepcopy(y)
self.rel.update(self.ocedata.find_rel_h(self.lon, self.lat))
else:
self.rel.update(HRel._make([None for i in range(11)]))
self.lon = None
self.lat = None
if z is not None:
self.dep = z
self.dep = copy.deepcopy(z)
if self.ocedata.readiness["Z"]:
self.rel.update(self.ocedata.find_rel_v(z))
else:
Expand All @@ -199,7 +202,7 @@ def from_latlon(self, x=None, y=None, z=None, t=None, data=None):
self.dep = None

if t is not None:
self.t = t
self.t = copy.deepcopy(t)
if self.ocedata.readiness["time"]:
self.rel.update(self.ocedata.find_rel_t(t))
else:
Expand All @@ -225,23 +228,47 @@ def subset(self, which):
the_subset: Position object
The selected Positions.
"""
p = Position()
p = object.__new__(type(self))
vardict = vars(self)
keys = vardict.keys()
for i in keys:
item = vardict[i]
for key in keys:
item = vardict[key]
if isinstance(item, np.ndarray):
if len(item.shape) == 1:
setattr(p, i, item[which])
p.N = len(getattr(p, i))
setattr(p, key, item[which])
p.N = len(getattr(p, key))
elif key in ["px", "py"]:
setattr(p, key, item[:, which])
else:
setattr(p, i, item)
setattr(p, key, item)
elif isinstance(item, RelCoord):
setattr(p, i, item.subset(which))
setattr(p, key, item.subset(which))
else:
setattr(p, i, item)
setattr(p, key, item)
return p

def update_from_subset(self, sub, which):
vardict = vars(sub)
keys = vardict.keys()
for key in keys:
item = vardict[key]
if not hasattr(self, key):
logging.warning(
f"A new attribute {key} defined" "after updating from subset"
)
setattr(self, key, item)
if getattr(self, key) is None:
continue
if isinstance(item, np.ndarray):
if len(item.shape) == 1:
getattr(self, key)[which] = item
elif key in ["px", "py"]:
getattr(self, key)[:, which] = item
elif isinstance(item, RelCoord):
self.rel.update_from_subset(item, which)
elif isinstance(item, list):
setattr(self, key, item)

def fatten_h(self, knw, ind_moves_kwarg={}):
"""Fatten horizontal indices.

Expand All @@ -262,18 +289,18 @@ def fatten_h(self, knw, ind_moves_kwarg={}):
Read Topology.ind_moves for more detail.
"""
# self.ind_h_dict
kernel = knw.kernel
kernel = knw.kernel.astype(int)
kernel_tends = [_translate_to_tendency(k) for k in kernel]
m = len(kernel_tends)
n = len(self.iy)
tp = self.ocedata.tp

# the arrays we are going to return
if self.face is not None:
n_faces = np.zeros((n, m))
n_faces = np.zeros((n, m), int)
n_faces.T[:] = self.face
n_iys = np.zeros((n, m))
n_ixs = np.zeros((n, m))
n_iys = np.zeros((n, m), int)
n_ixs = np.zeros((n, m), int)

# first try to fatten it naively(fast and vectorized)
for i, node in enumerate(kernel):
Expand Down Expand Up @@ -302,9 +329,9 @@ def fatten_h(self, knw, ind_moves_kwarg={}):
else:
n_iys[j, i], n_ixs[j, i] = n_ind
if self.face is not None:
return n_faces.astype("int"), n_iys.astype("int"), n_ixs.astype("int")
return n_faces, n_iys, n_ixs
else:
return None, n_iys.astype("int"), n_ixs.astype("int")
return None, n_iys, n_ixs

def fatten_v(self, knw):
"""Fatten in vertical center coord.
Expand All @@ -319,13 +346,13 @@ def fatten_v(self, knw):
if self.iz is None:
return None
if knw.vkernel == "nearest":
return copy.deepcopy(self.iz.astype(int))
return copy.deepcopy(self.iz)
elif knw.vkernel in ["dz", "linear"]:
try:
self.iz_lin
except AttributeError:
self.rel.update(self.ocedata.find_rel_v_lin(self.dep))
return np.vstack([self.iz_lin.astype(int), self.iz_lin.astype(int) - 1]).T
return np.vstack([self.iz_lin, self.iz_lin - 1]).T
else:
raise ValueError("vkernel not supported")

Expand All @@ -342,13 +369,13 @@ def fatten_vl(self, knw):
if self.izl is None:
return None
if knw.vkernel == "nearest":
return copy.deepcopy(self.izl.astype(int))
return copy.deepcopy(self.izl)
elif knw.vkernel in ["dz", "linear"]:
try:
self.izl_lin
except AttributeError:
self.rel.update(self.ocedata.find_rel_vl_lin(self.dep))
return np.vstack([self.izl_lin.astype(int), self.izl_lin.astype(int) - 1]).T
return np.vstack([self.izl_lin, self.izl_lin - 1]).T
else:
raise ValueError("vkernel not supported")

Expand All @@ -365,13 +392,13 @@ def fatten_t(self, knw):
if self.it is None:
return None
if knw.tkernel == "nearest":
return copy.deepcopy(self.it.astype(int))
return copy.deepcopy(self.it)
elif knw.tkernel in ["dt", "linear"]:
try:
self.izl_lin
except AttributeError:
self.rel.update(self.ocedata.find_rel_t_lin(self.t))
return np.vstack([self.it_lin.astype(int), self.it_lin.astype(int) + 1]).T
return np.vstack([self.it_lin, self.it_lin + 1]).T
else:
raise ValueError("tkernel not supported")

Expand Down
12 changes: 8 additions & 4 deletions seaduck/get_masks.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def mask_w_node(maskC, tp=None):
return maskW


def get_masks(od, tp):
def get_mask_arrays(od):
"""Mask all staggered valocity points.

A wrapper around mask_u_node, mask_v_node, mask_w_node.
Expand All @@ -125,6 +125,7 @@ def get_masks(od, tp):
maskC,maskU,maskV,maskW: numpy.ndarray
masks at center points, U-walls, V-walls, W-walls respectively.
"""
tp = od.tp
keys = od._ds.keys()
if "maskC" not in keys:
warnings.warn("no maskC in the dataset, assuming nothing is masked.")
Expand Down Expand Up @@ -217,10 +218,13 @@ def get_masked(od, ind, cuvwg="C"):

def which_not_stuck(p):
"""Investigate which points are in land mask."""
ind = []
if p.izl_lin is not None:
ind.append(p.izl_lin - 1)
if p.face is not None:
ind = (p.izl_lin - 1, p.face, p.iy, p.ix)
else:
ind = (p.izl_lin - 1, p.iy, p.ix)
ind.append(p.face)
ind += [p.iy, p.ix]
ind = tuple(ind)
return get_masked(p.ocedata, ind).astype(bool)


Expand Down
24 changes: 11 additions & 13 deletions seaduck/kernel_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ def kernel_weight_x(kernel, ktype="interp", order=0):
function to calculate the hotizontal interpolation/derivative
weight
"""
xs = np.array(list(set(kernel.T[0]))).astype(float)
ys = np.array(list(set(kernel.T[1]))).astype(float)
xs = np.array(list(set(kernel.T[0])), dtype=float)
ys = np.array(list(set(kernel.T[1])), dtype=float)

# if you the kernel is a line rather than a cross
if len(xs) == 1:
Expand All @@ -149,8 +149,8 @@ def kernel_weight_x(kernel, ktype="interp", order=0):
y_poly.append(
list(combinations([i for i in ys if i != ay], len(ys) - 1 - order))
)
x_poly = np.array(x_poly).astype(float)
y_poly = np.array(y_poly).astype(float)
x_poly = np.array(x_poly, dtype=float)
y_poly = np.array(y_poly, dtype=float)

@compileable
def the_interp_func(rx, ry):
Expand Down Expand Up @@ -335,8 +335,8 @@ def kernel_weight_s(kernel, xorder=0, yorder=0):
function to calculate the hotizontal interpolation/derivative
weight
"""
xs = np.array(list(set(kernel.T[0]))).astype(float)
ys = np.array(list(set(kernel.T[1]))).astype(float)
xs = np.array(list(set(kernel.T[0])), dtype=float)
ys = np.array(list(set(kernel.T[1])), dtype=float)
xmaxorder = False
ymaxorder = False
if xorder < len(xs) - 1:
Expand All @@ -363,8 +363,8 @@ def kernel_weight_s(kernel, xorder=0, yorder=0):
y_poly.append(
list(combinations([i for i in ys if i != ay], len(ys) - 1 - yorder))
)
x_poly = np.array(x_poly).astype(float)
y_poly = np.array(y_poly).astype(float)
x_poly = np.array(x_poly, dtype=float)
y_poly = np.array(y_poly, dtype=float)

@compileable
def the_square_func(rx, ry):
Expand Down Expand Up @@ -722,8 +722,8 @@ def same_hsize(self, other):
if not type_same:
raise TypeError("the argument is not a KnW object")
try:
return (self.kernel == other.kernel).all()
except AttributeError:
return np.allclose(self.kernel, other.kernel)
except (ValueError, AttributeError):
return False

def same_size(self, other):
Expand All @@ -738,9 +738,7 @@ def __eq__(self, other):
type_same = isinstance(other, type(self))
if not type_same:
return False
shpe_same = (
self.kernel == other.kernel
).all() and self.inheritance == other.inheritance
shpe_same = self.same_hsize(other) and self.inheritance == other.inheritance
diff_same = (
(self.hkernel == other.hkernel)
and (self.vkernel == other.vkernel)
Expand Down
Loading