Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rewrite bit-blaster multiplier #5496

Closed
wants to merge 10 commits into from
1 change: 1 addition & 0 deletions src/ast/rewriter/bit_blaster/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
z3_add_component(bit_blaster
SOURCES
bit_blaster.cpp
bit_blaster_adder.cpp
bit_blaster_rewriter.cpp
COMPONENT_DEPENDENCIES
rewriter
Expand Down
112 changes: 112 additions & 0 deletions src/ast/rewriter/bit_blaster/bit_blaster_adder.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
#include "ast/rewriter/bit_blaster/bit_blaster_adder.h"
jameysharp marked this conversation as resolved.
Show resolved Hide resolved

bit_blaster_adder::bit_blaster_adder(bool_rewriter & rewriter, unsigned sz, numeral const & value):
m_rewriter(rewriter),
m_constant(value)
{
reduce();
expr_ref_vector empty(m());
m_variable.resize(sz, empty);
}

bit_blaster_adder::bit_blaster_adder(bool_rewriter & rewriter, unsigned sz, expr * const * bits):
bit_blaster_adder(rewriter, sz)
{
for (unsigned i = 0; i < sz; i++)
add_bit(i, bits[i]);
}

expr_ref bit_blaster_adder::sum_bits(vector< expr_ref_vector > & columns, expr_ref_vector & out_bits) const {
SASSERT(out_bits.empty());

expr_ref_vector carries(m());
expr_ref tmp1(m()), tmp2(m()), tmp3(m());

for (auto & column : columns) {
column.append(carries);
carries.reset();

while (column.size() >= 3) {
expr * a = column.back();
jameysharp marked this conversation as resolved.
Show resolved Hide resolved
column.pop_back();
expr * b = column.back();
column.pop_back();
expr * c = column.back();
column.pop_back();

m_rewriter.mk_xor(a, b, tmp1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this format is tedious. I finally added ability to tmp1 = m_rewriter.mk_or(a, b);

m_rewriter.mk_xor(tmp1, c, tmp2);
column.push_back(tmp2);

m_rewriter.mk_and(a, b, tmp2);
m_rewriter.mk_and(tmp1, c, tmp3);
// tmp2 and tmp3 can't be true at the same time, so use
// whichever of mk_or vs mk_xor makes the most sense here.
m_rewriter.mk_or(tmp2, tmp3, tmp1);
carries.push_back(tmp1);
}

if (column.size() == 2) {
expr * a = column.back();
column.pop_back();
expr * b = column.back();
column.pop_back();

m_rewriter.mk_xor(a, b, tmp1);
column.push_back(tmp1);

m_rewriter.mk_and(a, b, tmp1);
carries.push_back(tmp1);
}

out_bits.push_back(column.get(0, m().mk_false()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct, but subtle.
Why not explicit case on sz?

switch (sz) {
case 0:
case 1:
case 2:
default:
....
}

}

SASSERT(out_bits.size() == size());

// We return the carry separately in case the caller wants it.
tmp1 = m().mk_false();
for (auto & carry : carries) {
m_rewriter.mk_xor(tmp1, carry, tmp2);
tmp1 = tmp2;
}
return tmp1;
}

expr_ref bit_blaster_adder::variable_bits(expr_ref_vector & out_bits) const {
vector< expr_ref_vector > columns(m_variable);
return sum_bits(columns, out_bits);
}

expr_ref bit_blaster_adder::total_bits(expr_ref_vector & out_bits) const {
vector< expr_ref_vector > columns(m_variable);

expr_ref one(m());
one = m().mk_true();
for (unsigned i = 0; i < size(); i++)
if (m_constant.get_bit(i))
columns[i].push_back(one);

expr_ref carry(m());
carry = sum_bits(columns, out_bits);
if (m_constant.get_bit(size())) {
m_rewriter.mk_not(carry, one);
carry = one;
}
return carry;
}

bit_blaster_adder & bit_blaster_adder::add_shifted(bit_blaster_adder const & other, unsigned shift) {
add_shifted(other.m_constant, shift);
for (unsigned i = 0; shift + i < size(); i++)
m_variable[shift + i].append(other.m_variable[i]);
return *this;
}

bit_blaster_adder & bit_blaster_adder::add_shifted(unsigned sz, expr * const * bits, unsigned shift) {
if (sz > size() - shift)
sz = size() - shift;
for (unsigned i = 0; i < sz; i++)
add_bit(shift + i, bits[i]);
return *this;
}
98 changes: 98 additions & 0 deletions src/ast/rewriter/bit_blaster/bit_blaster_adder.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#pragma once

#include "ast/ast.h"
#include "ast/rewriter/bool_rewriter.h"
#include "util/rational.h"

class bit_blaster_adder {
public:
typedef rational numeral;

bit_blaster_adder(bool_rewriter & rewriter, unsigned sz, numeral const & value = numeral::zero());
bit_blaster_adder(bool_rewriter & rewriter, unsigned sz, expr * const * bits);

bit_blaster_adder(bool_rewriter & rewriter, expr_ref_vector & value):
bit_blaster_adder(rewriter, value.size(), value.data()) {}

bit_blaster_adder(bit_blaster_adder const & other):
bit_blaster_adder(other.m_rewriter, other.size()) {
*this += other;
}

bit_blaster_adder(bit_blaster_adder &&) noexcept = default;

// The *_bits functions all return the final carry bit out of the addition;
// just ignore it if you don't need it.

// Return the sum of the known-constant inputs to this adder.
bool constant_bits(numeral & value) const {
value = m_constant % power(size());
return m_constant.get_bit(size());
}

// Return the sum of the non-constant inputs to this adder.
expr_ref variable_bits(expr_ref_vector & out_bits) const;

// Return the sum of all inputs to this adder.
expr_ref total_bits(expr_ref_vector & out_bits) const;

unsigned size() const {
return m_variable.size();
}

ast_manager & m() const {
return m_rewriter.m();
}

bit_blaster_adder & operator+=(bit_blaster_adder const & other) {
SASSERT(size() == other.size());
return add_shifted(other, 0);
}

bit_blaster_adder & add_bit(unsigned idx, bool bit) {
SASSERT(idx < size());
if (bit)
m_constant += power(idx);
return *this;
}

bit_blaster_adder & add_bit(unsigned idx, expr * bit) {
SASSERT(idx < size());
if (m().is_true(bit))
m_constant += power(idx);
else if (!m().is_false(bit))
m_variable[idx].push_back(bit);
return *this;
}

bit_blaster_adder & add_bit(unsigned idx, expr_ref & bit) {
return add_bit(idx, bit.get());
}

bit_blaster_adder & add_shifted(bit_blaster_adder const & other, unsigned shift);
bit_blaster_adder & add_shifted(unsigned sz, expr * const * bits, unsigned shift);

bit_blaster_adder & add_shifted(expr_ref_vector const & bits, unsigned shift) {
return add_shifted(bits.size(), bits.data(), shift);
}

bit_blaster_adder & add_shifted(numeral const & value, unsigned shift) {
m_constant += value * power(shift);
reduce();
return *this;
}

protected:
bool_rewriter & m_rewriter;
jameysharp marked this conversation as resolved.
Show resolved Hide resolved
numeral m_constant;
vector< expr_ref_vector > m_variable;

numeral power(unsigned n) const { return numeral::power_of_two(n); }

void reduce() {
// keep one extra bit in case somebody wants the final carry bit
m_constant %= power(size() + 1);
}

expr_ref sum_bits(vector< expr_ref_vector > & columns, expr_ref_vector & out_bits) const;
};
9 changes: 2 additions & 7 deletions src/ast/rewriter/bit_blaster/bit_blaster_tpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Revision History:
--*/
#pragma once

#include "ast/rewriter/bit_blaster/bit_blaster_adder.h"
#include "util/rational.h"

template<typename Cfg>
Expand Down Expand Up @@ -73,9 +74,7 @@ class bit_blaster_tpl : public Cfg {
//


bool is_numeral(unsigned sz, expr * const * bits) const;
bool is_numeral(unsigned sz, expr * const * bits, numeral & r) const;
bool is_minus_one(unsigned sz, expr * const * bits) const;
void num2bits(numeral const & v, unsigned sz, expr_ref_vector & out_bits) const;

void mk_half_adder(expr * a, expr * b, expr_ref & out, expr_ref & cout);
Expand Down Expand Up @@ -120,12 +119,8 @@ class bit_blaster_tpl : public Cfg {
void mk_smul_no_underflow(unsigned sz, expr * const * a_bits, expr * const * b_bits, expr_ref & out);
void mk_comp(unsigned sz, expr * const * a_bits, expr * const * b_bits, expr_ref_vector & out_bits);

void mk_carry_save_adder(unsigned sz, expr * const * a_bits, expr * const * b_bits, expr * const * c_bits, expr_ref_vector & sum_bits, expr_ref_vector & carry_bits);
bool mk_const_multiplier(unsigned sz, expr * const * a_bits, expr * const * b_bits, expr_ref_vector & out_bits);
bool mk_const_case_multiplier(unsigned sz, expr * const * a_bits, expr * const * b_bits, expr_ref_vector & out_bits);
void mk_const_case_multiplier(bool is_a, unsigned i, unsigned sz, ptr_buffer<expr, 128>& a_bits, ptr_buffer<expr, 128>& b_bits, expr_ref_vector & out_bits);
void mk_const_multiplier(numeral & a, expr_ref_vector & b_bits, bit_blaster_adder & result);

bool is_bool_const(expr* e) const { return m().is_true(e) || m().is_false(e); }
void mk_abs(unsigned sz, expr * const * a_bits, expr_ref_vector & out_bits);
};

Loading