Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactored type annotation to be able to resuse union types #429

Merged
merged 1 commit into from
Jun 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/test_boundaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from tidy3d.components.boundary import Periodic, PECBoundary, PMCBoundary, BlochBoundary
from tidy3d.components.boundary import PML, StablePML, Absorber
from tidy3d.components.source import GaussianPulse, PlaneWave, PointDipole
from tidy3d.components.base import TYPE_TAG_STR
from tidy3d.components.types import TYPE_TAG_STR
from tidy3d.log import SetupError
from .utils import assert_log_level

Expand Down
2 changes: 1 addition & 1 deletion tests/test_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import tidy3d as td
from tidy3d.components.grid import Coords, FieldGrid, YeeGrid, Grid
from tidy3d.components.base import TYPE_TAG_STR
from tidy3d.components.types import TYPE_TAG_STR


def test_coords():
Expand Down
5 changes: 1 addition & 4 deletions tidy3d/components/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,12 @@
import numpy as np
from pydantic.fields import ModelField

from .types import ComplexNumber, Literal
from .types import ComplexNumber, Literal, TYPE_TAG_STR
from ..log import FileError

# default indentation (# spaces) in files
INDENT = 4

# type tag default name
TYPE_TAG_STR = "type"


class Tidy3dBaseModel(pydantic.BaseModel):
"""Base pydantic model that all Tidy3d components inherit from.
Expand Down
4 changes: 2 additions & 2 deletions tidy3d/components/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import h5py
import pydantic as pd

from .types import Numpy, Direction, Array, Literal, Ax, Coordinate, Axis
from .base import Tidy3dBaseModel, TYPE_TAG_STR
from .types import Numpy, Direction, Array, Literal, Ax, Coordinate, Axis, TYPE_TAG_STR
from .base import Tidy3dBaseModel
from .simulation import Simulation
from .boundary import Symmetry, BlochBoundary
from .monitor import Monitor
Expand Down
13 changes: 8 additions & 5 deletions tidy3d/components/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from .base import Tidy3dBaseModel, cached_property
from .types import Bound, Size, Coordinate, Axis, Coordinate2D, ArrayLike
from .types import Vertices, Ax, Shapely
from .types import Vertices, Ax, Shapely, annotate_type
from .viz import add_ax_if_none, equal_aspect
from .viz import PLOT_BUFFER, ARROW_LENGTH_FACTOR, ARROW_WIDTH_FACTOR, MAX_ARROW_WIDTH_FACTOR
from .viz import PlotParams, plot_params_geometry, polygon_patch
Expand Down Expand Up @@ -2089,15 +2089,14 @@ def normalize(v):
return np.swapaxes(vs_orig + shift_total, -2, -1), parallel_shift, (shift_x, shift_y)


# geometries that can be used to define structures.
GeometryFields = (Box, Sphere, Cylinder, PolySlab)
GeometryType = Union[GeometryFields]
# types of geometry including just one Geometry object (exluding group)
SingleGeometryType = Union[Box, Sphere, Cylinder, PolySlab]


class GeometryGroup(Geometry):
"""A collection of Geometry objects that can be called as a single geometry object."""

geometries: Tuple[GeometryType, ...] = pydantic.Field(
geometries: Tuple[annotate_type(SingleGeometryType), ...] = pydantic.Field(
...,
title="Geometries",
description="Tuple of geometries in a single grouping. "
Expand Down Expand Up @@ -2176,3 +2175,7 @@ def inside(self, x, y, z) -> bool:
individual_insides = (geometry.inside(x, y, z) for geometry in self.geometries)

return functools.reduce(lambda a, b: a | b, individual_insides)


# geometries usable to define a structure
GeometryType = Union[SingleGeometryType, GeometryGroup]
4 changes: 2 additions & 2 deletions tidy3d/components/grid/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import numpy as np
import pydantic as pd

from ..base import Tidy3dBaseModel, TYPE_TAG_STR, cached_property
from ..types import ArrayLike, Axis, Array
from ..base import Tidy3dBaseModel, cached_property
from ..types import ArrayLike, Axis, Array, TYPE_TAG_STR
from ..geometry import Box

from ...log import SetupError
Expand Down
22 changes: 9 additions & 13 deletions tidy3d/components/monitor.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
"""Objects that define how data is recorded from simulation."""
from abc import ABC, abstractmethod
from typing import Union, Tuple
from typing_extensions import Annotated

import pydantic
import numpy as np

from .types import Literal, Ax, EMField, ArrayLike, Bound, FreqArray
from .geometry import Box
from .validators import assert_plane
from .base import cached_property, TYPE_TAG_STR
from .base import cached_property
from .mode import ModeSpec
from .viz import PlotParams, plot_params_monitor, ARROW_COLOR_MONITOR, ARROW_ALPHA
from ..log import SetupError
Expand Down Expand Up @@ -463,15 +462,12 @@ def storage_size(self, num_cells: int, tmesh: int) -> int:


# types of monitors that are accepted by simulation
MonitorType = Annotated[
Union[
FieldMonitor,
FieldTimeMonitor,
PermittivityMonitor,
FluxMonitor,
FluxTimeMonitor,
ModeMonitor,
ModeFieldMonitor,
],
pydantic.Field(discriminator=TYPE_TAG_STR),
MonitorType = Union[
FieldMonitor,
FieldTimeMonitor,
PermittivityMonitor,
FluxMonitor,
FluxTimeMonitor,
ModeMonitor,
ModeFieldMonitor,
]
6 changes: 3 additions & 3 deletions tidy3d/components/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .validators import assert_unique_names, assert_objects_in_sim_bounds
from .validators import validate_mode_objects_symmetry
from .geometry import Box
from .types import Ax, Shapely, FreqBound, GridSize, Axis
from .types import Ax, Shapely, FreqBound, GridSize, Axis, annotate_type
from .grid import Coords1D, Grid, Coords, GridSpec, UniformGrid
from .medium import Medium, MediumType, AbstractMedium, PECMedium
from .boundary import BoundarySpec, Symmetry, BlochBoundary, PECBoundary, PMCBoundary
Expand Down Expand Up @@ -139,7 +139,7 @@ class Simulation(Box): # pylint:disable=too-many-public-methods
"simulation material properties in regions of spatial overlap.",
)

sources: Tuple[SourceType, ...] = pydantic.Field(
sources: Tuple[annotate_type(SourceType), ...] = pydantic.Field(
(),
title="Sources",
description="Tuple of electric current sources injecting fields into the simulation.",
Expand All @@ -151,7 +151,7 @@ class Simulation(Box): # pylint:disable=too-many-public-methods
description="Specification of boundary conditions along each dimension.",
)

monitors: Tuple[MonitorType, ...] = pydantic.Field(
monitors: Tuple[annotate_type(MonitorType), ...] = pydantic.Field(
(),
title="Monitors",
description="Tuple of monitors in the simulation. "
Expand Down
21 changes: 9 additions & 12 deletions tidy3d/components/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from typing import Union, Tuple
import logging

from typing_extensions import Annotated, Literal
from typing_extensions import Literal
import pydantic
import numpy as np

from .base import Tidy3dBaseModel, cached_property, TYPE_TAG_STR
from .base import Tidy3dBaseModel, cached_property
from .types import Direction, Polarization, Ax, FreqBound, ArrayLike, Axis, Bound
from .validators import assert_plane, validate_name_str
from .geometry import Box
Expand Down Expand Up @@ -673,14 +673,11 @@ class TFSF(AngledFieldSource, VolumeSource):


# sources allowed in Simulation.sources
SourceType = Annotated[
Union[
UniformCurrentSource,
PointDipole,
GaussianBeam,
AstigmaticGaussianBeam,
ModeSource,
PlaneWave,
],
pydantic.Field(discriminator=TYPE_TAG_STR),
SourceType = Union[
UniformCurrentSource,
PointDipole,
GaussianBeam,
AstigmaticGaussianBeam,
ModeSource,
PlaneWave,
]
18 changes: 10 additions & 8 deletions tidy3d/components/structure.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
# pylint:disable=too-many-arguments
"""Defines Geometric objects with Medium properties."""
from typing import Union

import pydantic

from .base import Tidy3dBaseModel
from .validators import validate_name_str
from .geometry import GeometryType, Box, GeometryGroup # pylint: disable=unused-import
from .medium import MediumType, Medium # pylint: disable=unused-import
from .types import Ax
from .geometry import GeometryType
from .medium import MediumType
from .types import Ax, TYPE_TAG_STR
from .viz import add_ax_if_none, equal_aspect


Expand All @@ -19,19 +16,24 @@ class Structure(Tidy3dBaseModel):

Example
-------
>>> from tidy3d import Box, Medium
>>> box = Box(center=(0,0,1), size=(2, 2, 2))
>>> glass = Medium(permittivity=3.9)
>>> struct = Structure(geometry=box, medium=glass, name='glass_box')
"""

geometry: Union[GeometryType, GeometryGroup] = pydantic.Field(
..., title="Geometry", description="Defines geometric properties of the structure."
geometry: GeometryType = pydantic.Field(
...,
title="Geometry",
description="Defines geometric properties of the structure.",
discriminator=TYPE_TAG_STR,
)

medium: MediumType = pydantic.Field(
...,
title="Medium",
description="Defines the electromagnetic properties of the structure's medium.",
discriminator=TYPE_TAG_STR,
)

name: str = pydantic.Field(None, title="Name", description="Optional name for the structure.")
Expand Down
10 changes: 10 additions & 0 deletions tidy3d/components/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Literal
except ImportError:
from typing_extensions import Literal
from typing_extensions import Annotated

import pydantic
import numpy as np
Expand All @@ -18,6 +19,15 @@
# warnings.filterwarnings(action="error", category=np.ComplexWarning)
""" Numpy Arrays """

# type tag default name
TYPE_TAG_STR = "type"


def annotate_type(UnionType): # pylint:disable=invalid-name
"""Annotated union type using TYPE_TAG_STR as discriminator."""
return Annotated[UnionType, pydantic.Field(discriminator=TYPE_TAG_STR)]


# generic numpy array
Numpy = np.ndarray

Expand Down