Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TVMMXBridge int64 type mismatch #998

Closed
jwfromm opened this issue Mar 13, 2018 · 3 comments
Closed

TVMMXBridge int64 type mismatch #998

jwfromm opened this issue Mar 13, 2018 · 3 comments

Comments

@jwfromm
Copy link
Contributor

jwfromm commented Mar 13, 2018

It looks like there's a typing error when using to_mxnet_func with 64 bit data types. Although a little lengthy, the below code defines a tvm pipeline that converts floating point tensors to integer tensors with type of either 32 or 64 bit int. When called with with bits=32, the below code works smoothly. However, when bits is set to 64, I get an error saying my argument needs to be int64, even though it definitely is. I suspect for some reason the 64 bit type comparison in the bridge isn't quite working and got skipped in testing.

import mxnet as mx
from mxnet import nd, gluon, autograd
import numpy as np
import tvm
import topi
from topi.util import get_const_tuple
from tvm.contrib.pickle_memoize import memoize
from tvm.contrib.mxnet import to_mxnet_func
from topi.util import get_const_int, simplify
from topi import tag
import time

def binarize_pack(data, bits=32, axis=None, name="PackedInput"):
    """Binarization and bit-packing along a certain axis.
    Parameters
    ----------
    data : tvm.Tensor
        n-D input, can be any layout.
    axis : None or int
        The axis along which to do binarization and bit-packing,
        default is the last axis.
    name : str, optional
        The name prefix operators generate.
    Returns
    -------
    output : tvm.Tensor
        n-D, the same layout as input, dtype is uint32.
    """
    if bits == 32:
        bintype = "int32"
    elif bits == 64:
        bintype = "int64"
    ishape = data.shape
    if axis is None:
        axis = len(ishape) - 1
    assert get_const_int(ishape[axis]) % bits == 0
    n = len(ishape)
    oshape = tuple(simplify(ishape[i] // bits) if i == axis \
        else ishape[i] for i in range(n))

    def _binarize_pack(*indices):
        start_idx = [indices[i] * bits if i == axis else indices[i] for i in range(n)]
        packed = tvm.const(0, bintype)
        for j in range(bits):
            idx = [start_idx[i] + j if i == axis else start_idx[i] for i in range(n)]
            sign = (data(*idx) >= 0).astype(bintype)
            packed = (packed | sign)
            if j == (bits-1):
                return packed
            packed = packed << 1

    return tvm.compute(oshape, _binarize_pack, name=name, tag='binarize_pack')


def binary_dense(data, weight, bits = 32):
    """Binary matrix multiplication using xor and bit-count.
    Parameters
    ----------
    data : tvm.Tensor
        2-D with shape [batch, in_dim], dtype is uint32.
    weight : tvm.Tensor
        2-D with shape [out_dim, in_dim], dtype is uint32.
    Returns
    -------
    output : tvm.Tensor
        2-D with shape [batch, out_dim], dtype is float32.
    """
    if bits == 32:
        bintype = 'int32'
    elif bits == 64:
        bintype = 'int64'
    assert data.dtype == bintype and weight.dtype == bintype, \
        "dtype of data and weight should be match bitwidth"
    assert len(data.shape) == 2 and len(weight.shape) == 2, \
        "only support 2-dim binary dense"
    batch, in_dim = data.shape
    out_dim, _ = weight.shape
    k = tvm.reduce_axis((0, in_dim), name='k')
    matmul = tvm.compute((batch, out_dim), lambda i, j: \
                          tvm.sum(tvm.popcount(data[i, k] ^ weight[j, k]), axis=k), \
                          tag='binary_dense')

    return tvm.compute((batch, out_dim), lambda i, j: \
                        bits * in_dim - 2. * matmul(i, j), \
                        tag=tag.ELEMWISE)

batch = 1
in_dim = 1024
out_dim = 1000
#ctx = mx.cpu()

A = tvm.placeholder((batch, in_dim), name='A')
B = tvm.placeholder((out_dim, in_dim), name='B')
bnn_A = binarize_pack(A, bits=64)
bnn_B = binarize_pack(B, bits=64)
# binary dense
bnn_A1 = tvm.placeholder(bnn_A.shape, dtype=bnn_A.dtype)
bnn_B1 = tvm.placeholder(bnn_B.shape, dtype=bnn_B.dtype)
bnn_C = binary_dense(bnn_A1, bnn_B1, bits=64)
# schedule
with tvm.target.create('llvm'):
    s1 = topi.generic.schedule_binarize_pack(bnn_A)
    s2 = topi.generic.schedule_binarize_pack(bnn_B)
    s3 = topi.generic.schedule_binary_dense(bnn_C)

dtype = A.dtype
# generate random matrix of +1 or -1 value
a_np = (np.random.randint(2, size=(batch, in_dim)) * 2 - 1).astype(dtype)
b_np = (np.random.randint(2, size=(out_dim, in_dim)) * 2 - 1).astype(dtype)
c_np = np.dot(a_np, b_np.T)

ctx = tvm.cpu(0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx)
bnn_a = tvm.nd.array(np.zeros(get_const_tuple(bnn_A.shape), dtype=bnn_A.dtype), ctx)
bnn_b = tvm.nd.array(np.zeros(get_const_tuple(bnn_B.shape), dtype=bnn_B.dtype), ctx)
bnn_c = tvm.nd.array(np.zeros(get_const_tuple(bnn_C.shape), dtype=bnn_C.dtype), ctx)
f1 = tvm.build(s1, [A, bnn_A], 'llvm')
f2 = tvm.build(s2, [B, bnn_B], 'llvm')
f3 = tvm.build(s3, [bnn_A1, bnn_B1, bnn_C], 'llvm')
f1(a, bnn_a)
f2(b, bnn_b)
f3(bnn_a, bnn_b, bnn_c)

mxf1 = to_mxnet_func(f1)
mxf2 = to_mxnet_func(f2)
mxf3 = to_mxnet_func(f3)

ctx = mx.cpu()
xa = mx.nd.array(a_np, ctx)
xb = mx.nd.array(b_np, ctx)
xbnn_a = nd.array(bnn_a.asnumpy()).astype(bnn_a.dtype)
xbnn_b = nd.array(bnn_b.asnumpy()).astype(bnn_b.dtype)
xbnn_c = nd.array(bnn_c.asnumpy()).astype(bnn_c.dtype)

mxf1(xa, xbnn_a)
xbnn_a

Yields the error

---------------------------------------------------------------------------
MXNetError                                Traceback (most recent call last)
/opt/conda/lib/python3.6/site-packages/IPython/core/formatters.py in __call__(self, obj)
    700                 type_pprinters=self.type_printers,
    701                 deferred_pprinters=self.deferred_printers)
--> 702             printer.pretty(obj)
    703             printer.flush()
    704             return stream.getvalue()

/opt/conda/lib/python3.6/site-packages/IPython/lib/pretty.py in pretty(self, obj)
    393                             if callable(meth):
    394                                 return meth(obj, self, cycle)
--> 395             return _default_pprint(obj, self, cycle)
    396         finally:
    397             self.end_group()

/opt/conda/lib/python3.6/site-packages/IPython/lib/pretty.py in _default_pprint(obj, p, cycle)
    508     if _safe_getattr(klass, '__repr__', None) is not object.__repr__:
    509         # A user-provided repr. Find newlines and replace them with p.break_()
--> 510         _repr_pprint(obj, p, cycle)
    511         return
    512     p.begin_group(1, '<')

/opt/conda/lib/python3.6/site-packages/IPython/lib/pretty.py in _repr_pprint(obj, p, cycle)
    699     """A pprint that just redirects to the normal repr function."""
    700     # Find newlines and replace them with p.break_()
--> 701     output = repr(obj)
    702     for idx,output_line in enumerate(output.splitlines()):
    703         if idx:

/incubator-mxnet/python/mxnet/ndarray/ndarray.py in __repr__(self)
    187         """Returns a string representation of the array."""
    188         shape_info = 'x'.join(['%d' % x for x in self.shape])
--> 189         return '\n%s\n<%s %s @%s>' % (str(self.asnumpy()),
    190                                       self.__class__.__name__,
    191                                       shape_info, self.context)

/incubator-mxnet/python/mxnet/ndarray/ndarray.py in asnumpy(self)
   1824             self.handle,
   1825             data.ctypes.data_as(ctypes.c_void_p),
-> 1826             ctypes.c_size_t(data.size)))
   1827         return data
   1828 

/incubator-mxnet/python/mxnet/base.py in check_call(ret)
    147     """
    148     if ret != 0:
--> 149         raise MXNetError(py_str(_LIB.MXGetLastError()))
    150 
    151 

MXNetError: [02:59:47] src/codegen/llvm/llvm_module.cc:59: Check failed: ret == 0 (-1 vs. 0) Assert fail: (((tvm_struct_get(arg1, 0, 5) == (uint8)0) && (tvm_struct_get(arg1, 0, 6) == (uint8)64)) && (tvm_struct_get(arg1, 0, 7) == (uint16)1)), arg1.dtype is expected to be int64

Stack trace returned 10 entries:
[bt] (0) /opt/conda/lib/python3.6/site-packages/tvm-0.2.0-py3.6-linux-x86_64.egg/tvm/libtvm.so(dmlc::StackTrace[abi:cxx11]()+0x5a) [0x7f75246278fa]
[bt] (1) /opt/conda/lib/python3.6/site-packages/tvm-0.2.0-py3.6-linux-x86_64.egg/tvm/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x28) [0x7f75246284e8]
[bt] (2) /opt/conda/lib/python3.6/site-packages/tvm-0.2.0-py3.6-linux-x86_64.egg/tvm/libtvm.so(tvm::codegen::LLVMModuleNode::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::shared_ptr<tvm::runtime::ModuleNode> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#2}::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const+0x18f) [0x7f7524897d3f]
[bt] (3) /opt/conda/lib/python3.6/site-packages/tvm-0.2.0-py3.6-linux-x86_64.egg/tvm/libtvm.so(std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::codegen::LLVMModuleNode::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::shared_ptr<tvm::runtime::ModuleNode> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#2}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)+0x17) [0x7f7524897e27]
[bt] (4) /incubator-mxnet/python/mxnet/../../lib/libmxnet.so(mxnet::TVMFunctor::Run(mxnet::RunContext const&)+0x123) [0x7f74fa452c93]
[bt] (5) /incubator-mxnet/python/mxnet/../../lib/libmxnet.so(+0x32986a2) [0x7f74fa4496a2]
[bt] (6) /incubator-mxnet/python/mxnet/../../lib/libmxnet.so(+0x378a75b) [0x7f74fa93b75b]
[bt] (7) /incubator-mxnet/python/mxnet/../../lib/libmxnet.so(mxnet::engine::ThreadedEngine::ExecuteOprBlock(mxnet::RunContext, mxnet::engine::OprBlock*)+0x2ba) [0x7f74fa934b8a]
[bt] (8) /incubator-mxnet/python/mxnet/../../lib/libmxnet.so(std::_Function_handler<void (std::shared_ptr<dmlc::ManualEvent>), mxnet::engine::ThreadedEnginePerDevice::PushToExecute(mxnet::engine::OprBlock*, bool)::{lambda()#1}::operator()() const::{lambda(std::shared_ptr<dmlc::ManualEvent>)#1}>::_M_invoke(std::_Any_data const&, std::shared_ptr<dmlc::ManualEvent>&&)+0xe2) [0x7f74fa94d092]
[bt] (9) /incubator-mxnet/python/mxnet/../../lib/libmxnet.so(std::thread::_Impl<std::_Bind_simple<std::function<void (std::shared_ptr<dmlc::ManualEvent>)> (std::shared_ptr<dmlc::ManualEvent>)> >::_M_run()+0x4a) [0x7f74fa9472ca]
@tqchen
Copy link
Member

tqchen commented Mar 13, 2018

This might due to the problem of MXNet's DLPack layer creating wrong DLTensor, let me take a look

@tqchen
Copy link
Member

tqchen commented Mar 13, 2018

confirmed the upstream problem in MXNet, will send a fix

@tqchen
Copy link
Member

tqchen commented Mar 13, 2018

fix in apache/mxnet#10083

@tqchen tqchen closed this as completed Mar 13, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants