Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure compile_ufl output has correct free indices #237

Closed
wants to merge 2 commits into from

Conversation

ReubenHill
Copy link
Contributor

@ReubenHill ReubenHill commented Feb 2, 2021

When an expression given to compile_ufl does not have points in the
input PointSet in its expression tree, the resulting gem expression has
missing free indices. This attempts to fix that.

Todo:

  • Fix failing tests

tsfc/fem.py Outdated Show resolved Hide resolved
tsfc/fem.py Outdated Show resolved Hide resolved
@ReubenHill
Copy link
Contributor Author

Sadly it seems that quite a lot of gem assumes that gem.Delta(j,j) returns Literal(1.) :/

tsfc/fem.py Outdated Show resolved Hide resolved
This removes a shortcut in delta which would return gem.Literal(1.) for
Delta(j, j). This shortcut stopped IndexSum(Delta(i, i), i) doing the
right thing since it stopped i from being a free index of the expression
produced by Delta.

Deltas with repeated indices are now replaced with gem.Literal(1.) by
gem.optimise.replace_indices_delta.
tsfc/fem.py Outdated Show resolved Hide resolved
tsfc/driver.py Outdated Show resolved Hide resolved
tsfc/fem.py Outdated Show resolved Hide resolved
tsfc/driver.py Outdated Show resolved Hide resolved
tsfc/driver.py Outdated Show resolved Hide resolved
tsfc/driver.py Outdated Show resolved Hide resolved
Copy link
Contributor

@wence- wence- left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor fixes

@wence-
Copy link
Contributor

wence- commented Feb 5, 2021

Very minor, your commit message for the compile_ufl change:

Ensure compile_ufl output has correct free indices  …
When an expression given to compile_ufl does not have points in the
input PointSet in its expression tree, the resulting gem expression has
missing free indices. This attempts to fix that.

I think it does fix it, so remove "This attempts to fix that" :).

@wence-
Copy link
Contributor

wence- commented Feb 5, 2021

Can you add a test that this is doing TRT? e.g. that compile_ufl(expr_that_doesn_depend_on_point_set) has the right free indices now?

@ReubenHill
Copy link
Contributor Author

Can you add a test that this is doing TRT? e.g. that compile_ufl(expr_that_doesn_depend_on_point_set) has the right free indices now?

I'm not quite sure of a nice way to set up such a test. Because of the weird interface that fem.compile_ufl has I need to first set up a whole kernel config which takes me close to having to reimplement much of compile_expression_dual_evaluation. Could the assertions do?

When an expression given to compile_ufl does not have points in the
input PointSet in its expression tree, the resulting gem expression has
missing free indices. This fixes that.
@wence-
Copy link
Contributor

wence- commented Feb 5, 2021

I'm not quite sure of a nice way to set up such a test. Because of the weird interface that fem.compile_ufl has I need to first set up a whole kernel config which takes me close to having to reimplement much of compile_expression_dual_evaluation. Could the assertions do?

OK, I see. yes I think that's ok.

@miklos1
Copy link
Member

miklos1 commented Feb 5, 2021

This feels to me like scratching your face with your elbow behind your back. Surely, the right way would be that the caller calmly accepts the fact, that the compiled expression may be constant along some of the point set indices, and then it deals with the situation accordingly. Leaving "mathematical one" nodes in the expression with fake index dependencies does not sound like a clean solution.

@wence-
Copy link
Contributor

wence- commented Feb 5, 2021

This feels to me like scratching your face with your elbow behind your back. Surely, the right way would be that the caller calmly accepts the fact, that the compiled expression may be constant along some of the point set indices, and then it deals with the situation accordingly. Leaving "mathematical one" nodes in the expression with fake index dependencies does not sound like a clean solution.

So there are two parts here. The early simplification of delta(i, i) => one I think is kind of wrong (because then as noted IndexSum(delta(i, i), (i, )) produces the wrong thing.

I agree that using delta(i, i) as a fake way of producing a "broadcast over these indices" is a bit messy. Perhaps a better way would be to introduce a Broadcast(expr, indices) node and handle it later. Then there's a philosophical point of whether this should be handled inside compile_ufl or outside.

@ReubenHill
Copy link
Contributor Author

This feels to me like scratching your face with your elbow behind your back. Surely, the right way would be that the caller calmly accepts the fact, that the compiled expression may be constant along some of the point set indices, and then it deals with the situation accordingly. Leaving "mathematical one" nodes in the expression with fake index dependencies does not sound like a clean solution.

Are you advocating not changing the behaviour of fem.compile_ufl? The trouble is that if I want to do anything with the expression after calling fem.compile_ufl then then the number of free indices it has is unpredictable without this fix - it's silently dependent on whether the points in the point set turn up in the expression

@ReubenHill
Copy link
Contributor Author

@dham was the person who first strongly advocated that this was a bug - he may want to add his thoughts

@miklos1
Copy link
Member

miklos1 commented Feb 5, 2021

The early simplification of delta(i, i) => one I think is kind of wrong (because then as noted IndexSum(delta(i, i), (i, )) produces the wrong thing.

I suppose IndexSum(one, (i,)) should simplify to Literal(i.extent), though I admit in other cases it is useful to have an index_sum helper function which just ignores indices not present in the integrand.

I agree that using delta(i, i) as a fake way of producing a "broadcast over these indices" is a bit messy. Perhaps a better way would be to introduce a Broadcast(expr, indices) node and handle it later.

Perhaps. Where is it used? See more in another answer below.

@miklos1
Copy link
Member

miklos1 commented Feb 5, 2021

The trouble is that if I want to do anything with the expression after calling fem.compile_ufl then then the number of free indices it has is unpredictable without this fix - it's silently dependent on whether the points in the point set turn up in the expression

I would need to see how you exactly use compile_ufl, but my preliminary opinion is that this is indeed the correct behaviour.
Cases like this have been dealt with in TSFC: for example, when interpolating grad(f) to a vector-DG1 space, it might happen that grad(f) is cellwise constant because f was linear. In this case the same values are written in each target DG node of the cell. impero_utils.Assignment accepts right-hand sides whose free indices are a proper subset of the free indices of the left-hand side.

@dham
Copy link
Member

dham commented Feb 5, 2021

The early simplification of delta(i, i) => one I think is kind of wrong (because then as noted IndexSum(delta(i, i), (i, )) produces the wrong thing.

I suppose IndexSum(one, (i,)) should simplify to Literal(i.extent), though I admit in other cases it is useful to have an index_sum helper function which just ignores indices not present in the integrand.

I agree that using delta(i, i) as a fake way of producing a "broadcast over these indices" is a bit messy. Perhaps a better way would be to introduce a Broadcast(expr, indices) node and handle it later.

Perhaps. Where is it used? See more in another answer below.

I think the IndexSum is a red herring here. The question to ask yourself is, what are the shape and free indices of delta(i, i) and one. I think that the former is scalar with one free index and the latter is scalar with no free indices. This makes the simplification invalid.

@ReubenHill
Copy link
Contributor Author

The trouble is that if I want to do anything with the expression after calling fem.compile_ufl then then the number of free indices it has is unpredictable without this fix - it's silently dependent on whether the points in the point set turn up in the expression

I would need to see how you exactly use compile_ufl, but my preliminary opinion is that this is indeed the correct behaviour.
Cases like this have been dealt with in TSFC: for example, when interpolating grad(f) to a vector-DG1 space, it might happen that grad(f) is cellwise constant because f was linear. In this case the same values are written in each target DG node of the cell. impero_utils.Assignment accepts right-hand sides whose free indices are a proper subset of the free indices of the left-hand side.

See this comment in FInAT/FInAT#57 which I am trying to tidy and land.

I am currently putting together a tensor of weights Q which I right multiply by a vector of function point evaluations (the function of interest here being whatever the ufl expression we are compiling with gem_expr = fem.compile_ufl(ufl_expression,...) is and the points are some FInAT.PointSet point_set).

If the ufl_expression function/expression happens to not be dependent on the points given to it, instead of getting gem_expr to be a vector of point evaluations (where some of the free indices are equal to point_set.indices) I get a scalar. To do my tensor contraction I then need to have a special case to deal with this quirk of fem.compile_ufl.

It seems to me better to get fem.compile_ufl to return something with the expected free indices.

@miklos1
Copy link
Member

miklos1 commented Feb 5, 2021

I think the IndexSum is a red herring here. The question to ask yourself is, what are the shape and free indices of delta(i, i) and one. I think that the former is scalar with one free index and the latter is scalar with no free indices. This makes the simplification invalid.

Well, yes and no. From a formal point of view, you are right, since Delta(i, i) is an expression which contains i without i being bound, so it does havei as a free index, while one does not have i as a free index. However, they are substantially equivalent, since for all valid values of i: Delta(i, i) = one. So, considering the purpose of filling the result tensor with the correct values, replacing Delta(i, i) with one is legit.

To apply optimisations, that invariably simplify the expression, as early as possible, is usually the right thing to do. The benefits are potentially unlocking further optimisations after simplification (which means better optimised output code), and less work for further traverals to do, due to the smaller resulting expression (which means faster compilation).

Currently, for interpolation, we just let assignment do the broadcasting. Form assembly with fancy optimisations can be even more complicated. For example, the mass matrix of a vector-CG element is the outer product of the scalar mass matrix with an identity matrix, which means we only broadcast along the diagonal loop, but not onto the entire (dense) result tensor. Is this new use case somehow more special than the previous ones?

@miklos1
Copy link
Member

miklos1 commented Feb 5, 2021

I am currently putting together a tensor of weights Q which I right multiply by a vector of function point evaluations (the function of interest here being whatever the ufl expression we are compiling with gem_expr = fem.compile_ufl(ufl_expression,...) is and the points are some FInAT.PointSet point_set).

If the ufl_expression function/expression happens to not be dependent on the points given to it, instead of getting gem_expr to be a vector of point evaluations (where some of the free indices are equal to point_set.indices) I get a scalar. To do my tensor contraction I then need to have a special case to deal with this quirk of fem.compile_ufl.

It seems to me better to get fem.compile_ufl to return something with the expected free indices.

I don't really see why some missing indices would get in the way of some tensor contraction. Since you say in the linked comment that Q is compile-time constant, I would go even further: check which point indices are not being used in the result of compile_ufl, and pre-contract the Q matrix along those indices. Then you have spared some loops at run time.

@dham
Copy link
Member

dham commented Feb 5, 2021

I think the IndexSum is a red herring here. The question to ask yourself is, what are the shape and free indices of delta(i, i) and one. I think that the former is scalar with one free index and the latter is scalar with no free indices. This makes the simplification invalid.

Well, yes and no. From a formal point of view, you are right, since Delta(i, i) is an expression which contains i without i being bound, so it does havei as a free index, while one does not have i as a free index. However, they are substantially equivalent, since for all valid values of i: Delta(i, i) = one. So, considering the purpose of filling the result tensor with the correct values, replacing Delta(i, i) with one is legit.

To apply optimisations, that invariably simplify the expression, as early as possible, is usually the right thing to do. The benefits are potentially unlocking further optimisations after simplification (which means better optimised output code), and less work for further traverals to do, due to the smaller resulting expression (which means faster compilation).

Currently, for interpolation, we just let assignment do the broadcasting. Form assembly with fancy optimisations can be even more complicated. For example, the mass matrix of a vector-CG element is the outer product of the scalar mass matrix with an identity matrix, which means we only broadcast along the diagonal loop, but not onto the entire (dense) result tensor. Is this new use case somehow more special than the previous ones?

Hmmm. The problem is that you get intermediate values that don't have the expected shape, and you then have to somehow implicitly know what has disappeared. In this case, compile_ufl is being used to evaluate an expression at a bunch of points. The expression might also have shape because it's vector or tensor-valued. If some of the shape/indices disappear because of this sort of index cancellation, how do I safely know which indices have disappeared? What if it's a 3D vector valued expression being evaluated at 3 points?

I think for this to be safe you have to have explicit broadcasting of some sort. You can then drop the broadcasts before you go to Impero in order to avoid loops, but while you're manipulating GEM I think you need to maintain consistent shapes.

@miklos1
Copy link
Member

miklos1 commented Feb 6, 2021 via email

@ReubenHill
Copy link
Contributor Author

We have come to the conclusion that this PR is not necessary, but that the behaviour of GEM with respect to shape being sacred and indices being discardable needs to be documented in a design document for TSFC. I will make an issue.

expr, = fem.compile_ufl(expression, **config, point_sum=False)
# the point set free indices should now be free indices of the compiled
# expression
assert set(point_set.indices) <= set(expr.free_indices)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wence- there was some discussion yesterday of adding some kind of assertion along these lines, but I can't work out what it should be. Given that the point set indices can disappear, and we don't know what other free indices there might be at this point, what's there left to assert?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What other free indices might there be?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think expr.free_indices \in point_set.indices?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if your expression has an argument there is a free index for that. Might be remembering wrong though

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah, so what we're expecting is expr.free_indices <= set(chain(point_set.indices, *argument_multiindices))) I think

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #241

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants