From 63fe989fef1c8c653f235fde60ec8ffc3abeab51 Mon Sep 17 00:00:00 2001 From: Reuben Nixon-Hill Date: Wed, 3 Feb 2021 12:32:29 +0000 Subject: [PATCH] Ensure compile_ufl output has correct free indices When an expression given to compile_ufl does not have points in the input PointSet in its expression tree, the resulting gem expression has missing free indices. This attempts to fix that. --- tsfc/driver.py | 8 ++++++++ tsfc/fem.py | 9 +++++++++ 2 files changed, 17 insertions(+) diff --git a/tsfc/driver.py b/tsfc/driver.py index d6db7120..a077226a 100644 --- a/tsfc/driver.py +++ b/tsfc/driver.py @@ -377,7 +377,11 @@ def compile_expression_dual_evaluation(expression, to_element, *, quad_rule = QuadratureRule(point_set, to_element._weights) config["quadrature_rule"] = quad_rule + # evaluate expression at the points specified by point_set expr, = fem.compile_ufl(expression, **config, point_sum=False) + # the point set free indices should now be free indices of the complied + # expression + assert all(i in expr.free_indices for i in point_set.indices) shape_indices = tuple(gem.Index() for _ in expr.shape) basis_indices = point_set.indices ir = gem.Indexed(expr, shape_indices) @@ -395,7 +399,11 @@ def compile_expression_dual_evaluation(expression, to_element, *, point_set = PointSet(pts) config = kernel_cfg.copy() config.update(point_set=point_set) + # evaluate expression at the points specified by point_set expr, = fem.compile_ufl(expression, **config, point_sum=False) + # the point set free indices should now be free indices of the + # complied expression + assert all(i in expr.free_indices for i in point_set.indices) expr = gem.partial_indexed(expr, shape_indices) expr_cache[pts] = expr, point_set weights = collections.defaultdict(list) diff --git a/tsfc/fem.py b/tsfc/fem.py index 2411a1dd..93ac33ee 100644 --- a/tsfc/fem.py +++ b/tsfc/fem.py @@ -697,6 +697,15 @@ def compile_ufl(expression, interior_facet=False, point_sum=False, **kwargs): # Translate UFL to GEM, lowering finite element specific nodes result = map_expr_dags(context.translator, expressions) + # Check that any indices from the PointSet have made it into the resulting expression + import functools + import operator + for i in range(len(result)): + expansion_indices = tuple(j for j in context.point_indices if j not in result[i].free_indices) + # blow up shape to include expansion_indices using deltas - these will + # eventually be cancelled out but ensure the our result has the expected + # free_indices and shape. + result[i] = functools.reduce(operator.mul, (gem.Delta(j, j) for j in expansion_indices), result[i]) if point_sum: result = [gem.index_sum(expr, context.point_indices) for expr in result] return result