diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 08d5e9379dc61..3ddffc5aba3b5 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -160,12 +160,23 @@ bool Analyzer::CanProveLess(const PrimExpr& expr, int64_t upper_bound) { } bool Analyzer::CanProveEqual(const PrimExpr& lhs, const PrimExpr& rhs) { + LOG(DEBUG) << "Attempting to prove whether " << lhs << " (dtype = " << lhs->dtype << ") and " + << rhs << " (dtype = " << rhs->dtype << ") are equal"; const auto* clhs = lhs.as(); const auto* crhs = rhs.as(); - if (clhs && crhs) return clhs->value == crhs->value; + if (clhs && crhs) { + LOG(DEBUG) << "\t" + << "Both values are integers, comparing directly"; + return clhs->value == crhs->value; + } if (lhs->dtype.is_handle() || rhs->dtype.is_handle()) { + LOG(DEBUG) << "\t" + << "Both values are handles, comparing reference equality"; return lhs.same_as(rhs); } + LOG(DEBUG) << "\t" + << "Falling back to CanProve, using expression " << (lhs - rhs == 0) + << ", which simplifies to " << Simplify(lhs - rhs == 0); return CanProve(lhs - rhs == 0); } diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index f4d4a9048ced7..39ee47b98a797 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -1679,6 +1679,10 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(EQ ret) { PVar c1, c2; PVar lanes; + auto ctrue = PConst(make_const(ret->dtype, true)); + + TVM_TRY_REWRITE(x == x, ctrue); + // vector rule if (ret->dtype.is_scalable_or_fixed_length_vector()) { TVM_TRY_REWRITE(broadcast(x, lanes) == broadcast(y, lanes), broadcast(x == y, lanes)); diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index 95d16c1abadfa..a123a3e458051 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -221,7 +221,8 @@ class BlockBuilderImpl : public BlockBuilderNode { analyzer_.MarkGlobalNonNegValue(shape_var); } else { const PrimExpr& old_shape_expr = (*it).second; - CHECK(analyzer_.CanProveEqual(old_shape_expr, shape_expr)) + CHECK(old_shape_expr.same_as(shape_expr) || + analyzer_.CanProveEqual(old_shape_expr, shape_expr)) << "Inconsistent shape var " << shape_var << " in scope: " << old_shape_expr << " vs " << shape_expr; } @@ -261,6 +262,8 @@ class BlockBuilderImpl : public BlockBuilderNode { cur_frame->bindings.push_back(match_cast); // NOTE match shape do not follow simple binding rule // as a result should not appear in binding table. + + AddDefinitionToScope(var); return var; } @@ -296,6 +299,7 @@ class BlockBuilderImpl : public BlockBuilderNode { // NOTE match shape do not follow simple binding rule // as a result should not appear in binding table. cur_frame->bindings.push_back(binding); + AddDefinitionToScope(match_cast->var); } else { LOG(FATAL) << "Unsupported binding type: " << binding->GetTypeKey(); } diff --git a/src/relax/ir/dataflow_expr_rewriter.cc b/src/relax/ir/dataflow_expr_rewriter.cc index e0dca16b2e913..81d26cb4730c9 100644 --- a/src/relax/ir/dataflow_expr_rewriter.cc +++ b/src/relax/ir/dataflow_expr_rewriter.cc @@ -727,6 +727,9 @@ ExprRewriter ExprRewriter::FromModule(IRModule mod) { } else if (auto func = expr.as()) { return ExternFuncPattern(func->global_symbol); + } else if (auto prim = expr.as()) { + return StructInfoPattern(WildcardPattern(), PrimStructInfo(prim->value)); + } else { LOG(FATAL) << "TypeError: " << "Cannot convert Relax expression of type " << expr->GetTypeKey() diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index a14ba1d9aaa11..4850d52546b2a 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -21,6 +21,8 @@ #include #include +#include + namespace tvm { namespace relax { @@ -589,6 +591,19 @@ Function::Function(Array params, Expr body, Optional ret_struct ret_struct_info = body_sinfo; } + auto f_shape_var_map = [&] { + auto tir_vars = DefinableTIRVarsInStructInfo(TupleStructInfo(params.Map(GetStructInfo))); + std::unordered_set lookup(tir_vars.begin(), tir_vars.end()); + return [lookup = std::move(lookup)](const tir::Var& var) -> Optional { + if (lookup.count(var)) { + return var; + } else { + return NullOpt; + } + }; + }(); + ret_struct_info = EraseToWellDefined(ret_struct_info.value(), f_shape_var_map); + FuncStructInfo func_sinfo(param_sinfo, ret_struct_info.value(), is_pure); // set the fields diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc index c2320de62a751..3ee403a25cda6 100644 --- a/src/relax/ir/expr_functor.cc +++ b/src/relax/ir/expr_functor.cc @@ -810,7 +810,8 @@ Expr ExprMutator::VisitWithNewScope(const Expr& expr, Optional> param builder_->BeginInnerScope(); // Inner scope also includes any TIR variables that are defined by // MatchCast nodes, and are internal to the scope. - Expr ret = ExprFunctor::VisitExpr(expr); + Expr ret = this->VisitExpr(expr); + builder_->EndScope(); // Normalization (and the resulting StructInfo inference) of the diff --git a/src/relax/utils.cc b/src/relax/utils.cc index f0239e424f300..77416dc92b1df 100644 --- a/src/relax/utils.cc +++ b/src/relax/utils.cc @@ -122,11 +122,7 @@ tvm::Map InferSymbolicVarMap( if (!var_sinfo) return; auto expr_sinfo = expr.as(); - CHECK(expr_sinfo) << "Cannot bind expression with struct type " << expr - << " to variable with struct type " << var; - CHECK_EQ(var_sinfo->dtype, expr_sinfo->dtype) - << "Cannot bind expression with struct type " << expr << " to variable with struct type " - << var << ", due to conflicting PrimExpr DataType"; + if (!expr_sinfo) return; if (!var_sinfo->value.defined() || !expr_sinfo->value.defined()) return; @@ -139,15 +135,12 @@ tvm::Map InferSymbolicVarMap( if (!var_shape->values.defined()) return; auto expr_shape = expr.as(); - CHECK(expr_shape) << "Cannot bind expression with struct type " << expr - << " to variable with struct type " << var; + if (!expr_shape) return; if (!expr_shape->values.defined()) return; auto var_shape_arr = var_shape->values.value(); auto expr_shape_arr = expr_shape->values.value(); - CHECK_EQ(var_shape_arr.size(), expr_shape_arr.size()) - << "Cannot bind shape " << expr_shape_arr << " of dimension " << expr_shape_arr.size() - << " to variable with shape " << var_shape_arr << " of dimension " << var_shape_arr.size(); + if (var_shape_arr.size() != expr_shape_arr.size()) return; for (size_t i = 0; i < var_shape_arr.size(); i++) { bind_from_prim_expr(var_shape_arr[i], expr_shape_arr[i]); } @@ -159,8 +152,7 @@ tvm::Map InferSymbolicVarMap( if (!var_tensor->shape.defined()) return; auto expr_tensor = expr.as(); - CHECK(expr_tensor) << "Cannot bind expression with struct type " << expr - << " to variable with struct type " << var; + if (!expr_tensor) return; if (!expr_tensor->shape.defined()) return; bind_from_shape(GetStructInfo(var_tensor->shape.value()), diff --git a/tests/python/relax/test_dataflow_rewriter.py b/tests/python/relax/test_dataflow_rewriter.py index 05bbe429bbccd..1377bd0c14987 100644 --- a/tests/python/relax/test_dataflow_rewriter.py +++ b/tests/python/relax/test_dataflow_rewriter.py @@ -377,6 +377,148 @@ def expected(A: R.Tensor([16], "float32")): tvm.ir.assert_structural_equal(expected, after) +def test_rewrite_of_arbitrary_dtype(): + """A pattern-match may apply to a tensor with unknown dtype + + In this test case, a pattern identifies `R.strided_slice` usage + which returns the last slice of an array, and replaces it with a + view into the input array. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern(A: R.Tensor(["M", "N"])) -> R.Tensor(["N"]): + M = T.int64() + N = T.int64() + last_slice_2d: R.Tensor([1, N]) = R.strided_slice(A, axes=[0], begin=[M - 1], end=[M]) + last_slice_1d: R.Tensor([N]) = R.squeeze(last_slice_2d, axis=0) + return last_slice_1d + + @R.function + def replacement(A: R.Tensor(["M", "N"])) -> R.Tensor(["N"]): + M = T.int64() + N = T.int64() + + # TODO(Lunderberg): Improve this syntax. A Relax + # PrimValue (e.g. `A.dtype.bits`) should be usable in any + # Relax context that accepts a `PrimExpr`. Currently, + # this requires `R.match_cast` to produce a TIR symbolic + # variable from the Relax PrimValue. + bits_per_element = T.uint8() + _ = R.match_cast( + A.dtype.bits, + R.Prim(value=bits_per_element), + ) + lanes_per_element = T.uint16() + _ = R.match_cast( + A.dtype.lanes, + R.Prim(value=lanes_per_element), + ) + + last_slice = R.memory.view( + A, + [N], + relative_byte_offset=(M - 1) + * N + * T.ceildiv( + bits_per_element.astype("int64") * lanes_per_element.astype("int64"), 8 + ), + ) + return last_slice + + @I.ir_module + class Before: + @R.function + def main( + A: R.Tensor([32, 16], "float16"), + B: R.Tensor(["P", "Q"], "int4x8"), + C: R.Tensor([16, 32]), + ): + P = T.int64() + Q = T.int64() + + A_slice_2d = R.strided_slice(A, axes=[0], begin=[31], end=[32]) + A_slice_1d = R.squeeze(A_slice_2d, axis=0) + + B_slice_2d = R.strided_slice(B, axes=[0], begin=[P - 1], end=[P]) + B_slice_1d = R.squeeze(B_slice_2d, axis=0) + + C_slice_2d = R.strided_slice(C, axes=[0], begin=[15], end=[16]) + C_slice_1d = R.squeeze(C_slice_2d, axis=0) + + return (A_slice_1d, B_slice_1d, C_slice_1d) + + @I.ir_module + class Expected: + @R.function + def main( + A: R.Tensor([32, 16], "float16"), + B: R.Tensor(["P", "Q"], "int4x8"), + C: R.Tensor([16, 32]), + ): + P = T.int64() + Q = T.int64() + + # The pattern matches any 2-d tensor, with any data type. + # When the match's shape and dtype are both known, + # normalization and canonicalization produces a statically + # known value for `relative_byte_offset`. + # + # Relative offset is `(31 rows) * + # (16 elements/row) * + # (2 bytes/element)` + A_slice_1d = R.memory.view(A, shape=[16], relative_byte_offset=992) + + # The pattern can also match a 2-d tensor with dynamic + # shape. The `relative_byte_offset` uses the known + # datatype (4 bytes for each int4x8), but with dynamic + # shape variables substituted in where required. + # + # Relative offset is `((P-1) rows) * + # (Q elements/row) * + # (4 bytes/element)` + B_slice_1d = R.memory.view(B, shape=[Q], relative_byte_offset=(P - 1) * Q * 4) + + # The pattern can also match a 2-d tensor with static + # shape, but unknown data type. The + # `relative_byte_offset` is determined based on the known + # number of elements, and the dynamic size of each + # element. + # + # Relative offset is `(15 rows) * + # (32 elements/row) * + # (ceildiv(bits*lanes,8) bytes/element)` + C_bits_per_element = T.uint8() + C_bits_prim_value = C.dtype.bits + _ = R.match_cast( + C_bits_prim_value, + R.Prim(value=C_bits_per_element), + ) + C_lanes_per_element = T.uint16() + C_lanes_prim_value = C.dtype.lanes + _ = R.match_cast( + C_lanes_prim_value, + R.Prim(value=C_lanes_per_element), + ) + + C_slice_1d = R.memory.view( + C, + shape=[32], + relative_byte_offset=( + (C_bits_per_element.astype("int64") * C_lanes_per_element.astype("int64") + 7) + // 8 + ) + * 480, + ) + + return (A_slice_1d, B_slice_1d, C_slice_1d) + + after = Rewriter(Before) + tvm.ir.assert_structural_equal(Expected, after) + + def test_rewrite_may_introduce_private_relax_subroutines(): """The replacement may contain subroutines"""