Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kbd small finch logic changes #636

Merged
merged 8 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/FinchLogic/nodes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,8 @@ function LogicNode(kind::LogicNodeKind, args::Vector)
return LogicNode(kind, args[1], Any, LogicNode[])
elseif kind === deferred && length(args) == 2
return LogicNode(kind, args[1], args[2], LogicNode[])
elseif kind === deferred && length(args) == 3
return LogicNode(kind, (args[1], args[3]), args[2], LogicNode[])
else
args = LogicNode_concatenate_args(args)
if (kind === table && length(args) >= 1) ||
Expand All @@ -231,7 +233,8 @@ end
function Base.getproperty(node::LogicNode, sym::Symbol)
if sym === :kind || sym === :val || sym === :type || sym === :children
return Base.getfield(node, sym)
elseif node.kind === deferred && sym === :ex node.val
elseif node.kind === deferred && sym === :ex node.val isa Tuple ? node.val[1] : node.val
elseif node.kind === deferred && sym === :imm node.val[2]
elseif node.kind === field && sym === :name node.val::Symbol
elseif node.kind === alias && sym === :name node.val::Symbol
elseif node.kind === table && sym === :tns node.children[1]
Expand Down
2 changes: 1 addition & 1 deletion src/FinchNotation/instances.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,4 +123,4 @@ Base.:(==)(a::VariableInstance, b::VariableInstance) = false
Base.:(==)(a::VariableInstance{tag}, b::VariableInstance{tag}) where {tag} = true
function Base.:(==)(a::FinchNodeInstance, b::FinchNodeInstance)
return operation(a) == operation(b) && arguments(a) == arguments(b)
end
end
73 changes: 73 additions & 0 deletions src/FinchNotation/syntax.jl
Original file line number Diff line number Diff line change
Expand Up @@ -438,3 +438,76 @@ function display_statement(io, mime, node::Union{FinchNode, FinchNodeInstance},
error("unimplemented")
end
end

finch_unparse_program(ctx, node) = finch_unparse_program(ctx, finch_leaf(node))
function finch_unparse_program(ctx, node::Union{FinchNode, FinchNodeInstance})
if operation(node) === value
node.val
elseif operation(node) === literal
node.val
elseif operation(node) === index
node.name
elseif operation(node) === variable
node.name
elseif operation(node) === cached
finch_unparse_program(ctx, node.arg)
elseif operation(node) === tag
@assert operation(node.var) === variable
node.var.name
elseif operation(node) === virtual
if node.val == dimless
:_
else
ctx(node)
end
elseif operation(node) === access
tns = finch_unparse_program(ctx, node.tns)
idxs = map(x -> finch_unparse_program(ctx, x), node.idxs)
:($tns[$(idxs...)])
elseif operation(node) === call
op = finch_unparse_program(ctx, node.op)
args = map(x -> finch_unparse_program(ctx, x), node.args)
:($op($(args...)))
elseif operation(node) === loop
idx = finch_unparse_program(ctx, node.idx)
ext = finch_unparse_program(ctx, node.ext)
body = finch_unparse_program(ctx, node.body)
:(for $idx = $ext; $body end)
elseif operation(node) === define
lhs = finch_unparse_program(ctx, node.lhs)
rhs = finch_unparse_program(ctx, node.rhs)
body = finch_unparse_program(ctx, node.body)
:(let $lhs = $rhs; $body end)
elseif operation(node) === sieve
cond = finch_unparse_program(ctx, node.cond)
body = finch_unparse_program(ctx, node.body)
:(if $cond; $body end)
elseif operation(node) === assign
lhs = finch_unparse_program(ctx, node.lhs)
op = finch_unparse_program(ctx, node.op)
rhs = finch_unparse_program(ctx, node.rhs)
if haskey(incs, op)
Expr(incs[op], lhs, rhs)
else
:($lhs <<$op>>= $rhs)
end
elseif operation(node) === declare
tns = finch_unparse_program(ctx, node.tns)
init = finch_unparse_program(ctx, node.init)
:($tns .= $init)
elseif operation(node) === freeze
tns = finch_unparse_program(ctx, node.tns)
:(@freeze($tns))
elseif operation(node) === thaw
tns = finch_unparse_program(ctx, node.tns)
:(@thaw($tns))
elseif operation(node) === yieldbind
args = map(x -> finch_unparse_program(ctx, x), node.args)
:(return($(args...)))
elseif operation(node) === block
bodies = map(x -> finch_unparse_program(ctx, x), node.bodies)
Expr(:block, bodies...)
else
error("unimplemented")
end
end
6 changes: 3 additions & 3 deletions src/scheduler/LogicExecutor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ is given as input to the program.
"""
function defer_tables(ex, node::LogicNode)
if @capture node table(~tns::isimmediate, ~idxs...)
table(deferred(:($ex.tns.val), typeof(tns.val)), map(enumerate(node.idxs)) do (i, idx)
table(deferred(:($ex.tns.val), typeof(tns.val), tns.val), map(enumerate(node.idxs)) do (i, idx)
defer_tables(:($ex.idxs[$i]), idx)
end)
elseif istree(node)
Expand All @@ -29,7 +29,7 @@ function cache_deferred!(ctx, root::LogicNode)
get!(seen, node.val) do
var = freshen(ctx, :V)
push_preamble!(ctx, :($var = $(node.ex)::$(node.type)))
deferred(var, node.type)
deferred(var, node.type, node.imm)
end
end))(root)
end
Expand Down Expand Up @@ -90,4 +90,4 @@ end

function (ctx::LogicExecutorCode)(prgm)
return logic_executor_code(ctx.ctx, prgm)
end
end
15 changes: 14 additions & 1 deletion test/test_interface.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
using Finch: AsArray
using Finch: AsArray, JuliaContext
using Finch.FinchNotation: finch_unparse_program, @finch_program_instance

@testset "interface" begin

@info "Testing Finch Interface"

@testset "finch_unparse" begin
prgm = @finch_program quote
A .= 0
for i = _
A[i] += 1
end
end
@test prgm.val == @finch_program $(finch_unparse_program(JuliaContext(), prgm))
end

#https://github.com/finch-tensor/Finch.jl/issues/383
let
A = [0.0 0.0 4.4; 1.1 0.0 0.0; 2.2 0.0 5.5; 3.3 0.0 0.0]
Expand Down Expand Up @@ -814,4 +825,6 @@ using Finch: AsArray
B = dropfills!(swizzle(A, 2, 1), [0.0 0.0 4.4; 1.1 0.0 0.0; 2.2 0.0 5.5; 3.3 0.0 0.0])
@test B == swizzle(Tensor(Dense{Int64}(SparseList{Int64}(Element{0.0, Float64, Int64}([4.4, 1.1, 2.2, 5.5, 3.3]), 3, [1, 2, 3, 5, 6], [3, 1, 1, 3, 1]), 4)), 2, 1)
end


end
Loading