From 538347e49f89d345b48d68382a15fa5ffca698b2 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Sun, 30 Jan 2022 05:59:09 +0800 Subject: [PATCH] [MetaSchedule] postproc: rewrite_cooperative_fetch (#10081) Co-authored-by: Junru Shao Co-authored-by: Xiyou Zhou Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Ruihang Lai Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin Co-authored-by: Junru Shao Co-authored-by: Xiyou Zhou Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Ruihang Lai Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin --- python/tvm/meta_schedule/postproc/__init__.py | 1 + .../postproc/rewrite_cooperative_fetch.py | 34 ++++ .../postproc/rewrite_cooperative_fetch.cc | 156 ++++++++++++++++++ ...dule_postproc_rewrite_cooperative_fetch.py | 155 +++++++++++++++++ 4 files changed, 346 insertions(+) create mode 100644 python/tvm/meta_schedule/postproc/rewrite_cooperative_fetch.py create mode 100644 src/meta_schedule/postproc/rewrite_cooperative_fetch.cc create mode 100644 tests/python/unittest/test_meta_schedule_postproc_rewrite_cooperative_fetch.py diff --git a/python/tvm/meta_schedule/postproc/__init__.py b/python/tvm/meta_schedule/postproc/__init__.py index 0c914ac809f9..96361e739186 100644 --- a/python/tvm/meta_schedule/postproc/__init__.py +++ b/python/tvm/meta_schedule/postproc/__init__.py @@ -17,6 +17,7 @@ """The tvm.meta_schedule.postproc package.""" from .postproc import Postproc, PyPostproc from .disallow_dynamic_loop import DisallowDynamicLoop +from .rewrite_cooperative_fetch import RewriteCooperativeFetch from .rewrite_parallel_vectorize_unroll import RewriteParallelVectorizeUnroll from .rewrite_reduction_block import RewriteReductionBlock from .rewrite_unbound_block import RewriteUnboundBlock diff --git a/python/tvm/meta_schedule/postproc/rewrite_cooperative_fetch.py b/python/tvm/meta_schedule/postproc/rewrite_cooperative_fetch.py new file mode 100644 index 000000000000..e2d7c2212382 --- /dev/null +++ b/python/tvm/meta_schedule/postproc/rewrite_cooperative_fetch.py @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A postprocessor that rewrites the cooperative fetch annotation to actual +vectorized cooperative fetching in loop bindings.""" + +from tvm._ffi.registry import register_object +from .. import _ffi_api +from .postproc import Postproc + + +@register_object("meta_schedule.RewriteCooperativeFetch") +class RewriteCooperativeFetch(Postproc): + """A postprocessor that rewrites the cooperative fetch annotation to actual vectorized + cooperative fetching in loop bindings. + """ + + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.PostprocRewriteCooperativeFetch, # type: ignore # pylint: disable=no-member + ) diff --git a/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc b/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc new file mode 100644 index 000000000000..ad8ee9854265 --- /dev/null +++ b/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +/*! + * \brief Parse instruction: sch.bind(..., axis) + * \param sch The schedule + * \param inst The instruction to be parsed + * \param axis The axis name expected + * \return NullOpt if parsing fails; Otherwise, the extent of thread axis + */ +Optional ParseThreadBinding(const Schedule& sch, const Instruction& inst, String axis) { + static InstructionKind inst_kind_bind = InstructionKind::Get("Bind"); + if (!inst->kind.same_as(inst_kind_bind)) { + return NullOpt; + } + ICHECK_EQ(inst->inputs.size(), 1); + ICHECK_EQ(inst->attrs.size(), 1); + String thread_axis = Downcast(inst->attrs[0]); + if (thread_axis != axis) { + return NullOpt; + } + return Downcast(sch->Get(Downcast(inst->inputs[0]))->extent); +} + +/*! + * \brief Parse instruction: sch.annotate(..., attr::meta_schedule_cooperative_fetch) + * \param sch The schedule + * \param inst The instruction to be parsed + * \param vector_lane The number of vector lane in vectorized cooperative fetching + * \return NullOpt if parsing fails; Otherwise, the annotated block + */ +Optional ParseAnnotate(const Schedule& sch, const Instruction& inst, int* vector_lane) { + static InstructionKind inst_kind_annotate = InstructionKind::Get("Annotate"); + if (!inst->kind.same_as(inst_kind_annotate)) { + return NullOpt; + } + ICHECK_EQ(inst->inputs.size(), 2); + ICHECK_EQ(inst->attrs.size(), 1); + String ann_key = Downcast(inst->attrs[0]); + if (ann_key != attr::meta_schedule_cooperative_fetch) { + return NullOpt; + } + *vector_lane = Downcast(sch->Get(Downcast(inst->inputs[1])))->value; + return Downcast(inst->inputs[0]); +} + +} // namespace tir + +namespace meta_schedule { + +/*! + * \brief Rewrite the cooperative fetch annotation to actual vectorized cooperative fetching + * in loop bindings. + */ +class RewriteCooperativeFetchNode : public PostprocNode { + public: + // Inherited from PostprocNode + void InitializeWithTuneContext(const TuneContext& context) final {} + // Inherited from PostprocNode + bool Apply(const tir::Schedule& sch) final; + + void VisitAttrs(tvm::AttrVisitor* v) {} + + static constexpr const char* _type_key = "meta_schedule.RewriteCooperativeFetch"; + TVM_DECLARE_FINAL_OBJECT_INFO(RewriteCooperativeFetchNode, PostprocNode); +}; + +bool RewriteCooperativeFetchNode::Apply(const tir::Schedule& sch) { + tir::Trace trace = sch->trace().value(); + int thread_extent_x = -1; + int thread_extent_y = -1; + int vector_lane = -1; + std::vector> tasks; + for (const tir::Instruction& inst : trace->insts) { + if (Optional new_thread_extent = tir::ParseThreadBinding(sch, inst, "threadIdx.x")) { + thread_extent_x = new_thread_extent.value()->value; + } else if (Optional new_thread_extent = + tir::ParseThreadBinding(sch, inst, "threadIdx.y")) { + thread_extent_y = new_thread_extent.value()->value; + } else if (Optional block_rv = tir::ParseAnnotate(sch, inst, &vector_lane)) { + ICHECK_NE(thread_extent_x, -1); + if (vector_lane > 1) { + tasks.push_back([thread_extent_x, thread_extent_y, vector_lane, sch, + block = block_rv.value()]() -> void { + tir::LoopRV fused = sch->GetLoops(block).back(); + if (thread_extent_y == -1) { + Array split = sch->Split(fused, {NullOpt, // + Integer(thread_extent_x), // + Integer(vector_lane)}); + sch->Vectorize(split[2]); + sch->Bind(split[1], "threadIdx.x"); + } else { + Array split = sch->Split(fused, {NullOpt, // + Integer(thread_extent_y), // + Integer(thread_extent_x), // + Integer(vector_lane)}); + sch->Vectorize(split[3]); + sch->Bind(split[2], "threadIdx.x"); + sch->Bind(split[1], "threadIdx.y"); + } + }); + } else { + tasks.push_back( + [thread_extent_x, thread_extent_y, sch, block = block_rv.value()]() -> void { + tir::LoopRV fused = sch->GetLoops(block).back(); + if (thread_extent_y == -1) { + Array split = sch->Split(fused, {NullOpt, Integer(thread_extent_x)}); + sch->Bind(split[1], "threadIdx.x"); + } else { + Array split = sch->Split(fused, {NullOpt, // + Integer(thread_extent_y), // + Integer(thread_extent_x)}); + sch->Bind(split[2], "threadIdx.x"); + sch->Bind(split[1], "threadIdx.y"); + } + }); + } + } + } + for (auto&& task : tasks) { + task(); + } + return true; +} + +Postproc Postproc::RewriteCooperativeFetch() { + ObjectPtr n = make_object(); + return Postproc(n); +} + +TVM_REGISTER_NODE_TYPE(RewriteCooperativeFetchNode); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteCooperativeFetch") + .set_body_typed(Postproc::RewriteCooperativeFetch); + +} // namespace meta_schedule +} // namespace tvm diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_cooperative_fetch.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_cooperative_fetch.py new file mode 100644 index 000000000000..31e92e09e50e --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_cooperative_fetch.py @@ -0,0 +1,155 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +import tvm +from tvm import tir +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.postproc import RewriteCooperativeFetch +from tvm.meta_schedule.testing import te_workload +from tvm.script import tir as T +from tvm.target import Target +from tvm.te import create_prim_func + + +def _target() -> Target: + return Target("cuda", host="llvm") + + +def _create_context(mod, target) -> TuneContext: + ctx = TuneContext( + mod=mod, + target=target, + postprocs=[ + RewriteCooperativeFetch(), + ], + task_name="test", + ) + for rule in ctx.postprocs: + rule.initialize_with_tune_context(ctx) + return ctx + + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks + +@tvm.script.ir_module +class AfterRewrite0: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = T.match_buffer(var_A, [512, 512], dtype="float32") + B = T.match_buffer(var_B, [512, 512], dtype="float32") + C = T.match_buffer(var_C, [512, 512], dtype="float32") + # body + # with T.block("root") + C_local = T.alloc_buffer([512, 512], dtype="float32", scope="local") + A_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") + B_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") + for i0_0_i1_0_fused in T.thread_binding(0, 16, thread="blockIdx.x"): + for i0_1_i1_1_fused in T.thread_binding(0, 16, thread="vthread.x"): + for i0_2_i1_2_fused in T.thread_binding(0, 8, thread="threadIdx.x"): + for i2_0 in T.serial(0, 1): + for ax0_ax1_fused_0 in T.serial(0, 32768): + for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.x"): + with T.block("A_shared"): + v0 = T.axis.spatial(512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) // 512) + v1 = T.axis.spatial(512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) % 512) + T.reads([A[v0, v1]]) + T.writes([A_shared[v0, v1]]) + T.block_attr({"meta_schedule.cooperative_fetch":1}) + A_shared[v0, v1] = A[v0, v1] + for ax0_ax1_fused_0 in T.serial(0, 1024): + for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.x"): + for ax0_ax1_fused_2 in T.vectorized(0, 2): + with T.block("B_shared"): + v0 = T.axis.spatial(512, (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) // 32) + v1 = T.axis.spatial(512, i0_0_i1_0_fused * 32 + (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) % 32) + T.reads([B[v0, v1]]) + T.writes([B_shared[v0, v1]]) + T.block_attr({"meta_schedule.cooperative_fetch":2}) + B_shared[v0, v1] = B[v0, v1] + for i2_1, i0_3, i1_3, i2_2, i0_4, i1_4 in T.grid(16, 2, 2, 32, 16, 2): + with T.block("C"): + i = T.axis.spatial(512, i0_1_i1_1_fused * 32 + i0_3 * 16 + i0_4) + j = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + i1_3 * 2 + i1_4) + k = T.axis.reduce(512, i2_1 * 32 + i2_2) + T.reads([C_local[i, j], A_shared[i, k], B_shared[k, j]]) + T.writes([C_local[i, j]]) + with T.init(): + C_local[i, j] = T.float32(0) + C_local[i, j] = C_local[i, j] + A_shared[i, k] * B_shared[k, j] + for ax0, ax1 in T.grid(32, 4): + with T.block("C_local"): + v0 = T.axis.spatial(512, i0_1_i1_1_fused * 32 + ax0) + v1 = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + ax1) + T.reads([C_local[v0, v1]]) + T.writes([C[v0, v1]]) + C[v0, v1] = C_local[v0, v1] + + +# pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks +# fmt: on + + +def test_rewrite_cooperative_fetch(): + mod = create_prim_func(te_workload.matmul(n=512, m=512, k=512)) + target = _target() + ctx = _create_context(mod, target) + + sch = tir.Schedule(mod, debug_mask="all") + # fmt: off + # pylint: disable=line-too-long,invalid-name + b0 = sch.get_block(name="C", func_name="main") + b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + l2, l3, l4 = sch.get_loops(block=b0) + v5, v6, v7, v8, v9 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64, decision=[1, 16, 1, 2, 16]) + l10, l11, l12, l13, l14 = sch.split(loop=l2, factors=[v5, v6, v7, v8, v9]) + v15, v16, v17, v18, v19 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[16, 1, 8, 2, 2]) + l20, l21, l22, l23, l24 = sch.split(loop=l3, factors=[v15, v16, v17, v18, v19]) + v25, v26, v27 = sch.sample_perfect_tile(loop=l4, n=3, max_innermost_factor=64, decision=[1, 16, 32]) + l28, l29, l30 = sch.split(loop=l4, factors=[v25, v26, v27]) + sch.reorder(l10, l20, l11, l21, l12, l22, l28, l29, l13, l23, l30, l14, l24) + l31 = sch.fuse(l10, l20) + sch.bind(loop=l31, thread_axis="blockIdx.x") + l32 = sch.fuse(l11, l21) + sch.bind(loop=l32, thread_axis="vthread.x") + l33 = sch.fuse(l12, l22) + sch.bind(loop=l33, thread_axis="threadIdx.x") + b34 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared") + sch.compute_at(block=b34, loop=l28, preserve_unit_loops=True) + _, _, _, _, l39, l40 = sch.get_loops(block=b34) + l41 = sch.fuse(l39, l40) + _, v43 = sch.sample_perfect_tile(loop=l41, n=2, max_innermost_factor=4, decision=[262144, 1]) + sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v43) + b44 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared") + sch.compute_at(block=b44, loop=l28, preserve_unit_loops=True) + _, _, _, _, l49, l50 = sch.get_loops(block=b44) + l51 = sch.fuse(l49, l50) + _, v53 = sch.sample_perfect_tile(loop=l51, n=2, max_innermost_factor=4, decision=[8192, 2]) + sch.annotate(block_or_loop=b44, ann_key="meta_schedule.cooperative_fetch", ann_val=v53) + sch.reverse_compute_at(block=b1, loop=l33, preserve_unit_loops=True) + # pylint: enable=line-too-long,invalid-name + # fmt: on + sch.enter_postproc() + assert ctx.postprocs[0].apply(sch) + tvm.ir.assert_structural_equal(sch.mod, AfterRewrite0) + + +if __name__ == "__main__": + test_rewrite_cooperative_fetch()