diff --git a/meshmode/discretization/connection/direct.py b/meshmode/discretization/connection/direct.py index d4247cb93..e215bad56 100644 --- a/meshmode/discretization/connection/direct.py +++ b/meshmode/discretization/connection/direct.py @@ -24,7 +24,7 @@ import numpy.linalg as la from abc import ABC, abstractmethod -from typing import Sequence, Optional, List, Tuple +from typing import Generic, Sequence, Optional, List, Tuple from pytools import memoize_method import loopy as lp @@ -33,12 +33,11 @@ DiscretizationElementAxisTag, DiscretizationDOFAxisTag) from pytools import memoize_in, keyed_memoize_method from arraycontext import ( - ArrayContext, NotAnArrayContainerError, + ArrayContext, ArrayT, ArrayOrContainerT, NotAnArrayContainerError, serialize_container, deserialize_container, make_loopy_program, tag_axes ) from arraycontext.metadata import NameHint -from arraycontext.container import ArrayT, ArrayOrContainerT from meshmode.discretization import Discretization, ElementGroupBase from meshmode.dof_array import DOFArray @@ -54,7 +53,7 @@ def _reshape_and_preserve_tags( # {{{ interpolation batch @dataclass -class InterpolationBatch: +class InterpolationBatch(Generic[ArrayT]): """One interpolation batch captures how a batch of elements *within* an element group should be an interpolated. Note that while it's possible that an interpolation batch takes care of interpolating an entire element group @@ -178,7 +177,7 @@ def _global_from_element_indices( # {{{ _FromGroupPickData @dataclass -class _FromGroupPickData: +class _FromGroupPickData(Generic[ArrayT]): """Represents information needed to pick DOFs from one source element group to a target element group. Note that the connection between these groups must be such that the information transfer can occur by indirect diff --git a/meshmode/dof_array.py b/meshmode/dof_array.py index 027551e6d..7e141113d 100644 --- a/meshmode/dof_array.py +++ b/meshmode/dof_array.py @@ -37,12 +37,11 @@ from meshmode.transform_metadata import ( ConcurrentElementInameTag, ConcurrentDOFInameTag) from arraycontext import ( - ArrayContext, NotAnArrayContainerError, + ArrayContext, ArrayOrContainerT, NotAnArrayContainerError, make_loopy_program, with_container_arithmetic, serialize_container, deserialize_container, with_array_context, rec_map_array_container, rec_multimap_array_container, mapped_over_array_containers, multimapped_over_array_containers) -from arraycontext.container import ArrayOrContainerT __doc__ = """ .. autoclass:: DOFArray