Skip to content

Commit

Permalink
codegen_spirv support Call::reinterpret
Browse files Browse the repository at this point in the history
  • Loading branch information
ajtulloch committed Aug 28, 2019
1 parent 062f8cc commit ec89864
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 41 deletions.
3 changes: 3 additions & 0 deletions src/codegen/spirv/codegen_spirv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,9 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const Call* op) {
} else {
return builder_->MakeValue(spv::OpShiftRightLogical, a.stype, a, b);
}
} else if (op->is_intrinsic(Call::reinterpret)) {
return builder_->MakeValue(spv::OpBitcast, builder_->GetSType(op->type),
MakeValue(op->args[0]));
} else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) {
return this->CreateStorageSync(op);
} else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
Expand Down
79 changes: 38 additions & 41 deletions src/codegen/spirv/ir_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand Down Expand Up @@ -522,17 +522,17 @@ Value IRBuilder::Cast(const SType& dst_type, spirv::Value value) {
} \
}

#define DEFINE_BUILDER_BINARY_SIGN_OP(_OpName, _Op) \
Value IRBuilder::_OpName(Value a, Value b) { \
CHECK_EQ(a.stype.id, b.stype.id); \
if (a.stype.type.is_int()) { \
return MakeValue(spv::OpS ## _Op, a.stype, a, b); \
} else if (a.stype.type.is_uint()) { \
return MakeValue(spv::OpU ## _Op, a.stype, a, b); \
} else { \
CHECK(a.stype.type.is_float()); \
return MakeValue(spv::OpF ## _Op, a.stype, a, b); \
} \
#define DEFINE_BUILDER_BINARY_SIGN_OP(_OpName, _Op) \
Value IRBuilder::_OpName(Value a, Value b) { \
CHECK_EQ(a.stype.id, b.stype.id); \
if (a.stype.type.is_int()) { \
return MakeValue(spv::OpS##_Op, a.stype, a, b); \
} else if (a.stype.type.is_uint()) { \
return MakeValue(spv::OpU##_Op, a.stype, a, b); \
} else { \
CHECK(a.stype.type.is_float()); \
return MakeValue(spv::OpF##_Op, a.stype, a, b); \
} \
}

DEFINE_BUILDER_BINARY_USIGN_OP(Add, Add);
Expand All @@ -552,48 +552,45 @@ Value IRBuilder::Mod(Value a, Value b) {
}
}


#define DEFINE_BUILDER_CMP_OP(_OpName, _Op) \
Value IRBuilder:: _OpName(Value a, Value b) { \
CHECK_EQ(a.stype.id, b.stype.id); \
if (t_bool_.id == 0) { \
t_bool_ = DeclareType(UInt(1)); \
} \
if (a.stype.type.is_int()) { \
return MakeValue(spv::OpS ## _Op, t_bool_, a, b); \
} else if (a.stype.type.is_uint()) { \
return MakeValue(spv::OpU ## _Op, t_bool_, a, b); \
} else { \
CHECK(a.stype.type.is_float()); \
return MakeValue(spv::OpFOrd ## _Op, t_bool_, a, b); \
} \
#define DEFINE_BUILDER_CMP_OP(_OpName, _Op) \
Value IRBuilder::_OpName(Value a, Value b) { \
CHECK_EQ(a.stype.id, b.stype.id); \
CHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \
const auto& bool_type = this->GetSType(UInt(1).with_lanes(a.stype.type.lanes())); \
if (a.stype.type.is_int()) { \
return MakeValue(spv::OpS##_Op, bool_type, a, b); \
} else if (a.stype.type.is_uint()) { \
return MakeValue(spv::OpU##_Op, bool_type, a, b); \
} else { \
CHECK(a.stype.type.is_float()); \
return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \
} \
}

DEFINE_BUILDER_CMP_OP(LT, LessThan);
DEFINE_BUILDER_CMP_OP(LE, LessThanEqual);
DEFINE_BUILDER_CMP_OP(GT, GreaterThan);
DEFINE_BUILDER_CMP_OP(GE, GreaterThanEqual);

#define DEFINE_BUILDER_CMP_UOP(_OpName, _Op) \
Value IRBuilder:: _OpName(Value a, Value b) { \
CHECK_EQ(a.stype.id, b.stype.id); \
if (t_bool_.id == 0) { \
t_bool_ = DeclareType(UInt(1)); \
} \
if (a.stype.type.is_int() || a.stype.type.is_uint()) { \
return MakeValue(spv::OpI ## _Op, t_bool_, a, b); \
} else { \
CHECK(a.stype.type.is_float()); \
return MakeValue(spv::OpFOrd ## _Op, t_bool_, a, b); \
} \
#define DEFINE_BUILDER_CMP_UOP(_OpName, _Op) \
Value IRBuilder::_OpName(Value a, Value b) { \
CHECK_EQ(a.stype.id, b.stype.id); \
CHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \
const auto& bool_type = this->GetSType(UInt(1).with_lanes(a.stype.type.lanes())); \
if (a.stype.type.is_int() || a.stype.type.is_uint()) { \
return MakeValue(spv::OpI##_Op, bool_type, a, b); \
} else { \
CHECK(a.stype.type.is_float()); \
return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \
} \
}

DEFINE_BUILDER_CMP_UOP(EQ, Equal);
DEFINE_BUILDER_CMP_UOP(NE, NotEqual);

Value IRBuilder::Select(Value cond, Value a, Value b) {
CHECK_EQ(a.stype.id, b.stype.id);
CHECK_EQ(cond.stype.type, UInt(1));
CHECK_EQ(cond.stype.type.element_of(), UInt(1));
return MakeValue(spv::OpSelect, a.stype, cond, a, b);
}

Expand Down
58 changes: 58 additions & 0 deletions tests/python/unittest/test_codegen_vulkan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import re


def test_vector_comparison():
if not tvm.module.enabled("vulkan"):
print("Skipping due to no Vulkan module")
return

target = 'vulkan'

def check_correct_assembly(dtype):
n = (1024,)
A = tvm.placeholder(n, dtype=dtype, name='A')
B = tvm.compute(
A.shape,
lambda i: tvm.expr.Select(
A[i] >= 0, A[i] + tvm.const(1, dtype),
tvm.const(0, dtype)), name='B')
s = tvm.create_schedule(B.op)

(bx, tx) = s[B].split(s[B].op.axis[0], factor=128)
(tx, vx) = s[B].split(tx, factor=4)
s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
s[B].bind(tx, tvm.thread_axis("threadIdx.x"))
s[B].vectorize(vx)
f = tvm.build(s, [A, B], target)

# Verify we generate the boolx4 type declaration and the OpSelect
# v4{float,half,int} instruction
assembly = f.imported_modules[0].get_source()
matches = re.findall("%v4bool = OpTypeVector %bool 4", assembly)
assert len(matches) == 1
matches = re.findall("OpSelect %v4.*", assembly)
assert len(matches) == 1
check_correct_assembly('float32')
check_correct_assembly('int32')
check_correct_assembly('float16')


if __name__ == "__main__":
test_vector_comparison()

0 comments on commit ec89864

Please sign in to comment.