Skip to content

Commit

Permalink
Merge pull request #96 from JuliaSymbolics/fix-binarization
Browse files Browse the repository at this point in the history
Fix binarization type inconsistency bug issue #95
  • Loading branch information
0x0f0f0f authored Jan 24, 2022
2 parents dc38591 + 82915f6 commit f316739
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 7 deletions.
10 changes: 3 additions & 7 deletions src/ematch_compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ module EMatchCompiler

using AutoHashEquals
using TermInterface
using Metatheory: alwaystrue, binarize
using Metatheory: alwaystrue, binarize, binarize_rec
using Metatheory.Patterns

abstract type Instruction end
Expand Down Expand Up @@ -103,8 +103,6 @@ export Fail


function compile_ground!(reg, p::PatTerm, prog)
p = binarize(p)

if haskey(prog.ground_terms, p)
# push!(prog.instructions, CheckClassEq(reg, prog.ground_terms[p]))
return nothing
Expand Down Expand Up @@ -146,8 +144,6 @@ end
# =============================================

function compile_pat!(reg, p::PatTerm, prog)
p = binarize(p)

if haskey(prog.ground_terms, p)
push!(prog.instructions, CheckClassEq(reg, prog.ground_terms[p]))
return nothing
Expand Down Expand Up @@ -203,15 +199,15 @@ function compile_pat!(reg, p::Any, prog)
push!(prog.instructions, CheckClassEq(reg, prog.ground_terms[p]))
return nothing
end
@error "This shouldn't be printed. Report an issue for ematching literals"
@error "This shouldn't be printed. Report an issue for ematching literals" p
end


#= ====================================================================================== =#

# EXPECTS INDEXES OF PATTERN VARIABLES TO BE ALREADY POPULATED
function compile_pat(p)
p = binarize(p)
p = binarize_rec(p)
pvars = patvars(p)
nvars = length(pvars)

Expand Down
19 changes: 19 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,25 @@ function binarize(e::T) where {T}
return e
end

"""
Recursive version of binarize
"""
function binarize_rec(e::T) where {T}
!istree(e) && return e
head = exprhead(e)
op = operation(e)
args = map(binarize_rec, arguments(e))
meta = metadata(e)
if head == :call
if op binarize_ops && arity(e) > 2
return foldl((x,y) -> similarterm(T, op, [x,y], symtype(e); metadata=meta, exprhead=head), args)
end
end
return similarterm(T, op, args, symtype(e); metadata=meta, exprhead=head)
end



const binarize_ops = [:(+), :(*), (+), (*)]

function cleanast(e::Expr)
Expand Down

0 comments on commit f316739

Please sign in to comment.