Skip to content

Commit

Permalink
Ensure compile_ufl output has correct free indices
Browse files Browse the repository at this point in the history
When an expression given to compile_ufl does not have points in the
input PointSet in its expression tree, the resulting gem expression has
missing free indices. This attempts to fix that.
  • Loading branch information
ReubenHill committed Feb 4, 2021
1 parent 6cdeb3c commit 63fe989
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
8 changes: 8 additions & 0 deletions tsfc/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,11 @@ def compile_expression_dual_evaluation(expression, to_element, *,
quad_rule = QuadratureRule(point_set, to_element._weights)
config["quadrature_rule"] = quad_rule

# evaluate expression at the points specified by point_set
expr, = fem.compile_ufl(expression, **config, point_sum=False)
# the point set free indices should now be free indices of the complied
# expression
assert all(i in expr.free_indices for i in point_set.indices)
shape_indices = tuple(gem.Index() for _ in expr.shape)
basis_indices = point_set.indices
ir = gem.Indexed(expr, shape_indices)
Expand All @@ -395,7 +399,11 @@ def compile_expression_dual_evaluation(expression, to_element, *,
point_set = PointSet(pts)
config = kernel_cfg.copy()
config.update(point_set=point_set)
# evaluate expression at the points specified by point_set
expr, = fem.compile_ufl(expression, **config, point_sum=False)
# the point set free indices should now be free indices of the
# complied expression
assert all(i in expr.free_indices for i in point_set.indices)
expr = gem.partial_indexed(expr, shape_indices)
expr_cache[pts] = expr, point_set
weights = collections.defaultdict(list)
Expand Down
9 changes: 9 additions & 0 deletions tsfc/fem.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,15 @@ def compile_ufl(expression, interior_facet=False, point_sum=False, **kwargs):

# Translate UFL to GEM, lowering finite element specific nodes
result = map_expr_dags(context.translator, expressions)
# Check that any indices from the PointSet have made it into the resulting expression
import functools
import operator
for i in range(len(result)):
expansion_indices = tuple(j for j in context.point_indices if j not in result[i].free_indices)
# blow up shape to include expansion_indices using deltas - these will
# eventually be cancelled out but ensure the our result has the expected
# free_indices and shape.
result[i] = functools.reduce(operator.mul, (gem.Delta(j, j) for j in expansion_indices), result[i])
if point_sum:
result = [gem.index_sum(expr, context.point_indices) for expr in result]
return result

0 comments on commit 63fe989

Please sign in to comment.