Skip to content

Commit

Permalink
new initialization method online
Browse files Browse the repository at this point in the history
  • Loading branch information
MaceKuailv committed Apr 4, 2024
1 parent 77c7306 commit b2cecdc
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 11 deletions.
25 changes: 15 additions & 10 deletions seaduck/eulerian.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,19 +258,17 @@ def from_bool_array(
except KeyError:
self.ocedata._add_missing_vol()

inds = (xr.DataArray(i, dim="dummy") for i in np.where(bool_array))
inds = np.where(bool_array)
try:
assert isinstance(data["Vol"], xr.DataArray)
xr_vol = data["Vol"]
except (KeyError, AssertionError):
xr_vol = data._ds["drF"] * data._ds["rA"]
if "HFacC" in data._ds.data_vars:
xr_vol *= data._ds["HFacC"]
vols = np.array(data._ds["Vol"][inds])
data["Vol"]
except KeyError:
data._add_missing_vol()
np_vol = np.array(data["Vol"])
vols = np.array(np_vol[inds])
num_each = np.round(vols * num / np.sum(vols)).astype(int)
num = np.sum(num_each)
self.N = num
inds = (np.repeat(np.array(i), num_each) for i in inds)
inds = tuple(np.repeat(np.array(i), num_each) for i in inds)
if len(inds) == 3:
iz, iy, ix = inds
ind = (iy, ix)
Expand All @@ -281,7 +279,7 @@ def from_bool_array(
else:
raise ValueError("bool_array must be 3 or 4 dimensional")
cs, sn, dx, dy, bx, by = _read_h(
self.XC, self.YC, self.dX, self.dY, self.CS, self.SN, ind
data.XC, data.YC, data.dX, data.dY, data.CS, data.SN, ind
)
rx = np.random.random(num) - 0.5
ry = np.random.random(num) - 0.5
Expand All @@ -293,6 +291,8 @@ def from_bool_array(
dzl_lin = data.dZl[izl_lin]
self.rel.update(VlLinRel(izl_lin, rzl_lin, dzl_lin, bzl_lin))

# set lon temporaily
self.lon = bx
px, py = self.get_px_py()
w = self.get_f_node_weight()
self.lon = np.einsum("nj,nj->n", w, px.T)
Expand All @@ -302,6 +302,11 @@ def from_bool_array(
self.rel.update(self.ocedata._find_rel_v(self.dep))
self.rel.update(self.ocedata._find_rel_vl(self.dep))

if isinstance(t, (int, float, np.floating)):
t = np.ones(self.N, float) * t
elif isinstance(t, np.ndarray):
if len(t)!=self.N:
raise ValueError("Mismatch between input t and final particle number")
if t is not None:
self.t = copy.deepcopy(t)
if self.ocedata.readiness["time"]:
Expand Down
5 changes: 4 additions & 1 deletion seaduck/lagrangian.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,10 @@ def __init__(
**kwarg,
):
Position.__init__(self)
self.from_latlon(**kwarg)
if "bool_array" in kwarg.keys():
self.from_bool_array(kwarg)
else:
self.from_latlon(**kwarg)

try:
self.px, self.py = self.get_px_py()
Expand Down

0 comments on commit b2cecdc

Please sign in to comment.