diff --git a/goalie/adjoint.py b/goalie/adjoint.py index 9ba7d6d..c788412 100644 --- a/goalie/adjoint.py +++ b/goalie/adjoint.py @@ -1,6 +1,7 @@ """ Drivers for solving adjoint problems on sequences of meshes. """ + import firedrake from firedrake.petsc import PETSc from firedrake.adjoint import pyadjoint @@ -382,8 +383,8 @@ def wrapped_solver(subinterval, ic, **kwargs): ) elif j * stride + 1 == num_solve_blocks: if i + 1 < num_subintervals: - sols.adjoint_next[i][j].project( - sols.adjoint_next[i + 1][0] + project( + sols.adjoint_next[i + 1][0], sols.adjoint_next[i][j] ) else: raise IndexError( diff --git a/goalie/go_mesh_seq.py b/goalie/go_mesh_seq.py index ee3a4fa..c33245c 100644 --- a/goalie/go_mesh_seq.py +++ b/goalie/go_mesh_seq.py @@ -1,6 +1,7 @@ """ Drivers for goal-oriented error estimation on sequences of meshes. """ + from .adjoint import AdjointMeshSeq from .error_estimation import get_dwr_indicator from .log import pyrint @@ -61,7 +62,9 @@ def get_enriched_mesh_seq( get_qoi=self._get_qoi, get_bcs=self._get_bcs, qoi_type=self.qoi_type, + parameters=self.params, ) + mesh_seq_e.update_function_spaces() # Apply p-refinement if enrichment_method == "p": @@ -312,7 +315,9 @@ def fixed_point_iteration( break # Adapt meshes and log element counts - continue_unconditionally = adaptor(self, self.solutions, self.indicators, **adaptor_kwargs) + continue_unconditionally = adaptor( + self, self.solutions, self.indicators, **adaptor_kwargs + ) if self.params.drop_out_converged: self.check_convergence[:] = np.logical_not( np.logical_or(continue_unconditionally, self.converged) diff --git a/goalie/mesh_seq.py b/goalie/mesh_seq.py index a2d623c..59c4fbb 100644 --- a/goalie/mesh_seq.py +++ b/goalie/mesh_seq.py @@ -1,6 +1,7 @@ """ Sequences of meshes corresponding to a :class:`~.TimePartition`. """ + import firedrake from firedrake.adjoint import pyadjoint from firedrake.adjoint_utils.solving import get_solve_blocks @@ -235,8 +236,10 @@ def _function_spaces_consistent(self) -> bool: ) return consistent - @property - def function_spaces(self) -> list: + def update_function_spaces(self) -> list: + """ + Update the function space dictionary associated with the :class:`MeshSeq`. + """ if self._fs is None or not self._function_spaces_consistent: self._fs = [self.get_function_spaces(mesh) for mesh in self.meshes] self._fs = AttrDict( @@ -250,6 +253,10 @@ def function_spaces(self) -> list: ), "Meshes and function spaces are inconsistent" return self._fs + @property + def function_spaces(self): + return self.update_function_spaces() + @property def initial_condition(self) -> AttrDict: ic = OrderedDict(self.get_initial_condition()) @@ -653,7 +660,11 @@ def check_element_count_convergence(self): @PETSc.Log.EventDecorator() def fixed_point_iteration( - self, adaptor: Callable, solver_kwargs: dict = {}, adaptor_kwargs: dict = {}, **kwargs + self, + adaptor: Callable, + solver_kwargs: dict = {}, + adaptor_kwargs: dict = {}, + **kwargs, ): r""" Apply goal-oriented mesh adaptation using a fixed point iteration loop.