Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Limit the scope of Looplet/Unfurl closures #608

Closed
7 tasks done
willow-ahrens opened this issue Jun 18, 2024 · 2 comments · Fixed by #629
Closed
7 tasks done

Limit the scope of Looplet/Unfurl closures #608

willow-ahrens opened this issue Jun 18, 2024 · 2 comments · Fixed by #629

Comments

@willow-ahrens
Copy link
Collaborator

willow-ahrens commented Jun 18, 2024

In order to support parallelism on more exciting architectures, we must redesign the unfurl function to avoid carrying closed variables from one dimension to another, so that we can re-label those symbols. As an example of the problem, consider the following Gustavson's matmul:

function spgemm_finch_gustavson_kernel_parallel(A, B)
    # @assert Threads.nthreads() >= 2
    z = default(A) * default(B) + false
    C = Tensor(Dense(Separate(SparseList(Element(z)))))
    w = Tensor(Dense(Element(z)))
    @finch_code begin
        C .= 0
        for j=parallel(_)
            w .= 0
            for k=_, i=_; w[i] += A[i, k] * B[k, j] end
            for i=_; C[i, j] = w[i] end
        end
    end
    @finch begin
        C .= 0
        for j=parallel(_)
            w .= 0
            for k=_, i=_; w[i] += A[i, k] * B[k, j] end
            for i=_; C[i, j] = w[i] end
        end
    end
    return C
end

This produces:

quote
    C_lvl = ((ex.bodies[1]).bodies[1]).tns.bind.lvl
    C_lvl_2 = C_lvl.lvl
    C_lvl_3 = C_lvl_2.lvl
    C_lvl_2_val = C_lvl_2.lvl.val
    w_lvl = (((ex.bodies[1]).bodies[2]).body.bodies[1]).tns.bind.lvl
    w_lvl_val = w_lvl.lvl.val
    A_lvl = ((((ex.bodies[1]).bodies[2]).body.bodies[2]).body.body.rhs.args[1]).tns.bind.lvl
    A_lvl_2 = A_lvl.lvl
    A_lvl_2_val = A_lvl_2.lvl.val
    B_lvl = ((((ex.bodies[1]).bodies[2]).body.bodies[2]).body.body.rhs.args[2]).tns.bind.lvl
    B_lvl_2 = B_lvl.lvl
    B_lvl_2_val = B_lvl_2.lvl.val
    B_lvl_2.shape == A_lvl.shape || throw(DimensionMismatch("mismatched dimension limits ($(B_lvl_2.shape) != $(A_lvl.shape))"))
    result = nothing
    pos_stop = A_lvl_2.shape * B_lvl.shape
    Finch.resize_if_smaller!(C_lvl_2_val, pos_stop)
    Finch.fill_range!(C_lvl_2_val, 0.0, 1, pos_stop)
    B_lvl_2_val = (Finch).moveto(B_lvl_2_val, CPU(Threads.nthreads()))
    A_lvl_2_val = (Finch).moveto(A_lvl_2_val, CPU(Threads.nthreads()))
    val_3 = C_lvl_2_val
    C_lvl_2_val = (Finch).moveto(C_lvl_2_val, CPU(Threads.nthreads()))
    Threads.@threads for i_7 = 1:Threads.nthreads()
            val_4 = w_lvl_val
            w_lvl_val = (Finch).moveto(w_lvl_val, CPUThread(i_7, CPU(Threads.nthreads()), Serial()))
            phase_start_2 = max(1, 1 + fld(B_lvl.shape * (i_7 + -1), Threads.nthreads()))
            phase_stop_2 = min(B_lvl.shape, fld(B_lvl.shape * i_7, Threads.nthreads()))
            if phase_stop_2 >= phase_start_2
                for j_6 = phase_start_2:phase_stop_2
                    B_lvl_q = (1 - 1) * B_lvl.shape + j_6
                    C_lvl_q = (1 - 1) * B_lvl.shape + j_6
                    Finch.resize_if_smaller!(w_lvl_val, A_lvl_2.shape)
                    Finch.fill_range!(w_lvl_val, 0.0, 1, A_lvl_2.shape)
                    for k_4 = 1:B_lvl_2.shape
                        A_lvl_q = (1 - 1) * A_lvl.shape + k_4
                        B_lvl_2_q = (B_lvl_q - 1) * B_lvl_2.shape + k_4
                        B_lvl_3_val = B_lvl_2_val[B_lvl_2_q]
                        for i_8 = 1:A_lvl_2.shape
                            w_lvl_q = (1 - 1) * A_lvl_2.shape + i_8
                            A_lvl_2_q = (A_lvl_q - 1) * A_lvl_2.shape + i_8
                            A_lvl_3_val = A_lvl_2_val[A_lvl_2_q]
                            w_lvl_val[w_lvl_q] += B_lvl_3_val * A_lvl_3_val
                        end
                    end
                    resize!(w_lvl_val, A_lvl_2.shape)
                    for i_9 = 1:A_lvl_2.shape
                        C_lvl_2_q = (C_lvl_q - 1) * A_lvl_2.shape + i_9
                        w_lvl_q_2 = (1 - 1) * A_lvl_2.shape + i_9
                        w_lvl_2_val = w_lvl_val[w_lvl_q_2]
                        C_lvl_2_val[C_lvl_2_q] = w_lvl_2_val
                    end
                end
            end
            w_lvl_val = val_4
        end
    resize!(val_3, A_lvl_2.shape * B_lvl.shape)
    result = (C = Tensor((DenseLevel){Int64}((DenseLevel){Int64}(C_lvl_3, A_lvl_2.shape), B_lvl.shape)),)
    result
end

We try to moveto the workspace to local memory but we end up setting the
global variable w_lvl_val = copy(w_lvl_val). In general, some looplet
structures currently capture state from previous outer loops, state which is not
captured by moveto because it's not accounted for explicitly. The solution
here is that we need to represent the state of all subfibers (including COO and
Masks) as an explicit struct with field names we can call moveto on. This means that each call to unfurl
would occur in the scope of the loop being unrolled, rather than at the top
level call to instantiate. More generally, it would benefit the project if we
could distinguish between instantiate and unfurl at every loop nest, to
insert side effects that eagerly expand as soon as a tensor is unfurled, versus
side effects that wait until the last possible moment to expand with
instantiate. This might clean up some of the lowering for Scalars as well.

This may involve a fair amount of code change, but it would be for the better.
For example, the state carried by a COO level would be encapsulated in a more
explicit COOSubLevel struct that reflects the current COO index search variables.

In the end, this would enable moveto to "re-virtualize" tensors upon entering
parallel regions, which is critical for local variables.

Tasks:

  • SparseArray State Objects
  • AbstractArray State Objects
  • Masks State Objects
  • COO State Objects
  • Redesign Unfurl to take one protocol at a time
  • Update definitions of unfurl and instantiate
  • delete Furlable
@wraith1995
Copy link
Collaborator

wraith1995 commented Jun 18, 2024

@willow-ahrens I want to add that I think this issue is also potentially related to "closure" issue that we've had in generating parallel code. In particular, part of the issue there is that we don't know much about the local variables once something has been lowered so we loose track and need to recover what needs to be passed into a closure/function for running parallel code. If everyone had structs describing the local state with types, then at lowering time we could figure out the fields currently in use and just pass those in.

@willow-ahrens
Copy link
Collaborator Author

willow-ahrens commented Jun 18, 2024

Right, although the closure issue is also resolved with https://github.com/willow-ahrens/Finch.jl/blob/a7f5b899d49051167aebd7568a65c2840ea6db69/src/util/shims.jl#L59-L208 We can certainly copy the logic from there to inside the compiler, so that as long as we rename state variables correctly we can generate explicit closures at the moment of parallel lowering, based on the variables that we end up using in the code (i.e. lower it and see what variables it uses). You are correct though that this would enable the analysis you describe where it would be possible to write a function that lists all variables that are reachable from a block of finch code, without lowering it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants