Skip to content

Commit

Permalink
Merge pull request #121 from MaceKuailv/dic
Browse files Browse the repository at this point in the history
Dic
  • Loading branch information
MaceKuailv authored Jul 8, 2024
2 parents 8055b00 + 963603d commit dfd63ec
Show file tree
Hide file tree
Showing 13 changed files with 700 additions and 146 deletions.
129 changes: 118 additions & 11 deletions seaduck/eulerian.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

from seaduck.get_masks import get_masked
from seaduck.kernel_weight import KnW, _find_pk_4d, _translate_to_tendency
from seaduck.ocedata import HRel, OceData, RelCoord, TRel, VlRel, VRel
from seaduck.ocedata import HRel, OceData, RelCoord, TRel, VlLinRel, VlRel, VRel
from seaduck.smart_read import smart_read
from seaduck.utils import (
_general_len,
_read_h,
find_px_py,
get_key_by_value,
local_to_latlon,
Expand Down Expand Up @@ -193,13 +194,15 @@ def from_latlon(self, x=None, y=None, z=None, t=None, data=None):
self.rel.update(VRel._make(None for i in range(4)))
if self.ocedata.readiness["Zl"]:
self.rel.update(self.ocedata._find_rel_vl(z))
self.rel.update(self.ocedata._find_rel_vl_lin(z))
else:
self.rel.update(VlRel._make(None for i in range(4)))
self.rel.update(VlLinRel._make(None for i in range(4)))
else:
self.rel.update(VRel._make(None for i in range(4)))
self.rel.update(VlRel._make(None for i in range(4)))
self.rel.update(VlLinRel._make(None for i in range(4)))
self.dep = None

if t is not None:
self.t = copy.deepcopy(t)
if self.ocedata.readiness["time"]:
Expand All @@ -212,6 +215,109 @@ def from_latlon(self, x=None, y=None, z=None, t=None, data=None):
# (self.it, self.rt, self.dt, self.bt, self.t) = (None for i in range(5))
return self

def from_bool_array(
self, t=None, data=None, bool_array=None, num=None, random_seed=None
):
"""Update/Generate new object with random points in given grid boxes.
Use the methods from the ocedata to transform
from lat-lon-dep-time coords to rel-coords
store the output in the Position object.
Parameters
----------
t: numpy.ndarray, float or None, default None
1D array of the time coords
data: OceData object
The field where the Positions are defined on.
bool_array: numpy.ndarray, or xr.DataArray
Points are generated where it is True.
num: int
Total number of particles to seed (approximately).
random_seed: int optional
The random seed used for reproducible results.
"""
if random_seed is not None:
np.random.seed(random_seed)

if isinstance(data, OceData):
if data.readiness["h"] != "oceanparcel":
raise NotImplementedError(
"This method only support datasets that has XG, YG in it."
)
if not (data.readiness["Z"] and data.readiness["Zl"]):
raise NotImplementedError(
"This method only support datasets that has all Z coords."
)
self.ocedata = data
else:
raise ValueError("Input data must be OceData")
try:
self.ocedata["Vol"]
except KeyError:
self.ocedata._add_missing_vol()
self.tp = self.ocedata.tp

inds = np.where(bool_array)
try:
data["Vol"]
except KeyError:
data._add_missing_vol()
np_vol = np.array(data["Vol"])
vols = np.array(np_vol[inds])
num_each = np.round(vols * num / np.sum(vols)).astype(int)
num = np.sum(num_each)
self.N = num
inds = tuple(np.repeat(np.array(i), num_each) for i in inds)
if len(inds) == 3:
iz, iy, ix = inds
ind = (iy, ix)
face = None
elif len(inds) == 4:
iz, face, iy, ix = inds
ind = (face, iy, ix)
else:
raise ValueError("bool_array must be 3 or 4 dimensional")
cs, sn, dx, dy, bx, by = _read_h(
data.XC, data.YC, data.dX, data.dY, data.CS, data.SN, ind
)
rx = np.random.random(num) - 0.5
ry = np.random.random(num) - 0.5
rzl_lin = np.random.random(num)
self.rel.update(HRel(face, iy, ix, rx, ry, cs, sn, dx, dy, bx, by))

izl_lin = iz + 1
bzl_lin = data.Zl[izl_lin]
dzl_lin = data.dZl[izl_lin]
self.rel.update(VlLinRel(izl_lin, rzl_lin, dzl_lin, bzl_lin))

# set lon temporaily
self.lon = bx
px, py = self.get_px_py()
w = self.get_f_node_weight()
self.lon = np.einsum("nj,nj->n", w, px.T)
self.lat = np.einsum("nj,nj->n", w, py.T)

self.dep = self.bzl_lin + self.dzl_lin * self.rzl_lin
self.rel.update(self.ocedata._find_rel_v(self.dep))
self.rel.update(self.ocedata._find_rel_vl(self.dep))

if isinstance(t, (int, float, np.floating)):
t = np.ones(self.N, float) * t
elif isinstance(t, np.ndarray):
if len(t) != self.N:
raise ValueError("Mismatch between input t and final particle number")
if t is not None:
self.t = copy.deepcopy(t)
if self.ocedata.readiness["time"]:
self.rel.update(self.ocedata._find_rel_t(t))
else:
self.rel.update(TRel._make(None for i in range(4)))
else:
self.rel.update(TRel._make(None for i in range(4)))
self.t = None
return self

def subset(self, which):
"""Create a subset of the Position object.
Expand Down Expand Up @@ -317,10 +423,11 @@ def _fatten_h(self, knw, ind_moves_kwarg={}):
x_disp, y_disp = node
n_iys[:, i] = self.iy + y_disp
n_ixs[:, i] = self.ix + x_disp
cuvwg = ind_moves_kwarg.get("cuvwg", "C")
if self.face is not None:
illegal = tp.check_illegal((n_faces, n_iys, n_ixs))
illegal = tp.check_illegal((n_faces, n_iys, n_ixs), cuvwg=cuvwg)
else:
illegal = tp.check_illegal((n_iys, n_ixs))
illegal = tp.check_illegal((n_iys, n_ixs), cuvwg=cuvwg)

redo = np.array(np.where(illegal)).T
for loc in redo:
Expand Down Expand Up @@ -405,7 +512,7 @@ def _fatten_t(self, knw):
return copy.deepcopy(self.it)
elif knw.tkernel in ["dt", "linear"]:
try:
self.izl_lin
self.it_lin
except AttributeError:
self.rel.update(self.ocedata._find_rel_t_lin(self.t))
return np.vstack([self.it_lin, self.it_lin + 1]).T
Expand Down Expand Up @@ -748,21 +855,21 @@ def _fatten_required_index_and_register(self, hash_index, main_dict):
main_key = get_key_by_value(hash_index, hs)
var_name, dims, knw = main_dict[main_key]
if isinstance(var_name, str):
old_dims = dims
old_dims = copy.deepcopy(dims)
elif isinstance(var_name, tuple):
old_dims = dims[0]
old_dims = copy.deepcopy(dims[0])
dims = []
for i in old_dims:
if i in ["Xp1", "Yp1"]:
dims.append(i[:1])
else:
dims.append(i)
dims = tuple(dims)
if "Xp1" in old_dims or "Yp1" in old_dims:
cuvwg = "G"
else:
cuvwg = "C"
if isinstance(var_name, str):
if "Xp1" in old_dims and "Yp1" in old_dims:
cuvwg = "G"
else:
cuvwg = "C"
ind = self.fatten(
knw, required=dims, four_d=True, ind_moves_kwarg={"cuvwg": cuvwg}
)
Expand Down
Loading

0 comments on commit dfd63ec

Please sign in to comment.