Skip to content

Commit

Permalink
[MetaSchedule] postproc: rewrite_cooperative_fetch
Browse files Browse the repository at this point in the history
Co-authored-by: Junru Shao <[email protected]>
Co-authored-by: Xiyou Zhou <[email protected]>
Co-authored-by: Bohan Hou <[email protected]>
Co-authored-by: Ruihang Lai <[email protected]>
Co-authored-by: Hongyi Jin <[email protected]>
Co-authored-by: Wuwei Lin <[email protected]>
  • Loading branch information
7 people committed Jan 27, 2022
1 parent a40de47 commit d90b791
Show file tree
Hide file tree
Showing 4 changed files with 348 additions and 0 deletions.
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/postproc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""The tvm.meta_schedule.postproc package."""
from .postproc import Postproc, PyPostproc
from .disallow_dynamic_loop import DisallowDynamicLoop
from .rewrite_cooperative_fetch import RewriteCooperativeFetch
from .rewrite_parallel_vectorize_unroll import RewriteParallelVectorizeUnroll
from .rewrite_reduction_block import RewriteReductionBlock
from .rewrite_unbound_block import RewriteUnboundBlock
Expand Down
34 changes: 34 additions & 0 deletions python/tvm/meta_schedule/postproc/rewrite_cooperative_fetch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""A postprocessor that rewrites the cooperative fetch annotation to actual
vectorized cooperative fetching in loop bindings."""

from tvm._ffi.registry import register_object
from .. import _ffi_api
from .postproc import Postproc


@register_object("meta_schedule.RewriteCooperativeFetch")
class RewriteCooperativeFetch(Postproc):
"""A postprocessor that rewrites the cooperative fetch annotation to actual vectorized
cooperative fetching in loop bindings.
"""

def __init__(self) -> None:
self.__init_handle_by_constructor__(
_ffi_api.PostprocRewriteCooperativeFetch, # type: ignore # pylint: disable=no-member
)
158 changes: 158 additions & 0 deletions src/meta_schedule/postproc/rewrite_cooperative_fetch.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "../utils.h"

namespace tvm {
namespace tir {

/*!
* \brief Parse instruction: sch.bind(..., axis)
* \param sch The schedule
* \param inst The instruction to be parsed
* \param axis The axis name expected
* \return NullOpt if parsing fails; Otherwise, the extent of thread axis
*/
Optional<Integer> ParseThreadBinding(const Schedule& sch, const Instruction& inst, String axis) {
static InstructionKind inst_kind_bind = InstructionKind::Get("Bind");
if (!inst->kind.same_as(inst_kind_bind)) {
return NullOpt;
}
ICHECK_EQ(inst->inputs.size(), 1);
ICHECK_EQ(inst->attrs.size(), 1);
String thread_axis = Downcast<String>(inst->attrs[0]);
if (thread_axis != axis) {
return NullOpt;
}
return Downcast<Integer>(sch->Get(Downcast<LoopRV>(inst->inputs[0]))->extent);
}

/*!
* \brief Parse instruction: sch.annotate(..., attr::meta_schedule_cooperative_fetch)
* \param sch The schedule
* \param inst The instruction to be parsed
* \param vector_lane The length of vector lane in vectorized cooperative fetching
* \return NullOpt if parsing fails; Otherwise, the annotated block
*/
Optional<BlockRV> ParseAnnotate(const Schedule& sch, const Instruction& inst, int* vector_lane) {
static InstructionKind inst_kind_annotate = InstructionKind::Get("Annotate");
if (!inst->kind.same_as(inst_kind_annotate)) {
return NullOpt;
}
ICHECK_EQ(inst->inputs.size(), 2);
ICHECK_EQ(inst->attrs.size(), 1);
String ann_key = Downcast<String>(inst->attrs[0]);
if (ann_key != attr::meta_schedule_cooperative_fetch) {
return NullOpt;
}
*vector_lane = Downcast<Integer>(sch->Get(Downcast<ExprRV>(inst->inputs[1])))->value;
return Downcast<BlockRV>(inst->inputs[0]);
}

} // namespace tir

namespace meta_schedule {

/*!
* \brief Rewrite the cooperative fetch annotation to actual vectorized cooperative fetching
* in loop bindings.
*/
class RewriteCooperativeFetchNode : public PostprocNode {
public:
// Inherited from PostprocNode
void InitializeWithTuneContext(const TuneContext& context) final {}
// Inherited from PostprocNode
bool Apply(const tir::Schedule& sch) final;

void VisitAttrs(tvm::AttrVisitor* v) {}

static constexpr const char* _type_key = "meta_schedule.RewriteCooperativeFetch";
TVM_DECLARE_FINAL_OBJECT_INFO(RewriteCooperativeFetchNode, PostprocNode);
};

bool RewriteCooperativeFetchNode::Apply(const tir::Schedule& sch) {
tir::Trace trace = sch->trace().value();
int thread_extent_x = -1;
int thread_extent_y = -1;
int vector_lane = -1;
std::vector<std::function<void()>> tasks;
for (const tir::Instruction& inst : trace->insts) {
if (Optional<Integer> new_thread_extent = tir::ParseThreadBinding(sch, inst, "threadIdx.x")) {
thread_extent_x = new_thread_extent.value()->value;
} else if (Optional<Integer> new_thread_extent =
tir::ParseThreadBinding(sch, inst, "threadIdx.y")) {
thread_extent_y = new_thread_extent.value()->value;
} else if (Optional<tir::BlockRV> block_rv = tir::ParseAnnotate(sch, inst, &vector_lane)) {
ICHECK_NE(thread_extent_x, -1);
if (vector_lane > 1) {
tasks.push_back([thread_extent_x, thread_extent_y, vector_lane, sch,
block = block_rv.value()]() -> void {
tir::LoopRV fused = sch->GetLoops(block).back();
if (thread_extent_y == -1) {
Array<tir::LoopRV> split = sch->Split(fused, {NullOpt, //
Integer(thread_extent_x), //
Integer(vector_lane)});
sch->Vectorize(split[2]);
sch->Bind(split[1], "threadIdx.x");
} else {
Array<tir::LoopRV> split = sch->Split(fused, {NullOpt, //
Integer(thread_extent_y), //
Integer(thread_extent_x), //
Integer(vector_lane)});
sch->Vectorize(split[3]);
sch->Bind(split[2], "threadIdx.x");
sch->Bind(split[1], "threadIdx.y");
sch->StorageAlign(block, 0, -2, 32, 8);
}
});
} else {
tasks.push_back(
[thread_extent_x, thread_extent_y, sch, block = block_rv.value()]() -> void {
tir::LoopRV fused = sch->GetLoops(block).back();
if (thread_extent_y == -1) {
Array<tir::LoopRV> split = sch->Split(fused, {NullOpt, Integer(thread_extent_x)});
sch->Bind(split[1], "threadIdx.x");
} else {
Array<tir::LoopRV> split = sch->Split(fused, {NullOpt, //
Integer(thread_extent_y), //
Integer(thread_extent_x)});
sch->Bind(split[2], "threadIdx.x");
sch->Bind(split[1], "threadIdx.y");
sch->StorageAlign(block, 0, -2, 32, 8);
}
});
}
}
}
for (auto&& task : tasks) {
task();
}
return true;
}

Postproc Postproc::RewriteCooperativeFetch() {
ObjectPtr<RewriteCooperativeFetchNode> n = make_object<RewriteCooperativeFetchNode>();
return Postproc(n);
}

TVM_REGISTER_NODE_TYPE(RewriteCooperativeFetchNode);
TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteCooperativeFetch")
.set_body_typed(Postproc::RewriteCooperativeFetch);

} // namespace meta_schedule
} // namespace tvm
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring

import tvm
from tvm import tir
from tvm.meta_schedule import TuneContext
from tvm.meta_schedule.postproc import RewriteCooperativeFetch
from tvm.meta_schedule.testing import te_workload
from tvm.script import tir as T
from tvm.target import Target
from tvm.te import create_prim_func


def _target() -> Target:
return Target("cuda", host="llvm")


def _create_context(mod, target) -> TuneContext:
ctx = TuneContext(
mod=mod,
target=target,
postprocs=[
RewriteCooperativeFetch(),
],
task_name="test",
)
for rule in ctx.postprocs:
rule.initialize_with_tune_context(ctx)
return ctx


# fmt: off
# pylint: disable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks

@tvm.script.ir_module
class AfterRewrite0:
@T.prim_func
def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
A = T.match_buffer(var_A, [512, 512], dtype="float32")
B = T.match_buffer(var_B, [512, 512], dtype="float32")
C = T.match_buffer(var_C, [512, 512], dtype="float32")
# body
# with T.block("root")
C_local = T.alloc_buffer([512, 512], dtype="float32", scope="local")
A_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared")
B_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared")
for i0_0_i1_0_fused in T.thread_binding(0, 16, thread="blockIdx.x"):
for i0_1_i1_1_fused in T.thread_binding(0, 16, thread="vthread.x"):
for i0_2_i1_2_fused in T.thread_binding(0, 8, thread="threadIdx.x"):
for i2_0 in T.serial(0, 1):
for ax0_ax1_fused_0 in T.serial(0, 32768):
for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.x"):
with T.block("A_shared"):
v0 = T.axis.spatial(512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) // 512)
v1 = T.axis.spatial(512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) % 512)
T.reads([A[v0, v1]])
T.writes([A_shared[v0, v1]])
T.block_attr({"meta_schedule.cooperative_fetch":1})
A_shared[v0, v1] = A[v0, v1]
for ax0_ax1_fused_0 in T.serial(0, 1024):
for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.x"):
for ax0_ax1_fused_2 in T.vectorized(0, 2):
with T.block("B_shared"):
v0 = T.axis.spatial(512, (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) // 32)
v1 = T.axis.spatial(512, i0_0_i1_0_fused * 32 + (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) % 32)
T.reads([B[v0, v1]])
T.writes([B_shared[v0, v1]])
T.block_attr({"meta_schedule.cooperative_fetch":2})
B_shared[v0, v1] = B[v0, v1]
for i2_1, i0_3, i1_3, i2_2, i0_4, i1_4 in T.grid(16, 2, 2, 32, 16, 2):
with T.block("C"):
i = T.axis.spatial(512, i0_1_i1_1_fused * 32 + i0_3 * 16 + i0_4)
j = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + i1_3 * 2 + i1_4)
k = T.axis.reduce(512, i2_1 * 32 + i2_2)
T.reads([C_local[i, j], A_shared[i, k], B_shared[k, j]])
T.writes([C_local[i, j]])
with T.init():
C_local[i, j] = T.float32(0)
C_local[i, j] = C_local[i, j] + A_shared[i, k] * B_shared[k, j]
for ax0, ax1 in T.grid(32, 4):
with T.block("C_local"):
v0 = T.axis.spatial(512, i0_1_i1_1_fused * 32 + ax0)
v1 = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + ax1)
T.reads([C_local[v0, v1]])
T.writes([C[v0, v1]])
C[v0, v1] = C_local[v0, v1]


# pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks
# fmt: on


def test_rewrite_cooperative_fetch():
mod = create_prim_func(te_workload.matmul(n=512, m=512, k=512))
target = _target()
ctx = _create_context(mod, target)

sch = tir.Schedule(mod, debug_mask="all")
# fmt: off
# pylint: disable=line-too-long,invalid-name
b0 = sch.get_block(name="C", func_name="main")
b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")
l2, l3, l4 = sch.get_loops(block=b0)
v5, v6, v7, v8, v9 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64, decision=[1, 16, 1, 2, 16])
l10, l11, l12, l13, l14 = sch.split(loop=l2, factors=[v5, v6, v7, v8, v9])
v15, v16, v17, v18, v19 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[16, 1, 8, 2, 2])
l20, l21, l22, l23, l24 = sch.split(loop=l3, factors=[v15, v16, v17, v18, v19])
v25, v26, v27 = sch.sample_perfect_tile(loop=l4, n=3, max_innermost_factor=64, decision=[1, 16, 32])
l28, l29, l30 = sch.split(loop=l4, factors=[v25, v26, v27])
sch.reorder(l10, l20, l11, l21, l12, l22, l28, l29, l13, l23, l30, l14, l24)
l31 = sch.fuse(l10, l20)
sch.bind(loop=l31, thread_axis="blockIdx.x")
l32 = sch.fuse(l11, l21)
sch.bind(loop=l32, thread_axis="vthread.x")
l33 = sch.fuse(l12, l22)
sch.bind(loop=l33, thread_axis="threadIdx.x")
b34 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")
sch.compute_at(block=b34, loop=l28, preserve_unit_loops=True)
_, _, _, _, l39, l40 = sch.get_loops(block=b34)
l41 = sch.fuse(l39, l40)
_, v43 = sch.sample_perfect_tile(loop=l41, n=2, max_innermost_factor=4, decision=[262144, 1])
sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v43)
b44 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")
sch.compute_at(block=b44, loop=l28, preserve_unit_loops=True)
_, _, _, _, l49, l50 = sch.get_loops(block=b44)
l51 = sch.fuse(l49, l50)
_, v53 = sch.sample_perfect_tile(loop=l51, n=2, max_innermost_factor=4, decision=[8192, 2])
sch.annotate(block_or_loop=b44, ann_key="meta_schedule.cooperative_fetch", ann_val=v53)
sch.reverse_compute_at(block=b1, loop=l33, preserve_unit_loops=True)
# pylint: enable=line-too-long,invalid-name
# fmt: on
sch.enter_postproc()
assert ctx.postprocs[0].apply(sch)
tvm.ir.assert_structural_equal(sch.mod, AfterRewrite0)


if __name__ == "__main__":
test_rewrite_cooperative_fetch()

0 comments on commit d90b791

Please sign in to comment.