Skip to content

Commit

Permalink
[TOPI] Bitserial low-precision convolution (apache#1332)
Browse files Browse the repository at this point in the history
  • Loading branch information
Meghan Cowan authored and sergei-mironov committed Aug 8, 2018
1 parent 240d6a3 commit 38c01ba
Show file tree
Hide file tree
Showing 12 changed files with 1,273 additions and 0 deletions.
25 changes: 25 additions & 0 deletions python/tvm/intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,31 @@ def call_extern(dtype, func_name, *args):
dtype, func_name, convert(args), _Call.Extern, None, 0)


def call_llvm_intrin(dtype, name, *args):
"""Build expression by calling an llvm intrinsic function
Parameters
----------
dtype : str
The data type of the result.
name : str
The name of the llvm intrinsic function.
args : list
Poistional arguments.
Returns
-------
call : Expr
The call expression.
"""
import tvm
llvm_id = tvm.codegen.llvm_lookup_intrinsic_id(name)
assert llvm_id != 0, "%s is not an LLVM intrinsic" % name
return call_pure_intrin(dtype, 'llvm_intrin', tvm.const(llvm_id, 'uint32'), *args)


def exp(x):
"""Take exponetial of input x.
Expand Down
9 changes: 9 additions & 0 deletions src/codegen/llvm/llvm_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,15 @@ class LLVMModuleNode final : public runtime::ModuleNode {
std::shared_ptr<llvm::LLVMContext> ctx_;
};

unsigned LookupLLVMIntrinsic(const std::string& name) {
return llvm::Function::lookupIntrinsicID(name);
}

TVM_REGISTER_API("codegen.llvm_lookup_intrinsic_id")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = static_cast<int64_t>(LookupLLVMIntrinsic(args[0]));
});

TVM_REGISTER_API("codegen.build_llvm")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::shared_ptr<LLVMModuleNode> n = std::make_shared<LLVMModuleNode>();
Expand Down
11 changes: 11 additions & 0 deletions tests/python/unittest/test_codegen_llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@ def test_llvm_intrin():
func = tvm.ir_pass.MakeAPI(body, "prefetch", [A], 0, True)
fcode = tvm.build(func, None, "llvm")

def test_llvm_lookup_intrin():
ib = tvm.ir_builder.create()
m = tvm.var("m")
A = ib.pointer("uint8x8", name="A")
x = tvm.call_llvm_intrin("uint8x8", "llvm.ctpop.i8", tvm.const(1, 'uint32'), A)
ib.emit(x)
body = ib.get()
func = tvm.ir_pass.MakeAPI(body, "ctpop", [A], 1, True)
fcode = tvm.build(func, None, "llvm")

def test_llvm_add_pipeline():
nn = 1024
n = tvm.convert(nn)
Expand Down Expand Up @@ -324,3 +334,4 @@ def test_alignment():
test_llvm_flip_pipeline()
test_llvm_madd_pipeline()
test_llvm_temp_space()
test_llvm_lookup_intrin()
35 changes: 35 additions & 0 deletions topi/python/topi/generic/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,41 @@ def schedule_depthwise_conv2d_nhwc(outs):
"""
return _default_schedule(outs, False)

@tvm.target.generic_func
def schedule_bitserial_conv2d_nchw(outs):
"""Schedule for bitserial_conv2d_nchw
Parameters
----------
outs: Array of Tensor
The computation graph description of bitserial_conv2d_nchw
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)


@tvm.target.generic_func
def schedule_bitserial_conv2d_nhwc(outs):
"""Schedule for bitserial_conv2d_nhwc
Parameters
----------
outs: Array of Tensor
The computation graph description of bitserial_conv2d_nchw
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)


@tvm.target.override_native_generic_func("schedule_reduce")
def schedule_reduce(outs):
Expand Down
1 change: 1 addition & 0 deletions topi/python/topi/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@
from .bnn import *
from .upsampling import *
from .local_response_norm import *
from .bitserial_conv2d import *
from .l2_normalize import *
Loading

0 comments on commit 38c01ba

Please sign in to comment.