From f6c52fa9df97c4aead6774f8104ab6827e2a90b5 Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Tue, 7 May 2019 20:52:24 -0700 Subject: [PATCH] Handle vectorize for LE statement (#3137) * Handle vectorize for LE statement Fix a new cases introduced by commit 7afbca5691fdb599cd90b043d5a5036e55cae2d6 * Add test --- src/pass/vectorize_loop.cc | 3 +++ tests/python/unittest/test_pass_vectorize.py | 24 ++++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/src/pass/vectorize_loop.cc b/src/pass/vectorize_loop.cc index bd0a91ce4a99..f87e80c2d030 100644 --- a/src/pass/vectorize_loop.cc +++ b/src/pass/vectorize_loop.cc @@ -166,6 +166,9 @@ class Vectorizer : public IRMutator { Expr Mutate_(const LT* op, const Expr &e) final { return BinaryVec(op, e); } + Expr Mutate_(const LE* op, const Expr &e) final { + return BinaryVec(op, e); + } Expr Mutate_(const GT* op, const Expr &e) final { return BinaryVec(op, e); } diff --git a/tests/python/unittest/test_pass_vectorize.py b/tests/python/unittest/test_pass_vectorize.py index 03516872e835..fca22a1eca30 100644 --- a/tests/python/unittest/test_pass_vectorize.py +++ b/tests/python/unittest/test_pass_vectorize.py @@ -69,6 +69,28 @@ def test_vectorize_with_if(): assert stmt.then_case.value.dtype == "float32x4" assert isinstance(stmt.else_case, tvm.stmt.For) +def test_vectorize_with_le_cond(): + n = tvm.var('n') + ib = tvm.ir_builder.create() + A = ib.pointer("float32", name="A") + with ib.for_range(0, 4, for_type="vectorize") as i: + with ib.if_scope(i <= n): + A[i] = A[i] + 1 + stmt = ib.get() + stmt = tvm.ir_pass.VectorizeLoop(stmt) + assert isinstance(stmt, tvm.stmt.For) + +def test_vectorize_with_ge_cond(): + n = tvm.var('n') + ib = tvm.ir_builder.create() + A = ib.pointer("float32", name="A") + with ib.for_range(0, 4, for_type="vectorize") as i: + with ib.if_scope(i >= n): + A[i] = A[i] + 1 + stmt = ib.get() + stmt = tvm.ir_pass.VectorizeLoop(stmt) + assert isinstance(stmt, tvm.stmt.For) + def test_vectorize_if_then_else(): n = tvm.var('n') x = tvm.var('x') @@ -102,3 +124,5 @@ def test_vectorize_if_then_else(): test_vectorize_with_if() test_vectorize_loop() test_vectorize_if_then_else() + test_vectorize_with_le_cond() + test_vectorize_with_ge_cond()