diff --git a/include/tvm/topi/nn/rms_norm.h b/include/tvm/topi/nn/rms_norm.h index ba2f7e49ac98..7e95000f1ee2 100644 --- a/include/tvm/topi/nn/rms_norm.h +++ b/include/tvm/topi/nn/rms_norm.h @@ -67,6 +67,25 @@ inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const Array<Int for (int i : real_axis) { reduce_extent *= data_fp32->shape[i]; } + auto rsqrt_func = [&](const Array<Var>& indices) { + Array<Var> non_reduce_indices; + for (int i = 0, n = static_cast<int>(indices.size()); i < n; ++i) { + if (std::find(real_axis.begin(), real_axis.end(), i) == real_axis.end()) { + non_reduce_indices.push_back(indices[i]); + } + } + auto output = + tvm::rsqrt(square_sum(non_reduce_indices) / reduce_extent + make_const(data_type, epsilon)); + return output; + }; + auto rsqrt_shape = Array<PrimExpr>(); + for (int i = 0, n = static_cast<int>(data_fp32->shape.size()); i < n; ++i) { + if (std::find(real_axis.begin(), real_axis.end(), i) == real_axis.end()) { + rsqrt_shape.push_back(data_fp32->shape[i]); + } + } + auto rsqrt = tvm::te::compute(rsqrt_shape, rsqrt_func, "rsqrt", tag); + auto rms_norm_func = [&](const Array<Var>& indices) { Array<Var> reduce_indices, non_reduce_indices; for (int i = 0, n = static_cast<int>(indices.size()); i < n; ++i) { @@ -76,12 +95,11 @@ inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const Array<Int non_reduce_indices.push_back(indices[i]); } } - auto output = - data_fp32(indices) * weight_fp32(reduce_indices) * - tvm::rsqrt(square_sum(non_reduce_indices) / reduce_extent + make_const(data_type, epsilon)); + auto output = rsqrt(non_reduce_indices) * data_fp32(indices) * weight_fp32(reduce_indices); return output; }; auto rms_norm = tvm::te::compute(data_fp32->shape, rms_norm_func, name, tag); + return cast(rms_norm, data_type); } diff --git a/python/tvm/dlight/gpu/__init__.py b/python/tvm/dlight/gpu/__init__.py index f48bdb2c8182..7db383a161cd 100644 --- a/python/tvm/dlight/gpu/__init__.py +++ b/python/tvm/dlight/gpu/__init__.py @@ -24,3 +24,4 @@ from .reduction import Reduction from .transpose import Transpose from .general_reduction import GeneralReduction +from .rmsnorm import RMSNorm diff --git a/python/tvm/dlight/gpu/rmsnorm.py b/python/tvm/dlight/gpu/rmsnorm.py new file mode 100644 index 000000000000..f8b2bb4a172d --- /dev/null +++ b/python/tvm/dlight/gpu/rmsnorm.py @@ -0,0 +1,140 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +"""A RMS norm schedule rule for GPU operators.""" + +import tvm +from tvm import tir +from tvm.tir import Block, BufferStore +from tvm.tir.expr import Cast, BufferLoad, Call +from tvm.target import Target + +from ..base import ScheduleRule + + +def identify_cast_or_load_block(block: Block) -> bool: + if len(block.reads) != 1 or len(block.writes) != 1: + return False + + if not isinstance(block.body, BufferStore): + return False + store = block.body + + # check types + if isinstance(store.value, BufferLoad): + load = store.value + elif isinstance(store.value, Cast): + load = store.value.value + if not isinstance(load, BufferLoad): + return False + else: + return False + + # check indices + if len(load.indices) != len(store.indices): + return False + + for lhs, rhs in zip(load.indices, store.indices): + if not lhs.same_as(rhs): + return False + + return True + + +def identify_rsqrt_block(block: Block) -> bool: + if len(block.reads) != 1 or len(block.writes) != 1: + return False + + if not isinstance(block.body, BufferStore): + return False + store = block.body + + if not isinstance(store.value, Call): + return False + call = store.value + op = call.op + + return op == tvm.ir.op.Op.get("tir.rsqrt") + + +class RMSNorm(ScheduleRule): + """A rule for RMS norm.""" + + def apply( # pylint: disable=too-many-locals,missing-docstring + self, + func: tir.PrimFunc, + target: Target, + _: bool, + ) -> tir.Schedule: + if target.kind.name == "cuda": + num_tx = 512 + else: + num_tx = 64 + + sch = tir.Schedule(func) + root = sch.get_block(name="root", func_name="main") + + blocks = sch.get_child_blocks(root) + + if not any([identify_rsqrt_block(sch.get(block)) for block in blocks]): + return None + + read = sch.cache_read(block=blocks[0], read_buffer_index=0, storage_scope="local") + write = sch.cache_write(block=blocks[-1], write_buffer_index=0, storage_scope="local") + + for block in blocks: + if identify_cast_or_load_block(sch.get(block)): + sch.compute_inline(block) + + blocks = sch.get_child_blocks(root) + + read, sqr, redsum, rsqrt, norm, write = blocks + + if not identify_rsqrt_block(sch.get(rsqrt)): + return None + + for name in [read, sqr, redsum, rsqrt, norm, write]: + loops = sch.get_loops(name) + sch.fuse(*loops[:-1]) + + block_loop, loops = sch.get_loops(block=read) + thread_loop, _, _ = sch.split( + loop=loops, factors=[num_tx, None, 8], preserve_unit_iters=True + ) + sch.bind(block_loop, thread_axis="blockIdx.x") + sch.bind(thread_loop, thread_axis="threadIdx.x") + sch.vectorize(sch.get_loops(block=read)[-1]) + sch.reverse_compute_at(block=sqr, loop=thread_loop) + sch.reverse_compute_at(block=redsum, loop=thread_loop) + + sch.reverse_compute_at(block=rsqrt, loop=block_loop, index=-1) + sch.reverse_compute_at(block=norm, loop=block_loop, index=-1) + block_loop, loops = sch.get_loops(block=norm) + thread_loop, _, _ = sch.split( + loop=loops, factors=[num_tx, None, 8], preserve_unit_iters=True + ) + sch.bind(thread_loop, thread_axis="threadIdx.x") + + sch.reverse_compute_at(block=write, loop=thread_loop, index=-1) + sch.vectorize(sch.get_loops(block=write)[-1]) + + sch.set_scope(block=sqr, buffer_index=0, storage_scope="local") + sch.set_scope(block=redsum, buffer_index=0, storage_scope="local") + sch.set_scope(block=rsqrt, buffer_index=0, storage_scope="shared") + sch.set_scope(block=norm, buffer_index=0, storage_scope="local") + + return sch diff --git a/tests/python/dlight/test_gpu_rmsnorm.py b/tests/python/dlight/test_gpu_rmsnorm.py new file mode 100644 index 000000000000..301dac5c66ac --- /dev/null +++ b/tests/python/dlight/test_gpu_rmsnorm.py @@ -0,0 +1,287 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import tvm.testing + +from tvm.ir import IRModule, assert_structural_equal +from tvm import dlight as dl +from tvm.script import ir as I +from tvm.target import Target +from tvm.script import tir as T + + +def _check(mod_before: IRModule, mod_after: IRModule): + target = Target("nvidia/geforce-rtx-3090-ti") + with target: + mod = dl.ApplyDefaultSchedule( # pylint: disable=not-callable + dl.gpu.RMSNorm(), + )(mod_before) + assert_structural_equal(mod, mod_after) + + +def test_rms_norm_with_casting(): + # fmt: off + @I.ir_module + class Before: + @T.prim_func + def main(var_data: T.handle, weight: T.Buffer((4096,), "float16"), var_T_cast: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int32() + data = T.match_buffer(var_data, (1, n, 4096), "float16") + T_cast = T.match_buffer(var_T_cast, (1, n, 4096), "float16") + # with T.block("root"): + T_cast_1 = T.alloc_buffer((1, n, 4096)) + T_multiply = T.alloc_buffer((1, n, 4096)) + T_multiply_red = T.alloc_buffer((1, n)) + rsqrt = T.alloc_buffer((1, n)) + T_cast_2 = T.alloc_buffer((4096,)) + T_rms_norm = T.alloc_buffer((1, n, 4096)) + for ax0, ax1, ax2 in T.grid(1, n, 4096): + with T.block("T_cast"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(data[v_ax0, v_ax1, v_ax2]) + T.writes(T_cast_1[v_ax0, v_ax1, v_ax2]) + T_cast_1[v_ax0, v_ax1, v_ax2] = T.Cast("float32", data[v_ax0, v_ax1, v_ax2]) + for ax0, ax1, ax2 in T.grid(1, n, 4096): + with T.block("T_multiply"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(T_cast_1[v_ax0, v_ax1, v_ax2]) + T.writes(T_multiply[v_ax0, v_ax1, v_ax2]) + T_multiply[v_ax0, v_ax1, v_ax2] = T_cast_1[v_ax0, v_ax1, v_ax2] * T_cast_1[v_ax0, v_ax1, v_ax2] + for ax0, ax1, k2 in T.grid(1, n, 4096): + with T.block("T_multiply_red"): + v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2]) + T.reads(T_multiply[v_ax0, v_ax1, v_k2]) + T.writes(T_multiply_red[v_ax0, v_ax1]) + with T.init(): + T_multiply_red[v_ax0, v_ax1] = T.float32(0) + T_multiply_red[v_ax0, v_ax1] = T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0, v_ax1, v_k2] + for ax0, ax1 in T.grid(1, n): + with T.block("rsqrt"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(T_multiply_red[v_ax0, v_ax1]) + T.writes(rsqrt[v_ax0, v_ax1]) + rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07)) + for ax0 in range(4096): + with T.block("T_cast_1"): + v_ax0 = T.axis.spatial(4096, ax0) + T.reads(weight[v_ax0]) + T.writes(T_cast_2[v_ax0]) + T_cast_2[v_ax0] = T.Cast("float32", weight[v_ax0]) + for ax0, ax1, ax2 in T.grid(1, n, 4096): + with T.block("T_rms_norm"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(rsqrt[v_ax0, v_ax1], T_cast_1[v_ax0, v_ax1, v_ax2], T_cast_2[v_ax2]) + T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2]) + T_rms_norm[v_ax0, v_ax1, v_ax2] = rsqrt[v_ax0, v_ax1] * T_cast_1[v_ax0, v_ax1, v_ax2] * T_cast_2[v_ax2] + for ax0, ax1, ax2 in T.grid(1, n, 4096): + with T.block("T_cast_2"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(T_rms_norm[v_ax0, v_ax1, v_ax2]) + T.writes(T_cast[v_ax0, v_ax1, v_ax2]) + T_cast[v_ax0, v_ax1, v_ax2] = T.Cast("float16", T_rms_norm[v_ax0, v_ax1, v_ax2]) + + @I.ir_module + class After: + @T.prim_func + def main(var_data: T.handle, weight: T.Buffer((4096,), "float16"), var_T_cast: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + n = T.int32() + data = T.match_buffer(var_data, (1, n, 4096), "float16") + T_cast = T.match_buffer(var_T_cast, (1, n, 4096), "float16") + # with T.block("root"): + T_multiply_local = T.alloc_buffer((1, n, 4096), scope="local") + T_multiply_red_local = T.alloc_buffer((1, n), scope="local") + rsqrt_shared = T.alloc_buffer((1, n), scope="shared") + T_rms_norm_local = T.alloc_buffer((1, n, 4096), scope="local") + data_local = T.alloc_buffer((1, n, 4096), "float16", scope="local") + for ax0_ax1_fused in T.thread_binding(n, thread="blockIdx.x"): + for ax2_0 in T.thread_binding(512, thread="threadIdx.x"): + for ax2_1 in range(1): + for ax2_2 in T.vectorized(8): + with T.block("data_local"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(n, ax0_ax1_fused) + v2 = T.axis.spatial(4096, ax2_0 * 8 + ax2_1 * 8 + ax2_2) + T.reads(data[v0, v1, v2]) + T.writes(data_local[v0, v1, v2]) + data_local[v0, v1, v2] = data[v0, v1, v2] + for ax0 in range(8): + with T.block("T_multiply"): + v_ax0 = T.axis.spatial(1, 0) + v_ax1 = T.axis.spatial(n, ax0_ax1_fused) + v_ax2 = T.axis.spatial(4096, ax2_0 * 8 + ax0) + T.reads(data_local[v_ax0, v_ax1, v_ax2]) + T.writes(T_multiply_local[v_ax0, v_ax1, v_ax2]) + T_multiply_local[v_ax0, v_ax1, v_ax2] = T.Cast("float32", data_local[v_ax0, v_ax1, v_ax2]) * T.Cast("float32", data_local[v_ax0, v_ax1, v_ax2]) + for ax0 in range(8): + with T.block("T_multiply_red"): + v_ax0 = T.axis.spatial(1, 0) + v_ax1 = T.axis.spatial(n, ax0_ax1_fused) + v_k2 = T.axis.reduce(4096, ax2_0 * 8 + ax0) + T.reads(T_multiply_local[v_ax0, v_ax1, v_k2]) + T.writes(T_multiply_red_local[v_ax0, v_ax1]) + with T.init(): + T_multiply_red_local[v_ax0, v_ax1] = T.float32(0) + T_multiply_red_local[v_ax0, v_ax1] = T_multiply_red_local[v_ax0, v_ax1] + T_multiply_local[v_ax0, v_ax1, v_k2] + with T.block("rsqrt"): + v_ax0 = T.axis.spatial(1, 0) + v_ax1 = T.axis.spatial(n, ax0_ax1_fused) + T.reads(T_multiply_red_local[v_ax0, v_ax1]) + T.writes(rsqrt_shared[v_ax0, v_ax1]) + rsqrt_shared[v_ax0, v_ax1] = T.rsqrt(T_multiply_red_local[v_ax0, v_ax1] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07)) + for ax0_0 in T.thread_binding(512, thread="threadIdx.x"): + for ax0_1, ax0_2 in T.grid(1, 8): + with T.block("T_rms_norm"): + v_ax0 = T.axis.spatial(1, 0) + v_ax1 = T.axis.spatial(n, ax0_ax1_fused) + v_ax2 = T.axis.spatial(4096, ax0_0 * 8 + ax0_1 * 8 + ax0_2) + T.reads(rsqrt_shared[v_ax0, v_ax1], data_local[v_ax0, v_ax1, v_ax2], weight[v_ax2]) + T.writes(T_rms_norm_local[v_ax0, v_ax1, v_ax2]) + T_rms_norm_local[v_ax0, v_ax1, v_ax2] = rsqrt_shared[v_ax0, v_ax1] * T.Cast("float32", data_local[v_ax0, v_ax1, v_ax2]) * T.Cast("float32", weight[v_ax2]) + for ax0 in T.vectorized(8): + with T.block("T_cast_local"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(n, ax0_ax1_fused) + v2 = T.axis.spatial(4096, ax0_0 * 8 + ax0) + T.reads(T_rms_norm_local[v0, v1, v2]) + T.writes(T_cast[v0, v1, v2]) + T_cast[v0, v1, v2] = T.Cast("float16", T_rms_norm_local[v0, v1, v2]) + # fmt: on + _check(Before, After) + + +def test_rms_norm_without_casting(): + # fmt: off + @I.ir_module + class Before: + @T.prim_func + def main(var_data: T.handle, weight: T.Buffer((4096,), "float32"), var_T_cast: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int32() + data = T.match_buffer(var_data, (1, n, 4096)) + T_cast = T.match_buffer(var_T_cast, (1, n, 4096)) + # with T.block("root"): + T_multiply = T.alloc_buffer((1, n, 4096)) + T_multiply_red = T.alloc_buffer((1, n)) + rsqrt = T.alloc_buffer((1, n)) + T_rms_norm = T.alloc_buffer((1, n, 4096)) + for ax0, ax1, ax2 in T.grid(1, n, 4096): + with T.block("T_multiply"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(data[v_ax0, v_ax1, v_ax2]) + T.writes(T_multiply[v_ax0, v_ax1, v_ax2]) + T_multiply[v_ax0, v_ax1, v_ax2] = data[v_ax0, v_ax1, v_ax2] * data[v_ax0, v_ax1, v_ax2] + for ax0, ax1, k2 in T.grid(1, n, 4096): + with T.block("T_multiply_red"): + v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2]) + T.reads(T_multiply[v_ax0, v_ax1, v_k2]) + T.writes(T_multiply_red[v_ax0, v_ax1]) + with T.init(): + T_multiply_red[v_ax0, v_ax1] = T.float32(0) + T_multiply_red[v_ax0, v_ax1] = T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0, v_ax1, v_k2] + for ax0, ax1 in T.grid(1, n): + with T.block("rsqrt"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(T_multiply_red[v_ax0, v_ax1]) + T.writes(rsqrt[v_ax0, v_ax1]) + rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07)) + for ax0, ax1, ax2 in T.grid(1, n, 4096): + with T.block("T_rms_norm"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(rsqrt[v_ax0, v_ax1], data[v_ax0, v_ax1, v_ax2], weight[v_ax2]) + T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2]) + T_rms_norm[v_ax0, v_ax1, v_ax2] = rsqrt[v_ax0, v_ax1] * data[v_ax0, v_ax1, v_ax2] * weight[v_ax2] + for ax0, ax1, ax2 in T.grid(1, n, 4096): + with T.block("T_cast_2"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(T_rms_norm[v_ax0, v_ax1, v_ax2]) + T.writes(T_cast[v_ax0, v_ax1, v_ax2]) + T_cast[v_ax0, v_ax1, v_ax2] = T_rms_norm[v_ax0, v_ax1, v_ax2] + + @I.ir_module + class After: + @T.prim_func + def main(var_data: T.handle, weight: T.Buffer((4096,), "float32"), var_T_cast: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + n = T.int32() + data = T.match_buffer(var_data, (1, n, 4096)) + T_cast = T.match_buffer(var_T_cast, (1, n, 4096)) + # with T.block("root"): + T_multiply_local = T.alloc_buffer((1, n, 4096), scope="local") + T_multiply_red_local = T.alloc_buffer((1, n), scope="local") + rsqrt_shared = T.alloc_buffer((1, n), scope="shared") + T_rms_norm_local = T.alloc_buffer((1, n, 4096), scope="local") + data_local = T.alloc_buffer((1, n, 4096), scope="local") + for ax0_ax1_fused in T.thread_binding(n, thread="blockIdx.x"): + for ax2_0 in T.thread_binding(512, thread="threadIdx.x"): + for ax2_1 in range(1): + for ax2_2 in T.vectorized(8): + with T.block("data_local"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(n, ax0_ax1_fused) + v2 = T.axis.spatial(4096, ax2_0 * 8 + ax2_1 * 8 + ax2_2) + T.reads(data[v0, v1, v2]) + T.writes(data_local[v0, v1, v2]) + data_local[v0, v1, v2] = data[v0, v1, v2] + for ax0 in range(8): + with T.block("T_multiply"): + v_ax0 = T.axis.spatial(1, 0) + v_ax1 = T.axis.spatial(n, ax0_ax1_fused) + v_ax2 = T.axis.spatial(4096, ax2_0 * 8 + ax0) + T.reads(data_local[v_ax0, v_ax1, v_ax2]) + T.writes(T_multiply_local[v_ax0, v_ax1, v_ax2]) + T_multiply_local[v_ax0, v_ax1, v_ax2] = data_local[v_ax0, v_ax1, v_ax2] * data_local[v_ax0, v_ax1, v_ax2] + for ax0 in range(8): + with T.block("T_multiply_red"): + v_ax0 = T.axis.spatial(1, 0) + v_ax1 = T.axis.spatial(n, ax0_ax1_fused) + v_k2 = T.axis.reduce(4096, ax2_0 * 8 + ax0) + T.reads(T_multiply_local[v_ax0, v_ax1, v_k2]) + T.writes(T_multiply_red_local[v_ax0, v_ax1]) + with T.init(): + T_multiply_red_local[v_ax0, v_ax1] = T.float32(0) + T_multiply_red_local[v_ax0, v_ax1] = T_multiply_red_local[v_ax0, v_ax1] + T_multiply_local[v_ax0, v_ax1, v_k2] + with T.block("rsqrt"): + v_ax0 = T.axis.spatial(1, 0) + v_ax1 = T.axis.spatial(n, ax0_ax1_fused) + T.reads(T_multiply_red_local[v_ax0, v_ax1]) + T.writes(rsqrt_shared[v_ax0, v_ax1]) + rsqrt_shared[v_ax0, v_ax1] = T.rsqrt(T_multiply_red_local[v_ax0, v_ax1] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07)) + for ax0_0 in T.thread_binding(512, thread="threadIdx.x"): + for ax0_1, ax0_2 in T.grid(1, 8): + with T.block("T_rms_norm"): + v_ax0 = T.axis.spatial(1, 0) + v_ax1 = T.axis.spatial(n, ax0_ax1_fused) + v_ax2 = T.axis.spatial(4096, ax0_0 * 8 + ax0_1 * 8 + ax0_2) + T.reads(rsqrt_shared[v_ax0, v_ax1], data_local[v_ax0, v_ax1, v_ax2], weight[v_ax2]) + T.writes(T_rms_norm_local[v_ax0, v_ax1, v_ax2]) + T_rms_norm_local[v_ax0, v_ax1, v_ax2] = rsqrt_shared[v_ax0, v_ax1] * data_local[v_ax0, v_ax1, v_ax2] * weight[v_ax2] + for ax0 in T.vectorized(8): + with T.block("T_cast_local"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(n, ax0_ax1_fused) + v2 = T.axis.spatial(4096, ax0_0 * 8 + ax0) + T.reads(T_rms_norm_local[v0, v1, v2]) + T.writes(T_cast[v0, v1, v2]) + T_cast[v0, v1, v2] = T_rms_norm_local[v0, v1, v2] + # fmt: on + _check(Before, After) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py index 74da77f7d8c5..07fbc3419b98 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn.py +++ b/tests/python/relax/test_transform_legalize_ops_nn.py @@ -2773,9 +2773,10 @@ def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "floa T.func_attr({"tir.noalias": T.bool(True)}) # with T.block("root"): T_cast_1 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5))) - T_cast_2 = T.alloc_buffer((T.int64(4), T.int64(5))) T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5))) T_multiply_red = T.alloc_buffer((T.int64(2), T.int64(3))) + rsqrt = T.alloc_buffer((T.int64(2), T.int64(3))) + T_cast_2 = T.alloc_buffer((T.int64(4), T.int64(5))) T_rms_norm = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5))) for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): with T.block("T_cast"): @@ -2783,12 +2784,6 @@ def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "floa T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3]) T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax1, v_ax2, v_ax3] - for ax0, ax1 in T.grid(T.int64(4), T.int64(5)): - with T.block("T_cast_1"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(B[v_ax0, v_ax1]) - T.writes(T_cast_2[v_ax0, v_ax1]) - T_cast_2[v_ax0, v_ax1] = B[v_ax0, v_ax1] for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): with T.block("T_multiply"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) @@ -2803,12 +2798,24 @@ def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "floa with T.init(): T_multiply_red[v_ax0, v_ax1] = T.float32(0) T_multiply_red[v_ax0, v_ax1] = T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0, v_ax1, v_k2, v_k3] + for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): + with T.block("rsqrt"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(T_multiply_red[v_ax0, v_ax1]) + T.writes(rsqrt[v_ax0, v_ax1]) + rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1] * T.float32(0.050000000000000003) + T.float32(1.0000000000000001e-05)) + for ax0, ax1 in T.grid(T.int64(4), T.int64(5)): + with T.block("T_cast_1"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(B[v_ax0, v_ax1]) + T.writes(T_cast_2[v_ax0, v_ax1]) + T_cast_2[v_ax0, v_ax1] = B[v_ax0, v_ax1] for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): with T.block("T_rms_norm"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3], T_cast_2[v_ax2, v_ax3], T_multiply_red[v_ax0, v_ax1]) + T.reads(rsqrt[v_ax0, v_ax1], T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3], T_cast_2[v_ax2, v_ax3]) T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3]) - T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_cast_2[v_ax2, v_ax3] * T.rsqrt(T_multiply_red[v_ax0, v_ax1] * T.float32(0.050000000000000003) + T.float32(1.0000000000000001e-05)) + T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = rsqrt[v_ax0, v_ax1] * T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_cast_2[v_ax2, v_ax3] for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): with T.block("T_cast_2"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) @@ -2842,9 +2849,10 @@ def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "floa T.func_attr({"tir.noalias": T.bool(True)}) # with T.block("root"): T_cast_1 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5))) - T_cast_2 = T.alloc_buffer((T.int64(4), T.int64(5))) T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5))) T_multiply_red = T.alloc_buffer((T.int64(2), T.int64(3))) + rsqrt = T.alloc_buffer((T.int64(2), T.int64(3))) + T_cast_2 = T.alloc_buffer((T.int64(4), T.int64(5))) T_rms_norm = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5))) for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): with T.block("T_cast"): @@ -2852,12 +2860,6 @@ def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "floa T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3]) T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] = T.Cast("float32", A[v_ax0, v_ax1, v_ax2, v_ax3]) - for ax0, ax1 in T.grid(T.int64(4), T.int64(5)): - with T.block("T_cast_1"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(B[v_ax0, v_ax1]) - T.writes(T_cast_2[v_ax0, v_ax1]) - T_cast_2[v_ax0, v_ax1] = T.Cast("float32", B[v_ax0, v_ax1]) for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): with T.block("T_multiply"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) @@ -2872,12 +2874,24 @@ def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "floa with T.init(): T_multiply_red[v_ax0, v_ax1] = T.float32(0) T_multiply_red[v_ax0, v_ax1] = T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0, v_ax1, v_k2, v_k3] + for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): + with T.block("rsqrt"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(T_multiply_red[v_ax0, v_ax1]) + T.writes(rsqrt[v_ax0, v_ax1]) + rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1] * T.float32(0.050000000000000003) + T.float32(1.0000000000000001e-05)) + for ax0, ax1 in T.grid(T.int64(4), T.int64(5)): + with T.block("T_cast_1"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(B[v_ax0, v_ax1]) + T.writes(T_cast_2[v_ax0, v_ax1]) + T_cast_2[v_ax0, v_ax1] = T.Cast("float32", B[v_ax0, v_ax1]) for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): with T.block("T_rms_norm"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3], T_cast_2[v_ax2, v_ax3], T_multiply_red[v_ax0, v_ax1]) + T.reads(rsqrt[v_ax0, v_ax1], T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3], T_cast_2[v_ax2, v_ax3]) T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3]) - T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_cast_2[v_ax2, v_ax3] * T.rsqrt(T_multiply_red[v_ax0, v_ax1] * T.float32(0.050000000000000003) + T.float32(1.0000000000000001e-05)) + T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = rsqrt[v_ax0, v_ax1] * T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_cast_2[v_ax2, v_ax3] for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): with T.block("T_cast_2"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) @@ -2918,9 +2932,10 @@ def rms_norm(var_A: T.handle, var_B: T.handle, var_T_cast: T.handle): T_cast = T.match_buffer(var_T_cast, (n, s, f)) # with T.block("root"): T_cast_1 = T.alloc_buffer((n, s, f)) - T_cast_2 = T.alloc_buffer((s, f)) T_multiply = T.alloc_buffer((n, s, f)) T_multiply_red = T.alloc_buffer((n,)) + rsqrt = T.alloc_buffer((n,)) + T_cast_2 = T.alloc_buffer((s, f)) T_rms_norm = T.alloc_buffer((n, s, f)) for ax0, ax1, ax2 in T.grid(n, s, f): with T.block("T_cast"): @@ -2928,12 +2943,6 @@ def rms_norm(var_A: T.handle, var_B: T.handle, var_T_cast: T.handle): T.reads(A[v_ax0, v_ax1, v_ax2]) T.writes(T_cast_1[v_ax0, v_ax1, v_ax2]) T_cast_1[v_ax0, v_ax1, v_ax2] = A[v_ax0, v_ax1, v_ax2] - for ax0, ax1 in T.grid(s, f): - with T.block("T_cast_1"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(B[v_ax0, v_ax1]) - T.writes(T_cast_2[v_ax0, v_ax1]) - T_cast_2[v_ax0, v_ax1] = B[v_ax0, v_ax1] for ax0, ax1, ax2 in T.grid(n, s, f): with T.block("T_multiply"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) @@ -2948,12 +2957,24 @@ def rms_norm(var_A: T.handle, var_B: T.handle, var_T_cast: T.handle): with T.init(): T_multiply_red[v_ax0] = T.float32(0) T_multiply_red[v_ax0] = T_multiply_red[v_ax0] + T_multiply[v_ax0, v_k1, v_k2] + for ax0 in range(n): + with T.block("rsqrt"): + v_ax0 = T.axis.spatial(n, ax0) + T.reads(T_multiply_red[v_ax0]) + T.writes(rsqrt[v_ax0]) + rsqrt[v_ax0] = T.rsqrt(T_multiply_red[v_ax0] / (T.Cast("float32", s) * T.Cast("float32", f)) + T.float32(1.0000000000000001e-05)) + for ax0, ax1 in T.grid(s, f): + with T.block("T_cast_1"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(B[v_ax0, v_ax1]) + T.writes(T_cast_2[v_ax0, v_ax1]) + T_cast_2[v_ax0, v_ax1] = B[v_ax0, v_ax1] for ax0, ax1, ax2 in T.grid(n, s, f): with T.block("T_rms_norm"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(T_cast_1[v_ax0, v_ax1, v_ax2], T_cast_2[v_ax1, v_ax2], T_multiply_red[v_ax0]) + T.reads(rsqrt[v_ax0], T_cast_1[v_ax0, v_ax1, v_ax2], T_cast_2[v_ax1, v_ax2]) T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2]) - T_rms_norm[v_ax0, v_ax1, v_ax2] = T_cast_1[v_ax0, v_ax1, v_ax2] * T_cast_2[v_ax1, v_ax2] * T.rsqrt(T_multiply_red[v_ax0] / (T.Cast("float32", s) * T.Cast("float32", f)) + T.float32(1.0000000000000001e-05)) + T_rms_norm[v_ax0, v_ax1, v_ax2] = rsqrt[v_ax0] * T_cast_1[v_ax0, v_ax1, v_ax2] * T_cast_2[v_ax1, v_ax2] for ax0, ax1, ax2 in T.grid(n, s, f): with T.block("T_cast_2"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) @@ -2990,9 +3011,10 @@ def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "floa T.func_attr({"tir.noalias": T.bool(True)}) # with T.block("root"): T_cast_1 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5))) - T_cast_2 = T.alloc_buffer((T.int64(4), T.int64(5))) T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5))) T_multiply_red = T.alloc_buffer((T.int64(2), T.int64(3))) + rsqrt = T.alloc_buffer((T.int64(2), T.int64(3))) + T_cast_2 = T.alloc_buffer((T.int64(4), T.int64(5))) T_rms_norm = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5))) for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): with T.block("T_cast"): @@ -3000,12 +3022,6 @@ def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "floa T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3]) T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax1, v_ax2, v_ax3] - for ax0, ax1 in T.grid(T.int64(4), T.int64(5)): - with T.block("T_cast_1"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(B[v_ax0, v_ax1]) - T.writes(T_cast_2[v_ax0, v_ax1]) - T_cast_2[v_ax0, v_ax1] = B[v_ax0, v_ax1] for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): with T.block("T_multiply"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) @@ -3020,12 +3036,24 @@ def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "floa with T.init(): T_multiply_red[v_ax0, v_ax1] = T.float32(0) T_multiply_red[v_ax0, v_ax1] = T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0, v_ax1, v_k2, v_k3] + for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): + with T.block("rsqrt"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(T_multiply_red[v_ax0, v_ax1]) + T.writes(rsqrt[v_ax0, v_ax1]) + rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1] * T.float32(0.050000000000000003) + T.float32(1.0000000000000001e-05)) + for ax0, ax1 in T.grid(T.int64(4), T.int64(5)): + with T.block("T_cast_1"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(B[v_ax0, v_ax1]) + T.writes(T_cast_2[v_ax0, v_ax1]) + T_cast_2[v_ax0, v_ax1] = B[v_ax0, v_ax1] for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): with T.block("T_rms_norm"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3], T_cast_2[v_ax2, v_ax3], T_multiply_red[v_ax0, v_ax1]) + T.reads(rsqrt[v_ax0, v_ax1], T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3], T_cast_2[v_ax2, v_ax3]) T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3]) - T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_cast_2[v_ax2, v_ax3] * T.rsqrt(T_multiply_red[v_ax0, v_ax1] * T.float32(0.050000000000000003) + T.float32(1.0000000000000001e-05)) + T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = rsqrt[v_ax0, v_ax1] * T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_cast_2[v_ax2, v_ax3] for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): with T.block("T_cast_2"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])