diff --git a/tests/python/relax/test_transform_dead_code_elimination.py b/tests/python/relax/test_transform_dead_code_elimination.py index 0cb0d4624731..142faf51607b 100644 --- a/tests/python/relax/test_transform_dead_code_elimination.py +++ b/tests/python/relax/test_transform_dead_code_elimination.py @@ -454,7 +454,7 @@ def main( R.output(lv0) gv_x = R.astype(x, dtype="float16") - gv_w = R.astype(x, dtype="float16") + gv_w = R.astype(w, dtype="float16") with R.dataflow(): lv1: R.Tensor((2, 28, 28, 3), dtype="float16") = R.permute_dims( @@ -481,7 +481,7 @@ def main( w: R.Tensor((4, 3, 3, 3), dtype="float32"), ): gv_x = R.astype(x, dtype="float16") - gv_w = R.astype(x, dtype="float16") + gv_w = R.astype(w, dtype="float16") with R.dataflow(): lv1: R.Tensor((2, 28, 28, 3), dtype="float16") = R.permute_dims(