From 175a406948930dd97477ed0a97d139248f8110b2 Mon Sep 17 00:00:00 2001 From: tylerflex Date: Wed, 29 Jun 2022 11:36:02 -0700 Subject: [PATCH] refactored type annotation to be able to resuse union types --- tests/test_boundaries.py | 2 +- tests/test_grid.py | 2 +- tidy3d/components/base.py | 5 +---- tidy3d/components/data.py | 4 ++-- tidy3d/components/geometry.py | 13 ++++++++----- tidy3d/components/grid/grid.py | 4 ++-- tidy3d/components/monitor.py | 22 +++++++++------------- tidy3d/components/simulation.py | 6 +++--- tidy3d/components/source.py | 21 +++++++++------------ tidy3d/components/structure.py | 18 ++++++++++-------- tidy3d/components/types.py | 10 ++++++++++ 11 files changed, 56 insertions(+), 51 deletions(-) diff --git a/tests/test_boundaries.py b/tests/test_boundaries.py index be23e5e70..5cc2a9296 100644 --- a/tests/test_boundaries.py +++ b/tests/test_boundaries.py @@ -8,7 +8,7 @@ from tidy3d.components.boundary import Periodic, PECBoundary, PMCBoundary, BlochBoundary from tidy3d.components.boundary import PML, StablePML, Absorber from tidy3d.components.source import GaussianPulse, PlaneWave, PointDipole -from tidy3d.components.base import TYPE_TAG_STR +from tidy3d.components.types import TYPE_TAG_STR from tidy3d.log import SetupError from .utils import assert_log_level diff --git a/tests/test_grid.py b/tests/test_grid.py index b8a1f335f..12c60ddb2 100644 --- a/tests/test_grid.py +++ b/tests/test_grid.py @@ -4,7 +4,7 @@ import tidy3d as td from tidy3d.components.grid import Coords, FieldGrid, YeeGrid, Grid -from tidy3d.components.base import TYPE_TAG_STR +from tidy3d.components.types import TYPE_TAG_STR def test_coords(): diff --git a/tidy3d/components/base.py b/tidy3d/components/base.py index da240a335..1b8f72a49 100644 --- a/tidy3d/components/base.py +++ b/tidy3d/components/base.py @@ -9,15 +9,12 @@ import numpy as np from pydantic.fields import ModelField -from .types import ComplexNumber, Literal +from .types import ComplexNumber, Literal, TYPE_TAG_STR from ..log import FileError # default indentation (# spaces) in files INDENT = 4 -# type tag default name -TYPE_TAG_STR = "type" - class Tidy3dBaseModel(pydantic.BaseModel): """Base pydantic model that all Tidy3d components inherit from. diff --git a/tidy3d/components/data.py b/tidy3d/components/data.py index 34b757658..1dd2706ba 100644 --- a/tidy3d/components/data.py +++ b/tidy3d/components/data.py @@ -10,8 +10,8 @@ import h5py import pydantic as pd -from .types import Numpy, Direction, Array, Literal, Ax, Coordinate, Axis -from .base import Tidy3dBaseModel, TYPE_TAG_STR +from .types import Numpy, Direction, Array, Literal, Ax, Coordinate, Axis, TYPE_TAG_STR +from .base import Tidy3dBaseModel from .simulation import Simulation from .boundary import Symmetry, BlochBoundary from .monitor import Monitor diff --git a/tidy3d/components/geometry.py b/tidy3d/components/geometry.py index ac1e480fe..fda51fdcc 100644 --- a/tidy3d/components/geometry.py +++ b/tidy3d/components/geometry.py @@ -12,7 +12,7 @@ from .base import Tidy3dBaseModel, cached_property from .types import Bound, Size, Coordinate, Axis, Coordinate2D, ArrayLike -from .types import Vertices, Ax, Shapely +from .types import Vertices, Ax, Shapely, annotate_type from .viz import add_ax_if_none, equal_aspect from .viz import PLOT_BUFFER, ARROW_LENGTH_FACTOR, ARROW_WIDTH_FACTOR, MAX_ARROW_WIDTH_FACTOR from .viz import PlotParams, plot_params_geometry, polygon_patch @@ -2089,15 +2089,14 @@ def normalize(v): return np.swapaxes(vs_orig + shift_total, -2, -1), parallel_shift, (shift_x, shift_y) -# geometries that can be used to define structures. -GeometryFields = (Box, Sphere, Cylinder, PolySlab) -GeometryType = Union[GeometryFields] +# types of geometry including just one Geometry object (exluding group) +SingleGeometryType = Union[Box, Sphere, Cylinder, PolySlab] class GeometryGroup(Geometry): """A collection of Geometry objects that can be called as a single geometry object.""" - geometries: Tuple[GeometryType, ...] = pydantic.Field( + geometries: Tuple[annotate_type(SingleGeometryType), ...] = pydantic.Field( ..., title="Geometries", description="Tuple of geometries in a single grouping. " @@ -2176,3 +2175,7 @@ def inside(self, x, y, z) -> bool: individual_insides = (geometry.inside(x, y, z) for geometry in self.geometries) return functools.reduce(lambda a, b: a | b, individual_insides) + + +# geometries usable to define a structure +GeometryType = Union[SingleGeometryType, GeometryGroup] diff --git a/tidy3d/components/grid/grid.py b/tidy3d/components/grid/grid.py index d2c921b83..7e8376144 100644 --- a/tidy3d/components/grid/grid.py +++ b/tidy3d/components/grid/grid.py @@ -5,8 +5,8 @@ import numpy as np import pydantic as pd -from ..base import Tidy3dBaseModel, TYPE_TAG_STR, cached_property -from ..types import ArrayLike, Axis, Array +from ..base import Tidy3dBaseModel, cached_property +from ..types import ArrayLike, Axis, Array, TYPE_TAG_STR from ..geometry import Box from ...log import SetupError diff --git a/tidy3d/components/monitor.py b/tidy3d/components/monitor.py index 0a6cf75c3..4e2ede23e 100644 --- a/tidy3d/components/monitor.py +++ b/tidy3d/components/monitor.py @@ -1,7 +1,6 @@ """Objects that define how data is recorded from simulation.""" from abc import ABC, abstractmethod from typing import Union, Tuple -from typing_extensions import Annotated import pydantic import numpy as np @@ -9,7 +8,7 @@ from .types import Literal, Ax, EMField, ArrayLike, Bound, FreqArray from .geometry import Box from .validators import assert_plane -from .base import cached_property, TYPE_TAG_STR +from .base import cached_property from .mode import ModeSpec from .viz import PlotParams, plot_params_monitor, ARROW_COLOR_MONITOR, ARROW_ALPHA from ..log import SetupError @@ -463,15 +462,12 @@ def storage_size(self, num_cells: int, tmesh: int) -> int: # types of monitors that are accepted by simulation -MonitorType = Annotated[ - Union[ - FieldMonitor, - FieldTimeMonitor, - PermittivityMonitor, - FluxMonitor, - FluxTimeMonitor, - ModeMonitor, - ModeFieldMonitor, - ], - pydantic.Field(discriminator=TYPE_TAG_STR), +MonitorType = Union[ + FieldMonitor, + FieldTimeMonitor, + PermittivityMonitor, + FluxMonitor, + FluxTimeMonitor, + ModeMonitor, + ModeFieldMonitor, ] diff --git a/tidy3d/components/simulation.py b/tidy3d/components/simulation.py index 4d0e58a0f..339df8c07 100644 --- a/tidy3d/components/simulation.py +++ b/tidy3d/components/simulation.py @@ -14,7 +14,7 @@ from .validators import assert_unique_names, assert_objects_in_sim_bounds from .validators import validate_mode_objects_symmetry from .geometry import Box -from .types import Ax, Shapely, FreqBound, GridSize, Axis +from .types import Ax, Shapely, FreqBound, GridSize, Axis, annotate_type from .grid import Coords1D, Grid, Coords, GridSpec, UniformGrid from .medium import Medium, MediumType, AbstractMedium, PECMedium from .boundary import BoundarySpec, Symmetry, BlochBoundary, PECBoundary, PMCBoundary @@ -139,7 +139,7 @@ class Simulation(Box): # pylint:disable=too-many-public-methods "simulation material properties in regions of spatial overlap.", ) - sources: Tuple[SourceType, ...] = pydantic.Field( + sources: Tuple[annotate_type(SourceType), ...] = pydantic.Field( (), title="Sources", description="Tuple of electric current sources injecting fields into the simulation.", @@ -151,7 +151,7 @@ class Simulation(Box): # pylint:disable=too-many-public-methods description="Specification of boundary conditions along each dimension.", ) - monitors: Tuple[MonitorType, ...] = pydantic.Field( + monitors: Tuple[annotate_type(MonitorType), ...] = pydantic.Field( (), title="Monitors", description="Tuple of monitors in the simulation. " diff --git a/tidy3d/components/source.py b/tidy3d/components/source.py index 90eb72ee1..a366bc1c1 100644 --- a/tidy3d/components/source.py +++ b/tidy3d/components/source.py @@ -4,11 +4,11 @@ from typing import Union, Tuple import logging -from typing_extensions import Annotated, Literal +from typing_extensions import Literal import pydantic import numpy as np -from .base import Tidy3dBaseModel, cached_property, TYPE_TAG_STR +from .base import Tidy3dBaseModel, cached_property from .types import Direction, Polarization, Ax, FreqBound, ArrayLike, Axis, Bound from .validators import assert_plane, validate_name_str from .geometry import Box @@ -673,14 +673,11 @@ class TFSF(AngledFieldSource, VolumeSource): # sources allowed in Simulation.sources -SourceType = Annotated[ - Union[ - UniformCurrentSource, - PointDipole, - GaussianBeam, - AstigmaticGaussianBeam, - ModeSource, - PlaneWave, - ], - pydantic.Field(discriminator=TYPE_TAG_STR), +SourceType = Union[ + UniformCurrentSource, + PointDipole, + GaussianBeam, + AstigmaticGaussianBeam, + ModeSource, + PlaneWave, ] diff --git a/tidy3d/components/structure.py b/tidy3d/components/structure.py index 674a08da1..2456035c3 100644 --- a/tidy3d/components/structure.py +++ b/tidy3d/components/structure.py @@ -1,14 +1,11 @@ -# pylint:disable=too-many-arguments """Defines Geometric objects with Medium properties.""" -from typing import Union - import pydantic from .base import Tidy3dBaseModel from .validators import validate_name_str -from .geometry import GeometryType, Box, GeometryGroup # pylint: disable=unused-import -from .medium import MediumType, Medium # pylint: disable=unused-import -from .types import Ax +from .geometry import GeometryType +from .medium import MediumType +from .types import Ax, TYPE_TAG_STR from .viz import add_ax_if_none, equal_aspect @@ -19,19 +16,24 @@ class Structure(Tidy3dBaseModel): Example ------- + >>> from tidy3d import Box, Medium >>> box = Box(center=(0,0,1), size=(2, 2, 2)) >>> glass = Medium(permittivity=3.9) >>> struct = Structure(geometry=box, medium=glass, name='glass_box') """ - geometry: Union[GeometryType, GeometryGroup] = pydantic.Field( - ..., title="Geometry", description="Defines geometric properties of the structure." + geometry: GeometryType = pydantic.Field( + ..., + title="Geometry", + description="Defines geometric properties of the structure.", + discriminator=TYPE_TAG_STR, ) medium: MediumType = pydantic.Field( ..., title="Medium", description="Defines the electromagnetic properties of the structure's medium.", + discriminator=TYPE_TAG_STR, ) name: str = pydantic.Field(None, title="Name", description="Optional name for the structure.") diff --git a/tidy3d/components/types.py b/tidy3d/components/types.py index ae7670eef..ad7c0d106 100644 --- a/tidy3d/components/types.py +++ b/tidy3d/components/types.py @@ -7,6 +7,7 @@ from typing import Literal except ImportError: from typing_extensions import Literal +from typing_extensions import Annotated import pydantic import numpy as np @@ -18,6 +19,15 @@ # warnings.filterwarnings(action="error", category=np.ComplexWarning) """ Numpy Arrays """ +# type tag default name +TYPE_TAG_STR = "type" + + +def annotate_type(UnionType): # pylint:disable=invalid-name + """Annotated union type using TYPE_TAG_STR as discriminator.""" + return Annotated[UnionType, pydantic.Field(discriminator=TYPE_TAG_STR)] + + # generic numpy array Numpy = np.ndarray