From 18742a23287b280b9555f0812b7ed26ea8aec31e Mon Sep 17 00:00:00 2001 From: Willow Ahrens Date: Mon, 13 May 2024 09:56:20 -0400 Subject: [PATCH 01/10] add test --- test/test_interface.jl | 34 +++++++++++++++++++++++++++++----- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/test/test_interface.jl b/test/test_interface.jl index 9b63c23c5..a0af20399 100644 --- a/test/test_interface.jl +++ b/test/test_interface.jl @@ -3,14 +3,38 @@ using Finch: AsArray @testset "interface" begin @info "Testing Finch Interface" + #https://github.com/willow-ahrens/Finch.jl/issues/535 + let + LEN = 10; + a_raw = rand(LEN, LEN - 5) * 10; + b_raw = rand(LEN, LEN - 5) * 10; + c_raw = rand(LEN, LEN) * 10; + + a = lazy(swizzle(Tensor(a_raw), 1, 2)); + b = lazy(swizzle(Tensor(b_raw), 1, 2)); + c = lazy(swizzle(Tensor(c_raw), 1, 2)); + + ref = reshape(c_raw, 10, 10, 1) .* reshape(a_raw, 10, 1, 5) .* reshape(b_raw, 1, 10, 5); + + plan = c[:, :, nothing] .* a[:, nothing, :] .* b[nothing, :, :]; + @test compute(plan) == ref + + plan = broadcast(*, broadcast(*, c[:, :, nothing], a[:, nothing, :]), b[nothing, :, :]); + @test compute(plan) == ref + end + #https://github.com/willow-ahrens/Finch.jl/issues/536 - A = [1 2; 3 4] - swizzle(lazy(A), 2, 1) == permutedims(A) + let + A = [1 2; 3 4] + swizzle(lazy(A), 2, 1) == permutedims(A) + end #https://github.com/willow-ahrens/Finch.jl/issues/530 - A_tns = Tensor(Dense(Dense(Dense(Element(0.0)))), zeros(3, 3, 3)) - A_sw = swizzle(A_tns, 2, 3, 1) - A_tns == A_sw #fails + let + A_tns = Tensor(Dense(Dense(Dense(Element(0.0)))), zeros(3, 3, 3)) + A_sw = swizzle(A_tns, 2, 3, 1) + A_tns == A_sw #fails + end #https://github.com/willow-ahrens/Finch.jl/issues/524 let From cb0e9f821b7f9ed5d6642d01559f86ca81de30bb Mon Sep 17 00:00:00 2001 From: Willow Ahrens Date: Mon, 13 May 2024 10:06:04 -0400 Subject: [PATCH 02/10] add resolves --- benchmark/runbenchmarks.jl | 3 ++- benchmark/runjudge.jl | 3 ++- test/Project.toml | 1 + test/runtests.jl | 2 ++ 4 files changed, 7 insertions(+), 2 deletions(-) diff --git a/benchmark/runbenchmarks.jl b/benchmark/runbenchmarks.jl index 3105a179b..21facd8e4 100755 --- a/benchmark/runbenchmarks.jl +++ b/benchmark/runbenchmarks.jl @@ -1,8 +1,9 @@ #!/usr/bin/env julia if abspath(PROGRAM_FILE) == @__FILE__ using Pkg - Pkg.develop(PackageSpec(path = joinpath(@__DIR__, ".."))) Pkg.activate(@__DIR__) + Pkg.develop(PackageSpec(path = joinpath(@__DIR__, ".."))) + Pkg.resolve() Pkg.instantiate() end diff --git a/benchmark/runjudge.jl b/benchmark/runjudge.jl index 21dd77482..7cb7b82d0 100755 --- a/benchmark/runjudge.jl +++ b/benchmark/runjudge.jl @@ -1,8 +1,9 @@ #!/usr/bin/env julia if abspath(PROGRAM_FILE) == @__FILE__ using Pkg - Pkg.develop(PackageSpec(path = joinpath(@__DIR__, ".."))) Pkg.activate(@__DIR__) + Pkg.develop(PackageSpec(path = joinpath(@__DIR__, ".."))) + Pkg.resolve() Pkg.instantiate() end diff --git a/test/Project.toml b/test/Project.toml index 94dfae417..164deb9b4 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,6 +6,7 @@ CIndices = "5a98b6c4-18fa-405d-92b3-8277d93fed36" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +Finch = "9177782c-1635-4eb9-9bfb-d9dfa25e6bce" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" diff --git a/test/runtests.jl b/test/runtests.jl index 44218ffb7..866ed3f69 100755 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,6 +2,8 @@ if abspath(PROGRAM_FILE) == @__FILE__ using Pkg Pkg.activate(@__DIR__) + Pkg.develop(PackageSpec(path = joinpath(@__DIR__, ".."))) + Pkg.resolve() Pkg.instantiate() end From 5f6535ad2cd93f61cda02165c714b243110c89f3 Mon Sep 17 00:00:00 2001 From: Willow Ahrens Date: Mon, 13 May 2024 10:18:05 -0400 Subject: [PATCH 03/10] one-liner --- src/interface/lazy.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/interface/lazy.jl b/src/interface/lazy.jl index d9df9a170..d603e223b 100644 --- a/src/interface/lazy.jl +++ b/src/interface/lazy.jl @@ -183,7 +183,7 @@ function broadcast_to_query(tns::LazyTensor{T, N}, idxs) where {T, N} end function broadcast_to_extrude(bc::Broadcast.Broadcasted, n) - any(map(arg -> broadcast_to_extrude(arg, n), bc.args)) + all(map(arg -> broadcast_to_extrude(arg, n), bc.args)) end function broadcast_to_extrude(tns::LazyTensor, n) From efcfbc336c10affd4081f2ba44d0d7b1097c0476 Mon Sep 17 00:00:00 2001 From: Willow Ahrens Date: Mon, 13 May 2024 10:19:18 -0400 Subject: [PATCH 04/10] I hope all this works --- docs/fix.jl | 2 ++ docs/make.jl | 2 ++ 2 files changed, 4 insertions(+) diff --git a/docs/fix.jl b/docs/fix.jl index 95ec22903..12ffbb0b0 100755 --- a/docs/fix.jl +++ b/docs/fix.jl @@ -2,6 +2,8 @@ if abspath(PROGRAM_FILE) == @__FILE__ using Pkg Pkg.activate(@__DIR__) + Pkg.develop(PackageSpec(path = joinpath(@__DIR__, ".."))) + Pkg.resolve() Pkg.instantiate() end diff --git a/docs/make.jl b/docs/make.jl index 8415645d9..af9a5f574 100755 --- a/docs/make.jl +++ b/docs/make.jl @@ -2,6 +2,8 @@ if abspath(PROGRAM_FILE) == @__FILE__ using Pkg Pkg.activate(@__DIR__) + Pkg.develop(PackageSpec(path = joinpath(@__DIR__, ".."))) + Pkg.resolve() Pkg.instantiate() end From 4a4e78d291c6ef0e3f64ec26a59f4da7bddd5dac Mon Sep 17 00:00:00 2001 From: Willow Ahrens Date: Mon, 13 May 2024 10:55:59 -0400 Subject: [PATCH 05/10] cleanup --- src/Finch.jl | 2 +- src/interface/eager.jl | 8 ++++++++ src/interface/lazy.jl | 20 ++++++++++++++------ src/interface/traits.jl | 26 +++++++++++++------------- src/scheduler/optimize.jl | 2 +- 5 files changed, 37 insertions(+), 21 deletions(-) diff --git a/src/Finch.jl b/src/Finch.jl index 1cd0d11d6..63571e800 100644 --- a/src/Finch.jl +++ b/src/Finch.jl @@ -197,7 +197,7 @@ end )) if @load_preference("precompile", true) - @info "Running enhanced precompilation... (to disable, run `using Preferences; Preferences.set_preference(\"Finch\", \"precompile\"=>false)`" + @info "Running enhanced precompilation... (to disable, run `using Preferences; Preferences.set_preferences!(\"Finch\", \"precompile\"=>false)`" include("../test/precompile.jl") end end diff --git a/src/interface/eager.jl b/src/interface/eager.jl index a590d51ae..133929257 100644 --- a/src/interface/eager.jl +++ b/src/interface/eager.jl @@ -115,3 +115,11 @@ function LinearAlgebra.norm(arr::AbstractTensorOrBroadcast, p::Real = 2) return root(sum(broadcasted(power, broadcasted(norm, arr, p), p))) end end + +""" + expanddims(arr::AbstractTensor, dims) + +Expand the dimensions of an array by inserting a new singleton axis or axes that +will appear at the `dims` position in the expanded array shape. +""" +expanddims(arr::AbstractTensor, dims) = compute(expanddims(lazy(arr), dims)) \ No newline at end of file diff --git a/src/interface/lazy.jl b/src/interface/lazy.jl index d603e223b..0b2fbee52 100644 --- a/src/interface/lazy.jl +++ b/src/interface/lazy.jl @@ -26,12 +26,20 @@ function Base.getindex(arr::LazyTensor{T, N}, idxs::Vararg{Union{Nothing, Colon} if length(idxs) - count(isnothing, idxs) != N throw(ArgumentError("Cannot index a lazy tensor with more or fewer `:` dims than it had original dims.")) end - fields = [field(gensym(:i)) for _ in 1:length(idxs)] - original_fields = fields[findall(!isnothing, idxs)] - data = reorder(relabel(arr.data, original_fields...), fields...) - extrude = [true for _ in 1:length(idxs)] - extrude[findall(!isnothing, idxs)] .= arr.extrude - return LazyTensor{T}(data, (extrude...,), arr.default) + return expanddims(arr, findall(isnothing, idxs)) +end + +function expanddims(arr::LazyTensor{T}, dims) where {T} + @assert allunique(dims) + antidims = setdiff(1:ndims(arr) + length(dims), dims) + @assert length(antidims) == ndims(arr) + idxs_1 = [field(gensym(:i)) for _ in 1:ndims(arr)] + idxs_2 = [field(gensym(:i)) for _ in 1:ndims(arr) + length(dims)] + idxs_2[antidims] .= idxs_1 + data_2 = reorder(relabel(arr.data, idxs_1...), idxs_2...) + extrude_2 = [false for _ in 1:ndims(arr) + length(dims)] + extrude_2[antidims] .= arr.extrude + return LazyTensor{T, ndims(arr) + length(dims)}(data_2, tuple(extrude_2...), default(arr)) end function identify(data) diff --git a/src/interface/traits.jl b/src/interface/traits.jl index f3a468a72..eb5412573 100644 --- a/src/interface/traits.jl +++ b/src/interface/traits.jl @@ -130,22 +130,22 @@ Return a storage trait object representing the result of mapping `f` over storage traits `args`. Assumes representation is collapsed. """ function map_rep(f, args...) - map_rep_def(f, map(arg -> pad_data_rep(arg, maximum(ndims, args)), args)) + map_rep_def(f, map(arg -> paddims_rep(arg, maximum(ndims, args)), args)) end -pad_data_rep(rep, n) = ndims(rep) < n ? pad_data_rep(ExtrudeData(rep), n) : rep +paddims_rep(rep, n) = ndims(rep) < n ? paddims_rep(ExtrudeData(rep), n) : rep """ - extrude_rep(tns, dims) + expanddims_rep(tns, dims) Expand the representation of `tns` to the dimensions `dims`, which must have length(ndims(tns)) and be in ascending order. """ -extrude_rep(tns, dims) = extrude_rep_def(tns, reverse(dims)...) -extrude_rep_def(tns) = tns -extrude_rep_def(tns::HollowData, dims...) = HollowData(extrude_rep_def(tns.lvl, dims...)) -extrude_rep_def(tns::SparseData, dim, dims...) = SparseData(pad_data_rep(extrude_rep_def(tns.lvl, dims...), dim - 1)) -extrude_rep_def(tns::RepeatData, dim, dims...) = RepeatData(pad_data_rep(extrude_rep_def(tns.lvl, dims...), dim - 1)) -extrude_rep_def(tns::DenseData, dim, dims...) = DenseData(pad_data_rep(extrude_rep_def(tns.lvl, dims...), dim - 1)) -extrude_rep_def(tns::ExtrudeData, dim, dims...) = ExtrudeData(pad_data_rep(extrude_rep_def(tns.lvl, dims...), dim - 1)) +expanddims_rep(tns, dims) = expanddims_rep_def(tns, reverse(dims)...) +expanddims_rep_def(tns) = tns +expanddims_rep_def(tns::HollowData, dims...) = HollowData(expanddims_rep_def(tns.lvl, dims...)) +expanddims_rep_def(tns::SparseData, dim, dims...) = SparseData(paddims_rep(expanddims_rep_def(tns.lvl, dims...), dim - 1)) +expanddims_rep_def(tns::RepeatData, dim, dims...) = RepeatData(paddims_rep(expanddims_rep_def(tns.lvl, dims...), dim - 1)) +expanddims_rep_def(tns::DenseData, dim, dims...) = DenseData(paddims_rep(expanddims_rep_def(tns.lvl, dims...), dim - 1)) +expanddims_rep_def(tns::ExtrudeData, dim, dims...) = ExtrudeData(paddims_rep(expanddims_rep_def(tns.lvl, dims...), dim - 1)) struct MapRepExtrudeStyle end struct MapRepSparseStyle end @@ -309,10 +309,10 @@ function permutedims_rep(tns, perm) end n += 1 end - src = extrude_rep(tns, src_dims) + src = expanddims_rep(tns, src_dims) for mask_dims in diags - mask = extrude_rep(DenseData(SparseData(ElementData(false, Bool))), mask_dims) - src = map_rep(filterop(default(src)), pad_data_rep(mask, ndims(src)), src) + mask = expanddims_rep(DenseData(SparseData(ElementData(false, Bool))), mask_dims) + src = map_rep(filterop(default(src)), paddims_rep(mask, ndims(src)), src) end aggregate_rep(initwrite(default(tns)), default(tns), src, setdiff(src_dims, dst_dims)) end diff --git a/src/scheduler/optimize.jl b/src/scheduler/optimize.jl index 4b5f5cf35..28d34c5df 100644 --- a/src/scheduler/optimize.jl +++ b/src/scheduler/optimize.jl @@ -350,7 +350,7 @@ function (ctx::SuitableRep)(ex) rep = permutedims_rep(rep, perm) dims = findall(idx -> idx in idxs, ex.idxs) #then add new dimensions - return pad_data_rep(extrude_rep(rep, dims), length(ex.idxs)) + return paddims_rep(expanddims_rep(rep, dims), length(ex.idxs)) elseif ex.kind === relabel return ctx(ex.arg) elseif ex.kind === reformat From fa7ae560dcd0c92f18da8cc60023703e3dad4245 Mon Sep 17 00:00:00 2001 From: Willow Ahrens Date: Mon, 13 May 2024 12:10:11 -0400 Subject: [PATCH 06/10] fix --- src/interface/lazy.jl | 56 +- src/scheduler/LogicCompiler.jl | 48 +- src/scheduler/LogicExecutor.jl | 7 +- src/scheduler/LogicInterpreter.jl | 50 +- test/test_interface.jl | 828 +++++++++++++++--------------- 5 files changed, 547 insertions(+), 442 deletions(-) diff --git a/src/interface/lazy.jl b/src/interface/lazy.jl index 0b2fbee52..424ffc424 100644 --- a/src/interface/lazy.jl +++ b/src/interface/lazy.jl @@ -431,19 +431,57 @@ is fused with the execution of `z + 1`. """ lazy(arg) = LazyTensor(arg) +default_scheduler(;verbose=false) = LogicExecutor(DefaultLogicOptimizer(LogicCompiler()), verbose=verbose) + """ - fused(f, args...; [optimizer=DefaultOptimizer()]) + fused(f, args...; kwargs...) This function decorator modifies `f` to fuse the contained array operations and optimize the resulting program. The function must return a single -array or tuple of arrays. The `optimizer` keyword argument specifies the -optimizer to use. +array or tuple of arrays. `kwargs` are passed to [`compute`](@ref) """ -function fused(f, args...; optimizer=DefaultOptimizer()) - compute(f(map(LazyTensor, args...)), optimizer) +function fused(f, args...; kwargs...) + compute(f(map(LazyTensor, args...)), kwargs...) end -default_scheduler(;verbose=false) = LogicExecutor(DefaultLogicOptimizer(LogicCompiler()), verbose=verbose) +current_scheduler = Ref{Any}(default_scheduler()) + +""" + set_scheduler!(scheduler) + +Set the current scheduler to `scheduler`. The scheduler is used by `compute` to +execute lazy tensor programs. +""" +set_scheduler!(scheduler) = current_scheduler[] = scheduler + +""" + get_scheduler() + +Get the current Finch scheduler used by `compute` to execute lazy tensor programs. +""" +get_scheduler() = current_scheduler[] + +""" + with_scheduler(f, scheduler) + +Execute `f` with the current scheduler set to `scheduler`. For example, +```jldoctest +with_scheduler(LogicExecutor(DefaultLogicOptimizer(LogicCompiler()), verbose=true)) do + x = lazy([1, 2]) + y = lazy([3, 4]) + compute(x + y) +end +``` +""" +function with_scheduler(f, scheduler) + old_scheduler = get_scheduler() + set_scheduler!(scheduler) + try + return f() + finally + set_scheduler!(old_scheduler) + end +end """ compute(args..., ctx=default_scheduler()) -> Any @@ -451,9 +489,9 @@ default_scheduler(;verbose=false) = LogicExecutor(DefaultLogicOptimizer(LogicCom Compute the value of a lazy tensor. The result is the argument itself, or a tuple of arguments if multiple arguments are passed. """ -compute(args...; ctx=default_scheduler(), kwargs...) = compute_parse(set_options(ctx; kwargs...), map(lazy, args)) -compute(arg; ctx=default_scheduler(), kwargs...) = compute_parse(set_options(ctx; kwargs...), (lazy(arg),))[1] -compute(args::Tuple; ctx=default_scheduler(), kwargs...) = compute_parse(set_options(ctx; kwargs...), map(lazy, args)) +compute(args...; ctx=get_scheduler(), kwargs...) = compute_parse(set_options(ctx; kwargs...), map(lazy, args)) +compute(arg; ctx=get_scheduler(), kwargs...) = compute_parse(set_options(ctx; kwargs...), (lazy(arg),))[1] +compute(args::Tuple; ctx=get_scheduler(), kwargs...) = compute_parse(set_options(ctx; kwargs...), map(lazy, args)) function compute_parse(ctx, args::Tuple) args = collect(args) vars = map(arg -> alias(gensym(:A)), args) diff --git a/src/scheduler/LogicCompiler.jl b/src/scheduler/LogicCompiler.jl index 04625794c..58f6da0c6 100644 --- a/src/scheduler/LogicCompiler.jl +++ b/src/scheduler/LogicCompiler.jl @@ -42,10 +42,22 @@ end mode = :fast end -function finch_pointwise_logic_to_code(ex) +@kwdef struct PointwiseLowerer + bound_idxs = [] +end + +function compile_pointwise_logic(ex) + ctx = PointwiseLowerer() + code = ctx(ex) + bound_idxs = ctx.bound_idxs + (code, bound_idxs) +end + +function (ctx::PointwiseLowerer)(ex) if @capture ex mapjoin(~op, ~args...) - :($(op.val)($(map(arg -> finch_pointwise_logic_to_code(arg), args)...))) + :($(op.val)($(map(ctx, args)...))) elseif (@capture ex reorder(relabel(~arg::isalias, ~idxs_1...), ~idxs_2...)) + append!(ctx.bound_idxs, idxs_1) :($(arg.name)[$(map(idx -> idx in idxs_2 ? idx.name : 1, idxs_1)...)]) #TODO need a trait for the first index elseif (@capture ex reorder(~arg::isimmediate, ~idxs...)) arg.val @@ -82,12 +94,18 @@ function (ctx::LogicLowerer)(ex) elseif @capture ex query(~lhs::isalias, reformat(~tns, reorder(relabel(~arg::isalias, ~idxs_1...), ~idxs_2...))) loop_idxs = map(idx -> idx.name, withsubsequence(intersect(idxs_1, idxs_2), idxs_2)) lhs_idxs = map(idx -> idx.name, idxs_2) - rhs = finch_pointwise_logic_to_code(reorder(relabel(arg, idxs_1...), idxs_2...)) + (rhs, rhs_idxs) = compile_pointwise_logic(reorder(relabel(arg, idxs_1...), idxs_2...)) body = :($(lhs.name)[$(lhs_idxs...)] = $rhs) for idx in loop_idxs - body = :(for $idx = _ - $body - end) + if field(idx) in rhs_idxs + body = :(for $idx = _ + $body + end) + elseif idx in lhs_idxs + body = :(for $idx = 1:1 + $body + end) + end end quote $(lhs.name) = $(compile_logic_constant(tns)) @@ -102,13 +120,19 @@ function (ctx::LogicLowerer)(ex) ctx(query(lhs, reformat(tns, aggregate(initwrite(z), immediate(z), mapjoin(args...))))) elseif @capture ex query(~lhs, reformat(~tns, aggregate(~op, ~init, ~arg, ~idxs_1...))) idxs_2 = map(idx -> idx.name, getfields(arg)) - idxs_3 = map(idx -> idx.name, setdiff(getfields(arg), idxs_1)) - rhs = finch_pointwise_logic_to_code(arg) - body = :($(lhs.name)[$(idxs_3...)] <<$(compile_logic_constant(op))>>= $rhs) + lhs_idxs = map(idx -> idx.name, setdiff(getfields(arg), idxs_1)) + (rhs, rhs_idxs) = compile_pointwise_logic(arg) + body = :($(lhs.name)[$(lhs_idxs...)] <<$(compile_logic_constant(op))>>= $rhs) for idx in idxs_2 - body = :(for $idx = _ - $body - end) + if field(idx) in rhs_idxs + body = :(for $idx = _ + $body + end) + elseif idx in lhs_idxs + body = :(for $idx = 1:1 + $body + end) + end end quote $(lhs.name) = $(compile_logic_constant(tns)) diff --git a/src/scheduler/LogicExecutor.jl b/src/scheduler/LogicExecutor.jl index d16854074..9fd6fe95a 100644 --- a/src/scheduler/LogicExecutor.jl +++ b/src/scheduler/LogicExecutor.jl @@ -68,12 +68,13 @@ end codes = Dict() function (ctx::LogicExecutor)(prgm) - f = get!(codes, get_structure(prgm)) do - eval(logic_executor_code(ctx.ctx, prgm)) + (f, code) = get!(codes, get_structure(prgm)) do + thunk = logic_executor_code(ctx.ctx, prgm) + (eval(thunk), thunk) end if ctx.verbose println("Executing:") - display(logic_executor_code(ctx.ctx, prgm)) + display(code) end return Base.invokelatest(f, prgm) end diff --git a/src/scheduler/LogicInterpreter.jl b/src/scheduler/LogicInterpreter.jl index 0adf78387..ef25c24da 100644 --- a/src/scheduler/LogicInterpreter.jl +++ b/src/scheduler/LogicInterpreter.jl @@ -1,13 +1,25 @@ using Finch.FinchNotation: block_instance, declare_instance, call_instance, loop_instance, index_instance, variable_instance, tag_instance, access_instance, assign_instance, literal_instance, yieldbind_instance -function finch_pointwise_logic_to_program(scope, ex) +@kwdef struct PointwiseMachineLowerer + ctx + bound_idxs = [] +end + +function lower_pointwise_logic(ctx, ex) + ctx = PointwiseMachineLowerer(ctx=ctx) + code = ctx(ex) + return (code, ctx.bound_idxs) +end + +function (ctx::PointwiseMachineLowerer)(ex) if @capture ex mapjoin(~op, ~args...) - call_instance(literal_instance(op.val), map(arg -> finch_pointwise_logic_to_program(scope, arg), args)...) + call_instance(literal_instance(op.val), map(ctx, args)...) elseif (@capture ex reorder(relabel(~arg::isalias, ~idxs_1...), ~idxs_2...)) + append!(ctx.bound_idxs, idxs_1) idxs_3 = map(enumerate(idxs_1)) do (n, idx) - idx in idxs_2 ? index_instance(idx.name) : first(axes(arg)[n]) + idx in idxs_2 ? index_instance(idx.name) : first(axes(ctx.ctx.scope[arg])[n]) end - access_instance(tag_instance(variable_instance(arg.name), scope[arg]), literal_instance(reader), idxs_3...) + access_instance(tag_instance(variable_instance(arg.name), ctx.ctx.scope[arg]), literal_instance(reader), idxs_3...) elseif (@capture ex reorder(~arg::isimmediate, ~idxs...)) literal_instance(arg.val) elseif ex.kind === immediate @@ -32,14 +44,18 @@ function (ctx::LogicMachine)(ex) elseif @capture ex table(~tns, ~idxs...) return tns.val elseif @capture ex reformat(~tns, reorder(relabel(~arg::isalias, ~idxs_1...), ~idxs_2...)) - loop_idxs = map(idx -> index_instance(idx.name), withsubsequence(intersect(idxs_1, idxs_2), idxs_2)) - lhs_idxs = map(idx -> index_instance(idx.name), idxs_2) + loop_idxs = withsubsequence(intersect(idxs_1, idxs_2), idxs_2) + lhs_idxs = idxs_2 res = tag_instance(variable_instance(:res), tns.val) - lhs = access_instance(res, literal_instance(updater), lhs_idxs...) - rhs = finch_pointwise_logic_to_program(ctx.scope, reorder(relabel(arg, idxs_1...), idxs_2...)) + lhs = access_instance(res, literal_instance(updater), map(idx -> index_instance(idx.name), lhs_idxs)...) + (rhs, rhs_idxs) = lower_pointwise_logic(ctx, reorder(relabel(arg, idxs_1...), idxs_2...)) body = assign_instance(lhs, literal_instance(initwrite(default(tns.val))), rhs) for idx in loop_idxs - body = loop_instance(idx, dimless, body) + if idx in rhs_idxs + body = loop_instance(index_instance(idx.name), dimless, body) + elseif idx in lhs_idxs + body = loop_instance(index_instance(idx.name), call_instance(literal_instance(extent), literal_instance(1), literal_instance(1)), body) + end end body = block_instance(declare_instance(res, literal_instance(default(tns.val))), body, yieldbind_instance(res)) if ctx.verbose @@ -51,14 +67,18 @@ function (ctx::LogicMachine)(ex) z = default(tns.val) ctx(reformat(tns, aggregate(initwrite(z), immediate(z), mapjoin(args...)))) elseif @capture ex reformat(~tns, aggregate(~op, ~init, ~arg, ~idxs_1...)) - idxs_2 = map(idx -> index_instance(idx.name), getfields(arg)) - idxs_3 = map(idx -> index_instance(idx.name), setdiff(getfields(arg), idxs_1)) + loop_idxs = getfields(arg) + lhs_idxs = setdiff(getfields(arg), idxs_1) res = tag_instance(variable_instance(:res), tns.val) - lhs = access_instance(res, literal_instance(updater), idxs_3...) - rhs = finch_pointwise_logic_to_program(ctx.scope, arg) + lhs = access_instance(res, literal_instance(updater), map(idx -> index_instance(idx.name), lhs_idxs)...) + (rhs, rhs_idxs) = lower_pointwise_logic(ctx, arg) body = assign_instance(lhs, literal_instance(op.val), rhs) - for idx in idxs_2 - body = loop_instance(idx, dimless, body) + for idx in loop_idxs + if idx in rhs_idxs + body = loop_instance(index_instance(idx.name), dimless, body) + elseif idx in lhs_idxs + body = loop_instance(index_instance(idx.name), call_instance(literal_instance(extent), literal_instance(1), literal_instance(1)), body) + end end body = block_instance(declare_instance(res, literal_instance(default(tns.val))), body, yieldbind_instance(res)) if ctx.verbose diff --git a/test/test_interface.jl b/test/test_interface.jl index a0af20399..f1113f108 100644 --- a/test/test_interface.jl +++ b/test/test_interface.jl @@ -1,236 +1,482 @@ using Finch: AsArray @testset "interface" begin + @info "Testing Finch Interface" - #https://github.com/willow-ahrens/Finch.jl/issues/535 - let - LEN = 10; - a_raw = rand(LEN, LEN - 5) * 10; - b_raw = rand(LEN, LEN - 5) * 10; - c_raw = rand(LEN, LEN) * 10; - - a = lazy(swizzle(Tensor(a_raw), 1, 2)); - b = lazy(swizzle(Tensor(b_raw), 1, 2)); - c = lazy(swizzle(Tensor(c_raw), 1, 2)); + for scheduler in [Finch.default_scheduler(), Finch.DefaultLogicOptimizer(Finch.LogicInterpreter())] + Finch.with_scheduler(scheduler) do + @info "Testing $scheduler" - ref = reshape(c_raw, 10, 10, 1) .* reshape(a_raw, 10, 1, 5) .* reshape(b_raw, 1, 10, 5); + #https://github.com/willow-ahrens/Finch.jl/issues/554 + let + @test broadcast(trunc, swizzle(Tensor(ones(1)), 1)) == Tensor(ones(1)) - plan = c[:, :, nothing] .* a[:, nothing, :] .* b[nothing, :, :]; - @test compute(plan) == ref + @test broadcast(trunc, swizzle(Tensor(ones(2)), 1)) == Tensor(ones(2)) + end - plan = broadcast(*, broadcast(*, c[:, :, nothing], a[:, nothing, :]), b[nothing, :, :]); - @test compute(plan) == ref - end + #https://github.com/willow-ahrens/Finch.jl/issues/533 + let + A = lazy(fsprand(1, 1,0.5)) + compute(sum(A .+ A)) #fails + end - #https://github.com/willow-ahrens/Finch.jl/issues/536 - let - A = [1 2; 3 4] - swizzle(lazy(A), 2, 1) == permutedims(A) - end + #https://github.com/willow-ahrens/Finch.jl/issues/535 + let + LEN = 10; + a_raw = rand(LEN, LEN - 5) * 10; + b_raw = rand(LEN, LEN - 5) * 10; + c_raw = rand(LEN, LEN) * 10; + + a = lazy(swizzle(Tensor(a_raw), 1, 2)); + b = lazy(swizzle(Tensor(b_raw), 1, 2)); + c = lazy(swizzle(Tensor(c_raw), 1, 2)); - #https://github.com/willow-ahrens/Finch.jl/issues/530 - let - A_tns = Tensor(Dense(Dense(Dense(Element(0.0)))), zeros(3, 3, 3)) - A_sw = swizzle(A_tns, 2, 3, 1) - A_tns == A_sw #fails - end + ref = reshape(c_raw, 10, 10, 1) .* reshape(a_raw, 10, 1, 5) .* reshape(b_raw, 1, 10, 5); - #https://github.com/willow-ahrens/Finch.jl/issues/524 - let - arr3d = rand(Int, 3, 2, 3) .% 10 - tns = Tensor(Dense(Dense(Dense(Element(0)))), arr3d) - - tns_l = lazy(tns) - reduced = sum(tns_l, dims=(1, 2)) - - plan = broadcast(+, tns_l, reduced) - result = compute(plan) - end + plan = c[:, :, nothing] .* a[:, nothing, :] .* b[nothing, :, :]; + @test compute(plan) == ref - #https://github.com/willow-ahrens/Finch.jl/issues/527 - let - tns_1 = swizzle(Tensor(ones(10, 10)), 1, 2) - tns_1[:, :] # == tns_1 https://github.com/willow-ahrens/Finch.jl/issues/530 + plan = broadcast(*, broadcast(*, c[:, :, nothing], a[:, nothing, :]), b[nothing, :, :]); + @test compute(plan) == ref + end - tns_2 = swizzle(Tensor(ones(10)), 1) - tns_2[:]# == tns_2 https://github.com/willow-ahrens/Finch.jl/issues/530 - end + #https://github.com/willow-ahrens/Finch.jl/issues/536 + let + A = [1 2; 3 4] + swizzle(lazy(A), 2, 1) == permutedims(A) + end - #https://github.com/willow-ahrens/Finch.jl/issues/528 - let - tns = swizzle(Tensor(ones(10, 10)), 1, 2) - @test tns[:, :] == ones(10, 10) - @test tns[nothing, :, :] == ones(1, 10, 10) - @test tns[:, nothing, :] == ones(10, 1, 10) - @test tns[:, :, nothing] == ones(10, 10, 1) - end + #https://github.com/willow-ahrens/Finch.jl/issues/530 + let + A_tns = Tensor(Dense(Dense(Dense(Element(0.0)))), zeros(3, 3, 3)) + A_sw = swizzle(A_tns, 2, 3, 1) + A_tns == A_sw #fails + end - #https://github.com/willow-ahrens/Finch.jl/issues/428 - let - @testset "Verbose" begin - a = [1 2; 3 4] - b = [5 6; 7 8] - a_l = lazy(a) - b_l = lazy(b) + #https://github.com/willow-ahrens/Finch.jl/issues/524 + let + arr3d = rand(Int, 3, 2, 3) .% 10 + tns = Tensor(Dense(Dense(Dense(Element(0)))), arr3d) + + tns_l = lazy(tns) + reduced = sum(tns_l, dims=(1, 2)) + + plan = broadcast(+, tns_l, reduced) + result = compute(plan) + end - c = permutedims(broadcast(.+, permutedims(a_l, (2, 1)), permutedims(b_l, (2, 1))), (2, 1)) - compute(c, verbose=true) - end - end + #https://github.com/willow-ahrens/Finch.jl/issues/527 + let + tns_1 = swizzle(Tensor(ones(10, 10)), 1, 2) + tns_1[:, :] # == tns_1 https://github.com/willow-ahrens/Finch.jl/issues/530 - let - - @testset "Einsum Tests" begin - # Test 0 - A = [1 2; 3 4] - B = [5 6; 7 8] - s = Scalar(0) - @einsum s[] += abs(A[i, k] * B[k, j]) - @test s[] == 134 - - # Test 1 - A = [1 2; 3 4] - B = [5 6; 7 8] - @einsum C[i, j] += A[i, k] * B[k, j] - @test C == [19 22; 43 50] - - # Test 2 - A = Tensor(Dense(SparseList(Element(0))), fsprand(Int, 3, 5, 0.5)) - B = Tensor(Dense(SparseList(Element(0))), fsprand(Int, 5, 3, 0.5)) - @einsum C[i, j, k] += A[i, j] * B[j, k] - - C_ref = zeros(Int, 3, 5, 3) - for i = 1:3, j = 1:5, k = 1:3 - C_ref[i, j, k] += A[i, j] * B[j, k] + tns_2 = swizzle(Tensor(ones(10)), 1) + tns_2[:]# == tns_2 https://github.com/willow-ahrens/Finch.jl/issues/530 end - @test C == C_ref - # Test 3 - X = Tensor(Dense(SparseList(Element(0))), fsprand(Int, 4, 6, 0.5)) - Y = Tensor(Dense(SparseList(Element(0))), fsprand(Int, 6, 4, 0.5)) - @einsum D[i, k] += X[i, j] * Y[j, k] + #https://github.com/willow-ahrens/Finch.jl/issues/528 + let + tns = swizzle(Tensor(ones(10, 10)), 1, 2) + @test tns[:, :] == ones(10, 10) + @test tns[nothing, :, :] == ones(1, 10, 10) + @test tns[:, nothing, :] == ones(10, 1, 10) + @test tns[:, :, nothing] == ones(10, 10, 1) + end + + #https://github.com/willow-ahrens/Finch.jl/issues/428 + let + @testset "Verbose" begin + a = [1 2; 3 4] + b = [5 6; 7 8] + a_l = lazy(a) + b_l = lazy(b) + + c = permutedims(broadcast(.+, permutedims(a_l, (2, 1)), permutedims(b_l, (2, 1))), (2, 1)) + compute(c, verbose=true) + end + end - D_ref = zeros(Int, 4, 4) - for i = 1:4, j = 1:6, k = 1:4 - D_ref[i, k] += X[i, j] * Y[j, k] + let + + @testset "Einsum Tests" begin + # Test 0 + A = [1 2; 3 4] + B = [5 6; 7 8] + s = Scalar(0) + @einsum s[] += abs(A[i, k] * B[k, j]) + @test s[] == 134 + + # Test 1 + A = [1 2; 3 4] + B = [5 6; 7 8] + @einsum C[i, j] += A[i, k] * B[k, j] + @test C == [19 22; 43 50] + + # Test 2 + A = Tensor(Dense(SparseList(Element(0))), fsprand(Int, 3, 5, 0.5)) + B = Tensor(Dense(SparseList(Element(0))), fsprand(Int, 5, 3, 0.5)) + @einsum C[i, j, k] += A[i, j] * B[j, k] + + C_ref = zeros(Int, 3, 5, 3) + for i = 1:3, j = 1:5, k = 1:3 + C_ref[i, j, k] += A[i, j] * B[j, k] + end + @test C == C_ref + + # Test 3 + X = Tensor(Dense(SparseList(Element(0))), fsprand(Int, 4, 6, 0.5)) + Y = Tensor(Dense(SparseList(Element(0))), fsprand(Int, 6, 4, 0.5)) + @einsum D[i, k] += X[i, j] * Y[j, k] + + D_ref = zeros(Int, 4, 4) + for i = 1:4, j = 1:6, k = 1:4 + D_ref[i, k] += X[i, j] * Y[j, k] + end + @test D == D_ref + + # Test 4 + H = Tensor(Dense(SparseList(Element(0))), fsprand(Int, 5, 5, 0.6)) + I = Tensor(Dense(SparseList(Element(0))), fsprand(Int, 5, 5, 0.6)) + @einsum J[i, j] = H[i, j] * I[i, j] + + J_ref = zeros(Int, 5, 5) + for i = 1:5, j = 1:5 + J_ref[i, j] = H[i, j] * I[i, j] + end + @test J == J_ref + + # Test 5 + K = Tensor(Dense(SparseList(Element(0))), fsprand(Int, 4, 4, 0.7)) + L = Tensor(Dense(SparseList(Element(0))), fsprand(Int, 4, 4, 0.7)) + M = Tensor(Dense(SparseList(Element(0))), fsprand(Int, 4, 4, 0.7)) + @einsum N[i, j] += K[i, k] * L[k, j] - M[i, j] + + N_ref = zeros(Int, 4, 4) + for i = 1:4, k = 1:4, j = 1:4 + N_ref[i, j] += K[i, k] * L[k, j] - M[i, j] + end + @test N == N_ref + + # Test 6 + P = Tensor(Dense(SparseList(Element(-Inf))), fsprand(Int, 3, 3, 0.7)) # Adjacency matrix with probabilities + Q = Tensor(Dense(SparseList(Element(-Inf))), fsprand(Int, 3, 3, 0.7)) + @einsum init=-Inf R[i, j] <>= P[i, k] + Q[k, j] # Max-plus product + + R_ref = fill(-Inf, 3, 3) + for i = 1:3, j = 1:3 + for k = 1:3 + R_ref[i, j] = max(R_ref[i, j], P[i, k] + Q[k, j]) + end + end + @test R == R_ref + + # Test for Sparse Matrix-Vector Multiplication (SpMV) + # Define a sparse matrix `S` and a dense vector `v` + S = Tensor(Dense(SparseList(Element(0))), sprand(Int, 10, 10, 0.3)) # 10x10 sparse matrix with 30% density + v = Tensor(Dense(Element(0)), rand(Int, 10)) # Dense vector of size 10 + + # Perform matrix-vector multiplication using the @einsum macro + @einsum w[i] += S[i, k] * v[k] # Compute the product + + # Reference calculation using explicit loop for validation + w_ref = zeros(Int, 10) + for i = 1:10 + for k = 1:10 + w_ref[i] += S[i, k] * v[k] + end + end + + # Test to ensure the results match + @test w == w_ref + + # Test for Transposed Sparse Matrix-Vector Multiplication (SpMV) + # Define a sparse matrix `T` and a dense vector `u` + T = Tensor(Dense(SparseList(Element(0))), sprand(Int, 10, 10, 0.3)) # 10x10 sparse matrix with 30% density + u = Tensor(Dense(Element(0)), rand(Int, 10)) # Dense vector of size 10 + + # Perform transposed matrix-vector multiplication using the @einsum macro + @einsum x[k] += T[j, k] * u[j] # Compute the product using the transpose of T + + # Reference calculation using explicit loop for validation + x_ref = zeros(Int, 10) + for k = 1:10 + for j = 1:10 + x_ref[k] += T[j, k] * u[j] + end + end + + # Test to ensure the results match + @test x == x_ref + + # Test for Outer Product with Output Named A + # Define two vectors for outer product + v1 = Tensor(Dense(Element(0)), rand(Int, 5)) # Dense vector of size 5 + v2 = Tensor(Dense(Element(0)), rand(Int, 7)) # Dense vector of size 7 + + # Perform outer product using the @einsum macro + @einsum A[i, j] = v1[i] * v2[j] # Compute the outer product + + # Reference calculation using explicit loop for validation + A_ref = zeros(Int, 5, 7) + for i = 1:5 + for j = 1:7 + A_ref[i, j] = v1[i] * v2[j] + end + end + + # Test to ensure the results match + @test A == A_ref + + + # Test for multiplying a vector by a Scalar + v = Tensor(Dense(Element(0)), rand(Int, 5)) + n = 7 + + #Perform scalar multiplcation + @einsum A[i] = n*v[i] + + # Reference Calculation using explicit loop for validation + A_ref = Tensor(Dense(Element(0)), rand(Int, 5)) + for i = 1:5 + A_ref[i] = v[i]*n + end + + #Test to ensure the results match + @test A == A_ref + + + end end - @test D == D_ref - # Test 4 - H = Tensor(Dense(SparseList(Element(0))), fsprand(Int, 5, 5, 0.6)) - I = Tensor(Dense(SparseList(Element(0))), fsprand(Int, 5, 5, 0.6)) - @einsum J[i, j] = H[i, j] * I[i, j] + A = Tensor(SparseList(Element(0.0)), fsparse([1, 3, 5, 7, 9], [2.0, 3.0, 4.0, 5.0, 6.0], (10,))) + B = Tensor(SparseList(Element(0.0)), A) + @test A == B + + A = [0.0 0.0 0.0 0.0; 1.0 0.0 0.0 1.0] + B = Tensor(Dense(SparseList(Element(0.0))), A) + C = Tensor(Dense(Dense(Element(0.0))), A) + @test A == B + + A = [0 0; 0 0] + B = Tensor(Dense(Dense(Element(0.0))), A) + @test A == B + + A = Tensor(Dense(Element(0.0)), [0, 0, 0, 0]) + B = Tensor(Dense(Element(0.0)), [0, 0, 0, 0, 0]) + @test size(A) != size(B) && A != B + + A = [0 0 0 0 1 0 0 1] + B = Tensor(Dense(SparseList(Element(0))), [0 0 0 0; 1 0 0 1]) + @test size(A) != size(B) && A != B + + A = Tensor(Dense(SparseList(Element(0.0))), [1 0 0 0; 1 1 0 0; 1 1 1 0]) + B = [0 0 0 0; 1 1 0 0; 1 1 1 0] + @test size(A) == size(B) && A != B + C = Tensor(Dense(SparseList(Element(0.0))), [0 0 0 0; 1 1 0 0; 1 1 1 0]) + @test B == C + + A = [NaN, 0.0, 3.14, 0.0] + B = Tensor(SparseList(Element(0.0)), [NaN, 0.0, 3.14, 0.0]) + C = Tensor(SparseList(Element(0.0)), [NaN, 0.0, 3.14, 0.0]) + D = [1.0, 2.0, 4.0, 8.0] + @test isequal(A, B) + @test isequal(A, C) + @test isequal(B, C) + @test isequal(B, A) + @test !isequal(A, D) + @test A != B + + let + io = IOBuffer() + println(io, "getindex tests") + + A = Tensor(SparseList(Dense(SparseList(Element{0.0}(collect(1:30).* 1.01), 5, [1, 3, 6, 8, 12, 14, 17, 20, 24, 27, 27, 28, 31], [2, 3, 3, 4, 5, 2, 3, 1, 3, 4, 5, 2, 4, 2, 4, 5, 2, 3, 5, 1, 3, 4, 5, 2, 3, 4, 2, 1, 2, 3]), 3), 4, [1, 5], [1, 2, 3, 4])) + + print(io, "A = ") + show(io, MIME("text/plain"), A) + println(io) + + for inds in [(1, 2, 3), (1, 1, 1), (1, :, 3), (:, 1, 3), (:, :, 3), (:, :, :)] + print(io, "A["); join(io, inds, ","); print(io, "] = ") + show(io, MIME("text/plain"), A[inds...]) + println(io) + end - J_ref = zeros(Int, 5, 5) - for i = 1:5, j = 1:5 - J_ref[i, j] = H[i, j] * I[i, j] + @test check_output("interface/getindex.txt", String(take!(io))) end - @test J == J_ref - # Test 5 - K = Tensor(Dense(SparseList(Element(0))), fsprand(Int, 4, 4, 0.7)) - L = Tensor(Dense(SparseList(Element(0))), fsprand(Int, 4, 4, 0.7)) - M = Tensor(Dense(SparseList(Element(0))), fsprand(Int, 4, 4, 0.7)) - @einsum N[i, j] += K[i, k] * L[k, j] - M[i, j] + let + io = IOBuffer() + println(io, "setindex! tests") + + @repl io A = Tensor(Dense(Dense(Element(0.0))), 10, 12) + @repl io A[1, 4] = 3 + @repl io AsArray(A) + @repl io A[4:6, 6] = 5:7 + @repl io AsArray(A) + @repl io A[9, :] = 1:12 + @repl io AsArray(A) - N_ref = zeros(Int, 4, 4) - for i = 1:4, k = 1:4, j = 1:4 - N_ref[i, j] += K[i, k] * L[k, j] - M[i, j] + @test check_output("interface/setindex.txt", String(take!(io))) end - @test N == N_ref - # Test 6 - P = Tensor(Dense(SparseList(Element(-Inf))), fsprand(Int, 3, 3, 0.7)) # Adjacency matrix with probabilities - Q = Tensor(Dense(SparseList(Element(-Inf))), fsprand(Int, 3, 3, 0.7)) - @einsum init=-Inf R[i, j] <>= P[i, k] + Q[k, j] # Max-plus product + let + io = IOBuffer() + println(io, "broadcast tests") - R_ref = fill(-Inf, 3, 3) - for i = 1:3, j = 1:3 - for k = 1:3 - R_ref[i, j] = max(R_ref[i, j], P[i, k] + Q[k, j]) - end + @repl io A = Tensor(Dense(SparseList(Element(0.0))), [0.0 0.0 4.4; 1.1 0.0 0.0; 2.2 0.0 5.5; 3.3 0.0 0.0]) + @repl io B = [1, 2, 3, 4] + @repl io C = A .+ B true + @repl io AsArray(C) + @repl io D = A .* B true + @repl io AsArray(D) + @repl io E = ifelse.(A .== 0, 1, 2) + @repl io AsArray(E) + + @test check_output("interface/broadcast.txt", String(take!(io))) end - @test R == R_ref - # Test for Sparse Matrix-Vector Multiplication (SpMV) - # Define a sparse matrix `S` and a dense vector `v` - S = Tensor(Dense(SparseList(Element(0))), sprand(Int, 10, 10, 0.3)) # 10x10 sparse matrix with 30% density - v = Tensor(Dense(Element(0)), rand(Int, 10)) # Dense vector of size 10 + let + io = IOBuffer() + println(io, "reduce tests") - # Perform matrix-vector multiplication using the @einsum macro - @einsum w[i] += S[i, k] * v[k] # Compute the product + @repl io A = Tensor(Dense(SparseList(Element(0.0))), [0.0 0.0 4.4; 1.1 0.0 0.0; 2.2 0.0 5.5; 3.3 0.0 0.0]) + @repl io reduce(+, A, dims=(1,)) + @repl io reduce(+, A, dims=1) + @repl io reduce(+, A, dims=(2,)) + @repl io reduce(+, A, dims=2) + @repl io reduce(+, A, dims=(1,2)) + @repl io reduce(+, A, dims=:) - # Reference calculation using explicit loop for validation - w_ref = zeros(Int, 10) - for i = 1:10 - for k = 1:10 - w_ref[i] += S[i, k] * v[k] - end + @test check_output("interface/reduce.txt", String(take!(io))) end - # Test to ensure the results match - @test w == w_ref + let + io = IOBuffer() + println(io, "countstored tests") - # Test for Transposed Sparse Matrix-Vector Multiplication (SpMV) - # Define a sparse matrix `T` and a dense vector `u` - T = Tensor(Dense(SparseList(Element(0))), sprand(Int, 10, 10, 0.3)) # 10x10 sparse matrix with 30% density - u = Tensor(Dense(Element(0)), rand(Int, 10)) # Dense vector of size 10 + @repl io A = Tensor(Dense(SparseList(Element(0.0))), [0.0 0.0 4.4; 1.1 0.0 0.0; 2.2 0.0 5.5; 3.3 0.0 0.0]) + @repl io countstored(A) + @repl io A = Tensor(SparseCOO{2}(Element(0.0)), [0.0 0.0 4.4; 1.1 0.0 0.0; 2.2 0.0 5.5; 3.3 0.0 0.0]) + @repl io countstored(A) + @repl io A = Tensor(Dense(Dense(Element(0.0))), [0.0 0.0 4.4; 1.1 0.0 0.0; 2.2 0.0 5.5; 3.3 0.0 0.0]) + @repl io countstored(A) + @repl io A = Tensor(SparseList(Dense(Element(0.0))), [0.0 0.0 4.4; 1.1 0.0 0.0; 2.2 0.0 5.5; 3.3 0.0 0.0]) + @repl io countstored(A) - # Perform transposed matrix-vector multiplication using the @einsum macro - @einsum x[k] += T[j, k] * u[j] # Compute the product using the transpose of T + @test check_output("interface/countstored.txt", String(take!(io))) + end - # Reference calculation using explicit loop for validation - x_ref = zeros(Int, 10) - for k = 1:10 - for j = 1:10 - x_ref[k] += T[j, k] * u[j] - end + let + io = IOBuffer() + println(io, "+,-, *, / tests") + + @repl io A = Tensor(Dense(SparseList(Element(0.0))), [0.0 0.0 4.4; 1.1 0.0 0.0; 2.2 0.0 5.5; 3.3 0.0 0.0]) + @repl io A + 1 + @repl io 1 + A + @repl io A + A + @repl io 2 * A + @repl io A * 3 + @repl io A / 3 + @repl io 3 / A + + @test check_output("interface/asmd.txt", String(take!(io))) end - # Test to ensure the results match - @test x == x_ref + let + A_ref = [0.0 0.0 4.4; 1.1 0.0 0.0; 2.2 0.0 5.5; 3.3 0.0 0.0] + A_ref = A_ref * floatmax()/sum(A_ref) + A= Tensor(Dense(SparseList(Element(0.0))), A_ref) + @test sum(A) == sum(A_ref) + @test minimum(A) == minimum(A_ref) + @test maximum(A) == maximum(A_ref) + @test extrema(A) == extrema(A_ref) + @test norm(A) == norm(A_ref) + @test norm(A, -Inf) == norm(A_ref, -Inf) + @test norm(A, 0) == norm(A_ref, 0) + @test norm(A, 1) == norm(A_ref, 1) + @test norm(A, 1.5) == norm(A_ref, 1.5) + @test norm(A, Inf) == norm(A_ref, Inf) + end - # Test for Outer Product with Output Named A - # Define two vectors for outer product - v1 = Tensor(Dense(Element(0)), rand(Int, 5)) # Dense vector of size 5 - v2 = Tensor(Dense(Element(0)), rand(Int, 7)) # Dense vector of size 7 + let + A = Tensor(Dense(SparseList(Element(0.0))), [0.0 0.0 4.4; 1.1 0.0 0.0; 2.2 0.0 5.5; 3.3 0.0 0.0]) + B = Tensor(Dense(SparseList(Element(0.0))), [0.0 0.0 4.4; 1.1 0.0 0.0; 2.2 0.0 5.5; 3.3 0.0 0.0]) + C = lazy(A) + D = lazy(B) + E = (C + D) * 0.5 + F = compute(E) + @test F == A + end - # Perform outer product using the @einsum macro - @einsum A[i, j] = v1[i] * v2[j] # Compute the outer product + let + A = Tensor(Dense(SparseList(Element(0))), [0 0 44; 11 0 0; 22 00 55; 33 0 0]) + B = Tensor(Dense(SparseList(Element(0))), [0 0 44; 11 0 0; 22 00 55; 33 0 0]) + c_correct = Tensor(Dense(Dense(Element(0))), [1936 0 2420 0; 0 121 242 363; 2420 242 3509 726; 0 363 726 1089]) + c = compute(tensordot(lazy(A), lazy(B), ((2, ), (2,)), init=0)) + @test c == c_correct + end - # Reference calculation using explicit loop for validation - A_ref = zeros(Int, 5, 7) - for i = 1:5 - for j = 1:7 - A_ref[i, j] = v1[i] * v2[j] - end + let + A = lazy(Tensor(Dense(SparseList(Element(0))), [0 0 44; 11 0 0; 22 00 55; 33 0 0])) + B = lazy(Tensor(Dense(SparseList(Element(0))), [0 0 44; 11 0 0; 22 00 55; 33 0 0]')) + c_correct = Tensor(Dense(Dense(Element(0))), [1936 0 2420 0; 0 121 242 363; 2420 242 3509 726; 0 363 726 1089]) + c = compute(sum(A[:, :, nothing] .* B[nothing, :, :], dims=[2])) + @test c == c_correct end - # Test to ensure the results match - @test A == A_ref + #https://github.com/willow-ahrens/Finch.jl/issues/457 + let + A = zeros(2, 3, 3) + A[1, :, :] = [1 2 3; 4 5 6; 7 8 9] + A[2, :, :] = [1 1 1; 2 2 2; 3 3 3] + perm = (2, 3, 1) + A_t = permutedims(A, perm) + + A_tns = Tensor(Dense(Dense(Dense(Element(0.0)))), A) + A_sw = swizzle(A_tns, perm...) + A_lazy = lazy(A_sw) + A_result = compute(A_lazy) - # Test for multiplying a vector by a Scalar - v = Tensor(Dense(Element(0)), rand(Int, 5)) - n = 7 + @test Array(A_result) == A_t + @test permutedims(A_tns, perm) == A_t + end + + #https://github.com/willow-ahrens/Finch.jl/pull/477 + let + A = zeros(2, 3, 3) + A_tns = Tensor(Dense(Dense(Dense(Element(0.0)))), A) - #Perform scalar multiplcation - @einsum A[i] = n*v[i] + @test compute(A) == A #If the scheduler improves, we can change this to === + @test compute(A_tns) == A_tns #If the scheduler improves, we can change this to === + end - # Reference Calculation using explicit loop for validation - A_ref = Tensor(Dense(Element(0)), rand(Int, 5)) - for i = 1:5 - A_ref[i] = v[i]*n + #https://github.com/willow-ahrens/Finch.jl/issues/481 + let + r = fsprand(1, 10, 10, 0.01) + r_tns = Tensor(Dense(Dense(Dense(Element(0.0)))), r) + @test r_tns + r_tns == 2 * r_tns end - #Test to ensure the results match - @test A == A_ref + #https://github.com/willow-ahrens/Finch.jl/issues/487 + let + a = fsprand(10, 1, 0.8) + b = fsprand(10, 1, 0.8) - + permutedims(broadcast(+, permutedims(a, (2, 1)), permutedims(b, (2, 1))), (2, 1)) # passes + + a_l = lazy(a) + b_l = lazy(b) + + plan = permutedims(broadcast(+, permutedims(a_l, (2, 1)), permutedims(b_l, (2, 1))), (2, 1)) + compute(plan) # fails + end end end + @info "Testing Finch Interface (FinchLogic)" @testset "concordize" begin using Finch.FinchLogic A = alias(:A) @@ -420,7 +666,7 @@ using Finch: AsArray mapjoin(*, reorder(relabel(table(A, i, j, k), j, i, k), k, j, i))) expr_out = mapjoin(+, - mapjoin(*, + mapjoin(*, reorder(table(A, j, k, i), j, i, k), reorder(table(A, k, i, j), j, i, k)), mapjoin(*, @@ -439,228 +685,4 @@ using Finch: AsArray =# end - A = Tensor(SparseList(Element(0.0)), fsparse([1, 3, 5, 7, 9], [2.0, 3.0, 4.0, 5.0, 6.0], (10,))) - B = Tensor(SparseList(Element(0.0)), A) - @test A == B - - A = [0.0 0.0 0.0 0.0; 1.0 0.0 0.0 1.0] - B = Tensor(Dense(SparseList(Element(0.0))), A) - C = Tensor(Dense(Dense(Element(0.0))), A) - @test A == B - - A = [0 0; 0 0] - B = Tensor(Dense(Dense(Element(0.0))), A) - @test A == B - - A = Tensor(Dense(Element(0.0)), [0, 0, 0, 0]) - B = Tensor(Dense(Element(0.0)), [0, 0, 0, 0, 0]) - @test size(A) != size(B) && A != B - - A = [0 0 0 0 1 0 0 1] - B = Tensor(Dense(SparseList(Element(0))), [0 0 0 0; 1 0 0 1]) - @test size(A) != size(B) && A != B - - A = Tensor(Dense(SparseList(Element(0.0))), [1 0 0 0; 1 1 0 0; 1 1 1 0]) - B = [0 0 0 0; 1 1 0 0; 1 1 1 0] - @test size(A) == size(B) && A != B - C = Tensor(Dense(SparseList(Element(0.0))), [0 0 0 0; 1 1 0 0; 1 1 1 0]) - @test B == C - - A = [NaN, 0.0, 3.14, 0.0] - B = Tensor(SparseList(Element(0.0)), [NaN, 0.0, 3.14, 0.0]) - C = Tensor(SparseList(Element(0.0)), [NaN, 0.0, 3.14, 0.0]) - D = [1.0, 2.0, 4.0, 8.0] - @test isequal(A, B) - @test isequal(A, C) - @test isequal(B, C) - @test isequal(B, A) - @test !isequal(A, D) - @test A != B - - let - io = IOBuffer() - println(io, "getindex tests") - - A = Tensor(SparseList(Dense(SparseList(Element{0.0}(collect(1:30).* 1.01), 5, [1, 3, 6, 8, 12, 14, 17, 20, 24, 27, 27, 28, 31], [2, 3, 3, 4, 5, 2, 3, 1, 3, 4, 5, 2, 4, 2, 4, 5, 2, 3, 5, 1, 3, 4, 5, 2, 3, 4, 2, 1, 2, 3]), 3), 4, [1, 5], [1, 2, 3, 4])) - - print(io, "A = ") - show(io, MIME("text/plain"), A) - println(io) - - for inds in [(1, 2, 3), (1, 1, 1), (1, :, 3), (:, 1, 3), (:, :, 3), (:, :, :)] - print(io, "A["); join(io, inds, ","); print(io, "] = ") - show(io, MIME("text/plain"), A[inds...]) - println(io) - end - - @test check_output("interface/getindex.txt", String(take!(io))) - end - - let - io = IOBuffer() - println(io, "setindex! tests") - - @repl io A = Tensor(Dense(Dense(Element(0.0))), 10, 12) - @repl io A[1, 4] = 3 - @repl io AsArray(A) - @repl io A[4:6, 6] = 5:7 - @repl io AsArray(A) - @repl io A[9, :] = 1:12 - @repl io AsArray(A) - - @test check_output("interface/setindex.txt", String(take!(io))) - end - - let - io = IOBuffer() - println(io, "broadcast tests") - - @repl io A = Tensor(Dense(SparseList(Element(0.0))), [0.0 0.0 4.4; 1.1 0.0 0.0; 2.2 0.0 5.5; 3.3 0.0 0.0]) - @repl io B = [1, 2, 3, 4] - @repl io C = A .+ B true - @repl io AsArray(C) - @repl io D = A .* B true - @repl io AsArray(D) - @repl io E = ifelse.(A .== 0, 1, 2) - @repl io AsArray(E) - - @test check_output("interface/broadcast.txt", String(take!(io))) - end - - let - io = IOBuffer() - println(io, "reduce tests") - - @repl io A = Tensor(Dense(SparseList(Element(0.0))), [0.0 0.0 4.4; 1.1 0.0 0.0; 2.2 0.0 5.5; 3.3 0.0 0.0]) - @repl io reduce(+, A, dims=(1,)) - @repl io reduce(+, A, dims=1) - @repl io reduce(+, A, dims=(2,)) - @repl io reduce(+, A, dims=2) - @repl io reduce(+, A, dims=(1,2)) - @repl io reduce(+, A, dims=:) - - @test check_output("interface/reduce.txt", String(take!(io))) - end - - let - io = IOBuffer() - println(io, "countstored tests") - - @repl io A = Tensor(Dense(SparseList(Element(0.0))), [0.0 0.0 4.4; 1.1 0.0 0.0; 2.2 0.0 5.5; 3.3 0.0 0.0]) - @repl io countstored(A) - @repl io A = Tensor(SparseCOO{2}(Element(0.0)), [0.0 0.0 4.4; 1.1 0.0 0.0; 2.2 0.0 5.5; 3.3 0.0 0.0]) - @repl io countstored(A) - @repl io A = Tensor(Dense(Dense(Element(0.0))), [0.0 0.0 4.4; 1.1 0.0 0.0; 2.2 0.0 5.5; 3.3 0.0 0.0]) - @repl io countstored(A) - @repl io A = Tensor(SparseList(Dense(Element(0.0))), [0.0 0.0 4.4; 1.1 0.0 0.0; 2.2 0.0 5.5; 3.3 0.0 0.0]) - @repl io countstored(A) - - @test check_output("interface/countstored.txt", String(take!(io))) - end - - let - io = IOBuffer() - println(io, "+,-, *, / tests") - - @repl io A = Tensor(Dense(SparseList(Element(0.0))), [0.0 0.0 4.4; 1.1 0.0 0.0; 2.2 0.0 5.5; 3.3 0.0 0.0]) - @repl io A + 1 - @repl io 1 + A - @repl io A + A - @repl io 2 * A - @repl io A * 3 - @repl io A / 3 - @repl io 3 / A - - @test check_output("interface/asmd.txt", String(take!(io))) - end - - let - A_ref = [0.0 0.0 4.4; 1.1 0.0 0.0; 2.2 0.0 5.5; 3.3 0.0 0.0] - A_ref = A_ref * floatmax()/sum(A_ref) - A= Tensor(Dense(SparseList(Element(0.0))), A_ref) - @test sum(A) == sum(A_ref) - @test minimum(A) == minimum(A_ref) - @test maximum(A) == maximum(A_ref) - @test extrema(A) == extrema(A_ref) - @test norm(A) == norm(A_ref) - @test norm(A, -Inf) == norm(A_ref, -Inf) - @test norm(A, 0) == norm(A_ref, 0) - @test norm(A, 1) == norm(A_ref, 1) - @test norm(A, 1.5) == norm(A_ref, 1.5) - @test norm(A, Inf) == norm(A_ref, Inf) - end - - let - A = Tensor(Dense(SparseList(Element(0.0))), [0.0 0.0 4.4; 1.1 0.0 0.0; 2.2 0.0 5.5; 3.3 0.0 0.0]) - B = Tensor(Dense(SparseList(Element(0.0))), [0.0 0.0 4.4; 1.1 0.0 0.0; 2.2 0.0 5.5; 3.3 0.0 0.0]) - C = lazy(A) - D = lazy(B) - E = (C + D) * 0.5 - F = compute(E) - @test F == A - end - - let - A = Tensor(Dense(SparseList(Element(0))), [0 0 44; 11 0 0; 22 00 55; 33 0 0]) - B = Tensor(Dense(SparseList(Element(0))), [0 0 44; 11 0 0; 22 00 55; 33 0 0]) - c_correct = Tensor(Dense(Dense(Element(0))), [1936 0 2420 0; 0 121 242 363; 2420 242 3509 726; 0 363 726 1089]) - c = compute(tensordot(lazy(A), lazy(B), ((2, ), (2,)), init=0)) - @test c == c_correct - end - - let - A = lazy(Tensor(Dense(SparseList(Element(0))), [0 0 44; 11 0 0; 22 00 55; 33 0 0])) - B = lazy(Tensor(Dense(SparseList(Element(0))), [0 0 44; 11 0 0; 22 00 55; 33 0 0]')) - c_correct = Tensor(Dense(Dense(Element(0))), [1936 0 2420 0; 0 121 242 363; 2420 242 3509 726; 0 363 726 1089]) - c = compute(sum(A[:, :, nothing] .* B[nothing, :, :], dims=[2])) - @test c == c_correct - end - - #https://github.com/willow-ahrens/Finch.jl/issues/457 - let - A = zeros(2, 3, 3) - A[1, :, :] = [1 2 3; 4 5 6; 7 8 9] - A[2, :, :] = [1 1 1; 2 2 2; 3 3 3] - perm = (2, 3, 1) - A_t = permutedims(A, perm) - - A_tns = Tensor(Dense(Dense(Dense(Element(0.0)))), A) - A_sw = swizzle(A_tns, perm...) - A_lazy = lazy(A_sw) - - A_result = compute(A_lazy) - - @test Array(A_result) == A_t - @test permutedims(A_tns, perm) == A_t - end - - #https://github.com/willow-ahrens/Finch.jl/pull/477 - let - A = zeros(2, 3, 3) - A_tns = Tensor(Dense(Dense(Dense(Element(0.0)))), A) - - @test compute(A) == A #If the scheduler improves, we can change this to === - @test compute(A_tns) == A_tns #If the scheduler improves, we can change this to === - end - - #https://github.com/willow-ahrens/Finch.jl/issues/481 - let - r = fsprand(1, 10, 10, 0.01) - r_tns = Tensor(Dense(Dense(Dense(Element(0.0)))), r) - @test r_tns + r_tns == 2 * r_tns - end - - #https://github.com/willow-ahrens/Finch.jl/issues/487 - let - a = fsprand(10, 1, 0.8) - b = fsprand(10, 1, 0.8) - - permutedims(broadcast(+, permutedims(a, (2, 1)), permutedims(b, (2, 1))), (2, 1)) # passes - - a_l = lazy(a) - b_l = lazy(b) - - plan = permutedims(broadcast(+, permutedims(a_l, (2, 1)), permutedims(b_l, (2, 1))), (2, 1)) - compute(plan) # fails - end end From 7437202966662d9d5908ba9804d45dc54f2b82e1 Mon Sep 17 00:00:00 2001 From: Willow Ahrens Date: Mon, 13 May 2024 12:49:10 -0400 Subject: [PATCH 07/10] fixing more stuff --- src/interface/lazy.jl | 2 +- src/interface/traits.jl | 33 +++++++++++++++++++++------------ src/scheduler/optimize.jl | 5 ++--- 3 files changed, 24 insertions(+), 16 deletions(-) diff --git a/src/interface/lazy.jl b/src/interface/lazy.jl index 424ffc424..e9a9d3945 100644 --- a/src/interface/lazy.jl +++ b/src/interface/lazy.jl @@ -31,8 +31,8 @@ end function expanddims(arr::LazyTensor{T}, dims) where {T} @assert allunique(dims) + @assert issubset(dims,1:ndims(arr) + length(dims)) antidims = setdiff(1:ndims(arr) + length(dims), dims) - @assert length(antidims) == ndims(arr) idxs_1 = [field(gensym(:i)) for _ in 1:ndims(arr)] idxs_2 = [field(gensym(:i)) for _ in 1:ndims(arr) + length(dims)] idxs_2[antidims] .= idxs_1 diff --git a/src/interface/traits.jl b/src/interface/traits.jl index eb5412573..725b0ece4 100644 --- a/src/interface/traits.jl +++ b/src/interface/traits.jl @@ -137,15 +137,24 @@ paddims_rep(rep, n) = ndims(rep) < n ? paddims_rep(ExtrudeData(rep), n) : rep """ expanddims_rep(tns, dims) -Expand the representation of `tns` to the dimensions `dims`, which must have length(ndims(tns)) and be in ascending order. +Expand the representation of `tns` by inserting singleton dimensions `dims`. """ -expanddims_rep(tns, dims) = expanddims_rep_def(tns, reverse(dims)...) -expanddims_rep_def(tns) = tns -expanddims_rep_def(tns::HollowData, dims...) = HollowData(expanddims_rep_def(tns.lvl, dims...)) -expanddims_rep_def(tns::SparseData, dim, dims...) = SparseData(paddims_rep(expanddims_rep_def(tns.lvl, dims...), dim - 1)) -expanddims_rep_def(tns::RepeatData, dim, dims...) = RepeatData(paddims_rep(expanddims_rep_def(tns.lvl, dims...), dim - 1)) -expanddims_rep_def(tns::DenseData, dim, dims...) = DenseData(paddims_rep(expanddims_rep_def(tns.lvl, dims...), dim - 1)) -expanddims_rep_def(tns::ExtrudeData, dim, dims...) = ExtrudeData(paddims_rep(expanddims_rep_def(tns.lvl, dims...), dim - 1)) +function expanddims_rep(tns, dims) + @assert allunique(dims) + @assert issubset(dims,1:ndims(tns) + length(dims)) + expanddims_rep_def(tns, ndims(tns) + length(dims), dims) +end +expanddims_rep_def(tns::HollowData, dim, dims) = HollowData(expanddims_rep_def(tns.lvl, dim, dims)) +expanddims_rep_def(tns::ElementData, dim, dims) = + dim in dims ? ExtrudeData(expanddims_rep_def(tns, dim-1, dims)) : tns +expanddims_rep_def(tns::SparseData, dim, dims) = + dim in dims ? ExtrudeData(expanddims_rep_def(tns, dim-1, dims)) : SparseData(expanddims_rep_def(tns.lvl, dim-1, dims)) +expanddims_rep_def(tns::RepeatData, dim, dims) = + dim in dims ? ExtrudeData(expanddims_rep_def(tns, dim-1, dims)) : RepeatData(expanddims_rep_def(tns.lvl, dim-1, dims)) +expanddims_rep_def(tns::DenseData, dim, dims) = + dim in dims ? ExtrudeData(expanddims_rep_def(tns, dim-1, dims)) : DenseData(expanddims_rep_def(tns.lvl, dim-1, dims)) +expanddims_rep_def(tns::ExtrudeData, dim, dims) = + dim in dims ? ExtrudeData(expanddims_rep_def(tns, dim-1, dims)) : ExtrudeData(expanddims_rep_def(tns.lvl, dim-1, dims)) struct MapRepExtrudeStyle end struct MapRepSparseStyle end @@ -309,12 +318,12 @@ function permutedims_rep(tns, perm) end n += 1 end - src = expanddims_rep(tns, src_dims) + src = expanddims_rep(tns, setdiff(1:maximum(src_dims, init=0), src_dims)) for mask_dims in diags - mask = expanddims_rep(DenseData(SparseData(ElementData(false, Bool))), mask_dims) - src = map_rep(filterop(default(src)), paddims_rep(mask, ndims(src)), src) + mask = expanddims_rep(DenseData(SparseData(ElementData(false, Bool))), setdiff(1:ndims(src), mask_dims)) + src = map_rep(filterop(default(src)), mask, src) end - aggregate_rep(initwrite(default(tns)), default(tns), src, setdiff(src_dims, dst_dims)) + res = aggregate_rep(initwrite(default(tns)), default(tns), src, setdiff(src_dims, dst_dims)) end """ diff --git a/src/scheduler/optimize.jl b/src/scheduler/optimize.jl index 28d34c5df..8aa466c1d 100644 --- a/src/scheduler/optimize.jl +++ b/src/scheduler/optimize.jl @@ -344,13 +344,12 @@ function (ctx::SuitableRep)(ex) rep = ctx(ex.arg) idxs = getfields(ex.arg) #first reduce dropped dimensions - rep = aggregate_rep(initwrite(default(rep)), default(rep), rep, setdiff(idxs, ex.idxs)) + rep = aggregate_rep(initwrite(default(rep)), default(rep), rep, findall(idx -> idx in setdiff(idxs, ex.idxs), idxs)) #then permute remaining dimensions to match perm = sortperm(intersect(idxs, ex.idxs), by=idx->findfirst(isequal(idx), ex.idxs)) rep = permutedims_rep(rep, perm) - dims = findall(idx -> idx in idxs, ex.idxs) #then add new dimensions - return paddims_rep(expanddims_rep(rep, dims), length(ex.idxs)) + return expanddims_rep(rep, findall(idx -> !(idx in idxs), ex.idxs)) elseif ex.kind === relabel return ctx(ex.arg) elseif ex.kind === reformat From ebc6a8afe990867624f4e2332fb23a907d448504 Mon Sep 17 00:00:00 2001 From: Willow Ahrens Date: Mon, 13 May 2024 12:52:51 -0400 Subject: [PATCH 08/10] idk --- src/interface/lazy.jl | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/interface/lazy.jl b/src/interface/lazy.jl index e9a9d3945..cb57ee8fa 100644 --- a/src/interface/lazy.jl +++ b/src/interface/lazy.jl @@ -464,14 +464,7 @@ get_scheduler() = current_scheduler[] """ with_scheduler(f, scheduler) -Execute `f` with the current scheduler set to `scheduler`. For example, -```jldoctest -with_scheduler(LogicExecutor(DefaultLogicOptimizer(LogicCompiler()), verbose=true)) do - x = lazy([1, 2]) - y = lazy([3, 4]) - compute(x + y) -end -``` +Execute `f` with the current scheduler set to `scheduler`. """ function with_scheduler(f, scheduler) old_scheduler = get_scheduler() From 8fed3d9372c739ba653d852674a4c7dc9b227234 Mon Sep 17 00:00:00 2001 From: Willow Ahrens Date: Mon, 13 May 2024 12:58:49 -0400 Subject: [PATCH 09/10] idk --- src/interface/einsum.jl | 4 ++-- test/test_interface.jl | 7 +++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/interface/einsum.jl b/src/interface/einsum.jl index 54aa29fa9..12a61c442 100644 --- a/src/interface/einsum.jl +++ b/src/interface/einsum.jl @@ -125,10 +125,10 @@ function (ctx::EinsumParserVisitor)(ex) $(esc(tns)) = $einsum($(esc(op)), $arg, $(map(QuoteNode, idxs)...);$(map(esc, ctx.opts)...),) end else - throw(FinchSyntaxError("Invalid einsum expression: $ex")) + throw(FinchNotation.FinchSyntaxError("Invalid einsum expression: $ex")) end else - throw(FinchSyntaxError("Invalid einsum expression type: $ex")) + throw(FinchNotation.FinchSyntaxError("Invalid einsum expression type: $ex")) end end diff --git a/test/test_interface.jl b/test/test_interface.jl index f1113f108..d112f95e7 100644 --- a/test/test_interface.jl +++ b/test/test_interface.jl @@ -8,6 +8,13 @@ using Finch: AsArray Finch.with_scheduler(scheduler) do @info "Testing $scheduler" + #https://github.com/willow-ahrens/Finch.jl/issues/520 + let + A = rand(2, 2) + x = lazy(rand(2)) + @test @einsum y[i] = A[i, j] * x[j] + end + #https://github.com/willow-ahrens/Finch.jl/issues/554 let @test broadcast(trunc, swizzle(Tensor(ones(1)), 1)) == Tensor(ones(1)) From e6a175b7365d6c1a4ca9aceef495669b4ce460d8 Mon Sep 17 00:00:00 2001 From: Willow Ahrens Date: Mon, 13 May 2024 13:01:08 -0400 Subject: [PATCH 10/10] oops --- test/test_interface.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/test_interface.jl b/test/test_interface.jl index d112f95e7..c4659716a 100644 --- a/test/test_interface.jl +++ b/test/test_interface.jl @@ -11,8 +11,10 @@ using Finch: AsArray #https://github.com/willow-ahrens/Finch.jl/issues/520 let A = rand(2, 2) - x = lazy(rand(2)) - @test @einsum y[i] = A[i, j] * x[j] + x = rand(2) + lx = lazy(x) + y = compute(@einsum y[i] += A[i, j] * lx[j]) + @test y == A * x end #https://github.com/willow-ahrens/Finch.jl/issues/554