From e7161a90979118a85999d525318f349bfea692dc Mon Sep 17 00:00:00 2001 From: Tyler Hughes Date: Sat, 8 Jun 2024 13:51:14 -0400 Subject: [PATCH] fix deepcopy and polyslab bugs in autograd --- tidy3d/components/base.py | 8 +++++--- tidy3d/components/data/data_array.py | 17 +++++++++++++++++ tidy3d/components/geometry/polyslab.py | 3 +-- 3 files changed, 23 insertions(+), 5 deletions(-) diff --git a/tidy3d/components/base.py b/tidy3d/components/base.py index ca7fc6e24..27fc26e3b 100644 --- a/tidy3d/components/base.py +++ b/tidy3d/components/base.py @@ -261,13 +261,15 @@ def updated_copy(self, path: str = None, deep: bool = True, **kwargs) -> Tidy3dB sub_component = sub_component_list[index] sub_path = "/".join(path_components[2:]) - sub_component_list[index] = sub_component.updated_copy(path=sub_path, **kwargs) + sub_component_list[index] = sub_component.updated_copy( + path=sub_path, deep=deep, **kwargs + ) new_component = tuple(sub_component_list) else: sub_path = "/".join(path_components[1:]) - new_component = sub_component.updated_copy(path=sub_path, **kwargs) + new_component = sub_component.updated_copy(path=sub_path, deep=deep, **kwargs) - return self._updated_copy(**{field_name: new_component}) + return self._updated_copy(deep=deep, **{field_name: new_component}) def _updated_copy(self, deep: bool = True, **kwargs) -> Tidy3dBaseModel: """Make copy of a component instance with ``**kwargs`` indicating updated field values.""" diff --git a/tidy3d/components/data/data_array.py b/tidy3d/components/data/data_array.py index 5114beeda..768c13d0a 100644 --- a/tidy3d/components/data/data_array.py +++ b/tidy3d/components/data/data_array.py @@ -79,6 +79,23 @@ def __init__(self, data, *args, **kwargs): # NOTE: this is done because if we pass the traced array directly, it will create a # numpy array of `ArrayBox`, which is extremely slow + def __deepcopy__(self, memo): + """Define the behavior of ``deepcopy()`` a ``xr.DataArray``.""" + + # if we detect that this has tracers, we need to shallow copy + # otherwise it confuses autograd.. + if self.has_tracers: + return self.__copy__() + + return super().__deepcopy__(memo) + + @property + def has_tracers(self) -> bool: + """Whether the ``DataArray`` has ``autograd`` derivative information.""" + traced_data = self.data.dtype == object and isbox(self.data.flat[0]) + traced_attrs = AUTOGRAD_KEY in self.attrs + return traced_data or traced_attrs + @classmethod def __get_validators__(cls): """Validators that get run when :class:`.DataArray` objects are added to pydantic models.""" diff --git a/tidy3d/components/geometry/polyslab.py b/tidy3d/components/geometry/polyslab.py index 9ab5df76d..28f7be05b 100644 --- a/tidy3d/components/geometry/polyslab.py +++ b/tidy3d/components/geometry/polyslab.py @@ -1414,8 +1414,7 @@ def compute_derivative_vertices( # compute center positions between each edge edge_centers_plane = (vertices_next + vertices) / 2.0 - edge_centers_axis = np.mean(self.slab_bounds) * np.ones(num_vertices) - + edge_centers_axis = self.center_axis * np.ones(num_vertices) edge_centers_xyz = self.unpop_axis_vect(edge_centers_axis, edge_centers_plane) assert edge_centers_xyz.shape == (num_vertices, 3), "something bad happened"