diff --git a/include/tvm/topi/nn/rms_norm.h b/include/tvm/topi/nn/rms_norm.h
index ba2f7e49ac98..7e95000f1ee2 100644
--- a/include/tvm/topi/nn/rms_norm.h
+++ b/include/tvm/topi/nn/rms_norm.h
@@ -67,6 +67,25 @@ inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const Array<Int
   for (int i : real_axis) {
     reduce_extent *= data_fp32->shape[i];
   }
+  auto rsqrt_func = [&](const Array<Var>& indices) {
+    Array<Var> non_reduce_indices;
+    for (int i = 0, n = static_cast<int>(indices.size()); i < n; ++i) {
+      if (std::find(real_axis.begin(), real_axis.end(), i) == real_axis.end()) {
+        non_reduce_indices.push_back(indices[i]);
+      }
+    }
+    auto output =
+        tvm::rsqrt(square_sum(non_reduce_indices) / reduce_extent + make_const(data_type, epsilon));
+    return output;
+  };
+  auto rsqrt_shape = Array<PrimExpr>();
+  for (int i = 0, n = static_cast<int>(data_fp32->shape.size()); i < n; ++i) {
+    if (std::find(real_axis.begin(), real_axis.end(), i) == real_axis.end()) {
+      rsqrt_shape.push_back(data_fp32->shape[i]);
+    }
+  }
+  auto rsqrt = tvm::te::compute(rsqrt_shape, rsqrt_func, "rsqrt", tag);
+
   auto rms_norm_func = [&](const Array<Var>& indices) {
     Array<Var> reduce_indices, non_reduce_indices;
     for (int i = 0, n = static_cast<int>(indices.size()); i < n; ++i) {
@@ -76,12 +95,11 @@ inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const Array<Int
         non_reduce_indices.push_back(indices[i]);
       }
     }
-    auto output =
-        data_fp32(indices) * weight_fp32(reduce_indices) *
-        tvm::rsqrt(square_sum(non_reduce_indices) / reduce_extent + make_const(data_type, epsilon));
+    auto output = rsqrt(non_reduce_indices) * data_fp32(indices) * weight_fp32(reduce_indices);
     return output;
   };
   auto rms_norm = tvm::te::compute(data_fp32->shape, rms_norm_func, name, tag);
+
   return cast(rms_norm, data_type);
 }
 
diff --git a/python/tvm/dlight/gpu/__init__.py b/python/tvm/dlight/gpu/__init__.py
index f48bdb2c8182..7db383a161cd 100644
--- a/python/tvm/dlight/gpu/__init__.py
+++ b/python/tvm/dlight/gpu/__init__.py
@@ -24,3 +24,4 @@
 from .reduction import Reduction
 from .transpose import Transpose
 from .general_reduction import GeneralReduction
+from .rmsnorm import RMSNorm
diff --git a/python/tvm/dlight/gpu/rmsnorm.py b/python/tvm/dlight/gpu/rmsnorm.py
new file mode 100644
index 000000000000..f8b2bb4a172d
--- /dev/null
+++ b/python/tvm/dlight/gpu/rmsnorm.py
@@ -0,0 +1,140 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+"""A RMS norm schedule rule for GPU operators."""
+
+import tvm
+from tvm import tir
+from tvm.tir import Block, BufferStore
+from tvm.tir.expr import Cast, BufferLoad, Call
+from tvm.target import Target
+
+from ..base import ScheduleRule
+
+
+def identify_cast_or_load_block(block: Block) -> bool:
+    if len(block.reads) != 1 or len(block.writes) != 1:
+        return False
+
+    if not isinstance(block.body, BufferStore):
+        return False
+    store = block.body
+
+    # check types
+    if isinstance(store.value, BufferLoad):
+        load = store.value
+    elif isinstance(store.value, Cast):
+        load = store.value.value
+        if not isinstance(load, BufferLoad):
+            return False
+    else:
+        return False
+
+    # check indices
+    if len(load.indices) != len(store.indices):
+        return False
+
+    for lhs, rhs in zip(load.indices, store.indices):
+        if not lhs.same_as(rhs):
+            return False
+
+    return True
+
+
+def identify_rsqrt_block(block: Block) -> bool:
+    if len(block.reads) != 1 or len(block.writes) != 1:
+        return False
+
+    if not isinstance(block.body, BufferStore):
+        return False
+    store = block.body
+
+    if not isinstance(store.value, Call):
+        return False
+    call = store.value
+    op = call.op
+
+    return op == tvm.ir.op.Op.get("tir.rsqrt")
+
+
+class RMSNorm(ScheduleRule):
+    """A rule for RMS norm."""
+
+    def apply(  # pylint: disable=too-many-locals,missing-docstring
+        self,
+        func: tir.PrimFunc,
+        target: Target,
+        _: bool,
+    ) -> tir.Schedule:
+        if target.kind.name == "cuda":
+            num_tx = 512
+        else:
+            num_tx = 64
+
+        sch = tir.Schedule(func)
+        root = sch.get_block(name="root", func_name="main")
+
+        blocks = sch.get_child_blocks(root)
+
+        if not any([identify_rsqrt_block(sch.get(block)) for block in blocks]):
+            return None
+
+        read = sch.cache_read(block=blocks[0], read_buffer_index=0, storage_scope="local")
+        write = sch.cache_write(block=blocks[-1], write_buffer_index=0, storage_scope="local")
+
+        for block in blocks:
+            if identify_cast_or_load_block(sch.get(block)):
+                sch.compute_inline(block)
+
+        blocks = sch.get_child_blocks(root)
+
+        read, sqr, redsum, rsqrt, norm, write = blocks
+
+        if not identify_rsqrt_block(sch.get(rsqrt)):
+            return None
+
+        for name in [read, sqr, redsum, rsqrt, norm, write]:
+            loops = sch.get_loops(name)
+            sch.fuse(*loops[:-1])
+
+        block_loop, loops = sch.get_loops(block=read)
+        thread_loop, _, _ = sch.split(
+            loop=loops, factors=[num_tx, None, 8], preserve_unit_iters=True
+        )
+        sch.bind(block_loop, thread_axis="blockIdx.x")
+        sch.bind(thread_loop, thread_axis="threadIdx.x")
+        sch.vectorize(sch.get_loops(block=read)[-1])
+        sch.reverse_compute_at(block=sqr, loop=thread_loop)
+        sch.reverse_compute_at(block=redsum, loop=thread_loop)
+
+        sch.reverse_compute_at(block=rsqrt, loop=block_loop, index=-1)
+        sch.reverse_compute_at(block=norm, loop=block_loop, index=-1)
+        block_loop, loops = sch.get_loops(block=norm)
+        thread_loop, _, _ = sch.split(
+            loop=loops, factors=[num_tx, None, 8], preserve_unit_iters=True
+        )
+        sch.bind(thread_loop, thread_axis="threadIdx.x")
+
+        sch.reverse_compute_at(block=write, loop=thread_loop, index=-1)
+        sch.vectorize(sch.get_loops(block=write)[-1])
+
+        sch.set_scope(block=sqr, buffer_index=0, storage_scope="local")
+        sch.set_scope(block=redsum, buffer_index=0, storage_scope="local")
+        sch.set_scope(block=rsqrt, buffer_index=0, storage_scope="shared")
+        sch.set_scope(block=norm, buffer_index=0, storage_scope="local")
+
+        return sch
diff --git a/tests/python/dlight/test_gpu_rmsnorm.py b/tests/python/dlight/test_gpu_rmsnorm.py
new file mode 100644
index 000000000000..301dac5c66ac
--- /dev/null
+++ b/tests/python/dlight/test_gpu_rmsnorm.py
@@ -0,0 +1,287 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import tvm.testing
+
+from tvm.ir import IRModule, assert_structural_equal
+from tvm import dlight as dl
+from tvm.script import ir as I
+from tvm.target import Target
+from tvm.script import tir as T
+
+
+def _check(mod_before: IRModule, mod_after: IRModule):
+    target = Target("nvidia/geforce-rtx-3090-ti")
+    with target:
+        mod = dl.ApplyDefaultSchedule(  # pylint: disable=not-callable
+            dl.gpu.RMSNorm(),
+        )(mod_before)
+    assert_structural_equal(mod, mod_after)
+
+
+def test_rms_norm_with_casting():
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def main(var_data: T.handle, weight: T.Buffer((4096,), "float16"), var_T_cast: T.handle):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            n = T.int32()
+            data = T.match_buffer(var_data, (1, n, 4096), "float16")
+            T_cast = T.match_buffer(var_T_cast, (1, n, 4096), "float16")
+            # with T.block("root"):
+            T_cast_1 = T.alloc_buffer((1, n, 4096))
+            T_multiply = T.alloc_buffer((1, n, 4096))
+            T_multiply_red = T.alloc_buffer((1, n))
+            rsqrt = T.alloc_buffer((1, n))
+            T_cast_2 = T.alloc_buffer((4096,))
+            T_rms_norm = T.alloc_buffer((1, n, 4096))
+            for ax0, ax1, ax2 in T.grid(1, n, 4096):
+                with T.block("T_cast"):
+                    v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                    T.reads(data[v_ax0, v_ax1, v_ax2])
+                    T.writes(T_cast_1[v_ax0, v_ax1, v_ax2])
+                    T_cast_1[v_ax0, v_ax1, v_ax2] = T.Cast("float32", data[v_ax0, v_ax1, v_ax2])
+            for ax0, ax1, ax2 in T.grid(1, n, 4096):
+                with T.block("T_multiply"):
+                    v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                    T.reads(T_cast_1[v_ax0, v_ax1, v_ax2])
+                    T.writes(T_multiply[v_ax0, v_ax1, v_ax2])
+                    T_multiply[v_ax0, v_ax1, v_ax2] = T_cast_1[v_ax0, v_ax1, v_ax2] * T_cast_1[v_ax0, v_ax1, v_ax2]
+            for ax0, ax1, k2 in T.grid(1, n, 4096):
+                with T.block("T_multiply_red"):
+                    v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2])
+                    T.reads(T_multiply[v_ax0, v_ax1, v_k2])
+                    T.writes(T_multiply_red[v_ax0, v_ax1])
+                    with T.init():
+                        T_multiply_red[v_ax0, v_ax1] = T.float32(0)
+                    T_multiply_red[v_ax0, v_ax1] = T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0, v_ax1, v_k2]
+            for ax0, ax1 in T.grid(1, n):
+                with T.block("rsqrt"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(T_multiply_red[v_ax0, v_ax1])
+                    T.writes(rsqrt[v_ax0, v_ax1])
+                    rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07))
+            for ax0 in range(4096):
+                with T.block("T_cast_1"):
+                    v_ax0 = T.axis.spatial(4096, ax0)
+                    T.reads(weight[v_ax0])
+                    T.writes(T_cast_2[v_ax0])
+                    T_cast_2[v_ax0] = T.Cast("float32", weight[v_ax0])
+            for ax0, ax1, ax2 in T.grid(1, n, 4096):
+                with T.block("T_rms_norm"):
+                    v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                    T.reads(rsqrt[v_ax0, v_ax1], T_cast_1[v_ax0, v_ax1, v_ax2], T_cast_2[v_ax2])
+                    T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2])
+                    T_rms_norm[v_ax0, v_ax1, v_ax2] = rsqrt[v_ax0, v_ax1] * T_cast_1[v_ax0, v_ax1, v_ax2] * T_cast_2[v_ax2]
+            for ax0, ax1, ax2 in T.grid(1, n, 4096):
+                with T.block("T_cast_2"):
+                    v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                    T.reads(T_rms_norm[v_ax0, v_ax1, v_ax2])
+                    T.writes(T_cast[v_ax0, v_ax1, v_ax2])
+                    T_cast[v_ax0, v_ax1, v_ax2] = T.Cast("float16", T_rms_norm[v_ax0, v_ax1, v_ax2])
+
+    @I.ir_module
+    class After:
+        @T.prim_func
+        def main(var_data: T.handle, weight: T.Buffer((4096,), "float16"), var_T_cast: T.handle):
+            T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
+            n = T.int32()
+            data = T.match_buffer(var_data, (1, n, 4096), "float16")
+            T_cast = T.match_buffer(var_T_cast, (1, n, 4096), "float16")
+            # with T.block("root"):
+            T_multiply_local = T.alloc_buffer((1, n, 4096), scope="local")
+            T_multiply_red_local = T.alloc_buffer((1, n), scope="local")
+            rsqrt_shared = T.alloc_buffer((1, n), scope="shared")
+            T_rms_norm_local = T.alloc_buffer((1, n, 4096), scope="local")
+            data_local = T.alloc_buffer((1, n, 4096), "float16", scope="local")
+            for ax0_ax1_fused in T.thread_binding(n, thread="blockIdx.x"):
+                for ax2_0 in T.thread_binding(512, thread="threadIdx.x"):
+                    for ax2_1 in range(1):
+                        for ax2_2 in T.vectorized(8):
+                            with T.block("data_local"):
+                                v0 = T.axis.spatial(1, 0)
+                                v1 = T.axis.spatial(n, ax0_ax1_fused)
+                                v2 = T.axis.spatial(4096, ax2_0 * 8 + ax2_1 * 8 + ax2_2)
+                                T.reads(data[v0, v1, v2])
+                                T.writes(data_local[v0, v1, v2])
+                                data_local[v0, v1, v2] = data[v0, v1, v2]
+                    for ax0 in range(8):
+                        with T.block("T_multiply"):
+                            v_ax0 = T.axis.spatial(1, 0)
+                            v_ax1 = T.axis.spatial(n, ax0_ax1_fused)
+                            v_ax2 = T.axis.spatial(4096, ax2_0 * 8 + ax0)
+                            T.reads(data_local[v_ax0, v_ax1, v_ax2])
+                            T.writes(T_multiply_local[v_ax0, v_ax1, v_ax2])
+                            T_multiply_local[v_ax0, v_ax1, v_ax2] = T.Cast("float32", data_local[v_ax0, v_ax1, v_ax2]) * T.Cast("float32", data_local[v_ax0, v_ax1, v_ax2])
+                    for ax0 in range(8):
+                        with T.block("T_multiply_red"):
+                            v_ax0 = T.axis.spatial(1, 0)
+                            v_ax1 = T.axis.spatial(n, ax0_ax1_fused)
+                            v_k2 = T.axis.reduce(4096, ax2_0 * 8 + ax0)
+                            T.reads(T_multiply_local[v_ax0, v_ax1, v_k2])
+                            T.writes(T_multiply_red_local[v_ax0, v_ax1])
+                            with T.init():
+                                T_multiply_red_local[v_ax0, v_ax1] = T.float32(0)
+                            T_multiply_red_local[v_ax0, v_ax1] = T_multiply_red_local[v_ax0, v_ax1] + T_multiply_local[v_ax0, v_ax1, v_k2]
+                with T.block("rsqrt"):
+                    v_ax0 = T.axis.spatial(1, 0)
+                    v_ax1 = T.axis.spatial(n, ax0_ax1_fused)
+                    T.reads(T_multiply_red_local[v_ax0, v_ax1])
+                    T.writes(rsqrt_shared[v_ax0, v_ax1])
+                    rsqrt_shared[v_ax0, v_ax1] = T.rsqrt(T_multiply_red_local[v_ax0, v_ax1] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07))
+                for ax0_0 in T.thread_binding(512, thread="threadIdx.x"):
+                    for ax0_1, ax0_2 in T.grid(1, 8):
+                        with T.block("T_rms_norm"):
+                            v_ax0 = T.axis.spatial(1, 0)
+                            v_ax1 = T.axis.spatial(n, ax0_ax1_fused)
+                            v_ax2 = T.axis.spatial(4096, ax0_0 * 8 + ax0_1 * 8 + ax0_2)
+                            T.reads(rsqrt_shared[v_ax0, v_ax1], data_local[v_ax0, v_ax1, v_ax2], weight[v_ax2])
+                            T.writes(T_rms_norm_local[v_ax0, v_ax1, v_ax2])
+                            T_rms_norm_local[v_ax0, v_ax1, v_ax2] = rsqrt_shared[v_ax0, v_ax1] * T.Cast("float32", data_local[v_ax0, v_ax1, v_ax2]) * T.Cast("float32", weight[v_ax2])
+                    for ax0 in T.vectorized(8):
+                        with T.block("T_cast_local"):
+                            v0 = T.axis.spatial(1, 0)
+                            v1 = T.axis.spatial(n, ax0_ax1_fused)
+                            v2 = T.axis.spatial(4096, ax0_0 * 8 + ax0)
+                            T.reads(T_rms_norm_local[v0, v1, v2])
+                            T.writes(T_cast[v0, v1, v2])
+                            T_cast[v0, v1, v2] = T.Cast("float16", T_rms_norm_local[v0, v1, v2])
+    # fmt: on
+    _check(Before, After)
+
+
+def test_rms_norm_without_casting():
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def main(var_data: T.handle, weight: T.Buffer((4096,), "float32"), var_T_cast: T.handle):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            n = T.int32()
+            data = T.match_buffer(var_data, (1, n, 4096))
+            T_cast = T.match_buffer(var_T_cast, (1, n, 4096))
+            # with T.block("root"):
+            T_multiply = T.alloc_buffer((1, n, 4096))
+            T_multiply_red = T.alloc_buffer((1, n))
+            rsqrt = T.alloc_buffer((1, n))
+            T_rms_norm = T.alloc_buffer((1, n, 4096))
+            for ax0, ax1, ax2 in T.grid(1, n, 4096):
+                with T.block("T_multiply"):
+                    v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                    T.reads(data[v_ax0, v_ax1, v_ax2])
+                    T.writes(T_multiply[v_ax0, v_ax1, v_ax2])
+                    T_multiply[v_ax0, v_ax1, v_ax2] = data[v_ax0, v_ax1, v_ax2] * data[v_ax0, v_ax1, v_ax2]
+            for ax0, ax1, k2 in T.grid(1, n, 4096):
+                with T.block("T_multiply_red"):
+                    v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2])
+                    T.reads(T_multiply[v_ax0, v_ax1, v_k2])
+                    T.writes(T_multiply_red[v_ax0, v_ax1])
+                    with T.init():
+                        T_multiply_red[v_ax0, v_ax1] = T.float32(0)
+                    T_multiply_red[v_ax0, v_ax1] = T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0, v_ax1, v_k2]
+            for ax0, ax1 in T.grid(1, n):
+                with T.block("rsqrt"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(T_multiply_red[v_ax0, v_ax1])
+                    T.writes(rsqrt[v_ax0, v_ax1])
+                    rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07))
+            for ax0, ax1, ax2 in T.grid(1, n, 4096):
+                with T.block("T_rms_norm"):
+                    v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                    T.reads(rsqrt[v_ax0, v_ax1], data[v_ax0, v_ax1, v_ax2], weight[v_ax2])
+                    T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2])
+                    T_rms_norm[v_ax0, v_ax1, v_ax2] = rsqrt[v_ax0, v_ax1] * data[v_ax0, v_ax1, v_ax2] * weight[v_ax2]
+            for ax0, ax1, ax2 in T.grid(1, n, 4096):
+                with T.block("T_cast_2"):
+                    v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                    T.reads(T_rms_norm[v_ax0, v_ax1, v_ax2])
+                    T.writes(T_cast[v_ax0, v_ax1, v_ax2])
+                    T_cast[v_ax0, v_ax1, v_ax2] = T_rms_norm[v_ax0, v_ax1, v_ax2]
+
+    @I.ir_module
+    class After:
+        @T.prim_func
+        def main(var_data: T.handle, weight: T.Buffer((4096,), "float32"), var_T_cast: T.handle):
+            T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
+            n = T.int32()
+            data = T.match_buffer(var_data, (1, n, 4096))
+            T_cast = T.match_buffer(var_T_cast, (1, n, 4096))
+            # with T.block("root"):
+            T_multiply_local = T.alloc_buffer((1, n, 4096), scope="local")
+            T_multiply_red_local = T.alloc_buffer((1, n), scope="local")
+            rsqrt_shared = T.alloc_buffer((1, n), scope="shared")
+            T_rms_norm_local = T.alloc_buffer((1, n, 4096), scope="local")
+            data_local = T.alloc_buffer((1, n, 4096), scope="local")
+            for ax0_ax1_fused in T.thread_binding(n, thread="blockIdx.x"):
+                for ax2_0 in T.thread_binding(512, thread="threadIdx.x"):
+                    for ax2_1 in range(1):
+                        for ax2_2 in T.vectorized(8):
+                            with T.block("data_local"):
+                                v0 = T.axis.spatial(1, 0)
+                                v1 = T.axis.spatial(n, ax0_ax1_fused)
+                                v2 = T.axis.spatial(4096, ax2_0 * 8 + ax2_1 * 8 + ax2_2)
+                                T.reads(data[v0, v1, v2])
+                                T.writes(data_local[v0, v1, v2])
+                                data_local[v0, v1, v2] = data[v0, v1, v2]
+                    for ax0 in range(8):
+                        with T.block("T_multiply"):
+                            v_ax0 = T.axis.spatial(1, 0)
+                            v_ax1 = T.axis.spatial(n, ax0_ax1_fused)
+                            v_ax2 = T.axis.spatial(4096, ax2_0 * 8 + ax0)
+                            T.reads(data_local[v_ax0, v_ax1, v_ax2])
+                            T.writes(T_multiply_local[v_ax0, v_ax1, v_ax2])
+                            T_multiply_local[v_ax0, v_ax1, v_ax2] = data_local[v_ax0, v_ax1, v_ax2] * data_local[v_ax0, v_ax1, v_ax2]
+                    for ax0 in range(8):
+                        with T.block("T_multiply_red"):
+                            v_ax0 = T.axis.spatial(1, 0)
+                            v_ax1 = T.axis.spatial(n, ax0_ax1_fused)
+                            v_k2 = T.axis.reduce(4096, ax2_0 * 8 + ax0)
+                            T.reads(T_multiply_local[v_ax0, v_ax1, v_k2])
+                            T.writes(T_multiply_red_local[v_ax0, v_ax1])
+                            with T.init():
+                                T_multiply_red_local[v_ax0, v_ax1] = T.float32(0)
+                            T_multiply_red_local[v_ax0, v_ax1] = T_multiply_red_local[v_ax0, v_ax1] + T_multiply_local[v_ax0, v_ax1, v_k2]
+                with T.block("rsqrt"):
+                    v_ax0 = T.axis.spatial(1, 0)
+                    v_ax1 = T.axis.spatial(n, ax0_ax1_fused)
+                    T.reads(T_multiply_red_local[v_ax0, v_ax1])
+                    T.writes(rsqrt_shared[v_ax0, v_ax1])
+                    rsqrt_shared[v_ax0, v_ax1] = T.rsqrt(T_multiply_red_local[v_ax0, v_ax1] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07))
+                for ax0_0 in T.thread_binding(512, thread="threadIdx.x"):
+                    for ax0_1, ax0_2 in T.grid(1, 8):
+                        with T.block("T_rms_norm"):
+                            v_ax0 = T.axis.spatial(1, 0)
+                            v_ax1 = T.axis.spatial(n, ax0_ax1_fused)
+                            v_ax2 = T.axis.spatial(4096, ax0_0 * 8 + ax0_1 * 8 + ax0_2)
+                            T.reads(rsqrt_shared[v_ax0, v_ax1], data_local[v_ax0, v_ax1, v_ax2], weight[v_ax2])
+                            T.writes(T_rms_norm_local[v_ax0, v_ax1, v_ax2])
+                            T_rms_norm_local[v_ax0, v_ax1, v_ax2] = rsqrt_shared[v_ax0, v_ax1] * data_local[v_ax0, v_ax1, v_ax2] * weight[v_ax2]
+                    for ax0 in T.vectorized(8):
+                        with T.block("T_cast_local"):
+                            v0 = T.axis.spatial(1, 0)
+                            v1 = T.axis.spatial(n, ax0_ax1_fused)
+                            v2 = T.axis.spatial(4096, ax0_0 * 8 + ax0)
+                            T.reads(T_rms_norm_local[v0, v1, v2])
+                            T.writes(T_cast[v0, v1, v2])
+                            T_cast[v0, v1, v2] = T_rms_norm_local[v0, v1, v2]
+    # fmt: on
+    _check(Before, After)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py
index 74da77f7d8c5..07fbc3419b98 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -2773,9 +2773,10 @@ def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "floa
             T.func_attr({"tir.noalias": T.bool(True)})
             # with T.block("root"):
             T_cast_1 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)))
-            T_cast_2 = T.alloc_buffer((T.int64(4), T.int64(5)))
             T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)))
             T_multiply_red = T.alloc_buffer((T.int64(2), T.int64(3)))
+            rsqrt = T.alloc_buffer((T.int64(2), T.int64(3)))
+            T_cast_2 = T.alloc_buffer((T.int64(4), T.int64(5)))
             T_rms_norm = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)))
             for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)):
                 with T.block("T_cast"):
@@ -2783,12 +2784,6 @@ def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "floa
                     T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3])
                     T.writes(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3])
                     T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax1, v_ax2, v_ax3]
-            for ax0, ax1 in T.grid(T.int64(4), T.int64(5)):
-                with T.block("T_cast_1"):
-                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
-                    T.reads(B[v_ax0, v_ax1])
-                    T.writes(T_cast_2[v_ax0, v_ax1])
-                    T_cast_2[v_ax0, v_ax1] = B[v_ax0, v_ax1]
             for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)):
                 with T.block("T_multiply"):
                     v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
@@ -2803,12 +2798,24 @@ def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "floa
                     with T.init():
                         T_multiply_red[v_ax0, v_ax1] = T.float32(0)
                     T_multiply_red[v_ax0, v_ax1] = T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0, v_ax1, v_k2, v_k3]
+            for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
+                with T.block("rsqrt"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(T_multiply_red[v_ax0, v_ax1])
+                    T.writes(rsqrt[v_ax0, v_ax1])
+                    rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1] * T.float32(0.050000000000000003) + T.float32(1.0000000000000001e-05))
+            for ax0, ax1 in T.grid(T.int64(4), T.int64(5)):
+                with T.block("T_cast_1"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(B[v_ax0, v_ax1])
+                    T.writes(T_cast_2[v_ax0, v_ax1])
+                    T_cast_2[v_ax0, v_ax1] = B[v_ax0, v_ax1]
             for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)):
                 with T.block("T_rms_norm"):
                     v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
-                    T.reads(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3], T_cast_2[v_ax2, v_ax3], T_multiply_red[v_ax0, v_ax1])
+                    T.reads(rsqrt[v_ax0, v_ax1], T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3], T_cast_2[v_ax2, v_ax3])
                     T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_cast_2[v_ax2, v_ax3] * T.rsqrt(T_multiply_red[v_ax0, v_ax1] * T.float32(0.050000000000000003) + T.float32(1.0000000000000001e-05))
+                    T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = rsqrt[v_ax0, v_ax1] * T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_cast_2[v_ax2, v_ax3]
             for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)):
                 with T.block("T_cast_2"):
                     v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
@@ -2842,9 +2849,10 @@ def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "floa
             T.func_attr({"tir.noalias": T.bool(True)})
             # with T.block("root"):
             T_cast_1 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)))
-            T_cast_2 = T.alloc_buffer((T.int64(4), T.int64(5)))
             T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)))
             T_multiply_red = T.alloc_buffer((T.int64(2), T.int64(3)))
+            rsqrt = T.alloc_buffer((T.int64(2), T.int64(3)))
+            T_cast_2 = T.alloc_buffer((T.int64(4), T.int64(5)))
             T_rms_norm = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)))
             for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)):
                 with T.block("T_cast"):
@@ -2852,12 +2860,6 @@ def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "floa
                     T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3])
                     T.writes(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3])
                     T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] = T.Cast("float32", A[v_ax0, v_ax1, v_ax2, v_ax3])
-            for ax0, ax1 in T.grid(T.int64(4), T.int64(5)):
-                with T.block("T_cast_1"):
-                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
-                    T.reads(B[v_ax0, v_ax1])
-                    T.writes(T_cast_2[v_ax0, v_ax1])
-                    T_cast_2[v_ax0, v_ax1] = T.Cast("float32", B[v_ax0, v_ax1])
             for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)):
                 with T.block("T_multiply"):
                     v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
@@ -2872,12 +2874,24 @@ def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "floa
                     with T.init():
                         T_multiply_red[v_ax0, v_ax1] = T.float32(0)
                     T_multiply_red[v_ax0, v_ax1] = T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0, v_ax1, v_k2, v_k3]
+            for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
+                with T.block("rsqrt"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(T_multiply_red[v_ax0, v_ax1])
+                    T.writes(rsqrt[v_ax0, v_ax1])
+                    rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1] * T.float32(0.050000000000000003) + T.float32(1.0000000000000001e-05))
+            for ax0, ax1 in T.grid(T.int64(4), T.int64(5)):
+                with T.block("T_cast_1"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(B[v_ax0, v_ax1])
+                    T.writes(T_cast_2[v_ax0, v_ax1])
+                    T_cast_2[v_ax0, v_ax1] = T.Cast("float32", B[v_ax0, v_ax1])
             for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)):
                 with T.block("T_rms_norm"):
                     v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
-                    T.reads(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3], T_cast_2[v_ax2, v_ax3], T_multiply_red[v_ax0, v_ax1])
+                    T.reads(rsqrt[v_ax0, v_ax1], T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3], T_cast_2[v_ax2, v_ax3])
                     T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_cast_2[v_ax2, v_ax3] * T.rsqrt(T_multiply_red[v_ax0, v_ax1] * T.float32(0.050000000000000003) + T.float32(1.0000000000000001e-05))
+                    T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = rsqrt[v_ax0, v_ax1] * T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_cast_2[v_ax2, v_ax3]
             for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)):
                 with T.block("T_cast_2"):
                     v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
@@ -2918,9 +2932,10 @@ def rms_norm(var_A: T.handle, var_B: T.handle, var_T_cast: T.handle):
             T_cast = T.match_buffer(var_T_cast, (n, s, f))
             # with T.block("root"):
             T_cast_1 = T.alloc_buffer((n, s, f))
-            T_cast_2 = T.alloc_buffer((s, f))
             T_multiply = T.alloc_buffer((n, s, f))
             T_multiply_red = T.alloc_buffer((n,))
+            rsqrt = T.alloc_buffer((n,))
+            T_cast_2 = T.alloc_buffer((s, f))
             T_rms_norm = T.alloc_buffer((n, s, f))
             for ax0, ax1, ax2 in T.grid(n, s, f):
                 with T.block("T_cast"):
@@ -2928,12 +2943,6 @@ def rms_norm(var_A: T.handle, var_B: T.handle, var_T_cast: T.handle):
                     T.reads(A[v_ax0, v_ax1, v_ax2])
                     T.writes(T_cast_1[v_ax0, v_ax1, v_ax2])
                     T_cast_1[v_ax0, v_ax1, v_ax2] = A[v_ax0, v_ax1, v_ax2]
-            for ax0, ax1 in T.grid(s, f):
-                with T.block("T_cast_1"):
-                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
-                    T.reads(B[v_ax0, v_ax1])
-                    T.writes(T_cast_2[v_ax0, v_ax1])
-                    T_cast_2[v_ax0, v_ax1] = B[v_ax0, v_ax1]
             for ax0, ax1, ax2 in T.grid(n, s, f):
                 with T.block("T_multiply"):
                     v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
@@ -2948,12 +2957,24 @@ def rms_norm(var_A: T.handle, var_B: T.handle, var_T_cast: T.handle):
                     with T.init():
                         T_multiply_red[v_ax0] = T.float32(0)
                     T_multiply_red[v_ax0] = T_multiply_red[v_ax0] + T_multiply[v_ax0, v_k1, v_k2]
+            for ax0 in range(n):
+                with T.block("rsqrt"):
+                    v_ax0 = T.axis.spatial(n, ax0)
+                    T.reads(T_multiply_red[v_ax0])
+                    T.writes(rsqrt[v_ax0])
+                    rsqrt[v_ax0] = T.rsqrt(T_multiply_red[v_ax0] / (T.Cast("float32", s) * T.Cast("float32", f)) + T.float32(1.0000000000000001e-05))
+            for ax0, ax1 in T.grid(s, f):
+                with T.block("T_cast_1"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(B[v_ax0, v_ax1])
+                    T.writes(T_cast_2[v_ax0, v_ax1])
+                    T_cast_2[v_ax0, v_ax1] = B[v_ax0, v_ax1]
             for ax0, ax1, ax2 in T.grid(n, s, f):
                 with T.block("T_rms_norm"):
                     v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
-                    T.reads(T_cast_1[v_ax0, v_ax1, v_ax2], T_cast_2[v_ax1, v_ax2], T_multiply_red[v_ax0])
+                    T.reads(rsqrt[v_ax0], T_cast_1[v_ax0, v_ax1, v_ax2], T_cast_2[v_ax1, v_ax2])
                     T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2])
-                    T_rms_norm[v_ax0, v_ax1, v_ax2] = T_cast_1[v_ax0, v_ax1, v_ax2] * T_cast_2[v_ax1, v_ax2] * T.rsqrt(T_multiply_red[v_ax0] / (T.Cast("float32", s) * T.Cast("float32", f)) + T.float32(1.0000000000000001e-05))
+                    T_rms_norm[v_ax0, v_ax1, v_ax2] = rsqrt[v_ax0] * T_cast_1[v_ax0, v_ax1, v_ax2] * T_cast_2[v_ax1, v_ax2]
             for ax0, ax1, ax2 in T.grid(n, s, f):
                 with T.block("T_cast_2"):
                     v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
@@ -2990,9 +3011,10 @@ def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "floa
             T.func_attr({"tir.noalias": T.bool(True)})
             # with T.block("root"):
             T_cast_1 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)))
-            T_cast_2 = T.alloc_buffer((T.int64(4), T.int64(5)))
             T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)))
             T_multiply_red = T.alloc_buffer((T.int64(2), T.int64(3)))
+            rsqrt = T.alloc_buffer((T.int64(2), T.int64(3)))
+            T_cast_2 = T.alloc_buffer((T.int64(4), T.int64(5)))
             T_rms_norm = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)))
             for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)):
                 with T.block("T_cast"):
@@ -3000,12 +3022,6 @@ def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "floa
                     T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3])
                     T.writes(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3])
                     T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax1, v_ax2, v_ax3]
-            for ax0, ax1 in T.grid(T.int64(4), T.int64(5)):
-                with T.block("T_cast_1"):
-                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
-                    T.reads(B[v_ax0, v_ax1])
-                    T.writes(T_cast_2[v_ax0, v_ax1])
-                    T_cast_2[v_ax0, v_ax1] = B[v_ax0, v_ax1]
             for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)):
                 with T.block("T_multiply"):
                     v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
@@ -3020,12 +3036,24 @@ def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "floa
                     with T.init():
                         T_multiply_red[v_ax0, v_ax1] = T.float32(0)
                     T_multiply_red[v_ax0, v_ax1] = T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0, v_ax1, v_k2, v_k3]
+            for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
+                with T.block("rsqrt"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(T_multiply_red[v_ax0, v_ax1])
+                    T.writes(rsqrt[v_ax0, v_ax1])
+                    rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1] * T.float32(0.050000000000000003) + T.float32(1.0000000000000001e-05))
+            for ax0, ax1 in T.grid(T.int64(4), T.int64(5)):
+                with T.block("T_cast_1"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(B[v_ax0, v_ax1])
+                    T.writes(T_cast_2[v_ax0, v_ax1])
+                    T_cast_2[v_ax0, v_ax1] = B[v_ax0, v_ax1]
             for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)):
                 with T.block("T_rms_norm"):
                     v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
-                    T.reads(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3], T_cast_2[v_ax2, v_ax3], T_multiply_red[v_ax0, v_ax1])
+                    T.reads(rsqrt[v_ax0, v_ax1], T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3], T_cast_2[v_ax2, v_ax3])
                     T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_cast_2[v_ax2, v_ax3] * T.rsqrt(T_multiply_red[v_ax0, v_ax1] * T.float32(0.050000000000000003) + T.float32(1.0000000000000001e-05))
+                    T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = rsqrt[v_ax0, v_ax1] * T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_cast_2[v_ax2, v_ax3]
             for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)):
                 with T.block("T_cast_2"):
                     v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])