Skip to content

Commit

Permalink
[SPARSE] adjust implement regarding to suggestions;
Browse files Browse the repository at this point in the history
  • Loading branch information
liangfu committed Jun 20, 2018
1 parent 0a6c056 commit 3773516
Show file tree
Hide file tree
Showing 2 changed files with 280 additions and 0 deletions.
254 changes: 254 additions & 0 deletions python/tvm/contrib/sparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
"""Tensor and Operation class for computation declaration."""
# pylint: disable=invalid-name
from __future__ import absolute_import as _abs
from .._ffi.node import NodeBase, NodeGeneric, register_node, convert_to_node
from .. import _api_internal
from .. import make as _make
from .. import expr as _expr
from .. import api as _api

class TensorSlice(NodeGeneric, _expr.ExprOp):
"""Auxiliary data structure for enable slicing syntax from tensor."""
def __init__(self, tensor, indices):
if not isinstance(indices, tuple):
indices = (indices,)
self.tensor = tensor
self.indices = indices

def __getitem__(self, indices):
if not isinstance(indices, tuple):
indices = (indices,)
return TensorSlice(self.tensor, self.indices + indices)

def asnode(self):
"""Convert slice to node."""
return self.tensor(*self.indices)

@property
def dtype(self):
"""Data content of the tensor."""
return self.tensor.dtype

itervar_cls = None

@register_node
class CSRTensor(NodeBase, _expr.ExprOp):
"""Tensor object, to construct, see function.Tensor"""
def __call__(self, *indices):
ndim = self.ndim
if len(indices) != ndim:
raise ValueError("Need to provide %d index in tensor slice" % ndim)
indices = convert_to_node(indices)
args = []
for x in indices:
if isinstance(x, _expr.Expr):
args.append(x)
elif isinstance(x, iter_var_cls):
args.append(x.var)
else:
raise ValueError("The indices must be expression")

return _make.Call(self.dtype, self.op.name,
args, _expr.Call.Halide,
self.op, self.value_index)

def __getitem__(self, indices):
return TensorSlice(self, indices)

def __hash__(self):
return _api_internal._TensorHash(self)

def __eq__(self, other):
if not isinstance(other, Tensor):
if isinstance(other, _expr.ExprOp):
return _expr.EqualOp(self, other)
return False
if self.ndim == 0 and other.ndim == 0:
raise ValueError("Equal == comparison among rank-0 tensor is ambiguous, "
"use Tensor.equal for content expression equvalence, "
"use Tensor.same_as for exact reference comparison")
return _api_internal._TensorEqual(self, other)

@property
def ndim(self):
"""Dimension of the tensor."""
return len(self.shape)

@property
def axis(self):
"""Axis of the tensor."""
return self.__getattr__("axis")

@property
def op(self):
"""The corressponding :any:`Operation`."""
return self.__getattr__("op")

@property
def value_index(self):
"""The output value index the tensor corressponds to."""
return self.__getattr__("value_index")

@property
def shape(self):
"""The output shape of the tensor."""
return self.__getattr__("shape")

@property
def name(self):
op = self.op
if op.num_outputs == 1:
return op.name
return "%s.v%d" % (op.name, self.value_index)


class Operation(NodeBase):
"""Represent an operation that generate a tensor"""
def output(self, index):
"""Get the index-th output of the operation
Parameters
----------
index : int
The index size.
Returns
-------
out : Tensor
The i-th output.
"""
return _api_internal._OpGetOutput(self, index)

@property
def num_outputs(self):
"""Number of outputs of this op."""
return _api_internal._OpNumOutputs(self)

@property
def input_tensors(self):
"""List of input tensors to this op."""
return _api_internal._OpInputTensors(self)


@register_node
class PlaceholderOp(Operation):
"""Placeholder operation."""
pass


@register_node
class ComputeOp(Operation):
"""Compute operation."""
@property
def axis(self):
"""Represent axis of IterVar, only defined when it is a ComputeOp"""
return self.__getattr__("axis")

@property
def reduce_axis(self):
"""Represent axis of reductions, only defined when it is a ComputeOp"""
return self.__getattr__("reduce_axis")


@register_node
class ScanOp(Operation):
"""Scan operation."""
@property
def scan_axis(self):
"""Represent axis of scan, only defined when it is a ScanOp"""
return self.__getattr__("scan_axis")


@register_node
class ExternOp(Operation):
"""Extern operation."""
pass

float32 = "float32"
csr = "csr"

class Placeholder(object):
def __init__(self, shape, dtype, name, stype):
super(Placeholder, self).__init__()
self.shape = shape
self.dtype = dtype
self.name = name
self.stype = stype
shape = (0,)
self.data = _api.placeholder(shape, dtype, name+'_data')
self.indices = _api.placeholder(shape, dtype, name+'_indices')
self.indptr = _api.placeholder(shape, dtype, name+'_indptr')

def placeholder(shape, dtype=None, name="placeholder", stype=None):
"""Construct an empty tensor object.
Parameters
----------
shape: Tuple of Expr
The shape of the tensor
dtype: str, optional
The data type of the tensor
name: str, optional
The name hint of the tensor
Returns
-------
tensor: Tensor
The created tensor
"""
shape = (shape,) if isinstance(shape, _expr.Expr) else shape
dtype = float32 if dtype is None else dtype
stype = csr if stype is None else stype
return Placeholder(shape, dtype, name, stype)

def compute(shape, fcompute, name="compute", tag=""):
"""Construct a new tensor by computing over the shape domain.
The compute rule is result[axis] = fcompute(axis)
Parameters
----------
shape: Tuple of Expr
The shape of the tensor
fcompute: lambda function of indices-> value
Specifies the input source expression
name: str, optional
The name hint of the tensor
Returns
-------
tensor: Tensor
The created tensor
"""
if _tag.TagScope.current is not None:
if tag != "":
raise ValueError("nested tag is not allowed for now")
tag = _tag.TagScope.current.tag
shape = (shape,) if isinstance(shape, _expr.Expr) else shape
ndim = len(shape)
code = fcompute.__code__

if fcompute.__code__.co_argcount == 0:
arg_names = ["i%d" % i for i in range(ndim)]
else:
arg_names = code.co_varnames[:code.co_argcount]

if ndim != len(arg_names):
raise ValueError("fcompute do not match dimension, ndim=%d" % ndim)

dim_var = [_IterVar((0, s), x, 0) for x, s in zip(arg_names, shape)]
body = fcompute(*[v.var for v in dim_var])
if not isinstance(body, (list, tuple)):
body = [body]
body = convert(body)
op_node = _api_internal._ComputeOp(
name, tag, dim_var, body)
num = op_node.num_outputs
outputs = tuple(op_node.output(i) for i in range(num))
return outputs[0] if num == 1 else outputs


26 changes: 26 additions & 0 deletions tests/python/contrib/test_sparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os, sys
thisdir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.join(thisdir, '../../../python'))

import tvm
import tvm.contrib.sparse as tvmsp

def test_tensor():
dtype = 'float32'
stype = 'csr'
m = tvm.var('m')
n = tvm.var('n')
l = tvm.var('l')
A = tvmsp.placeholder((m, ), name='A', stype=stype, dtype=dtype)
B = tvmsp.placeholder((n, ), name='B', stype=stype, dtype=dtype)
print(A)
assert(A.stype == 'csr')
assert(B.stype == 'csr')
assert(A.data.shape == (0,))
assert(A.indices.shape == (0,))
assert(A.indptr.shape == (0,))
# T = tvmsp.compute((m, n, l), lambda i, j, k: A[i, k] * B[j, k])

if __name__ == "__main__":
test_tensor()

0 comments on commit 3773516

Please sign in to comment.