diff --git a/seaduck/eulerian.py b/seaduck/eulerian.py index 04dca3d5..faa2a5e4 100644 --- a/seaduck/eulerian.py +++ b/seaduck/eulerian.py @@ -5,10 +5,11 @@ from seaduck.get_masks import get_masked from seaduck.kernel_weight import KnW, _find_pk_4d, _translate_to_tendency -from seaduck.ocedata import HRel, OceData, RelCoord, TRel, VlRel, VRel +from seaduck.ocedata import HRel, OceData, RelCoord, TRel, VlLinRel, VlRel, VRel from seaduck.smart_read import smart_read from seaduck.utils import ( _general_len, + _read_h, find_px_py, get_key_by_value, local_to_latlon, @@ -193,13 +194,15 @@ def from_latlon(self, x=None, y=None, z=None, t=None, data=None): self.rel.update(VRel._make(None for i in range(4))) if self.ocedata.readiness["Zl"]: self.rel.update(self.ocedata._find_rel_vl(z)) + self.rel.update(self.ocedata._find_rel_vl_lin(z)) else: self.rel.update(VlRel._make(None for i in range(4))) + self.rel.update(VlLinRel._make(None for i in range(4))) else: self.rel.update(VRel._make(None for i in range(4))) self.rel.update(VlRel._make(None for i in range(4))) + self.rel.update(VlLinRel._make(None for i in range(4))) self.dep = None - if t is not None: self.t = copy.deepcopy(t) if self.ocedata.readiness["time"]: @@ -212,6 +215,109 @@ def from_latlon(self, x=None, y=None, z=None, t=None, data=None): # (self.it, self.rt, self.dt, self.bt, self.t) = (None for i in range(5)) return self + def from_bool_array( + self, t=None, data=None, bool_array=None, num=None, random_seed=None + ): + """Update/Generate new object with random points in given grid boxes. + + Use the methods from the ocedata to transform + from lat-lon-dep-time coords to rel-coords + store the output in the Position object. + + Parameters + ---------- + t: numpy.ndarray, float or None, default None + 1D array of the time coords + data: OceData object + The field where the Positions are defined on. + bool_array: numpy.ndarray, or xr.DataArray + Points are generated where it is True. + num: int + Total number of particles to seed (approximately). + random_seed: int optional + The random seed used for reproducible results. + """ + if random_seed is not None: + np.random.seed(random_seed) + + if isinstance(data, OceData): + if data.readiness["h"] != "oceanparcel": + raise NotImplementedError( + "This method only support datasets that has XG, YG in it." + ) + if not (data.readiness["Z"] and data.readiness["Zl"]): + raise NotImplementedError( + "This method only support datasets that has all Z coords." + ) + self.ocedata = data + else: + raise ValueError("Input data must be OceData") + try: + self.ocedata["Vol"] + except KeyError: + self.ocedata._add_missing_vol() + self.tp = self.ocedata.tp + + inds = np.where(bool_array) + try: + data["Vol"] + except KeyError: + data._add_missing_vol() + np_vol = np.array(data["Vol"]) + vols = np.array(np_vol[inds]) + num_each = np.round(vols * num / np.sum(vols)).astype(int) + num = np.sum(num_each) + self.N = num + inds = tuple(np.repeat(np.array(i), num_each) for i in inds) + if len(inds) == 3: + iz, iy, ix = inds + ind = (iy, ix) + face = None + elif len(inds) == 4: + iz, face, iy, ix = inds + ind = (face, iy, ix) + else: + raise ValueError("bool_array must be 3 or 4 dimensional") + cs, sn, dx, dy, bx, by = _read_h( + data.XC, data.YC, data.dX, data.dY, data.CS, data.SN, ind + ) + rx = np.random.random(num) - 0.5 + ry = np.random.random(num) - 0.5 + rzl_lin = np.random.random(num) + self.rel.update(HRel(face, iy, ix, rx, ry, cs, sn, dx, dy, bx, by)) + + izl_lin = iz + 1 + bzl_lin = data.Zl[izl_lin] + dzl_lin = data.dZl[izl_lin] + self.rel.update(VlLinRel(izl_lin, rzl_lin, dzl_lin, bzl_lin)) + + # set lon temporaily + self.lon = bx + px, py = self.get_px_py() + w = self.get_f_node_weight() + self.lon = np.einsum("nj,nj->n", w, px.T) + self.lat = np.einsum("nj,nj->n", w, py.T) + + self.dep = self.bzl_lin + self.dzl_lin * self.rzl_lin + self.rel.update(self.ocedata._find_rel_v(self.dep)) + self.rel.update(self.ocedata._find_rel_vl(self.dep)) + + if isinstance(t, (int, float, np.floating)): + t = np.ones(self.N, float) * t + elif isinstance(t, np.ndarray): + if len(t) != self.N: + raise ValueError("Mismatch between input t and final particle number") + if t is not None: + self.t = copy.deepcopy(t) + if self.ocedata.readiness["time"]: + self.rel.update(self.ocedata._find_rel_t(t)) + else: + self.rel.update(TRel._make(None for i in range(4))) + else: + self.rel.update(TRel._make(None for i in range(4))) + self.t = None + return self + def subset(self, which): """Create a subset of the Position object. @@ -317,10 +423,11 @@ def _fatten_h(self, knw, ind_moves_kwarg={}): x_disp, y_disp = node n_iys[:, i] = self.iy + y_disp n_ixs[:, i] = self.ix + x_disp + cuvwg = ind_moves_kwarg.get("cuvwg", "C") if self.face is not None: - illegal = tp.check_illegal((n_faces, n_iys, n_ixs)) + illegal = tp.check_illegal((n_faces, n_iys, n_ixs), cuvwg=cuvwg) else: - illegal = tp.check_illegal((n_iys, n_ixs)) + illegal = tp.check_illegal((n_iys, n_ixs), cuvwg=cuvwg) redo = np.array(np.where(illegal)).T for loc in redo: @@ -405,7 +512,7 @@ def _fatten_t(self, knw): return copy.deepcopy(self.it) elif knw.tkernel in ["dt", "linear"]: try: - self.izl_lin + self.it_lin except AttributeError: self.rel.update(self.ocedata._find_rel_t_lin(self.t)) return np.vstack([self.it_lin, self.it_lin + 1]).T @@ -748,9 +855,9 @@ def _fatten_required_index_and_register(self, hash_index, main_dict): main_key = get_key_by_value(hash_index, hs) var_name, dims, knw = main_dict[main_key] if isinstance(var_name, str): - old_dims = dims + old_dims = copy.deepcopy(dims) elif isinstance(var_name, tuple): - old_dims = dims[0] + old_dims = copy.deepcopy(dims[0]) dims = [] for i in old_dims: if i in ["Xp1", "Yp1"]: @@ -758,11 +865,11 @@ def _fatten_required_index_and_register(self, hash_index, main_dict): else: dims.append(i) dims = tuple(dims) + if "Xp1" in old_dims or "Yp1" in old_dims: + cuvwg = "G" + else: + cuvwg = "C" if isinstance(var_name, str): - if "Xp1" in old_dims and "Yp1" in old_dims: - cuvwg = "G" - else: - cuvwg = "C" ind = self.fatten( knw, required=dims, four_d=True, ind_moves_kwarg={"cuvwg": cuvwg} ) diff --git a/seaduck/eulerian_budget.py b/seaduck/eulerian_budget.py index 0a52ced0..eb9d294e 100644 --- a/seaduck/eulerian_budget.py +++ b/seaduck/eulerian_budget.py @@ -19,7 +19,8 @@ def _raise_if_no_xgcm(): ) -def create_ecco_grid(ds): +def create_ecco_grid(ds, for_outer=False): + _raise_if_no_xgcm() # pragma: no cover face_connections = { "face": { 0: {"X": ((12, "Y", False), (3, "X", False)), "Y": (None, (1, "Y", False))}, @@ -67,29 +68,43 @@ def create_ecco_grid(ds): }, } } + coords = { + "X": {"center": "X", "left": "Xp1"}, + "Y": {"center": "Y", "left": "Yp1"}, + "Z": {"center": "Z", "left": "Zl"}, + "time": {"center": "time", "inner": "time_midp"}, + } + if for_outer: + coords["Z"] = {"center": "Z", "outer": "Zl"} + xgcmgrd = xgcm.Grid( + ds, periodic=False, face_connections=face_connections, coords=coords + ) + return xgcmgrd + - grid = xgcm.Grid( +def create_periodic_grid(ds): + _raise_if_no_xgcm() # pragma: no cover + xgcmgrd = xgcm.Grid( ds, - periodic=False, - face_connections=face_connections, + periodic=["X"], coords={ - "X": {"center": "X", "left": "Xp1"}, - "Y": {"center": "Y", "left": "Yp1"}, + "X": {"center": "X", "outer": "Xp1"}, + "Y": {"center": "Y", "outer": "Yp1"}, "Z": {"center": "Z", "left": "Zl"}, - "time": {"center": "time", "inner": "time_midp"}, + "time": {"center": "time", "outer": "time_outer"}, }, ) - return grid + return xgcmgrd -def hor_div(tub, grid, xfluxname, yfluxname): +def hor_div(tub, xgcmgrd, xfluxname, yfluxname): """Calculate horizontal divergence using xgcm. Parameters ---------- tub: sd.OceData or xr.Dataset The dataset to calculate data from - grid: xgcm.Grid + xgcmgrd: xgcm.Grid The Grid of the dataset xfluxname, yfluxname: string The name of the variables corresponding to the horizontal fluxes @@ -99,7 +114,7 @@ def hor_div(tub, grid, xfluxname, yfluxname): tub["Vol"] except KeyError: tub._add_missing_vol() - xy_diff = grid.diff_2d_vector( + xy_diff = xgcmgrd.diff_2d_vector( {"X": tub[xfluxname].fillna(0), "Y": tub[yfluxname].fillna(0)}, boundary="fill", fill_value=0.0, @@ -110,14 +125,14 @@ def hor_div(tub, grid, xfluxname, yfluxname): return hConv -def ver_div(tub, grid, zfluxname): +def ver_div(tub, xgcmgrd, zfluxname): """Calculate horizontal divergence using xgcm. Parameters ---------- tub: sd.OceData or xr.Dataset The dataset to calculate data from - grid: xgcm.Grid + xgcmgrd: xgcm.Grid The Grid of the dataset xfluxname, yfluxname, zfluxname: string The name of the variables corresponding to the fluxes @@ -128,30 +143,51 @@ def ver_div(tub, grid, zfluxname): except KeyError: tub._add_missing_vol() vConv = ( - grid.diff(tub[zfluxname].fillna(0), "Z", boundary="fill", fill_value=0.0) + xgcmgrd.diff(tub[zfluxname].fillna(0), "Z", boundary="fill", fill_value=0.0) / tub["Vol"] ) return vConv -def total_div(tub, grid, xfluxname, yfluxname, zfluxname): +def total_div(tub, xgcmgrd, xfluxname, yfluxname, zfluxname): """Calculate 3D divergence using xgcm. Parameters ---------- tub: sd.OceData or xr.Dataset The dataset to calculate data from - grid: xgcm.Grid + xgcmgrd: xgcm.Grid The Grid of the dataset zfluxname: string The name of the variables corresponding to the vertical flux in concentration m^3/s """ - hConv = hor_div(tub, grid, xfluxname, yfluxname) - vConv = ver_div(tub, grid, zfluxname) + hConv = hor_div(tub, xgcmgrd, xfluxname, yfluxname) + vConv = ver_div(tub, xgcmgrd, zfluxname) return hConv + vConv +def bolus_vel_from_psi(tub, xgcmgrd, psixname="GM_PsiX", psiyname="GM_PsiY"): + strmx = tub[psixname].fillna(0) + strmy = tub[psiyname].fillna(0) + + u = xgcmgrd.diff(strmx, "Z", boundary="fill", fill_value=0.0) / tub["drF"] + v = xgcmgrd.diff(strmy, "Z", boundary="fill", fill_value=0.0) / tub["drF"] + + vstrmx = strmx * tub["dyG"] + vstrmy = strmy * tub["dxG"] + + xy_diff = xgcmgrd.diff_2d_vector( + {"X": vstrmx, "Y": vstrmy}, boundary="fill", fill_value=0.0 + ) + x_diff = xy_diff["X"] + y_diff = xy_diff["Y"] + hDiv = x_diff + y_diff + + w = hDiv / tub["rA"] + return u, v, w + + def _slice_corner(array, fc, iy1, iy2, ix1, ix2): left = np.minimum(ix1, ix2) righ = np.maximum(ix1, ix2) @@ -263,11 +299,115 @@ def buffer_y_withface(s, face, lm, rm, tp): return ybuffer +def buffer_x_periodic(s, lm, rm): + shape = list(s.shape) + shape[-1] += lm + rm + xbuffer = np.zeros(shape) + xbuffer[..., lm : shape[-1] - rm] = s + if lm > 0: + xbuffer[..., :lm] = s[..., -lm:] + if rm > 0: + xbuffer[..., -rm:] = s[..., :rm] + return xbuffer + + +def buffer_y_periodic(s, lm, rm): + shape = list(s.shape) + shape[-2] += lm + rm + ybuffer = np.zeros(shape) + ybuffer[..., lm : shape[-2] - rm, :] = s + if lm > 0: + ybuffer[..., :lm, :] = s[..., -lm:, :] + if rm > 0: + ybuffer[..., -rm:, :] = s[..., :rm, :] + return ybuffer + + +def buffer_z_nearest_withoutface(s, lm, rm): + shape = list(s.shape) + shape[-3] += lm + rm + zbuffer = np.zeros(shape) + zbuffer[..., lm : shape[-3] - rm, :, :] = s + if lm > 0: + zbuffer[..., :lm, :, :] = s[..., :1, :, :] + if rm > 0: + zbuffer[..., -rm:, :, :] = s[..., -1:, :, :] + return zbuffer + + +def _slope_ratio(Rjm, Rj, Rjp, u_sign, not_z=1): + """Calculate slope ratio for flux limiter.""" + cr_max = 1e6 # doesn't matter + cr = np.zeros_like(u_sign) + pos = not_z * u_sign > 0 + neg = not_z * u_sign <= 0 + cr[pos] = Rjm[pos] + cr[neg] = Rjp[neg] + zero_divide = np.abs(Rj) * cr_max <= np.abs(cr) + cr[zero_divide] = np.sign(cr[zero_divide]) * np.sign(u_sign[zero_divide]) * cr_max + cr[~zero_divide] = cr[~zero_divide] / Rj[~zero_divide] + return cr + + +def superbee_fluxlimiter(cr): + return np.maximum(0.0, np.maximum(np.minimum(1.0, 2 * cr), np.minimum(2.0, cr))) + + +def second_order_flux_limiter_x(s_center, u_cfl): + xbuffer = buffer_x_periodic(s_center, 2, 2) + deltas = np.nan_to_num(np.diff(xbuffer, axis=-1), 0) + Rjp = deltas[..., 2:] + Rj = deltas[..., 1:-1] + Rjm = deltas[..., :-2] + + cr = _slope_ratio(Rjm, Rj, Rjp, u_cfl) + limiter = superbee_fluxlimiter(cr) + swall = ( + np.nan_to_num(xbuffer[..., 1:-2] + xbuffer[..., 2:-1]) * 0.5 + - np.sign(u_cfl) * ((1 - limiter) + u_cfl * limiter) * Rj * 0.5 + ) + return swall + + +def second_order_flux_limiter_y(s_center, u_cfl): + ybuffer = buffer_y_periodic(s_center, 2, 2) + deltas = np.nan_to_num(np.diff(ybuffer, axis=-2), 0) + Rjp = deltas[..., 2:, :] + Rj = deltas[..., 1:-1, :] + Rjm = deltas[..., :-2, :] + + cr = _slope_ratio(Rjm, Rj, Rjp, u_cfl) + limiter = superbee_fluxlimiter(cr) + swall = ( + np.nan_to_num(ybuffer[..., 1:-2, :] + ybuffer[..., 2:-1, :]) * 0.5 + - np.sign(u_cfl) * ((1 - limiter) + u_cfl * limiter) * Rj * 0.5 + ) + return swall + + +def second_order_flux_limiter_z_withoutface(s_center, u_cfl): + zbuffer = buffer_z_nearest_withoutface(s_center, 2, 1) + deltas = np.nan_to_num(np.diff(zbuffer, axis=-3), 0) + Rjp = deltas[..., 2:, :, :] + Rj = deltas[..., 1:-1, :, :] + Rjm = deltas[..., :-2, :, :] + + cr = _slope_ratio(Rjm, Rj, Rjp, u_cfl, not_z=-1) + limiter = superbee_fluxlimiter(cr) + # swall = np.nan_to_num(zbuffer[...,1:-2,:,:]+zbuffer[...,2:-1,:,:])*0.5- np.sign(u_cfl)**Rj*0.5 + swall = ( + np.nan_to_num(zbuffer[..., 1:-2, :, :] + zbuffer[..., 2:-1, :, :]) * 0.5 + + np.sign(u_cfl) * ((1 - limiter) + u_cfl * limiter) * Rj * 0.5 + ) + return swall + + def third_order_upwind_z(s, w): """Get interpolated salinity in the vertical. for more info, see https://mitgcm.readthedocs.io/en/latest/algorithm/adv-schemes.html#third-order-upwind-bias-advection + This function currently only work when there is no through surface flux. Parameters ---------- diff --git a/seaduck/lagrangian.py b/seaduck/lagrangian.py index 353cfe93..3eeb2d26 100644 --- a/seaduck/lagrangian.py +++ b/seaduck/lagrangian.py @@ -68,10 +68,12 @@ class Particle(Position): If transport is true, pass in names of the volume/mass transport across cell wall in m^3/3 else, just pass something that is in m/s - dont_fly: Boolean + free_surface: string Sometimes there is non-zero vertical velocity at sea surface. - dont_fly = True set that to zero. - An error may occur depends on the situation if set otherwise. + free_surface = "noflux" set that to zero. + free_surface = "kick_back" move particles trying to cross back + to the middle of the cell. + There could be errors if neither is used. save_raw: Boolean Whether to record the analytical history of all particles in an unstructured list. @@ -93,7 +95,7 @@ def __init__( uname="UVELMASS", vname="VVELMASS", wname="WVELMASS", - dont_fly=True, + free_surface="noflux", save_raw=False, transport=False, callback=None, @@ -101,13 +103,11 @@ def __init__( **kwarg, ): Position.__init__(self) - self.from_latlon(**kwarg) - if self.ocedata.readiness["Zl"] and kwarg.get("z") is not None: - self.rel.update(self.ocedata._find_rel_vl_lin(self.dep)) + if "bool_array" in kwarg.keys(): + self.from_bool_array(**kwarg) else: - (self.izl_lin, self.rzl_lin, self.dzl_lin, self.bzl_lin) = ( - None for i in range(4) - ) + self.from_latlon(**kwarg) + try: self.px, self.py = self.get_px_py() except AttributeError: @@ -139,19 +139,20 @@ def __init__( self.transport = transport if self.transport: try: - self.ocedata["Vol"] - except KeyError: + assert isinstance(self.ocedata["Vol"], np.ndarray) + except (AssertionError, KeyError): self.ocedata._add_missing_vol(as_numpy=True) # whether or not setting the w at the surface # just to prevent particles taking off - self.dont_fly = dont_fly - if dont_fly: + self.free_surface = free_surface + if free_surface == "noflux": if wname is not None: logging.warning( - "Setting the surface velocity to zero. " "Dataset modified. " + "Setting the surface velocity to zero. " + "Dataset might be modified. " ) - self.ocedata[wname].loc[{"Zl": 0}] = 0 + self.ocedata[self.wname].loc[{"Zl": 0}] = 0 self.too_large = self.ocedata.too_large self.max_iteration = max_iteration @@ -625,7 +626,15 @@ def _cross_cell_wall_index(self, tend): type2 = tend == 4 tiz[type2] += 1 type3 = tend == 5 - tiz[type3] -= 1 + if self.free_surface != "kick_back": + tiz[type3] -= 1 + else: + going2fly = self.izl_lin == 1 + type3a = np.logical_and(going2fly, type3) + type3b = np.logical_and(~going2fly, type3) + tiz[type3b] -= 1 + self.dep[type3a] = self.bzl_lin[type3a] + self.dzl_lin[type3a] / 2 + self.izl_lin = tiz def _cross_cell_wall_read(self): @@ -867,7 +876,7 @@ def to_list_of_time( to_return = [] for i, tl in enumerate(stops): timestr = str(np.datetime64(round(tl), "s")) - # logging.info(timestr) + logging.info(timestr) print(timestr) if self.save_raw: # save the very start of everything. diff --git a/seaduck/lagrangian_budget.py b/seaduck/lagrangian_budget.py index d2a92eab..28eb391d 100644 --- a/seaduck/lagrangian_budget.py +++ b/seaduck/lagrangian_budget.py @@ -34,8 +34,11 @@ def read_from_ds(particle_ds, oce): temp.tp = temp.ocedata.tp # it = np.array(particle_ds.it) + if oce.tp.typ in ["LLC"]: + temp.face = np.array(particle_ds.fc).astype(int) + else: + temp.face = None izl = np.array(particle_ds.iz) - fc = np.array(particle_ds.fc) iy = np.array(particle_ds.iy) ix = np.array(particle_ds.ix) rzl = np.array(particle_ds.rz) @@ -45,7 +48,6 @@ def read_from_ds(particle_ds, oce): # temp.it = it .astype(int) temp.izl_lin = izl.astype(int) temp.iz = (izl - 1).astype(int) - temp.face = fc.astype(int) temp.iy = iy.astype(int) temp.ix = ix.astype(int) temp.rzl_lin = rzl @@ -148,8 +150,6 @@ def tres_update(tres0, temp, first, last, fraction_first, fraction_last): tres = tres0 * fracs tres[temp.vs > 6] = 0.0 - # mask = np.logical_and(tres==0, temp.vs<7) - # assert (tres[temp.vs<7]>0).all(), (tres0[mask],fracs[mask], fracs_a[mask], np.where(mask)) return tres @@ -182,40 +182,46 @@ def deepcopy_inds(temp): iz = copy.deepcopy(temp.izl_lin) iy = copy.deepcopy(temp.iy) ix = copy.deepcopy(temp.ix) - face = copy.deepcopy(temp.face) - # assert (iz>=1).all(),iz - return iz, face, iy, ix + if temp.face is not None: + face = copy.deepcopy(temp.face) + return iz, face, iy, ix + else: + return iz, iy, ix def wall_index(inds, iwall, tp): iw = iwall // 2 - iz, face, iy, ix = copy.deepcopy(inds) - # assert (iz>=1).all(),iz + iz = copy.deepcopy(inds[0]) + iy = copy.deepcopy(inds[-2]) + ix = copy.deepcopy(inds[-1]) - ind = copy.deepcopy(np.array([face, iy, ix])) + ind = np.array(inds[1:]) old_ind = copy.deepcopy(ind) naive_move = np.array([MOVE_DIC[i] for i in iwall], dtype=int).T + ind[-2] += naive_move[0] # iy iy += naive_move[0] + ind[-1] += naive_move[1] # ix ix += naive_move[1] - ind = np.array([face, iy, ix]) - illegal = tp.check_illegal(ind, cuvwg="C") - redo = np.array(np.where(illegal)).T - for num, loc in enumerate(redo): - j = loc[0] - ind = (iw[j],) + tuple(old_ind[:, j]) - new_ind = ind_tend_uv(ind, tp) - iw[j], face[j], iy[j], ix[j] = new_ind iz[iwall == 4] += 1 iz -= 1 - return np.array([iw, iz, face, iy, ix]).astype(int) + if tp.typ in ["LLC"]: + face = copy.deepcopy(inds[-3]) + illegal = tp.check_illegal(ind, cuvwg="G") + redo = np.array(np.where(illegal)).T + for num, loc in enumerate(redo): + j = loc[0] + ind = (iw[j],) + tuple(old_ind[:, j]) + new_ind = ind_tend_uv(ind, tp) + iw[j], face[j], iy[j], ix[j] = new_ind + + return np.array([iw, iz, face, iy, ix]).astype(int) + else: + return np.array([iw, iz, ind[-2], ind[-1]]).astype(int) def redo_index(pt): - # assert (pt.izl_lin>=1).all() inds = deepcopy_inds(pt) - iz, face, iy, ix = inds - # assert (iz>=1).all(),iz tendf, tf, tendb, tb = pseudo_motion(pt) funderflow = np.where(tendf == 6) @@ -224,20 +230,19 @@ def redo_index(pt): tendb[bunderflow] = 0 vf = wall_index(inds, tendf, pt.ocedata.tp) vb = wall_index(inds, tendb, pt.ocedata.tp) - # vf[:,funderflow] = vb[:,funderflow] - # vb[:,bunderflow] = vf[:,bunderflow] tim = tf - tb frac = -tb / tim assert (~np.isnan(tim)).any(), [ i[np.isnan(tim)] for i in [pt.rx, pt.ry, pt.rzl_lin - 1 / 2] ] - assert (tim != 0).all(), [i[tim == 0] for i in [pt.rx, pt.ry, pt.rzl_lin - 1 / 2]] - at_corner = np.where(tim == 0) - frac[at_corner] = 1 + # assert (tim != 0).all(), [i[tim == 0] for i in [pt.rx, pt.ry, pt.rzl_lin - 1 / 2]] + # at_corner = np.where(tim == 0) + # frac[at_corner] = 1 + frac = np.nan_to_num(frac, nan=1) return vf, vb, frac -def find_ind_frac_tres(neo, oce, region_names=False, region_polys=None): +def find_ind_frac_tres(neo, oce, region_names=False, region_polys=None, by_type=True): temp = read_from_ds(neo, oce) temp.shapes = list(temp.shapes) if region_names: @@ -246,23 +251,32 @@ def find_ind_frac_tres(neo, oce, region_names=False, region_polys=None): mask = parallelpointinpolygon(temp.lon, temp.lat, reg) # mask = np.where(mask)[0] masks.append(mask) + first, last, neither = first_last_neither(np.array(temp.shapes)) + if temp.face is not None: + num_ind = 5 + else: + num_ind = 4 - ind1 = np.zeros((5, temp.N), "int16") - ind2 = np.ones((5, temp.N), "int16") - frac = np.ones(temp.N) + if by_type: + ind1 = np.zeros((num_ind, temp.N), "int16") + ind2 = np.ones((num_ind, temp.N), "int16") + frac = np.ones(temp.N) - # ind1[:, wrong_ind] = lookup[:, lookup_ind] + # ind1[:, wrong_ind] = lookup[:, lookup_ind] - neithers = temp.subset(neither) - neither_inds = deepcopy_inds(neithers) - iwalls = which_wall(neithers) - ind1[:, neither] = wall_index(neither_inds, iwalls, temp.ocedata.tp) + if len(neither > 0): + neithers = temp.subset(neither) + neither_inds = deepcopy_inds(neithers) + iwalls = which_wall(neithers) + ind1[:, neither] = wall_index(neither_inds, iwalls, temp.ocedata.tp) - firsts = temp.subset(first) - lasts = temp.subset(last) - ind1[:, first], ind2[:, first], frac[first] = redo_index(firsts) - ind1[:, last], ind2[:, last], frac[last] = redo_index(lasts) + firsts = temp.subset(first) + lasts = temp.subset(last) + ind1[:, first], ind2[:, first], frac[first] = redo_index(firsts) + ind1[:, last], ind2[:, last], frac[last] = redo_index(lasts) + else: + ind1, ind2, frac = redo_index(temp) tres = tres_fraction(temp, first, last, frac[first], frac[last]) if region_names: @@ -285,7 +299,8 @@ def flatten(lstoflst, shapes=None): def particle2xarray(p): shapes = [len(i) for i in p.xxlist] # it = flatten(p.itlist,shapes = shapes) - fc = flatten(p.fclist, shapes=shapes) + if p.face is not None: + fc = flatten(p.fclist, shapes=shapes) iy = flatten(p.iylist, shapes=shapes) iz = flatten(p.izlist, shapes=shapes) ix = flatten(p.ixlist, shapes=shapes) @@ -308,7 +323,6 @@ def particle2xarray(p): coords=dict(shapes=(["shapes"], shapes), nprof=(["nprof"], np.arange(len(xx)))), data_vars=dict( # it = (['nprof'],it), - fc=(["nprof"], fc), iy=(["nprof"], iy), iz=(["nprof"], iz), ix=(["nprof"], ix), @@ -328,10 +342,14 @@ def particle2xarray(p): vs=(["nprof"], vs), ), ) + if p.face is not None: + ds["fc"] = xr.DataArray(fc, dims="nprof") return ds -def dump_to_zarr(neo, oce, filename, region_names=False, region_polys=None): +def dump_to_zarr( + neo, oce, filename, region_names=False, region_polys=None, preserve_checks=False +): if region_names: (ind1, ind2, frac, masks, tres, last, first) = find_ind_frac_tres( neo, oce, region_names=region_names, region_polys=region_polys @@ -339,7 +357,11 @@ def dump_to_zarr(neo, oce, filename, region_names=False, region_polys=None): else: ind1, ind2, frac, tres, last, first = find_ind_frac_tres(neo, oce) - neo["five"] = xr.DataArray(["iw", "iz", "face", "iy", "ix"], dims="five") + if oce.tp.typ in ["LLC"]: + neo["face"] = neo["fc"].astype("int16") + neo["five"] = xr.DataArray(["iw", "iz", "face", "iy", "ix"], dims="five") + else: + neo["five"] = xr.DataArray(["iw", "iz", "iy", "ix"], dims="five") if region_names: for ir, reg in enumerate(region_names): neo[reg] = xr.DataArray(masks[ir].astype(bool), dims="nprof") @@ -360,22 +382,29 @@ def dump_to_zarr(neo, oce, filename, region_names=False, region_polys=None): # neo['last'] = xr.DataArray(last.astype('int64'), dims = 'shapes') # neo['first'] = xr.DataArray(first.astype('int64'), dims = 'shapes') - neo["face"] = neo["fc"].astype("int16") neo["ix"] = neo["ix"].astype("int16") neo["iy"] = neo["iy"].astype("int16") neo["iz"] = neo["iz"].astype("int16") neo["vs"] = neo["vs"].astype("int16") - neo = neo.drop_vars(["rx", "ry", "rz", "uu", "vv", "ww", "du", "dv", "dw", "fc"]) + if not preserve_checks: + neo = neo.drop_vars(["rx", "ry", "rz", "uu", "vv", "ww", "du", "dv", "dw"]) + if "fc" in neo.data_vars: + neo = neo.drop_vars(["fc"]) neo.to_zarr(filename, mode="w") zarr.consolidate_metadata(filename) -def store_lists(pt, name, region_names=False, region_polys=None): +def store_lists(pt, name, region_names=False, region_polys=None, **kwarg): neo = particle2xarray(pt) dump_to_zarr( - neo, pt.ocedata, name, region_names=region_names, region_polys=region_polys + neo, + pt.ocedata, + name, + region_names=region_names, + region_polys=region_polys, + **kwarg ) @@ -387,8 +416,131 @@ def prefetch_scalar(ds_slc, scalar_names): return prefetch -def prefetch_vector(ds_slc, xname="sxprime", yname="syprime", zname="szprime"): - return np.array(tuple(np.array(ds_slc[i]) for i in [xname, yname, zname])) +def read_wall_list(neo, tp, prefetch=None, scalar=True): + if "face" not in neo.data_vars: + ind = (neo.iz - 1, neo.iy, neo.ix) + deep_ind = (neo.iz, neo.iy, neo.ix) + right_ind = tuple( + [neo.iz - 1] + + list( + tp.ind_tend_vec( + (neo.iy, neo.ix), np.ones(len(neo.nprof)) * 3, cuvwg="G" + ) + ) + ) + up_ind = tuple( + [neo.iz - 1] + + list(tp.ind_tend_vec((neo.iy, neo.ix), np.ones(len(neo.nprof)) * 0)) + ) + uarray, varray, warray = prefetch + ur = uarray[right_ind] + vr = varray[up_ind] + else: + ind = (neo.iz - 1, neo.face, neo.iy, neo.ix) + deep_ind = (neo.iz, neo.face, neo.iy, neo.ix) + right_ind = tuple( + [neo.iz - 1] + + list( + tp.ind_tend_vec((neo.face, neo.iy, neo.ix), np.ones(len(neo.nprof)) * 3) + ) + ) + up_ind = tuple( + [neo.iz - 1] + + list( + tp.ind_tend_vec((neo.face, neo.iy, neo.ix), np.ones(len(neo.nprof)) * 0) + ) + ) + uarray, varray, warray = prefetch + + ur_temp = np.nan_to_num(uarray[right_ind]) + vr_temp = np.nan_to_num(varray[right_ind]) + uu_temp = np.nan_to_num(uarray[up_ind]) + vu_temp = np.nan_to_num(varray[up_ind]) + right_faces = np.vstack([ind[1], right_ind[1]]).T + up_faces = np.vstack([ind[1], up_ind[1]]).T + ufromu, ufromv, _, _ = tp.four_matrix_for_uv(right_faces) + _, _, vfromu, vfromv = tp.four_matrix_for_uv(up_faces) + if scalar: + ufromu, ufromv, vfromu, vfromv = ( + np.abs(i) for i in [ufromu, ufromv, vfromu, vfromv] + ) + + ur = ur_temp * ufromu[:, 1] + vr_temp * ufromv[:, 1] + vr = uu_temp * vfromu[:, 1] + vu_temp * vfromv[:, 1] + ul = uarray[ind] + vl = varray[ind] + wr = warray[ind] + wl = warray[deep_ind] + return np.array([np.nan_to_num(i) for i in (ul, ur, vl, vr, wl, wr)]) + + +def crude_convergence(u_list): + conv = np.array(u_list).T * np.array([1, -1, 1, -1, 1, -1]) + conv = np.sum(conv, axis=-1) + return conv + + +def check_particle_data_compat( + xrpt, + xrslc, + tp, + use_tracer_name=None, + wall_names=("sx", "sy", "sz"), + conv_name="divus", + debug=False, +): + if "iz" not in xrpt.data_vars: + raise NotImplementedError( + "This functionality only support 3D simulation at the moment." + ) + if isinstance(use_tracer_name, str): + wall_names = tuple(use_tracer_name + i for i in ["x", "y", "z"]) + conv_name = "divu" + use_tracer_name + elif use_tracer_name is not None: + raise ValueError("use_tracer_name has to be a string.") + prefetch = [] + for var in wall_names: + prefetch.append(np.array(xrslc[var])) + c_list = read_wall_list(xrpt, tp, prefetch) + + ul, ur = _uleftright_from_udu(xrpt.uu, xrpt.du, xrpt.rx) + vl, vr = _uleftright_from_udu(xrpt.vv, xrpt.dv, xrpt.ry) + wl, wr = _uleftright_from_udu(xrpt.ww, xrpt.dw, xrpt.rz - 0.5) + u_list = np.array([np.array(i) for i in [ul, ur, vl, vr, wl, wr]]) + + flux_list = c_list * u_list + lagrangian_conv = crude_convergence(flux_list) + + if "face" in xrpt.data_vars: + ind = (xrpt.iz - 1, xrpt.face, xrpt.iy, xrpt.ix) + else: + ind = (xrpt.iz - 1, xrpt.iy, xrpt.ix) + eulerian_conv = np.array(xrslc[conv_name])[ind] + if debug: + extra = (u_list, c_list, lagrangian_conv, eulerian_conv) + else: + extra = None + return np.allclose(lagrangian_conv, eulerian_conv), extra + + +def prefetch_vector( + ds_slc, xname="sxprime", yname="syprime", zname="szprime", same_size=True +): + if same_size: + return np.array(tuple(np.array(ds_slc[i]) for i in [xname, yname, zname])) + else: + xx = np.array(ds_slc[xname]) + yy = np.array(ds_slc[yname]) + zz = np.array(ds_slc[zname]) + shape = (3,) + tuple( + int(np.max([ar.shape[j] for ar in [xx, yy, zz]])) + for j in range(len(xx.shape)) + ) + larger = np.empty(shape) + larger[(0,) + tuple(slice(i) for i in xx.shape)] = xx + larger[(1,) + tuple(slice(i) for i in yy.shape)] = yy + larger[(2,) + tuple(slice(i) for i in zz.shape)] = zz + return larger def read_prefetched_scalar(ind, scalar_names, prefetch): @@ -406,12 +558,12 @@ def lhs_contribution(t, scalar_dic, last, lhs_name="lhs"): return correction -def contr_p_relaxed(deltas, tres, step_dic, termlist, p=1): +def contr_p_relaxed(deltas, tres, step_dic, termlist, p=1, error_prefix=""): nds = len(deltas) # if len(wrong_ind)>0: # if wrong_ind[-1] == len(deltas): # wrong_ind = wrong_ind[:-1] - dic = {"error": np.zeros(nds)} + dic = {error_prefix + "error": np.zeros(nds)} # dic['error'][wrong_ind] = deltas[wrong_ind] # deltas[wrong_ind] = 0 # tres[wrong_ind] = 0 @@ -430,6 +582,6 @@ def contr_p_relaxed(deltas, tres, step_dic, termlist, p=1): dic[var] = step_dic[var][:-1] * tres + ratio * disparity total += dic[var] final_correction = deltas - total - assert np.allclose(final_correction, 0) - dic["error"] += final_correction + # assert np.allclose(final_correction, 0) + dic[error_prefix + "error"] += final_correction return dic diff --git a/seaduck/ocedata.py b/seaduck/ocedata.py index c8752001..497ba7dd 100644 --- a/seaduck/ocedata.py +++ b/seaduck/ocedata.py @@ -162,8 +162,22 @@ class OceData: """ def __init__(self, data, alias=None, memory_limit=1e7): - self._ds = data - self.tp = Topology(data) + self._ds = data.transpose( + "time", + "time_midp", + "time_outer", + "Z", + "Zl", + "Zp1", + "face", + "Y", + "Yp1", + "X", + "Xp1", + ..., + missing_dims="ignore", + ) + self.tp = Topology(self._ds) if alias is None: self.alias = NO_ALIAS elif alias == "auto": @@ -184,6 +198,20 @@ def __init__(self, data, alias=None, memory_limit=1e7): f"use add_missing_variables or set_alias to create {missing}," "then call OceData.grid2array." ) + if self.readiness["Zl"]: + # make a more consistent vector definition + with_zl = [i for i in self._ds.data_vars if "Zl" in self._ds[i].dims] + if len(with_zl) > 0: + without_zl = [i for i in self._ds.data_vars if i not in with_zl] + bottom_buffer = xr.zeros_like(self._ds[with_zl].isel(Zl=slice(1))) + bottom_buffer["Zl"] = [self.Zl[-1]] + neods = xr.merge( + [ + xr.concat([self._ds[with_zl], bottom_buffer], dim="Zl"), + self._ds[without_zl], + ] + ) + self._ds = neods def __setitem__(self, key, item): if isinstance(item, xr.DataArray): @@ -275,16 +303,18 @@ def _add_missing_cs_sn(self): assert self["SN"] is not None assert self["CS"] is not None except (AttributeError, AssertionError): - cs, sn = missing_cs_sn(self) + cs, sn = missing_cs_sn(self._ds) self["CS"] = cs self["SN"] = sn def _add_missing_vol(self, as_numpy=False): if self.readiness["Zl"]: vol = self._ds["drF"] * self._ds["rA"] + if "HFacC" in self._ds.data_vars: + vol *= self._ds["HFacC"] else: vol = self._ds["rA"] - + vol = vol.fillna(0) if as_numpy: self["Vol"] = np.array(vol) else: @@ -351,19 +381,21 @@ def _vgrid2array(self): self.dZ = np.array(self["dZ"], dtype="float32") except KeyError: self.dZ = np.diff(self.Z) - self.dZ = np.append(self.dZ, self.dZ[-1]) + self.dZ = np.append(self.dZ, self.dZ[-1]).astype("float32") def _vlgrid2array(self): """Extract the vertical staggered point grid data into numpy arrays.""" - self.Zl = np.array(self["Zl"], dtype="float32") + if "Zp1" in self._ds.variables: + self.Zl = np.array(self["Zp1"], dtype="float32") + else: + self.Zl = np.zeros(len(self["Zl"]) + 1) + self.Zl[:-1] = np.array(self._ds["Zl"]) + self.Zl[-1] = 2 * self.Zl[-2] - self.Zl[-3] + self.Zl = self.Zl.astype("float32") try: self.dZl = np.array(self["dZl"], dtype="float32") except KeyError: - if "Zp1" in self._ds.variables: - self.dZl = np.diff(np.array(self["Zp1"])) - else: - self.dZl = np.diff(self.Zl) - self.dZl = np.append(self.dZl, self.dZl[-1]) + self.dZl = np.diff(self.Zl).astype("float32") # special treatment for dZl # self.dZl = np.roll(self.dZl,1) @@ -373,7 +405,8 @@ def _tgrid2array(self): """Extract the temporal grid data into numpy arrays.""" self.t_base = 0 self.ts = np.array(self["time"]) - self.ts = (self.ts).astype(float) / 1e9 + if self["time"].dtype != "float": + self.ts = (self.ts).astype(float) / 1e9 try: self.time_midp = np.array(self["time_midp"]) self.time_midp = (self.time_midp).astype(float) / 1e9 diff --git a/seaduck/topology.py b/seaduck/topology.py index d2182e33..02e7a8fa 100644 --- a/seaduck/topology.py +++ b/seaduck/topology.py @@ -279,7 +279,7 @@ def __init__(self, od, typ=None): if "XG" in od.variables: self.g_shape = od["XG"].shape else: - self.g_shape = None + self.g_shape = self.h_shape try: self.itmax = len(od["time"]) - 1 except (KeyError, TypeError): @@ -402,7 +402,7 @@ def ind_tend(self, ind, tend, cuvwg="C", **kwarg): else: raise ValueError("The type of grid point should be among C,U,V,G") elif self.typ == "x_periodic": - to_return = _box_ind_tend(ind, tend, self.iymax, self.ixmax, **kwarg) + to_return = _x_per_ind_tend(ind, tend, self.iymax, self.ixmax, **kwarg) elif self.typ == "box": to_return = _box_ind_tend(ind, tend, self.iymax, self.ixmax, **kwarg) else: @@ -486,8 +486,7 @@ def check_illegal(self, ind, cuvwg="C"): result = True return result else: # for numpy ndarray - result = np.zeros_like(ind[0]) - result = False # make it cleaner + result = False for i, z in enumerate(ind): max_pos = the_shape[i] result = np.logical_or( diff --git a/seaduck/utils.py b/seaduck/utils.py index 682de80a..9e0d956b 100644 --- a/seaduck/utils.py +++ b/seaduck/utils.py @@ -671,13 +671,10 @@ def missing_cs_sn(ds, return_xr=False): cs[-1], sn[-1] = find_cs_sn(yc[-2], xc[-2], yc[-1], xc[-1]) cs[1:-1], sn[1:-1] = find_cs_sn(yc[:-2], xc[:-2], yc[2:], xc[2:]) if return_xr: - ds["CS"] = ds["XC"] - ds["CS"].values = cs + cs = xr.DataArray(cs, dims=ds["XC"].dims) + sn = xr.DataArray(sn, dims=ds["XC"].dims) - ds["SN"] = ds["XC"] - ds["SN"].values = sn - - return ds + return cs, sn else: return cs, sn diff --git a/tests/test_eulerian.py b/tests/test_eulerian.py index 55f990c1..83f8f14e 100644 --- a/tests/test_eulerian.py +++ b/tests/test_eulerian.py @@ -222,9 +222,8 @@ def test_partial_flatten(): assert flattened[0].shape == (3, 1) -@pytest.mark.parametrize("ds", ["ecco"], indirect=True) @pytest.mark.parametrize("od", ["ecco"], indirect=True) -def test_wvel_quant_deepest(ds, od): +def test_wvel_quant_deepest(od): ind = (11, 75, 73) face, iy, ix = ind @@ -238,16 +237,15 @@ def test_wvel_quant_deepest(ds, od): assert vert_p.iy[0] == iy, "horizontal index does not match" seaduck_ans = vert_p.interpolate("WVELMASS1", sd.lagrangian.wknw) - wvel = interp1d(od.Zl, ds.WVELMASS1[:, face, iy, ix]) + wvel = interp1d(od.Zl, od._ds.WVELMASS1[:, face, iy, ix]) scipy_ans = wvel(z) assert np.allclose(scipy_ans, seaduck_ans) -@pytest.mark.parametrize("ds", ["ecco"], indirect=True) @pytest.mark.parametrize("od", ["ecco"], indirect=True) @pytest.mark.parametrize("seed", list(range(5))) -def test_wvel_quant_random_place(ds, od, seed): +def test_wvel_quant_random_place(od, seed): np.random.seed(seed) z = np.random.uniform(od.Zl[-1], 0, 50) x = np.random.uniform(-180, 180, 1) * np.ones_like(z) @@ -260,15 +258,14 @@ def test_wvel_quant_random_place(ds, od, seed): iy = vert_p.iy[0] ix = vert_p.ix[0] - wvel = interp1d(od.Zl, ds.WVELMASS1[:, face, iy, ix]) + wvel = interp1d(od.Zl, od._ds.WVELMASS1[:, face, iy, ix]) scipy_ans = wvel(z) assert np.allclose(scipy_ans, seaduck_ans) -@pytest.mark.parametrize("ds", ["ecco"], indirect=True) @pytest.mark.parametrize("od", ["ecco"], indirect=True) -def test_dw_quant_deepest(ds, od): +def test_dw_quant_deepest(od): ind = (11, 75, 73) face, iy, ix = ind @@ -284,7 +281,7 @@ def test_dw_quant_deepest(ds, od): # dw is a stepwise function. small_offset = 1e-12 - dw = -np.diff(np.array(ds.WVELMASS1[:, face, iy, ix])) + dw = -np.diff(np.array(od._ds.WVELMASS1[:, face, iy, ix])) zinterp = [0] dwinterp = [dw[0]] for i, zl in enumerate(od.Zl[1:-1]): @@ -300,10 +297,9 @@ def test_dw_quant_deepest(ds, od): assert np.allclose(scipy_ans, seaduck_ans) -@pytest.mark.parametrize("ds", ["ecco"], indirect=True) @pytest.mark.parametrize("od", ["ecco"], indirect=True) @pytest.mark.parametrize("seed", list(range(7, 12))) -def test_dw_quant_random(ds, od, seed): +def test_dw_quant_random(od, seed): np.random.seed(seed) z = np.random.uniform(od.Zl[-1], 0, 50) x = np.random.uniform(-180, 180, 1) * np.ones_like(z) @@ -318,7 +314,7 @@ def test_dw_quant_random(ds, od, seed): # dw is a stepwise function. small_offset = 1e-12 - dw = -np.diff(np.array(ds.WVELMASS1[:, face, iy, ix])) + dw = -np.diff(np.array(od._ds.WVELMASS1[:, face, iy, ix])) zinterp = [0] dwinterp = [dw[0]] for i, zl in enumerate(od.Zl[1:-1]): diff --git a/tests/test_eulerian_budget.py b/tests/test_eulerian_budget.py new file mode 100644 index 00000000..3d6290ea --- /dev/null +++ b/tests/test_eulerian_budget.py @@ -0,0 +1,78 @@ +import numpy as np +import pytest + +from seaduck.eulerian_budget import ( + buffer_x_periodic, + buffer_y_periodic, + buffer_z_nearest_withoutface, + second_order_flux_limiter_x, + second_order_flux_limiter_y, + second_order_flux_limiter_z_withoutface, + superbee_fluxlimiter, +) + + +@pytest.fixture +def random_4d(): + np.random.seed(401) + return np.random.random((3, 4, 5, 4)) + + +def test_superbee(): + cr = np.array([-1, 0.25, 0.5, 1, 2, 100]) + res = superbee_fluxlimiter(cr) + assert np.allclose(res, np.array([0.0, 0.5, 1.0, 1.0, 2.0, 2.0])) + + +@pytest.mark.parametrize(["lm", "rm"], [(0, 2), (2, 1), (1, 0)]) +def test_buffer_x_periodic(random_4d, lm, rm): + buffer = buffer_x_periodic(random_4d, lm, rm) + if rm != 0: + assert np.allclose(buffer[..., -1], random_4d[..., rm - 1]) + if lm != 0: + assert np.allclose(buffer[..., 0], random_4d[..., -lm]) + + +@pytest.mark.parametrize(["lm", "rm"], [(0, 2), (2, 1), (1, 0)]) +def test_buffer_y_periodic(random_4d, lm, rm): + buffer = buffer_y_periodic(random_4d, lm, rm) + if rm != 0: + assert np.allclose(buffer[..., -1, :], random_4d[..., rm - 1, :]) + if lm != 0: + assert np.allclose(buffer[..., 0, :], random_4d[..., -lm, :]) + + +@pytest.mark.parametrize(["lm", "rm"], [(0, 2), (2, 1), (1, 0)]) +def test_buffer_z_nearest_withoutface(random_4d, lm, rm): + buffer = buffer_z_nearest_withoutface(random_4d, lm, rm) + if rm != 0: + assert np.allclose(buffer[..., -1, :, :], random_4d[..., -1, :, :]) + if lm != 0: + assert np.allclose(buffer[..., 0, :, :], random_4d[..., 0, :, :]) + + +# The next three tests are not very quantitative, +# but they has been tested with real datasets, +# still there needs to be some improvements. +def test_second_order_flux_limiter_x(random_4d): + np.random.seed(401) + u_cfl = np.random.random((3, 4, 5, 5)) * 2 - 1 + ans = second_order_flux_limiter_x(random_4d, u_cfl) + assert ans.shape == u_cfl.shape + assert ans.dtype == "float64" + + +def test_second_order_flux_limiter_y(random_4d): + np.random.seed(401) + v_cfl = np.random.random((3, 4, 6, 4)) * 2 - 1 + ans = second_order_flux_limiter_y(random_4d, v_cfl) + assert ans.shape == v_cfl.shape + assert ans.dtype == "float64" + + +def test_second_order_flux_limiter_z_withoutface(random_4d): + np.random.seed(401) + w_cfl = np.random.random((3, 4, 5, 4)) * 2 - 1 + ans = second_order_flux_limiter_z_withoutface(random_4d, w_cfl) + assert ans.shape == w_cfl.shape + assert ans.dtype == "float64" diff --git a/tests/test_lagrangian.py b/tests/test_lagrangian.py index e8c84d00..50165dec 100644 --- a/tests/test_lagrangian.py +++ b/tests/test_lagrangian.py @@ -52,6 +52,16 @@ def ecco_p(): return sd.Particle(x=x, y=y, z=zz, t=t, data=od, transport=True) +@pytest.fixture +def kick_back_p(): + x = np.array([-38.594593, -37.512672, -36.42936, -34.08329, -35.06443]) + y = np.array([-77.95619, -77.97306, -77.98856, -77.25903, -76.86412]) + z = np.ones_like(x) * (-0.01) + t = utils.convert_time(start_time) * np.ones_like(x) + od = sd.OceData(utils.get_dataset("ecco")) + return sd.Particle(x=x, y=y, z=z, t=t, data=od, free_surface="kick_back") + + normal_stops = np.linspace(t[0], tf, 5) @@ -129,11 +139,10 @@ def test_multidim_uvw_array(ecco_p): assert ecco_p.uarray.shape[0] == 2 -@pytest.mark.parametrize("od", ["ecco"], indirect=True) -def test_update_w_array(ecco_p, od): - od["u0"] = od["UVELMASS"].isel(time=0) - od["v0"] = od["VVELMASS"].isel(time=0) - od["w0"] = od["WVELMASS"].isel(time=0) +def test_update_w_array(ecco_p): + ecco_p.ocedata._ds["u0"] = ecco_p.ocedata["UVELMASS"].isel(time=0) + ecco_p.ocedata._ds["v0"] = ecco_p.ocedata["VVELMASS"].isel(time=0) + ecco_p.ocedata._ds["w0"] = ecco_p.ocedata["WVELMASS"].isel(time=0) delattr(ecco_p, "warray") ecco_p.uname = "u0" ecco_p.vname = "v0" @@ -261,3 +270,9 @@ def test_get_u_du_quant(seed, od): assert np.allclose(ushould, u, atol=1e-18) assert np.allclose(vshould, v, atol=1e-18) assert np.allclose(wshould, w, atol=1e-18) + + +def test_kick_back(kick_back_p): + tend = kick_back_p.analytical_step(-1e10) + kick_back_p.cross_cell_wall(tend) + assert np.isclose(kick_back_p.rzl_lin, 0.5).any() diff --git a/tests/test_lagrangian_budget.py b/tests/test_lagrangian_budget.py index ed1b30f2..fc97c492 100644 --- a/tests/test_lagrangian_budget.py +++ b/tests/test_lagrangian_budget.py @@ -16,7 +16,7 @@ @pytest.fixture def custom_pt(): - x = np.linspace(-50, -15, 200) + x = np.linspace(-50, -15, 5) y = np.ones_like(x) * 52.0 z = np.ones_like(x) * (-9) t = np.ones_like(x) @@ -37,6 +37,24 @@ def custom_pt(): return pt +@pytest.fixture +def curv_pt(): + od = sd.OceData(utils.get_dataset("curv")) + curv_p = sd.Particle( + y=np.array([70.5]), + x=np.array([-14.0]), + z=np.array([-10.0]), + t=np.array([od.ts[0]]), + data=od, + uname="U", + vname="V", + wname="W", + save_raw=True, + ) + curv_p.to_next_stop(od.ts[0] + 1e4) + return curv_p + + @pytest.fixture def region_info(): GULF = np.array( @@ -110,6 +128,16 @@ def test_ind_frac_find(custom_pt, od): assert (tres >= 0).all() +@pytest.mark.parametrize("od", ["curv"], indirect=True) +def test_ind_frac_find_noface(curv_pt, od): + particle_datasets = particle2xarray(curv_pt) + tub = od + ind1, ind2, frac, tres, last, first = find_ind_frac_tres(particle_datasets, tub) + assert ind1.shape[0] == 4 + assert (frac != 1).any() + assert (tres >= 0).all() + + def test_store_lists(custom_pt): store_lists(custom_pt, "PleaseIgnore_dump.zarr") diff --git a/tests/test_ocedata.py b/tests/test_ocedata.py index e596b2d8..6b0141b2 100644 --- a/tests/test_ocedata.py +++ b/tests/test_ocedata.py @@ -11,10 +11,10 @@ def incomplete_data(request): ds = utils.get_dataset("curv") if request.param == "drop_YG": od = sd.OceData(ds) - ds_out = ds.drop_vars(["YG"]) + ds_out = od._ds.drop_vars(["YG"]) od._add_missing_cs_sn() - ds_out["CS"] = xr.DataArray(od["CS"], dims=ds["XC"].dims) - ds_out["SN"] = xr.DataArray(od["SN"], dims=ds["XC"].dims) + ds_out["CS"] = xr.DataArray(od["CS"], dims=od._ds["XC"].dims) + ds_out["SN"] = xr.DataArray(od["SN"], dims=od._ds["XC"].dims) return ds_out elif request.param == "drop_dyG": return ds.drop_vars(["dyG"]) diff --git a/tests/test_utils.py b/tests/test_utils.py index d0ec0096..3ec1038a 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -59,8 +59,8 @@ def test_none_in(): @pytest.mark.parametrize("ds", ["curv"], indirect=True) def test_cs_sn(ds): - sd.utils.missing_cs_sn(ds, return_xr=True) - assert isinstance(ds["CS"], xr.DataArray) + cs, sn = sd.utils.missing_cs_sn(ds, return_xr=True) + assert isinstance(cs, xr.DataArray) def test_covert_time():