Skip to content

Commit

Permalink
Add test case for matching against arbitrary dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
Lunderberg committed Jul 15, 2024
1 parent 826d270 commit 0f8b4fe
Show file tree
Hide file tree
Showing 8 changed files with 187 additions and 15 deletions.
13 changes: 12 additions & 1 deletion src/arith/analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,23 @@ bool Analyzer::CanProveLess(const PrimExpr& expr, int64_t upper_bound) {
}

bool Analyzer::CanProveEqual(const PrimExpr& lhs, const PrimExpr& rhs) {
LOG(DEBUG) << "Attempting to prove whether " << lhs << " (dtype = " << lhs->dtype << ") and "
<< rhs << " (dtype = " << rhs->dtype << ") are equal";
const auto* clhs = lhs.as<IntImmNode>();
const auto* crhs = rhs.as<IntImmNode>();
if (clhs && crhs) return clhs->value == crhs->value;
if (clhs && crhs) {
LOG(DEBUG) << "\t"
<< "Both values are integers, comparing directly";
return clhs->value == crhs->value;
}
if (lhs->dtype.is_handle() || rhs->dtype.is_handle()) {
LOG(DEBUG) << "\t"
<< "Both values are handles, comparing reference equality";
return lhs.same_as(rhs);
}
LOG(DEBUG) << "\t"
<< "Falling back to CanProve, using expression " << (lhs - rhs == 0)
<< ", which simplifies to " << Simplify(lhs - rhs == 0);
return CanProve(lhs - rhs == 0);
}

Expand Down
4 changes: 4 additions & 0 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1679,6 +1679,10 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(EQ ret) {
PVar<IntImm> c1, c2;
PVar<PrimExpr> lanes;

auto ctrue = PConst<PrimExpr>(make_const(ret->dtype, true));

TVM_TRY_REWRITE(x == x, ctrue);

// vector rule
if (ret->dtype.is_scalable_or_fixed_length_vector()) {
TVM_TRY_REWRITE(broadcast(x, lanes) == broadcast(y, lanes), broadcast(x == y, lanes));
Expand Down
6 changes: 5 additions & 1 deletion src/relax/ir/block_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,8 @@ class BlockBuilderImpl : public BlockBuilderNode {
analyzer_.MarkGlobalNonNegValue(shape_var);
} else {
const PrimExpr& old_shape_expr = (*it).second;
CHECK(analyzer_.CanProveEqual(old_shape_expr, shape_expr))
CHECK(old_shape_expr.same_as(shape_expr) ||
analyzer_.CanProveEqual(old_shape_expr, shape_expr))
<< "Inconsistent shape var " << shape_var << " in scope: " << old_shape_expr << " vs "
<< shape_expr;
}
Expand Down Expand Up @@ -261,6 +262,8 @@ class BlockBuilderImpl : public BlockBuilderNode {
cur_frame->bindings.push_back(match_cast);
// NOTE match shape do not follow simple binding rule
// as a result should not appear in binding table.

AddDefinitionToScope(var);
return var;
}

Expand Down Expand Up @@ -296,6 +299,7 @@ class BlockBuilderImpl : public BlockBuilderNode {
// NOTE match shape do not follow simple binding rule
// as a result should not appear in binding table.
cur_frame->bindings.push_back(binding);
AddDefinitionToScope(match_cast->var);
} else {
LOG(FATAL) << "Unsupported binding type: " << binding->GetTypeKey();
}
Expand Down
3 changes: 3 additions & 0 deletions src/relax/ir/dataflow_expr_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,9 @@ ExprRewriter ExprRewriter::FromModule(IRModule mod) {
} else if (auto func = expr.as<ExternFuncNode>()) {
return ExternFuncPattern(func->global_symbol);

} else if (auto prim = expr.as<PrimValueNode>()) {
return StructInfoPattern(WildcardPattern(), PrimStructInfo(prim->value));

} else {
LOG(FATAL) << "TypeError: "
<< "Cannot convert Relax expression of type " << expr->GetTypeKey()
Expand Down
15 changes: 15 additions & 0 deletions src/relax/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#include <tvm/relax/struct_info.h>
#include <tvm/relax/type.h>

#include <unordered_set>

namespace tvm {
namespace relax {

Expand Down Expand Up @@ -589,6 +591,19 @@ Function::Function(Array<Var> params, Expr body, Optional<StructInfo> ret_struct
ret_struct_info = body_sinfo;
}

auto f_shape_var_map = [&] {
auto tir_vars = DefinableTIRVarsInStructInfo(TupleStructInfo(params.Map(GetStructInfo)));
std::unordered_set<tir::Var> lookup(tir_vars.begin(), tir_vars.end());
return [lookup = std::move(lookup)](const tir::Var& var) -> Optional<PrimExpr> {
if (lookup.count(var)) {
return var;
} else {
return NullOpt;
}
};
}();
ret_struct_info = EraseToWellDefined(ret_struct_info.value(), f_shape_var_map);

FuncStructInfo func_sinfo(param_sinfo, ret_struct_info.value(), is_pure);

// set the fields
Expand Down
3 changes: 2 additions & 1 deletion src/relax/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,8 @@ Expr ExprMutator::VisitWithNewScope(const Expr& expr, Optional<Array<Var>> param
builder_->BeginInnerScope();
// Inner scope also includes any TIR variables that are defined by
// MatchCast nodes, and are internal to the scope.
Expr ret = ExprFunctor::VisitExpr(expr);
Expr ret = this->VisitExpr(expr);

builder_->EndScope();

// Normalization (and the resulting StructInfo inference) of the
Expand Down
16 changes: 4 additions & 12 deletions src/relax/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,7 @@ tvm::Map<tir::Var, PrimExpr> InferSymbolicVarMap(
if (!var_sinfo) return;

auto expr_sinfo = expr.as<PrimStructInfoNode>();
CHECK(expr_sinfo) << "Cannot bind expression with struct type " << expr
<< " to variable with struct type " << var;
CHECK_EQ(var_sinfo->dtype, expr_sinfo->dtype)
<< "Cannot bind expression with struct type " << expr << " to variable with struct type "
<< var << ", due to conflicting PrimExpr DataType";
if (!expr_sinfo) return;

if (!var_sinfo->value.defined() || !expr_sinfo->value.defined()) return;

Expand All @@ -139,15 +135,12 @@ tvm::Map<tir::Var, PrimExpr> InferSymbolicVarMap(
if (!var_shape->values.defined()) return;

auto expr_shape = expr.as<ShapeStructInfoNode>();
CHECK(expr_shape) << "Cannot bind expression with struct type " << expr
<< " to variable with struct type " << var;
if (!expr_shape) return;
if (!expr_shape->values.defined()) return;

auto var_shape_arr = var_shape->values.value();
auto expr_shape_arr = expr_shape->values.value();
CHECK_EQ(var_shape_arr.size(), expr_shape_arr.size())
<< "Cannot bind shape " << expr_shape_arr << " of dimension " << expr_shape_arr.size()
<< " to variable with shape " << var_shape_arr << " of dimension " << var_shape_arr.size();
if (var_shape_arr.size() != expr_shape_arr.size()) return;
for (size_t i = 0; i < var_shape_arr.size(); i++) {
bind_from_prim_expr(var_shape_arr[i], expr_shape_arr[i]);
}
Expand All @@ -159,8 +152,7 @@ tvm::Map<tir::Var, PrimExpr> InferSymbolicVarMap(
if (!var_tensor->shape.defined()) return;

auto expr_tensor = expr.as<TensorStructInfoNode>();
CHECK(expr_tensor) << "Cannot bind expression with struct type " << expr
<< " to variable with struct type " << var;
if (!expr_tensor) return;
if (!expr_tensor->shape.defined()) return;

bind_from_shape(GetStructInfo(var_tensor->shape.value()),
Expand Down
142 changes: 142 additions & 0 deletions tests/python/relax/test_dataflow_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,148 @@ def expected(A: R.Tensor([16], "float32")):
tvm.ir.assert_structural_equal(expected, after)


def test_rewrite_of_arbitrary_dtype():
"""A pattern-match may apply to a tensor with unknown dtype

In this test case, a pattern identifies `R.strided_slice` usage
which returns the last slice of an array, and replaces it with a
view into the input array.

"""

@R.rewriter
class Rewriter:
@R.function
def pattern(A: R.Tensor(["M", "N"])) -> R.Tensor(["N"]):
M = T.int64()
N = T.int64()
last_slice_2d: R.Tensor([1, N]) = R.strided_slice(A, axes=[0], begin=[M - 1], end=[M])
last_slice_1d: R.Tensor([N]) = R.squeeze(last_slice_2d, axis=0)
return last_slice_1d

@R.function
def replacement(A: R.Tensor(["M", "N"])) -> R.Tensor(["N"]):
M = T.int64()
N = T.int64()

# TODO(Lunderberg): Improve this syntax. A Relax
# PrimValue (e.g. `A.dtype.bits`) should be usable in any
# Relax context that accepts a `PrimExpr`. Currently,
# this requires `R.match_cast` to produce a TIR symbolic
# variable from the Relax PrimValue.
bits_per_element = T.uint8()
_ = R.match_cast(
A.dtype.bits,
R.Prim(value=bits_per_element),
)
lanes_per_element = T.uint16()
_ = R.match_cast(
A.dtype.lanes,
R.Prim(value=lanes_per_element),
)

last_slice = R.memory.view(
A,
[N],
relative_byte_offset=(M - 1)
* N
* T.ceildiv(
bits_per_element.astype("int64") * lanes_per_element.astype("int64"), 8
),
)
return last_slice

@I.ir_module
class Before:
@R.function
def main(
A: R.Tensor([32, 16], "float16"),
B: R.Tensor(["P", "Q"], "int4x8"),
C: R.Tensor([16, 32]),
):
P = T.int64()
Q = T.int64()

A_slice_2d = R.strided_slice(A, axes=[0], begin=[31], end=[32])
A_slice_1d = R.squeeze(A_slice_2d, axis=0)

B_slice_2d = R.strided_slice(B, axes=[0], begin=[P - 1], end=[P])
B_slice_1d = R.squeeze(B_slice_2d, axis=0)

C_slice_2d = R.strided_slice(C, axes=[0], begin=[15], end=[16])
C_slice_1d = R.squeeze(C_slice_2d, axis=0)

return (A_slice_1d, B_slice_1d, C_slice_1d)

@I.ir_module
class Expected:
@R.function
def main(
A: R.Tensor([32, 16], "float16"),
B: R.Tensor(["P", "Q"], "int4x8"),
C: R.Tensor([16, 32]),
):
P = T.int64()
Q = T.int64()

# The pattern matches any 2-d tensor, with any data type.
# When the match's shape and dtype are both known,
# normalization and canonicalization produces a statically
# known value for `relative_byte_offset`.
#
# Relative offset is `(31 rows) *
# (16 elements/row) *
# (2 bytes/element)`
A_slice_1d = R.memory.view(A, shape=[16], relative_byte_offset=992)

# The pattern can also match a 2-d tensor with dynamic
# shape. The `relative_byte_offset` uses the known
# datatype (4 bytes for each int4x8), but with dynamic
# shape variables substituted in where required.
#
# Relative offset is `((P-1) rows) *
# (Q elements/row) *
# (4 bytes/element)`
B_slice_1d = R.memory.view(B, shape=[Q], relative_byte_offset=(P - 1) * Q * 4)

# The pattern can also match a 2-d tensor with static
# shape, but unknown data type. The
# `relative_byte_offset` is determined based on the known
# number of elements, and the dynamic size of each
# element.
#
# Relative offset is `(15 rows) *
# (32 elements/row) *
# (ceildiv(bits*lanes,8) bytes/element)`
C_bits_per_element = T.uint8()
C_bits_prim_value = C.dtype.bits
_ = R.match_cast(
C_bits_prim_value,
R.Prim(value=C_bits_per_element),
)
C_lanes_per_element = T.uint16()
C_lanes_prim_value = C.dtype.lanes
_ = R.match_cast(
C_lanes_prim_value,
R.Prim(value=C_lanes_per_element),
)

C_slice_1d = R.memory.view(
C,
shape=[32],
relative_byte_offset=(
(C_bits_per_element.astype("int64") * C_lanes_per_element.astype("int64") + 7)
// 8
)
* 480,
)

return (A_slice_1d, B_slice_1d, C_slice_1d)

after = Rewriter(Before)
tvm.ir.assert_structural_equal(Expected, after)


def test_rewrite_may_introduce_private_relax_subroutines():
"""The replacement may contain subroutines"""

Expand Down

0 comments on commit 0f8b4fe

Please sign in to comment.