Skip to content

Commit

Permalink
map/mapi/foldl/foldli
Browse files Browse the repository at this point in the history
Signed-off-by: Nikolaj Bjorner <[email protected]>
  • Loading branch information
NikolajBjorner committed May 4, 2022
1 parent b3e0213 commit 87d2a3b
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 2 deletions.
2 changes: 2 additions & 0 deletions src/ast/array_decl_plugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,8 @@ class array_util : public array_recognizers {
func_decl * mk_array_ext(sort* domain, unsigned i);

sort * mk_array_sort(sort* dom, sort* range) { return mk_array_sort(1, &dom, range); }
sort * mk_array_sort(sort* a, sort* b, sort* range) { sort* dom[2] = { a, b }; return mk_array_sort(2, dom, range); }
sort * mk_array_sort(sort* a, sort* b, sort* c, sort* range) { sort* dom[3] = { a, b, c}; return mk_array_sort(3, dom, range); }

sort * mk_array_sort(unsigned arity, sort* const* domain, sort* range);

Expand Down
121 changes: 121 additions & 0 deletions src/ast/rewriter/seq_rewriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,22 @@ br_status seq_rewriter::mk_app_core(func_decl * f, unsigned num_args, expr * con
SASSERT(num_args == 3);
st = mk_seq_replace_all(args[0], args[1], args[2], result);
break;
case OP_SEQ_MAP:
SASSERT(num_args == 2);
st = mk_seq_map(args[0], args[1], result);
break;
case OP_SEQ_MAPI:
SASSERT(num_args == 3);
st = mk_seq_mapi(args[0], args[1], args[2], result);
break;
case OP_SEQ_FOLDL:
SASSERT(num_args == 3);
st = mk_seq_foldl(args[0], args[1], args[2], result);
break;
case OP_SEQ_FOLDLI:
SASSERT(num_args == 4);
st = mk_seq_foldli(args[0], args[1], args[2], args[3], result);
break;
case OP_SEQ_REPLACE_RE:
SASSERT(num_args == 3);
st = mk_seq_replace_re(args[0], args[1], args[2], result);
Expand Down Expand Up @@ -850,6 +866,14 @@ br_status seq_rewriter::mk_seq_length(expr* a, expr_ref& result) {
result = str().mk_length(x);
return BR_REWRITE1;
}
if (str().is_map(a, x, y)) {
result = str().mk_length(y);
return BR_REWRITE1;
}
if (str().is_mapi(a, x, y, z)) {
result = str().mk_length(z);
return BR_REWRITE1;
}
#if 0
expr* s = nullptr, *offset = nullptr, *length = nullptr;
if (str().is_extract(a, s, offset, length)) {
Expand Down Expand Up @@ -1640,6 +1664,13 @@ br_status seq_rewriter::mk_seq_nth_i(expr* a, expr* b, expr_ref& result) {
return BR_REWRITE1;
}

expr* f, *s;
if (str().is_map(a, f, s)) {
expr* args[2] = { f, str().mk_nth_i(s, b) };
result = array_util(m()).mk_select(2, args);
return BR_REWRITE1;
}

expr_ref_vector as(m());
str().get_concat_units(a, as);

Expand Down Expand Up @@ -2008,6 +2039,96 @@ br_status seq_rewriter::mk_seq_replace_all(expr* a, expr* b, expr* c, expr_ref&
return BR_FAILED;
}

/**
rewrites for map(f, s):
map(f, []) = []
map(f, [x]) = [f(x)]
map(f, s + t) = map(f, s) + map(f, t)
len(map(f, s)) = len(s)
nth_i(map(f,s), i) = f(nth_i(s, i))
*/
br_status seq_rewriter::mk_seq_map(expr* f, expr* seqA, expr_ref& result) {
if (str().is_empty(seqA)) {
result = str().mk_empty(get_array_range(f->get_sort()));
return BR_DONE;
}
expr* a, *s1, *s2;
if (str().is_unit(seqA, a)) {
array_util array(m());
expr* args[2] = { f, a };
result = str().mk_unit(array.mk_select(2, args));
return BR_REWRITE2;
}
if (str().is_concat(seqA, s1, s2)) {
result = str().mk_concat(str().mk_map(f, s1), str().mk_map(f, s2));
return BR_REWRITE2;
}
return BR_FAILED;
}

br_status seq_rewriter::mk_seq_mapi(expr* f, expr* i, expr* seqA, expr_ref& result) {
if (str().is_empty(seqA)) {
result = str().mk_empty(get_array_range(f->get_sort()));
return BR_DONE;
}
expr* a, *s1, *s2;
if (str().is_unit(seqA, a)) {
array_util array(m());
expr* args[3] = { f, i, a };
result = str().mk_unit(array.mk_select(3, args));
return BR_REWRITE2;
}
if (str().is_concat(seqA, s1, s2)) {
expr_ref j(m_autil.mk_add(i, str().mk_length(s1)), m());
result = str().mk_concat(str().mk_mapi(f, i, s1), str().mk_mapi(f, j, s2));
return BR_REWRITE2;
}
return BR_FAILED;
}

br_status seq_rewriter::mk_seq_foldl(expr* f, expr* b, expr* seqA, expr_ref& result) {
if (str().is_empty(seqA)) {
result = b;
return BR_DONE;
}
expr* a, *s1, *s2;
if (str().is_unit(seqA, a)) {
array_util array(m());
expr* args[3] = { f, b, a };
result = array.mk_select(3, args);
return BR_REWRITE1;
}
if (str().is_concat(seqA, s1, s2)) {
result = str().mk_foldl(f, b, s1);
result = str().mk_foldl(f, result, s2);
return BR_REWRITE3;
}
return BR_FAILED;
}

br_status seq_rewriter::mk_seq_foldli(expr* f, expr* i, expr* b, expr* seqA, expr_ref& result) {
if (str().is_empty(seqA)) {
result = b;
return BR_DONE;
}
expr* a, *s1, *s2;
if (str().is_unit(seqA, a)) {
array_util array(m());
expr* args[4] = { f, i, b, a };
result = array.mk_select(4, args);
return BR_REWRITE1;
}
if (str().is_concat(seqA, s1, s2)) {
expr_ref j(m_autil.mk_add(i, str().mk_length(s1)), m());
result = str().mk_foldli(f, i, b, s1);
result = str().mk_foldli(f, j, result, s2);
return BR_REWRITE3;
}
return BR_FAILED;
}

/*
* Returns false if s is not a single unit value or concatenation of unit values.
* Else extracts the units from s into vals and returns true.
Expand Down
4 changes: 4 additions & 0 deletions src/ast/rewriter/seq_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,10 @@ class seq_rewriter {
br_status mk_seq_replace_re(expr* a, expr* b, expr* c, expr_ref& result);
br_status mk_seq_prefix(expr* a, expr* b, expr_ref& result);
br_status mk_seq_suffix(expr* a, expr* b, expr_ref& result);
br_status mk_seq_map(expr* f, expr* s, expr_ref& result);
br_status mk_seq_mapi(expr* f, expr* i, expr* s, expr_ref& result);
br_status mk_seq_foldl(expr* f, expr* b, expr* s, expr_ref& result);
br_status mk_seq_foldli(expr* f, expr* i, expr* b, expr* s, expr_ref& result);
br_status mk_str_units(func_decl* f, expr_ref& result);
br_status mk_str_itos(expr* a, expr_ref& result);
br_status mk_str_stoi(expr* a, expr_ref& result);
Expand Down
25 changes: 24 additions & 1 deletion src/ast/seq_decl_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,18 +182,26 @@ sort* seq_decl_plugin::apply_binding(ptr_vector<sort> const& binding, sort* s) {
void seq_decl_plugin::init() {
if (m_init) return;
ast_manager& m = *m_manager;
array_util autil(m);
m_init = true;
sort* A = m.mk_uninterpreted_sort(symbol(0u));
sort* B = m.mk_uninterpreted_sort(symbol(1u));
sort* strT = m_string;
parameter paramA(A);
parameter paramB(B);
parameter paramS(strT);
sort* seqA = m.mk_sort(m_family_id, SEQ_SORT, 1, &paramA);
sort* seqB = m.mk_sort(m_family_id, SEQ_SORT, 1, &paramB);
parameter paramSA(seqA);
sort* reA = m.mk_sort(m_family_id, RE_SORT, 1, &paramSA);
sort* reT = m.mk_sort(m_family_id, RE_SORT, 1, &paramS);
sort* boolT = m.mk_bool_sort();
sort* intT = arith_util(m).mk_int();
sort* predA = array_util(m).mk_array_sort(A, boolT);
sort* predA = autil.mk_array_sort(A, boolT);
sort* arrAB = autil.mk_array_sort(A, B);
sort* arrIAB = autil.mk_array_sort(intT, A, B);
sort* arrBAB = autil.mk_array_sort(B, A, B);
sort* arrIBAB = autil.mk_array_sort(intT, B, A, B);
sort* seqAseqAseqA[3] = { seqA, seqA, seqA };
sort* seqAreAseqA[3] = { seqA, reA, seqA };
sort* seqAseqA[2] = { seqA, seqA };
Expand All @@ -209,6 +217,11 @@ void seq_decl_plugin::init() {
sort* str2TintT[3] = { strT, strT, intT };
sort* seqAintT[2] = { seqA, intT };
sort* seq3A[3] = { seqA, seqA, seqA };
sort* arrABseqA[2] = { arrAB, seqA };
sort* arrIABintTseqA[3] = { arrIAB, intT, seqA };
sort* arrBAB_BseqA[3] = { arrBAB, B,seqA };
sort* arrIBABintTBseqA[4] = { arrIBAB, intT, B, seqA };

m_sigs.resize(LAST_SEQ_OP);
// TBD: have (par ..) construct and load parameterized signature from premable.
m_sigs[OP_SEQ_UNIT] = alloc(psig, m, "seq.unit", 1, 1, &A, seqA);
Expand All @@ -226,6 +239,10 @@ void seq_decl_plugin::init() {
m_sigs[OP_SEQ_NTH_I] = alloc(psig, m, "seq.nth_i", 1, 2, seqAintT, A);
m_sigs[OP_SEQ_NTH_U] = alloc(psig, m, "seq.nth_u", 1, 2, seqAintT, A);
m_sigs[OP_SEQ_LENGTH] = alloc(psig, m, "seq.len", 1, 1, &seqA, intT);
m_sigs[OP_SEQ_MAP] = alloc(psig, m, "seq.map", 2, 2, arrABseqA, seqB);
m_sigs[OP_SEQ_MAPI] = alloc(psig, m, "seq.mapi", 2, 3, arrIABintTseqA, seqB);
m_sigs[OP_SEQ_FOLDL] = alloc(psig, m, "seq.fold_left", 2, 3, arrBAB_BseqA, B);
m_sigs[OP_SEQ_FOLDLI] = alloc(psig, m, "seq.fold_leftli", 2, 4, arrIBABintTBseqA, B);
m_sigs[OP_RE_PLUS] = alloc(psig, m, "re.+", 1, 1, &reA, reA);
m_sigs[OP_RE_STAR] = alloc(psig, m, "re.*", 1, 1, &reA, reA);
m_sigs[OP_RE_OPTION] = alloc(psig, m, "re.opt", 1, 1, &reA, reA);
Expand Down Expand Up @@ -582,6 +599,12 @@ func_decl* seq_decl_plugin::mk_func_decl(decl_kind k, unsigned num_parameters, p
case _OP_STRING_STRCTN:
return mk_str_fun(k, arity, domain, range, OP_SEQ_CONTAINS);

case OP_SEQ_MAP:
case OP_SEQ_MAPI:
case OP_SEQ_FOLDL:
case OP_SEQ_FOLDLI:
return mk_str_fun(k, arity, domain, range, k);

case OP_SEQ_TO_RE:
m_has_re = true;
return mk_seq_fun(k, arity, domain, range, _OP_STRING_TO_REGEXP);
Expand Down
17 changes: 16 additions & 1 deletion src/ast/seq_decl_plugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@ enum seq_op_kind {
OP_SEQ_REPLACE_RE_ALL, // Seq -> RegEx -> Seq -> Seq
OP_SEQ_REPLACE_RE, // Seq -> RegEx -> Seq -> Seq
OP_SEQ_REPLACE_ALL, // Seq -> Seq -> Seq -> Seq

OP_SEQ_MAP, // Array[A,B] -> Seq[A] -> Seq[B]
OP_SEQ_MAPI, // Array[Int,A,B] -> Int -> Seq[A] -> Seq[B]
OP_SEQ_FOLDL, // Array[B,A,B] -> B -> Seq[A] -> B
OP_SEQ_FOLDLI, // Array[Int,B,A,B] -> Int -> B -> Seq[A] -> B

OP_RE_PLUS,
OP_RE_STAR,
OP_RE_OPTION,
Expand Down Expand Up @@ -296,6 +300,10 @@ class seq_util {
app* mk_nth_i(expr* s, expr* i) const { expr* es[2] = { s, i }; return m.mk_app(m_fid, OP_SEQ_NTH_I, 2, es); }
app* mk_nth_u(expr* s, expr* i) const { expr* es[2] = { s, i }; return m.mk_app(m_fid, OP_SEQ_NTH_U, 2, es); }
app* mk_nth_c(expr* s, unsigned i) const;
app* mk_map(expr* f, expr* s) const { expr* es[2] = { f, s }; return m.mk_app(m_fid, OP_SEQ_MAP, 2, es); }
app* mk_mapi(expr* f, expr* i, expr* s) const { expr* es[3] = { f, i, s }; return m.mk_app(m_fid, OP_SEQ_MAPI, 3, es); }
app* mk_foldl(expr* f, expr* b, expr* s) const { expr* es[3] = { f, b, s }; return m.mk_app(m_fid, OP_SEQ_FOLDL, 3, es); }
app* mk_foldli(expr* f, expr* i, expr* b, expr* s) const { expr* es[4] = { f, i, b, s }; return m.mk_app(m_fid, OP_SEQ_FOLDLI, 4, es); }

app* mk_substr(expr* a, expr* b, expr* c) const { expr* es[3] = { a, b, c }; return m.mk_app(m_fid, OP_SEQ_EXTRACT, 3, es); }
app* mk_contains(expr* a, expr* b) const { expr* es[2] = { a, b }; return m.mk_app(m_fid, OP_SEQ_CONTAINS, 2, es); }
Expand Down Expand Up @@ -333,6 +341,10 @@ class seq_util {
}
bool is_concat(expr const* n) const { return is_app_of(n, m_fid, OP_SEQ_CONCAT); }
bool is_length(expr const* n) const { return is_app_of(n, m_fid, OP_SEQ_LENGTH); }
bool is_map(expr const* n) const { return is_app_of(n, m_fid, OP_SEQ_MAP); }
bool is_mapi(expr const* n) const { return is_app_of(n, m_fid, OP_SEQ_MAPI); }
bool is_foldl(expr const* n) const { return is_app_of(n, m_fid, OP_SEQ_FOLDL); }
bool is_foldli(expr const* n) const { return is_app_of(n, m_fid, OP_SEQ_FOLDLI); }
bool is_extract(expr const* n) const { return is_app_of(n, m_fid, OP_SEQ_EXTRACT); }
bool is_contains(expr const* n) const { return is_app_of(n, m_fid, OP_SEQ_CONTAINS); }
bool is_at(expr const* n) const { return is_app_of(n, m_fid, OP_SEQ_AT); }
Expand Down Expand Up @@ -384,6 +396,9 @@ class seq_util {
MATCH_BINARY(is_nth_u);
MATCH_BINARY(is_index);
MATCH_TERNARY(is_index);
MATCH_BINARY(is_map);
MATCH_TERNARY(is_mapi);
MATCH_TERNARY(is_foldl);
MATCH_BINARY(is_last_index);
MATCH_TERNARY(is_replace);
MATCH_TERNARY(is_replace_re);
Expand Down

0 comments on commit 87d2a3b

Please sign in to comment.