From 0761018225eebdf110b84ab0c99e3a133dece958 Mon Sep 17 00:00:00 2001 From: Willow Ahrens Date: Fri, 8 Nov 2024 16:31:26 -0500 Subject: [PATCH 1/7] quick stopping point --- src/FinchNotation/instances.jl | 5 ++- src/FinchNotation/nodes.jl | 1 + src/FinchNotation/syntax.jl | 73 ++++++++++++++++++++++++++++++++++ 3 files changed, 78 insertions(+), 1 deletion(-) diff --git a/src/FinchNotation/instances.jl b/src/FinchNotation/instances.jl index 108b06e1d..96cb93837 100644 --- a/src/FinchNotation/instances.jl +++ b/src/FinchNotation/instances.jl @@ -111,12 +111,15 @@ function Base.show(io::IO, node::FinchNodeInstance) end function Base.show(io::IO, mime::MIME"text/plain", node::FinchNodeInstance) - print(io, "Finch program instance: ") + print(io, "Finch program instance") + show(io, Finch.striplines(finch_unparse_program(Finch.FinchCompiler(), node))) + #= if isstateful(node) display_statement(io, mime, node, 0) else display_expression(io, mime, node) end + =# end Base.:(==)(a::VariableInstance, b::VariableInstance) = false diff --git a/src/FinchNotation/nodes.jl b/src/FinchNotation/nodes.jl index 71bf4f1fb..ce1e34d39 100644 --- a/src/FinchNotation/nodes.jl +++ b/src/FinchNotation/nodes.jl @@ -329,6 +329,7 @@ end function Base.show(io::IO, mime::MIME"text/plain", node::FinchNode) print(io, "Finch program: ") + show(io, finch_unparse_program(JuliaContext(), node)) if isstateful(node) display_statement(io, mime, node, 0) else diff --git a/src/FinchNotation/syntax.jl b/src/FinchNotation/syntax.jl index f529f66b8..b3c757b7c 100644 --- a/src/FinchNotation/syntax.jl +++ b/src/FinchNotation/syntax.jl @@ -438,3 +438,76 @@ function display_statement(io, mime, node::Union{FinchNode, FinchNodeInstance}, error("unimplemented") end end + +finch_unparse_program(ctx, node) = finch_unparse_program(ctx, finch_leaf(node)) +function finch_unparse_program(ctx, node::Union{FinchNode, FinchNodeInstance}) + if operation(node) === value + node.val + elseif operation(node) === literal + node.val + elseif operation(node) === index + node.name + elseif operation(node) === variable + node.name + elseif operation(node) === cached + finch_unparse_program(ctx, node.arg) + elseif operation(node) === tag + @assert operation(node.var) === variable + node.var.name + elseif operation(node) === virtual + if node.val == dimless + :_ + else + ctx(node) + end + elseif operation(node) === access + tns = finch_unparse_program(ctx, node.tns) + idxs = map(x -> finch_unparse_program(ctx, x), node.idxs) + :($tns[$(idxs...)]) + elseif operation(node) === call + op = finch_unparse_program(ctx, node.op) + args = map(x -> finch_unparse_program(ctx, x), node.args) + :($op($(args...))) + elseif operation(node) === loop + idx = finch_unparse_program(ctx, node.idx) + ext = finch_unparse_program(ctx, node.ext) + body = finch_unparse_program(ctx, node.body) + :(for $idx = $ext; $body end) + elseif operation(node) === define + lhs = finch_unparse_program(ctx, node.lhs) + rhs = finch_unparse_program(ctx, node.rhs) + body = finch_unparse_program(ctx, node.body) + :(let $lhs = $rhs; $body end) + elseif operation(node) === sieve + cond = finch_unparse_program(ctx, node.cond) + body = finch_unparse_program(ctx, node.body) + :(if $cond; $body end) + elseif operation(node) === assign + lhs = finch_unparse_program(ctx, node.lhs) + op = finch_unparse_program(ctx, node.op) + rhs = finch_unparse_program(ctx, node.rhs) + if haskey(incs, op) + Expr(incs[op], lhs, rhs) + else + :($lhs <<$op>>= $rhs) + end + elseif operation(node) === declare + tns = finch_unparse_program(ctx, node.tns) + init = finch_unparse_program(ctx, node.init) + :($tns .= $init) + elseif operation(node) === freeze + tns = finch_unparse_program(ctx, node.tns) + :(@freeze($tns)) + elseif operation(node) === thaw + tns = finch_unparse_program(ctx, node.tns) + :(@thaw($tns)) + elseif operation(node) === yieldbind + args = map(x -> finch_unparse_program(ctx, x), node.args) + :(return($(args...))) + elseif operation(node) === block + bodies = map(x -> finch_unparse_program(ctx, x), node.bodies) + Expr(:block, bodies...) + else + error("unimplemented") + end +end \ No newline at end of file From 2da118f316c9a92b9b0be66b24cdf9b9d39db8df Mon Sep 17 00:00:00 2001 From: kylebd99 Date: Tue, 12 Nov 2024 11:50:01 -0800 Subject: [PATCH 2/7] Add tensor instance to deferred --- src/FinchLogic/nodes.jl | 5 ++++- src/scheduler/LogicExecutor.jl | 6 +++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/FinchLogic/nodes.jl b/src/FinchLogic/nodes.jl index 28529198d..279ead8f8 100644 --- a/src/FinchLogic/nodes.jl +++ b/src/FinchLogic/nodes.jl @@ -205,6 +205,8 @@ function LogicNode(kind::LogicNodeKind, args::Vector) return LogicNode(kind, args[1], Any, LogicNode[]) elseif kind === deferred && length(args) == 2 return LogicNode(kind, args[1], args[2], LogicNode[]) + elseif kind === deferred && length(args) == 3 + return LogicNode(kind, (args[1], args[3]), args[2], LogicNode[]) else args = LogicNode_concatenate_args(args) if (kind === table && length(args) >= 1) || @@ -231,7 +233,8 @@ end function Base.getproperty(node::LogicNode, sym::Symbol) if sym === :kind || sym === :val || sym === :type || sym === :children return Base.getfield(node, sym) - elseif node.kind === deferred && sym === :ex node.val + elseif node.kind === deferred && sym === :ex node.val isa Tuple ? node.val[1] : node.val[2] + elseif node.kind === deferred && sym === :imm node.val[2] elseif node.kind === field && sym === :name node.val::Symbol elseif node.kind === alias && sym === :name node.val::Symbol elseif node.kind === table && sym === :tns node.children[1] diff --git a/src/scheduler/LogicExecutor.jl b/src/scheduler/LogicExecutor.jl index 9ba8f064c..434c59aaf 100644 --- a/src/scheduler/LogicExecutor.jl +++ b/src/scheduler/LogicExecutor.jl @@ -6,7 +6,7 @@ is given as input to the program. """ function defer_tables(ex, node::LogicNode) if @capture node table(~tns::isimmediate, ~idxs...) - table(deferred(:($ex.tns.val), typeof(tns.val)), map(enumerate(node.idxs)) do (i, idx) + table(deferred(:($ex.tns.val), typeof(tns.val), tns.val), map(enumerate(node.idxs)) do (i, idx) defer_tables(:($ex.idxs[$i]), idx) end) elseif istree(node) @@ -29,7 +29,7 @@ function cache_deferred!(ctx, root::LogicNode) get!(seen, node.val) do var = freshen(ctx, :V) push_preamble!(ctx, :($var = $(node.ex)::$(node.type))) - deferred(var, node.type) + deferred(var, node.type, node.imm) end end))(root) end @@ -90,4 +90,4 @@ end function (ctx::LogicExecutorCode)(prgm) return logic_executor_code(ctx.ctx, prgm) -end \ No newline at end of file +end From c6d4f8c9d79b688bebb6eceb44350e5a3d90d01c Mon Sep 17 00:00:00 2001 From: kylebd99 Date: Tue, 12 Nov 2024 12:27:17 -0800 Subject: [PATCH 3/7] fix .ex property --- src/FinchLogic/nodes.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/FinchLogic/nodes.jl b/src/FinchLogic/nodes.jl index 279ead8f8..88ed5246e 100644 --- a/src/FinchLogic/nodes.jl +++ b/src/FinchLogic/nodes.jl @@ -233,7 +233,7 @@ end function Base.getproperty(node::LogicNode, sym::Symbol) if sym === :kind || sym === :val || sym === :type || sym === :children return Base.getfield(node, sym) - elseif node.kind === deferred && sym === :ex node.val isa Tuple ? node.val[1] : node.val[2] + elseif node.kind === deferred && sym === :ex node.val isa Tuple ? node.val[1] : node.val elseif node.kind === deferred && sym === :imm node.val[2] elseif node.kind === field && sym === :name node.val::Symbol elseif node.kind === alias && sym === :name node.val::Symbol From 522da3f6fe33cde072cd2937a3b57ff89107cee4 Mon Sep 17 00:00:00 2001 From: Willow Ahrens Date: Wed, 13 Nov 2024 11:38:30 -0500 Subject: [PATCH 4/7] Update instances.jl --- src/FinchNotation/instances.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/FinchNotation/instances.jl b/src/FinchNotation/instances.jl index 96cb93837..6dabc00e6 100644 --- a/src/FinchNotation/instances.jl +++ b/src/FinchNotation/instances.jl @@ -112,18 +112,15 @@ end function Base.show(io::IO, mime::MIME"text/plain", node::FinchNodeInstance) print(io, "Finch program instance") - show(io, Finch.striplines(finch_unparse_program(Finch.FinchCompiler(), node))) - #= if isstateful(node) display_statement(io, mime, node, 0) else display_expression(io, mime, node) end - =# end Base.:(==)(a::VariableInstance, b::VariableInstance) = false Base.:(==)(a::VariableInstance{tag}, b::VariableInstance{tag}) where {tag} = true function Base.:(==)(a::FinchNodeInstance, b::FinchNodeInstance) return operation(a) == operation(b) && arguments(a) == arguments(b) -end \ No newline at end of file +end From d74b033569b18e7c795b2b5d8fc868b375606245 Mon Sep 17 00:00:00 2001 From: Willow Ahrens Date: Wed, 13 Nov 2024 11:38:49 -0500 Subject: [PATCH 5/7] Update instances.jl --- src/FinchNotation/instances.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/FinchNotation/instances.jl b/src/FinchNotation/instances.jl index 6dabc00e6..fbc4b17de 100644 --- a/src/FinchNotation/instances.jl +++ b/src/FinchNotation/instances.jl @@ -111,7 +111,7 @@ function Base.show(io::IO, node::FinchNodeInstance) end function Base.show(io::IO, mime::MIME"text/plain", node::FinchNodeInstance) - print(io, "Finch program instance") + print(io, "Finch program instance: ") if isstateful(node) display_statement(io, mime, node, 0) else From 9df9cfbc49f3a690097d4280a405088e76827a92 Mon Sep 17 00:00:00 2001 From: Willow Ahrens Date: Wed, 13 Nov 2024 11:39:11 -0500 Subject: [PATCH 6/7] Update nodes.jl --- src/FinchNotation/nodes.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/FinchNotation/nodes.jl b/src/FinchNotation/nodes.jl index ce1e34d39..71bf4f1fb 100644 --- a/src/FinchNotation/nodes.jl +++ b/src/FinchNotation/nodes.jl @@ -329,7 +329,6 @@ end function Base.show(io::IO, mime::MIME"text/plain", node::FinchNode) print(io, "Finch program: ") - show(io, finch_unparse_program(JuliaContext(), node)) if isstateful(node) display_statement(io, mime, node, 0) else From a5d6a6a8c255815d24fb77683d0a406289f8179b Mon Sep 17 00:00:00 2001 From: = <=> Date: Thu, 14 Nov 2024 10:43:17 -0800 Subject: [PATCH 7/7] add finch_unparse test --- test/test_interface.jl | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/test/test_interface.jl b/test/test_interface.jl index 4c46ce97f..3d9a0ba21 100644 --- a/test/test_interface.jl +++ b/test/test_interface.jl @@ -1,9 +1,20 @@ -using Finch: AsArray +using Finch: AsArray, JuliaContext +using Finch.FinchNotation: finch_unparse_program, @finch_program_instance @testset "interface" begin @info "Testing Finch Interface" + @testset "finch_unparse" begin + prgm = @finch_program quote + A .= 0 + for i = _ + A[i] += 1 + end + end + @test prgm.val == @finch_program $(finch_unparse_program(JuliaContext(), prgm)) + end + #https://github.com/finch-tensor/Finch.jl/issues/383 let A = [0.0 0.0 4.4; 1.1 0.0 0.0; 2.2 0.0 5.5; 3.3 0.0 0.0] @@ -814,4 +825,6 @@ using Finch: AsArray B = dropfills!(swizzle(A, 2, 1), [0.0 0.0 4.4; 1.1 0.0 0.0; 2.2 0.0 5.5; 3.3 0.0 0.0]) @test B == swizzle(Tensor(Dense{Int64}(SparseList{Int64}(Element{0.0, Float64, Int64}([4.4, 1.1, 2.2, 5.5, 3.3]), 3, [1, 2, 3, 5, 6], [3, 1, 1, 3, 1]), 4)), 2, 1) end + + end