From a5def36f55bf45c6e1492769a45fae71958a5478 Mon Sep 17 00:00:00 2001 From: Andrew Tulloch Date: Thu, 29 Aug 2019 17:25:07 -0700 Subject: [PATCH] codegen_spirv support Call::reinterpret (#3795) --- src/codegen/spirv/codegen_spirv.cc | 3 + src/codegen/spirv/ir_builder.cc | 79 ++++++++++---------- tests/python/unittest/test_codegen_vulkan.py | 58 ++++++++++++++ 3 files changed, 99 insertions(+), 41 deletions(-) create mode 100644 tests/python/unittest/test_codegen_vulkan.py diff --git a/src/codegen/spirv/codegen_spirv.cc b/src/codegen/spirv/codegen_spirv.cc index 7686250c5ce5..7caf3a258b6f 100644 --- a/src/codegen/spirv/codegen_spirv.cc +++ b/src/codegen/spirv/codegen_spirv.cc @@ -283,6 +283,9 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const Call* op) { } else { return builder_->MakeValue(spv::OpShiftRightLogical, a.stype, a, b); } + } else if (op->is_intrinsic(Call::reinterpret)) { + return builder_->MakeValue(spv::OpBitcast, builder_->GetSType(op->type), + MakeValue(op->args[0])); } else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) { return this->CreateStorageSync(op); } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) { diff --git a/src/codegen/spirv/ir_builder.cc b/src/codegen/spirv/ir_builder.cc index d6ba9e40c123..6afd3112021d 100644 --- a/src/codegen/spirv/ir_builder.cc +++ b/src/codegen/spirv/ir_builder.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -522,17 +522,17 @@ Value IRBuilder::Cast(const SType& dst_type, spirv::Value value) { } \ } -#define DEFINE_BUILDER_BINARY_SIGN_OP(_OpName, _Op) \ - Value IRBuilder::_OpName(Value a, Value b) { \ - CHECK_EQ(a.stype.id, b.stype.id); \ - if (a.stype.type.is_int()) { \ - return MakeValue(spv::OpS ## _Op, a.stype, a, b); \ - } else if (a.stype.type.is_uint()) { \ - return MakeValue(spv::OpU ## _Op, a.stype, a, b); \ - } else { \ - CHECK(a.stype.type.is_float()); \ - return MakeValue(spv::OpF ## _Op, a.stype, a, b); \ - } \ +#define DEFINE_BUILDER_BINARY_SIGN_OP(_OpName, _Op) \ + Value IRBuilder::_OpName(Value a, Value b) { \ + CHECK_EQ(a.stype.id, b.stype.id); \ + if (a.stype.type.is_int()) { \ + return MakeValue(spv::OpS##_Op, a.stype, a, b); \ + } else if (a.stype.type.is_uint()) { \ + return MakeValue(spv::OpU##_Op, a.stype, a, b); \ + } else { \ + CHECK(a.stype.type.is_float()); \ + return MakeValue(spv::OpF##_Op, a.stype, a, b); \ + } \ } DEFINE_BUILDER_BINARY_USIGN_OP(Add, Add); @@ -552,21 +552,19 @@ Value IRBuilder::Mod(Value a, Value b) { } } - -#define DEFINE_BUILDER_CMP_OP(_OpName, _Op) \ - Value IRBuilder:: _OpName(Value a, Value b) { \ - CHECK_EQ(a.stype.id, b.stype.id); \ - if (t_bool_.id == 0) { \ - t_bool_ = DeclareType(UInt(1)); \ - } \ - if (a.stype.type.is_int()) { \ - return MakeValue(spv::OpS ## _Op, t_bool_, a, b); \ - } else if (a.stype.type.is_uint()) { \ - return MakeValue(spv::OpU ## _Op, t_bool_, a, b); \ - } else { \ - CHECK(a.stype.type.is_float()); \ - return MakeValue(spv::OpFOrd ## _Op, t_bool_, a, b); \ - } \ +#define DEFINE_BUILDER_CMP_OP(_OpName, _Op) \ + Value IRBuilder::_OpName(Value a, Value b) { \ + CHECK_EQ(a.stype.id, b.stype.id); \ + CHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \ + const auto& bool_type = this->GetSType(UInt(1).with_lanes(a.stype.type.lanes())); \ + if (a.stype.type.is_int()) { \ + return MakeValue(spv::OpS##_Op, bool_type, a, b); \ + } else if (a.stype.type.is_uint()) { \ + return MakeValue(spv::OpU##_Op, bool_type, a, b); \ + } else { \ + CHECK(a.stype.type.is_float()); \ + return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \ + } \ } DEFINE_BUILDER_CMP_OP(LT, LessThan); @@ -574,18 +572,17 @@ DEFINE_BUILDER_CMP_OP(LE, LessThanEqual); DEFINE_BUILDER_CMP_OP(GT, GreaterThan); DEFINE_BUILDER_CMP_OP(GE, GreaterThanEqual); -#define DEFINE_BUILDER_CMP_UOP(_OpName, _Op) \ - Value IRBuilder:: _OpName(Value a, Value b) { \ - CHECK_EQ(a.stype.id, b.stype.id); \ - if (t_bool_.id == 0) { \ - t_bool_ = DeclareType(UInt(1)); \ - } \ - if (a.stype.type.is_int() || a.stype.type.is_uint()) { \ - return MakeValue(spv::OpI ## _Op, t_bool_, a, b); \ - } else { \ - CHECK(a.stype.type.is_float()); \ - return MakeValue(spv::OpFOrd ## _Op, t_bool_, a, b); \ - } \ +#define DEFINE_BUILDER_CMP_UOP(_OpName, _Op) \ + Value IRBuilder::_OpName(Value a, Value b) { \ + CHECK_EQ(a.stype.id, b.stype.id); \ + CHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \ + const auto& bool_type = this->GetSType(UInt(1).with_lanes(a.stype.type.lanes())); \ + if (a.stype.type.is_int() || a.stype.type.is_uint()) { \ + return MakeValue(spv::OpI##_Op, bool_type, a, b); \ + } else { \ + CHECK(a.stype.type.is_float()); \ + return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \ + } \ } DEFINE_BUILDER_CMP_UOP(EQ, Equal); @@ -593,7 +590,7 @@ DEFINE_BUILDER_CMP_UOP(NE, NotEqual); Value IRBuilder::Select(Value cond, Value a, Value b) { CHECK_EQ(a.stype.id, b.stype.id); - CHECK_EQ(cond.stype.type, UInt(1)); + CHECK_EQ(cond.stype.type.element_of(), UInt(1)); return MakeValue(spv::OpSelect, a.stype, cond, a, b); } diff --git a/tests/python/unittest/test_codegen_vulkan.py b/tests/python/unittest/test_codegen_vulkan.py new file mode 100644 index 000000000000..2d7edffae7bf --- /dev/null +++ b/tests/python/unittest/test_codegen_vulkan.py @@ -0,0 +1,58 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import re + + +def test_vector_comparison(): + if not tvm.module.enabled("vulkan"): + print("Skipping due to no Vulkan module") + return + + target = 'vulkan' + + def check_correct_assembly(dtype): + n = (1024,) + A = tvm.placeholder(n, dtype=dtype, name='A') + B = tvm.compute( + A.shape, + lambda i: tvm.expr.Select( + A[i] >= 0, A[i] + tvm.const(1, dtype), + tvm.const(0, dtype)), name='B') + s = tvm.create_schedule(B.op) + + (bx, tx) = s[B].split(s[B].op.axis[0], factor=128) + (tx, vx) = s[B].split(tx, factor=4) + s[B].bind(bx, tvm.thread_axis("blockIdx.x")) + s[B].bind(tx, tvm.thread_axis("threadIdx.x")) + s[B].vectorize(vx) + f = tvm.build(s, [A, B], target) + + # Verify we generate the boolx4 type declaration and the OpSelect + # v4{float,half,int} instruction + assembly = f.imported_modules[0].get_source() + matches = re.findall("%v4bool = OpTypeVector %bool 4", assembly) + assert len(matches) == 1 + matches = re.findall("OpSelect %v4.*", assembly) + assert len(matches) == 1 + check_correct_assembly('float32') + check_correct_assembly('int32') + check_correct_assembly('float16') + + +if __name__ == "__main__": + test_vector_comparison()