Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optionally encode floating point operations as uninterpreted functions #1057

Closed
wants to merge 31 commits into from

Conversation

can-leh-emmtrix
Copy link
Contributor

@can-leh-emmtrix can-leh-emmtrix commented Jun 18, 2024

In our ongoing work with alive2, we often encountered cases where floating point operations remain largely unchanged while other parts of the function differ. Since, Z3's floating point theory is quite slow I am looking at using uninterpreted functions to (over) approximate these floating point operations.

This PR adds a new command line option --uf-float that encodes all floating point operations using uninterpreted functions. The encoding is a conservative approximation, meaning that if a counter-example for the approximation is found, it is not necessarily a valid counter example.

Commutativity for FPBinOp and FCmp is encoded using the encoding proposed by Seongwon Bang, Seunghyeon Nam, Inwhan Chun, Ho Young Jhoo, and Juneyoung Lee in "SMT-based Translation Validation for Machine Learning Compiler.". It does not use the other parts of the encoding. It might be interesting to encode properties other than commutativity as well though.

I decided to open this PR as a draft, to see if there is interest in upstreaming this feature and get some feedback on the direction.

Related: #916

@regehr
Copy link
Contributor

regehr commented Jun 18, 2024

so my (now former) student Zhengyang has been using floats-as-UF in his Alive-based superoptimizer. @zhengyang92 do you have feedback on this?

@can-leh-emmtrix can-leh-emmtrix marked this pull request as ready for review July 2, 2024 09:35
@can-leh-emmtrix
Copy link
Contributor Author

I'd consider this ready for review now.

As the diff is quite large, I'd recommend looking at the first three commits individually and then looking at the diff between the third and last commit.

The main reason for the large diff is that implementing the --uf-float option required indenting the existing code:

if (is_uf_float()) {
  <create UF>
} else {
  <existing code>
}

@can-leh-emmtrix
Copy link
Contributor Author

@nunoplopes Do you have any feedback on this?

@nunoplopes
Copy link
Member

@nunoplopes Do you have any feedback on this?

Sorry for the delay. I still have a few bugs to fix on my queue. I'll get back to this afterwards.

llvm_util/cmd_args_def.h Outdated Show resolved Hide resolved
ir/instr.cpp Outdated Show resolved Hide resolved
ir/instr.cpp Outdated Show resolved Hide resolved
ir/instr.cpp Outdated
}
};

fast_math_flag(FastMathFlags::NNaN, "nnan");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this part. Also, where are the other flags?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It handles flags which may result in the operation returning poison. As far as I can tell the only flags which can result in poison are nnan and ninf. They are handled by creating additional uninterpreted functions op.np_nnan(x, y): bool and op.np_ninf(x, y): bool (in the binary case) whose results are added to the non_poison and expression.

The nsz flag is not relevant as in uf-float mode we do not really have a zero value.

I don't think I entirely understand how the remaining fast math flags are handled in the existing code, though it looks like the result of the operation is wrapped in a function call to another UF.

auto value = expr::mkUF(name, arg_values, res);
if (is_commutative) {
assert(args.size() == 2);
value = value & expr::mkUF(name, {arg_values[1], arg_values[0]}, res);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is a sound over-approximation. Do you need commutativity in your examples?
Would an axiom stating that forall x, y . fadd(x, y) = fadd(y, x) suffice?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The encoding comes from "SMT-based Translation Validation for Machine Learning Compiler" by Seongwon Bang, Seunghyeon Nam, Inwhan Chun, Ho Young Jhoo, and Juneyoung Lee. The idea is to encode commutativity as

$\mathrm{op}(x, y) = \mathrm{op}'(x, y) \mspace{0.3em} \mathrm{and} \mspace{0.3em} \mathrm{op}'(y, x)$

where $\mathrm{and}$ is the bitwise and operator and $\mathrm{op}'$ is an uninterpreted function. There is a proof of correctness in the supplementary material for the paper.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a comment referencing the paper.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The proof is very vague and doesn't show equisatisfiability clearly.
Why not use x <= y ? f(x,y) : f(y, x)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer to keep the current encoding, as I know that it works well on our test cases. While I could try other encodings, that might of course take some time. So here is my attempt at constructing a proof. It is of course inspired by the proof given in the supplementary material, but uses a simpler construction for $op'$ / $f'$.

Please correct me if I am wrong, but since we are already over approximating, we only need to prove that for all commutative functions $\mathrm{op}: \mathrm{BV}^n \times \mathrm{BV}^n \rightarrow \mathrm{BV}^n$ there is a function $\mathrm{op}': \mathrm{BV}^n \times \mathrm{BV}^n \rightarrow \mathrm{BV}^n$ such that

$$\mathrm{op}(x, y) = \mathrm{op}'(x, y) \mspace{0.3em} \mathrm{and} \mspace{0.3em} \mathrm{op}'(y, x)$$

(in the context of SMT solving this ensures that no model/counterexample can get lost.)

Proof: Let $\mathrm{BV}^n$ be the set of all bit vectors of size $n$ and let $\mathrm{op}: \mathrm{BV}^n \times \mathrm{BV}^n \rightarrow \mathrm{BV}^n$ be a commutative function.

Let $\mathrm{op}'(x, y) := \mathrm{op}(x, y)$. Then

$$ \mathrm{op}'(x, y) \mspace{0.3em} \mathrm{and} \mspace{0.3em} \mathrm{op}'(y, x) = \mathrm{op}(x, y) \mspace{0.3em} \mathrm{and} \mspace{0.3em} \mathrm{op}(y, x) $$

By commutativity of $\mathrm{op}$

$$ = \mathrm{op}(x, y) \mspace{0.3em} \mathrm{and} \mspace{0.3em} \mathrm{op}(x, y) $$

By $\forall x \in \mathrm{BV}^n: x \mspace{0.3em} \mathrm{and} \mspace{0.3em} x = x$

$$ = \mathrm{op}(x, y) $$

Thus we have shown that $\mathrm{op}'$ fullfills the requirement.

ir/instr.cpp Outdated Show resolved Hide resolved
ir/instr.cpp Outdated Show resolved Hide resolved
ir/instr.cpp Outdated Show resolved Hide resolved
ir/instr.cpp Outdated Show resolved Hide resolved
ir/instr.cpp Outdated Show resolved Hide resolved
ir/instr.cpp Outdated Show resolved Hide resolved
ir/instr.cpp Outdated Show resolved Hide resolved
ir/instr.cpp Outdated Show resolved Hide resolved
ir/instr.cpp Show resolved Hide resolved
ir/instr.cpp Outdated Show resolved Hide resolved
ir/instr.cpp Outdated Show resolved Hide resolved
ir/instr.cpp Outdated Show resolved Hide resolved
ir/instr.cpp Outdated Show resolved Hide resolved
ir/instr.cpp Outdated Show resolved Hide resolved
ir/instr.cpp Outdated Show resolved Hide resolved
@nunoplopes
Copy link
Member

I was just thinking about this again and I think there's a better approach, which is much less invasive.
We could implement this in smt/expr.cpp instead and have a global mode to switch to UF encoding. That way, instr.cpp wouldn't be changed, all fmath flags would work, and we have a single place to change things, instead of duplicating semantics in instr.cpp.

@can-leh-emmtrix
Copy link
Contributor Author

can-leh-emmtrix commented Aug 7, 2024

I played a bit with the idea, but I was unable to achieve the same performance as when implementing the feature in instr.cpp. I used the uf-float/select testcase as my benchmark.

I tried a few different things, but here are the most successful approaches I could come up with.

  • No Approximation is just using Z3's floating point theory
  • expr.cpp / Floating Point is implemented in expr.cpp and uses uninterpreted functions that take floating point values as arguments
  • expr.cpp / Bit Vector is implemented in expr.cpp and uses uninterpreted functions that take bit vector values as arguments
  • instr.cpp is this PR

Here are the results. I measured the total runtime of alive-tv on uf-float/select over 8 runs per approach.

uf_float_impls_1

Here is a chart that shows only the approximations.

uf_float_impls_2

So at least on this benchmark the approach taken by this PR has a clear advantage over inserting the uninterpreted functions in expr.cpp. This also matches my prior observations in our internal examples.

I think that the reason for this difference are the other operations applied fm_poison. It uses FloatType::fromFloat, any_fp_zero, and handle_subnormal, all of which are not neccessary for the approach taken by this PR. Additionally, I conjecture that in examples using instructions which are implemented as a more complex series of Z3 operations (e.g. FMin), we will observe the same problem. This is something that can not be fixed in expr.cpp.

Regarding the remaining fast math flags: Implementing the same logic as in fm_poison to over approximate them should work for uf_float too, shouldn't it?

@can-leh-emmtrix
Copy link
Contributor Author

Here is the relevant part of the diff for expr.cpp / Floating Point. This is of course very much a prototype and only implements the operations I needed for benchmarking.

diff --git a/smt/expr.cpp b/smt/expr.cpp
index 70353e8a..0af924a2 100644
--- a/smt/expr.cpp
+++ b/smt/expr.cpp
@@ -1205,7 +1205,8 @@ expr expr::fadd(const expr &rhs, const expr &rm) const {
 
 expr expr::fsub(const expr &rhs, const expr &rm) const {
   C(rhs, rm);
-  return simplify_const(Z3_mk_fpa_sub(ctx(), rm(), ast(), rhs()), *this, rhs);
+  //return simplify_const(Z3_mk_fpa_sub(ctx(), rm(), ast(), rhs()), *this, rhs);
+  return expr::mkUF("fsub", {*this, rhs}, *this);
 }
 
 expr expr::fmul(const expr &rhs, const expr &rm) const {
@@ -1282,7 +1283,8 @@ expr expr::foge(const expr &rhs) const {
 }
 
 expr expr::folt(const expr &rhs) const {
-  return binop_fold(rhs, Z3_mk_fpa_lt);
+  //return binop_fold(rhs, Z3_mk_fpa_lt);
+  return expr::mkUF("flt", {*this, rhs}, true);
 }
 
 expr expr::fole(const expr &rhs) const {
@@ -1310,7 +1312,8 @@ expr expr::fuge(const expr &rhs) const {
 }
 
 expr expr::fult(const expr &rhs) const {
-  return funo(rhs) || binop_fold(rhs, Z3_mk_fpa_lt);
+  //return funo(rhs) || binop_fold(rhs, Z3_mk_fpa_lt);
+  return funo(rhs) || expr::mkUF("flt", {*this, rhs}, true);
 }
 
 expr expr::fule(const expr &rhs) const {

Here is the relevant part of the diff for expr.cpp / Bit Vector:

diff --git a/smt/expr.cpp b/smt/expr.cpp
index 70353e8a..729c8ac3 100644
--- a/smt/expr.cpp
+++ b/smt/expr.cpp
@@ -1205,7 +1205,8 @@ expr expr::fadd(const expr &rhs, const expr &rm) const {
 
 expr expr::fsub(const expr &rhs, const expr &rm) const {
   C(rhs, rm);
-  return simplify_const(Z3_mk_fpa_sub(ctx(), rm(), ast(), rhs()), *this, rhs);
+  //return simplify_const(Z3_mk_fpa_sub(ctx(), rm(), ast(), rhs()), *this, rhs);
+  return expr::mkUF("fsub", {float2BV(), rhs.float2BV()}, float2BV()).BV2float(rhs);
 }
 
 expr expr::fmul(const expr &rhs, const expr &rm) const {
@@ -1282,7 +1283,8 @@ expr expr::foge(const expr &rhs) const {
 }
 
 expr expr::folt(const expr &rhs) const {
-  return binop_fold(rhs, Z3_mk_fpa_lt);
+  //return binop_fold(rhs, Z3_mk_fpa_lt);
+  return expr::mkUF("flt", {float2BV(), rhs.float2BV()}, true);
 }
 
 expr expr::fole(const expr &rhs) const {
@@ -1310,7 +1312,8 @@ expr expr::fuge(const expr &rhs) const {
 }
 
 expr expr::fult(const expr &rhs) const {
-  return funo(rhs) || binop_fold(rhs, Z3_mk_fpa_lt);
+  //return funo(rhs) || binop_fold(rhs, Z3_mk_fpa_lt);
+  return funo(rhs) || expr::mkUF("flt", {float2BV(), rhs.float2BV()}, true);
 }
 
 expr expr::fule(const expr &rhs) const {

@nunoplopes
Copy link
Member

I think those results are very encouraging and really suggest we should go the expr.cpp way.
Note that I don't think we should be using float2BV() at all. Floats should not exist; only BVs. We can also have the bit-width be user-specified.

This solution is superior in terms of maintainability. And getting rid of floats altogether should close the performance gap.

@can-leh-emmtrix
Copy link
Contributor Author

It would still require some changes in instr.cpp or type.cpp to disable checking for NaN and inserting nondeteriministic values though. Do you have a solution for instructions implemented in terms of multipe floating point and non floating point Z3 operations? Ideally they would be reduced to a single UF as well.

@nunoplopes
Copy link
Member

It would still require some changes in instr.cpp or type.cpp to disable checking for NaN and inserting nondeteriministic values though. Do you have a solution for instructions implemented in terms of multipe floating point and non floating point Z3 operations? Ideally they would be reduced to a single UF as well.

I don't remember of cases that are not 1-to-1 mapping with Z3. Do you have an example in mind?

(btw, I'm taking off for vacations tmr; I'm back in September)

@can-leh-emmtrix
Copy link
Contributor Author

can-leh-emmtrix commented Aug 19, 2024

I don't remember of cases that are not 1-to-1 mapping with Z3. Do you have an example in mind?

Looking through the source code I found the following cases:

For the FpConversionOp instructions LRInt, LRound, FPToSInt, FPToUInt the non_poison value is also computed using multiple floating point operations: https://github.com/AliveToolkit/alive2/blob/master/ir/instr.cpp#L1788-L1832

(btw, I'm taking off for vacations tmr; I'm back in September)

Have a nice vacation!

Edit: I will be on vacation until September 23.

@nunoplopes
Copy link
Member

Closing this as we don't want to have to duplicate semantics. A solution in smt/expr.cpp would avoid that.
Plus it's a better solution longer term once we start supporting abstraction refinement.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants