Skip to content

Commit

Permalink
refc: replace assert statements
Browse files Browse the repository at this point in the history
  • Loading branch information
yaugenst-flex committed Jul 15, 2024
1 parent 8d31f3e commit d5dcc25
Show file tree
Hide file tree
Showing 13 changed files with 52 additions and 20 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ select = [
"C", # flake8-comprehensions
"B", # flake8-bugbear
"UP",
"S101", # do not use asserts
"NPY201", # numpy 2.* compatibility check
]
ignore = [
Expand Down
1 change: 1 addition & 0 deletions tests/ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ extend-ignore = [
"E402", # module-level import not at top of file
"E731", # lambda assignment
"F841", # unused local variable
"S101", # asserts allowed in tests
]
6 changes: 4 additions & 2 deletions tidy3d/components/geometry/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,8 @@ def parse_xyz_kwargs(**xyz) -> Tuple[Axis, float]:
Index into xyz axis (0,1,2) and position along that axis.
"""
xyz_filtered = {k: v for k, v in xyz.items() if v is not None}
assert len(xyz_filtered) == 1, "exactly one kwarg in [x,y,z] must be specified."
if len(xyz_filtered) != 1:
raise ValueError("exactly one kwarg in [x,y,z] must be specified.")
axis_label, position = list(xyz_filtered.items())[0]
axis = "xyz".index(axis_label)
return axis, position
Expand Down Expand Up @@ -3216,7 +3217,8 @@ def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldM
)
vjp_dict_geo = geo.compute_derivatives(geo_info)
grad_vjp_values = list(vjp_dict_geo.values())
assert len(grad_vjp_values) == 1, "Got multiple gradients for single geometry field."
if len(grad_vjp_values) != 1:
raise AssertionError("Got multiple gradients for single geometry field.")
grad_vjps[field_path] = grad_vjp_values[0]

return grad_vjps
Expand Down
8 changes: 4 additions & 4 deletions tidy3d/components/geometry/polyslab.py
Original file line number Diff line number Diff line change
Expand Up @@ -1373,9 +1373,8 @@ def _surface_area(self, bounds: Bound) -> float:
def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap:
"""Compute the adjoint derivatives for this object."""

assert derivative_info.paths == [
("vertices",)
], "only support derivative wrt 'PolySlab.vertices'."
if derivative_info.paths != [("vertices",)]:
raise ValueError("only support derivative wrt 'PolySlab.vertices'.")

vjp_vertices = self.compute_derivative_vertices(derivative_info=derivative_info)

Expand All @@ -1397,7 +1396,8 @@ def compute_derivative_vertices(self, derivative_info: DerivativeInfo) -> Traced
edge_centers_axis = self.center_axis * np.ones(num_vertices)
edge_centers_xyz = self.unpop_axis_vect(edge_centers_axis, edge_centers_plane)

assert edge_centers_xyz.shape == (num_vertices, 3), "something bad happened"
if edge_centers_xyz.shape != (num_vertices, 3):
raise AssertionError("something bad happened")

# compute the E and D fields at the edge centers
E_der_at_edges = self.der_at_centers(
Expand Down
3 changes: 2 additions & 1 deletion tidy3d/components/grid/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,8 @@ def discretize_inds(self, box: Box, extend: bool = False) -> List[Tuple[int, int
# for each dimension
for axis, (pt_min, pt_max) in enumerate(zip(pts_min, pts_max)):
bound_coords = np.array(boundaries.to_list[axis])
assert pt_min <= pt_max, "min point was greater than max point"
if pt_min > pt_max:
raise AssertionError("min point was greater than max point")

# index of smallest coord greater than pt_max
inds_gt_pt_max = np.where(bound_coords > pt_max)[0]
Expand Down
7 changes: 6 additions & 1 deletion tidy3d/components/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1258,7 +1258,12 @@ def get_dls(geom: Geometry, axis: Axis, num_dls: int) -> List[float]:
def snap_to_grid(geom: Geometry, axis: Axis) -> Geometry:
"""Snap a 2D material to the Yee grid."""
center = get_bounds(geom, axis)[0]
assert get_bounds(geom, axis)[0] == get_bounds(geom, axis)[1]
if get_bounds(geom, axis)[0] != get_bounds(geom, axis)[1]:
raise AssertionError(
"Unexpected error encountered while processing 2D material. "
"The upper and lower bounds of the geometry in the normal direction are not equal. "
"If you encounter this error, please create an issue in the Tidy3D github repository."
)
snapped_center = snap_coordinate_to_grid(self.grid, center, axis)
return geom._update_from_bounds(bounds=(snapped_center, snapped_center), axis=axis)

Expand Down
9 changes: 6 additions & 3 deletions tidy3d/plugins/dispersion/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ def _unpack_coeffs(coeffs):
Tuple[np.ndarray[complex], np.ndarray[complex]]
"a" and "c" poles for the PoleResidue model.
"""
assert len(coeffs) % 4 == 0, "len(coeffs) must be multiple of 4."
if len(coeffs) % 4 != 0:
raise ValueError(f"len(coeffs) must be multiple of 4, got {len(coeffs)=}.")

a_real = coeffs[0::4]
a_imag = coeffs[1::4]
Expand Down Expand Up @@ -692,8 +693,10 @@ def from_file(cls, fname: str, **loadtxt_kwargs) -> DispersionFitter:
A :class:`DispersionFitter` instance.
"""
data = np.loadtxt(fname, **loadtxt_kwargs)
assert len(data.shape) == 2, "data must contain [wavelength, ndata, kdata] in columns"
assert data.shape[-1] in (2, 3), "data must have either 2 or 3 rows (if k data)"
if len(data.shape) != 2:
raise ValueError("data must contain [wavelength, ndata, kdata] in columns")
if data.shape[-1] not in (2, 3):
raise ValueError("data must have either 2 or 3 rows (if k data)")
if data.shape[-1] == 2:
wvl_um, n_data = data.T
k_data = None
Expand Down
6 changes: 5 additions & 1 deletion tidy3d/plugins/microwave/custom_path_integrals.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,11 @@ def compute_integral(
elif isinstance(em_field, FieldTimeData):
return TimeDataArray(data=result.data, coords=result.coords)
else:
assert isinstance(em_field, ModeSolverData)
if not isinstance(em_field, ModeSolverData):
raise TypeError(
f"Unsupported 'em_field' type: {type(em_field)}. "
"Expected one of 'FieldData', 'FieldTimeData', 'ModeSolverData'."
)
return FreqModeDataArray(data=result.data, coords=result.coords)

@staticmethod
Expand Down
7 changes: 6 additions & 1 deletion tidy3d/plugins/microwave/path_integrals.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,12 @@ def compute_integral(self, scalar_field: EMScalarFieldType) -> IntegralResultTyp
elif isinstance(scalar_field, ScalarFieldTimeDataArray):
return TimeDataArray(data=result.data, coords=result.coords)
else:
assert isinstance(scalar_field, ScalarModeFieldDataArray)
if not isinstance(scalar_field, ScalarModeFieldDataArray):
raise TypeError(
f"Unsupported 'scalar_field' type: {type(scalar_field)}. "
"Expected one of 'ScalarFieldDataArray', 'ScalarFieldTimeDataArray', "
"'ScalarModeFieldDataArray'."
)
return FreqModeDataArray(data=result.data, coords=result.coords)

def _get_field_along_path(self, scalar_field: EMScalarFieldType) -> EMScalarFieldType:
Expand Down
6 changes: 5 additions & 1 deletion tidy3d/plugins/smatrix/component_modelers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,11 @@ def ab_to_s(a_matrix: DataArray, b_matrix: DataArray) -> DataArray:
"""Get the scattering matrix given the power wave matrices."""

# move the input and output port dimensions to the end, for ease of matrix operations
assert a_matrix.dims == b_matrix.dims
if a_matrix.dims != b_matrix.dims:
raise ValueError(
"'a_matrix' and 'b_matrix' must have the same number of dimensions, "
f"got {a_matrix.dims=} and {b_matrix.dims=}"
)
dims = list(a_matrix.dims)
dims.append(dims.pop(dims.index("port_out")))
dims.append(dims.pop(dims.index("port_in")))
Expand Down
9 changes: 6 additions & 3 deletions tidy3d/plugins/smatrix/ports/rectangular_lumped.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,12 @@ def compute_current(self, sim_data: SimulationData) -> FreqDataArray:
inject_center = h_cap_coords_along_injection[orth_index]
# Some sanity checks, tangent H field coordinates should be directly above
# and below the coordinates of the resistive sheet
assert orth_index > 0
assert inject_center < h_coords_along_injection[orth_index]
assert h_coords_along_injection[orth_index - 1] < inject_center
if orth_index <= 0:
raise AssertionError
if inject_center >= h_coords_along_injection[orth_index]:
raise AssertionError
if h_coords_along_injection[orth_index - 1] >= inject_center:
raise AssertionError
# Distance between the h1_field and h2_field, a single cell size
dcap = h_coords_along_injection[orth_index] - h_coords_along_injection[orth_index - 1]

Expand Down
3 changes: 2 additions & 1 deletion tidy3d/web/api/autograd/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,8 @@ def postprocess_adj(
# todo: handle multi-frequency, move to a property?
frequencies = {src.source_time.freq0 for src in sim_data_adj.simulation.sources}
frequencies = list(frequencies)
assert len(frequencies) == 1, "Multiple adjoint freqs found"
if len(frequencies) != 1:
raise RuntimeError("Multiple adjoint frequencies found.")
freq_adj = frequencies[0]

eps_in = np.mean(structure.medium.eps_model(freq_adj))
Expand Down
6 changes: 4 additions & 2 deletions tidy3d/web/api/material_fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,10 @@ def submit(cls, fitter: DispersionFitter, options: FitterOptions) -> MaterialFit
options: FitterOptions
fitter options
"""
assert fitter
assert options
if not isinstance(fitter, DispersionFitter):
raise TypeError(f"fitter must be an instance of 'DispersionFitter', got {type(fitter)}")
if not isinstance(options, FitterOptions):
raise TypeError(f"options must be an instance of 'FitterOptions', got {type(options)}")
data = np.asarray(list(zip(fitter.wvl_um, fitter.n_data, fitter.k_data)))
with tempfile.NamedTemporaryFile(suffix=".csv") as temp:
np.savetxt(temp, data, delimiter=",", header="Wavelength,n,k")
Expand Down

0 comments on commit d5dcc25

Please sign in to comment.