Skip to content

Commit

Permalink
refactored type annotation to be able to resuse union types
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerflex committed Jun 30, 2022
1 parent 53816bb commit f32bf08
Show file tree
Hide file tree
Showing 11 changed files with 56 additions and 51 deletions.
2 changes: 1 addition & 1 deletion tests/test_boundaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from tidy3d.components.boundary import Periodic, PECBoundary, PMCBoundary, BlochBoundary
from tidy3d.components.boundary import PML, StablePML, Absorber
from tidy3d.components.source import GaussianPulse, PlaneWave, PointDipole
from tidy3d.components.base import TYPE_TAG_STR
from tidy3d.components.types import TYPE_TAG_STR
from tidy3d.log import SetupError
from .utils import assert_log_level

Expand Down
2 changes: 1 addition & 1 deletion tests/test_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import tidy3d as td
from tidy3d.components.grid import Coords, FieldGrid, YeeGrid, Grid
from tidy3d.components.base import TYPE_TAG_STR
from tidy3d.components.types import TYPE_TAG_STR


def test_coords():
Expand Down
5 changes: 1 addition & 4 deletions tidy3d/components/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,12 @@
import numpy as np
from pydantic.fields import ModelField

from .types import ComplexNumber, Literal
from .types import ComplexNumber, Literal, TYPE_TAG_STR
from ..log import FileError

# default indentation (# spaces) in files
INDENT = 4

# type tag default name
TYPE_TAG_STR = "type"


class Tidy3dBaseModel(pydantic.BaseModel):
"""Base pydantic model that all Tidy3d components inherit from.
Expand Down
4 changes: 2 additions & 2 deletions tidy3d/components/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import h5py
import pydantic as pd

from .types import Numpy, Direction, Array, Literal, Ax, Coordinate, Axis
from .base import Tidy3dBaseModel, TYPE_TAG_STR
from .types import Numpy, Direction, Array, Literal, Ax, Coordinate, Axis, TYPE_TAG_STR
from .base import Tidy3dBaseModel
from .simulation import Simulation
from .boundary import Symmetry, BlochBoundary
from .monitor import Monitor
Expand Down
13 changes: 8 additions & 5 deletions tidy3d/components/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from .base import Tidy3dBaseModel, cached_property
from .types import Bound, Size, Coordinate, Axis, Coordinate2D, ArrayLike
from .types import Vertices, Ax, Shapely
from .types import Vertices, Ax, Shapely, annotate_type
from .viz import add_ax_if_none, equal_aspect
from .viz import PLOT_BUFFER, ARROW_LENGTH_FACTOR, ARROW_WIDTH_FACTOR, MAX_ARROW_WIDTH_FACTOR
from .viz import PlotParams, plot_params_geometry, polygon_patch
Expand Down Expand Up @@ -2089,15 +2089,14 @@ def normalize(v):
return np.swapaxes(vs_orig + shift_total, -2, -1), parallel_shift, (shift_x, shift_y)


# geometries that can be used to define structures.
GeometryFields = (Box, Sphere, Cylinder, PolySlab)
GeometryType = Union[GeometryFields]
# types of geometry including just one Geometry object (exluding group)
SingleGeometryType = Union[Box, Sphere, Cylinder, PolySlab]


class GeometryGroup(Geometry):
"""A collection of Geometry objects that can be called as a single geometry object."""

geometries: Tuple[GeometryType, ...] = pydantic.Field(
geometries: Tuple[annotate_type(SingleGeometryType), ...] = pydantic.Field(
...,
title="Geometries",
description="Tuple of geometries in a single grouping. "
Expand Down Expand Up @@ -2176,3 +2175,7 @@ def inside(self, x, y, z) -> bool:
individual_insides = (geometry.inside(x, y, z) for geometry in self.geometries)

return functools.reduce(lambda a, b: a | b, individual_insides)


# geometries usable to define a structure
GeometryType = Union[SingleGeometryType, GeometryGroup]
4 changes: 2 additions & 2 deletions tidy3d/components/grid/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import numpy as np
import pydantic as pd

from ..base import Tidy3dBaseModel, TYPE_TAG_STR, cached_property
from ..types import ArrayLike, Axis, Array
from ..base import Tidy3dBaseModel, cached_property
from ..types import ArrayLike, Axis, Array, TYPE_TAG_STR
from ..geometry import Box

from ...log import SetupError
Expand Down
22 changes: 9 additions & 13 deletions tidy3d/components/monitor.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
"""Objects that define how data is recorded from simulation."""
from abc import ABC, abstractmethod
from typing import Union, Tuple
from typing_extensions import Annotated

import pydantic
import numpy as np

from .types import Literal, Ax, EMField, ArrayLike, Bound, FreqArray
from .geometry import Box
from .validators import assert_plane
from .base import cached_property, TYPE_TAG_STR
from .base import cached_property
from .mode import ModeSpec
from .viz import PlotParams, plot_params_monitor, ARROW_COLOR_MONITOR, ARROW_ALPHA
from ..log import SetupError
Expand Down Expand Up @@ -463,15 +462,12 @@ def storage_size(self, num_cells: int, tmesh: int) -> int:


# types of monitors that are accepted by simulation
MonitorType = Annotated[
Union[
FieldMonitor,
FieldTimeMonitor,
PermittivityMonitor,
FluxMonitor,
FluxTimeMonitor,
ModeMonitor,
ModeFieldMonitor,
],
pydantic.Field(discriminator=TYPE_TAG_STR),
MonitorType = Union[
FieldMonitor,
FieldTimeMonitor,
PermittivityMonitor,
FluxMonitor,
FluxTimeMonitor,
ModeMonitor,
ModeFieldMonitor,
]
6 changes: 3 additions & 3 deletions tidy3d/components/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .validators import assert_unique_names, assert_objects_in_sim_bounds
from .validators import validate_mode_objects_symmetry
from .geometry import Box
from .types import Ax, Shapely, FreqBound, GridSize, Axis
from .types import Ax, Shapely, FreqBound, GridSize, Axis, annotate_type
from .grid import Coords1D, Grid, Coords, GridSpec, UniformGrid
from .medium import Medium, MediumType, AbstractMedium, PECMedium
from .boundary import BoundarySpec, Symmetry, BlochBoundary, PECBoundary, PMCBoundary
Expand Down Expand Up @@ -139,7 +139,7 @@ class Simulation(Box): # pylint:disable=too-many-public-methods
"simulation material properties in regions of spatial overlap.",
)

sources: Tuple[SourceType, ...] = pydantic.Field(
sources: Tuple[annotate_type(SourceType), ...] = pydantic.Field(
(),
title="Sources",
description="Tuple of electric current sources injecting fields into the simulation.",
Expand All @@ -151,7 +151,7 @@ class Simulation(Box): # pylint:disable=too-many-public-methods
description="Specification of boundary conditions along each dimension.",
)

monitors: Tuple[MonitorType, ...] = pydantic.Field(
monitors: Tuple[annotate_type(MonitorType), ...] = pydantic.Field(
(),
title="Monitors",
description="Tuple of monitors in the simulation. "
Expand Down
21 changes: 9 additions & 12 deletions tidy3d/components/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from typing import Union, Tuple
import logging

from typing_extensions import Annotated, Literal
from typing_extensions import Literal
import pydantic
import numpy as np

from .base import Tidy3dBaseModel, cached_property, TYPE_TAG_STR
from .base import Tidy3dBaseModel, cached_property
from .types import Direction, Polarization, Ax, FreqBound, ArrayLike, Axis, Bound
from .validators import assert_plane, validate_name_str
from .geometry import Box
Expand Down Expand Up @@ -673,14 +673,11 @@ class TFSF(AngledFieldSource, VolumeSource):


# sources allowed in Simulation.sources
SourceType = Annotated[
Union[
UniformCurrentSource,
PointDipole,
GaussianBeam,
AstigmaticGaussianBeam,
ModeSource,
PlaneWave,
],
pydantic.Field(discriminator=TYPE_TAG_STR),
SourceType = Union[
UniformCurrentSource,
PointDipole,
GaussianBeam,
AstigmaticGaussianBeam,
ModeSource,
PlaneWave,
]
18 changes: 10 additions & 8 deletions tidy3d/components/structure.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
# pylint:disable=too-many-arguments
"""Defines Geometric objects with Medium properties."""
from typing import Union

import pydantic

from .base import Tidy3dBaseModel
from .validators import validate_name_str
from .geometry import GeometryType, Box, GeometryGroup # pylint: disable=unused-import
from .medium import MediumType, Medium # pylint: disable=unused-import
from .types import Ax
from .geometry import GeometryType
from .medium import MediumType
from .types import Ax, TYPE_TAG_STR
from .viz import add_ax_if_none, equal_aspect


Expand All @@ -19,19 +16,24 @@ class Structure(Tidy3dBaseModel):
Example
-------
>>> from tidy3d import Box, Medium
>>> box = Box(center=(0,0,1), size=(2, 2, 2))
>>> glass = Medium(permittivity=3.9)
>>> struct = Structure(geometry=box, medium=glass, name='glass_box')
"""

geometry: Union[GeometryType, GeometryGroup] = pydantic.Field(
..., title="Geometry", description="Defines geometric properties of the structure."
geometry: GeometryType = pydantic.Field(
...,
title="Geometry",
description="Defines geometric properties of the structure.",
discriminator=TYPE_TAG_STR,
)

medium: MediumType = pydantic.Field(
...,
title="Medium",
description="Defines the electromagnetic properties of the structure's medium.",
discriminator=TYPE_TAG_STR,
)

name: str = pydantic.Field(None, title="Name", description="Optional name for the structure.")
Expand Down
10 changes: 10 additions & 0 deletions tidy3d/components/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Literal
except ImportError:
from typing_extensions import Literal
from typing_extensions import Annotated

import pydantic
import numpy as np
Expand All @@ -18,6 +19,15 @@
# warnings.filterwarnings(action="error", category=np.ComplexWarning)
""" Numpy Arrays """

# type tag default name
TYPE_TAG_STR = "type"


def annotate_type(UnionType): # pylint:disable=invalid-name
"""Annotated union type using TYPE_TAG_STR as discriminator."""
return Annotated[UnionType, pydantic.Field(discriminator=TYPE_TAG_STR)]


# generic numpy array
Numpy = np.ndarray

Expand Down

0 comments on commit f32bf08

Please sign in to comment.