From b2cecdccab7aaf300b7204294316bb0bf9735f0a Mon Sep 17 00:00:00 2001 From: Wenrui Jiang Date: Thu, 4 Apr 2024 18:58:11 -0400 Subject: [PATCH] new initialization method online --- seaduck/eulerian.py | 25 +++++++++++++++---------- seaduck/lagrangian.py | 5 ++++- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/seaduck/eulerian.py b/seaduck/eulerian.py index 29721663..6381e214 100644 --- a/seaduck/eulerian.py +++ b/seaduck/eulerian.py @@ -258,19 +258,17 @@ def from_bool_array( except KeyError: self.ocedata._add_missing_vol() - inds = (xr.DataArray(i, dim="dummy") for i in np.where(bool_array)) + inds = np.where(bool_array) try: - assert isinstance(data["Vol"], xr.DataArray) - xr_vol = data["Vol"] - except (KeyError, AssertionError): - xr_vol = data._ds["drF"] * data._ds["rA"] - if "HFacC" in data._ds.data_vars: - xr_vol *= data._ds["HFacC"] - vols = np.array(data._ds["Vol"][inds]) + data["Vol"] + except KeyError: + data._add_missing_vol() + np_vol = np.array(data["Vol"]) + vols = np.array(np_vol[inds]) num_each = np.round(vols * num / np.sum(vols)).astype(int) num = np.sum(num_each) self.N = num - inds = (np.repeat(np.array(i), num_each) for i in inds) + inds = tuple(np.repeat(np.array(i), num_each) for i in inds) if len(inds) == 3: iz, iy, ix = inds ind = (iy, ix) @@ -281,7 +279,7 @@ def from_bool_array( else: raise ValueError("bool_array must be 3 or 4 dimensional") cs, sn, dx, dy, bx, by = _read_h( - self.XC, self.YC, self.dX, self.dY, self.CS, self.SN, ind + data.XC, data.YC, data.dX, data.dY, data.CS, data.SN, ind ) rx = np.random.random(num) - 0.5 ry = np.random.random(num) - 0.5 @@ -293,6 +291,8 @@ def from_bool_array( dzl_lin = data.dZl[izl_lin] self.rel.update(VlLinRel(izl_lin, rzl_lin, dzl_lin, bzl_lin)) + # set lon temporaily + self.lon = bx px, py = self.get_px_py() w = self.get_f_node_weight() self.lon = np.einsum("nj,nj->n", w, px.T) @@ -302,6 +302,11 @@ def from_bool_array( self.rel.update(self.ocedata._find_rel_v(self.dep)) self.rel.update(self.ocedata._find_rel_vl(self.dep)) + if isinstance(t, (int, float, np.floating)): + t = np.ones(self.N, float) * t + elif isinstance(t, np.ndarray): + if len(t)!=self.N: + raise ValueError("Mismatch between input t and final particle number") if t is not None: self.t = copy.deepcopy(t) if self.ocedata.readiness["time"]: diff --git a/seaduck/lagrangian.py b/seaduck/lagrangian.py index b838576e..738f0c48 100644 --- a/seaduck/lagrangian.py +++ b/seaduck/lagrangian.py @@ -101,7 +101,10 @@ def __init__( **kwarg, ): Position.__init__(self) - self.from_latlon(**kwarg) + if "bool_array" in kwarg.keys(): + self.from_bool_array(kwarg) + else: + self.from_latlon(**kwarg) try: self.px, self.py = self.get_px_py()