Skip to content

Commit

Permalink
[TOPI] FIFO buffer op, to accelerate sequence modeling with dilated c…
Browse files Browse the repository at this point in the history
…onvolutions (apache#4039)

* Add FIFO buffer op to enable explicit computation re-use in convolution

* Add a test

* Add end-to-end test with 1D convolution

* Add a stub in MXNet frontend

* Address reviewer comments

* Add back stub for MXNet frontend
  • Loading branch information
hcho3 authored and Animesh Jain committed Oct 17, 2019
1 parent 20e30dd commit 83eb7e9
Show file tree
Hide file tree
Showing 9 changed files with 462 additions and 0 deletions.
9 changes: 9 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,15 @@ struct SparseTransposeAttrs : public tvm::AttrsNode<SparseTransposeAttrs> {
TVM_DECLARE_ATTRS(SparseTransposeAttrs, "relay.attrs.SparseTransposeAttrs") {}
};

/*! \brief Attributes for FIFO buffer operator */
struct FIFOBufferAttrs : public tvm::AttrsNode<FIFOBufferAttrs> {
int axis;

TVM_DECLARE_ATTRS(FIFOBufferAttrs, "relay.attrs.FIFOBufferAttrs") {
TVM_ATTR_FIELD(axis).set_default(0);
}
};

/*! \brief Attributes for upsampling operator */
struct UpSamplingAttrs : public tvm::AttrsNode<UpSamplingAttrs> {
int scale;
Expand Down
7 changes: 7 additions & 0 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,6 +1026,12 @@ def _mx_one_hot(inputs, attrs):
return _op.one_hot(indices, on_value, off_value, depth, -1, dtype)


def _mx_contrib_fifo_buffer(inputs, attrs):
new_attrs = {}
new_attrs['axis'] = attrs.get_int('axis')
return _op.nn.fifo_buffer(*inputs, **new_attrs)


# Note: due to attribute conversion constraint
# ops in the identity set must be attribute free
_identity_list = [
Expand Down Expand Up @@ -1198,6 +1204,7 @@ def _mx_one_hot(inputs, attrs):
# TODO(tvm-tvm): support all operators.
#
# "broadcast_to",
"contrib_fifo_buffer" : _mx_contrib_fifo_buffer,
}

# set identity list
Expand Down
14 changes: 14 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,20 @@ def schedule_dense(attrs, outputs, target):
reg.register_pattern("nn.dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)


@reg.register_compute('nn.fifo_buffer')
def compute_fifo_buffer(attrs, inputs, out_type, target):
return [topi.nn.fifo_buffer(inputs[0], inputs[1], axis=attrs.get_int('axis'))]


@reg.register_schedule('nn.fifo_buffer')
def schedule_fifo_buffer(attrs, outputs, target):
with target:
return topi.generic.schedule_injective(outputs)


reg.register_pattern("nn.fifo_buffer", OpPattern.OPAQUE)


# batch_matmul
@reg.register_compute("nn.batch_matmul")
def compute_batch_matmul(attrs, inputs, out_type, target):
Expand Down
30 changes: 30 additions & 0 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,36 @@ def dense(data, weight, units=None, out_dtype=""):
return _make.dense(data, weight, units, out_dtype)


def fifo_buffer(data, buffer, axis):
"""FIFO buffer
Compute equivalent of
```
concat(buffer, data, axis=axis) \
.slice_axis(axis=axis, begin=data.shape[axis], end=data.shape[axis]+buffer.shape[axis])
```
Useful for
* Encoding explicit re-use of computation in convolution ops operated on a sliding window input
* Implementing a FIFO queue to cache intermediate results, e.g. as in Fast WaveNet.
Parameters
----------
data : tvm.relay.Expr
The input data
buffer : tvm.relay.Expr
Previous value of the FIFO buffer
axis : int
Specify which axis should be used for buffering
Returns
-------
result : tvm.relay.Expr
Updated value for the buffer
"""
return _make.fifo_buffer(data, buffer, axis)


def relu(data):
"""Rectified linear unit.
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ class DenseAttrs(Attrs):
"""Attributes for nn.dense"""


@register_relay_attr_node
class FIFOBufferAttrs(Attrs):
"""Attributes for nn.fifo_buffer"""


@register_relay_attr_node
class UpSamplingAttrs(Attrs):
"""Attributes for nn.upsampling"""
Expand Down
67 changes: 67 additions & 0 deletions src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,73 @@ RELAY_REGISTER_OP("nn.bias_add")
});


// relay.nn.fifo_buffer
TVM_REGISTER_NODE_TYPE(FIFOBufferAttrs);

Expr MakeFIFOBuffer(Expr input, Expr buffer, int axis) {
auto attrs = make_node<FIFOBufferAttrs>();
attrs->axis = axis;
static const Op& op = Op::Get("nn.fifo_buffer");
return CallNode::make(op, {input, buffer}, Attrs(attrs), {});
}

bool FIFOBufferRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* input = types[0].as<TensorTypeNode>();
const auto* buffer = types[1].as<TensorTypeNode>();
const FIFOBufferAttrs* param = attrs.as<FIFOBufferAttrs>();
if (input == nullptr || buffer == nullptr) {
return false;
}
CHECK(param != nullptr);
CHECK_EQ(input->shape.size(), buffer->shape.size());

const size_t buffer_axis
= static_cast<size_t>(param->axis < 0 ? static_cast<int>(buffer->shape.size()) + param->axis
: param->axis);

reporter->Assert(buffer_axis < buffer->shape.size());
for (size_t i = 0; i < buffer->shape.size(); ++i) {
if (i != buffer_axis) {
reporter->AssertEQ(input->shape[i], buffer->shape[i]);
}
}
reporter->Assert(input->shape[buffer_axis] < buffer->shape[buffer_axis]);

Array<tvm::Expr> oshape = buffer->shape;

reporter->Assign(types[2], TensorTypeNode::make(oshape, buffer->dtype));
return true;
}

TVM_REGISTER_API("relay.op.nn._make.fifo_buffer")
.set_body_typed(MakeFIFOBuffer);

RELAY_REGISTER_OP("nn.fifo_buffer")
.describe(R"code(FIFO buffer
Compute equivalent of
```
concat(buffer, data, axis=axis) \
.slice_axis(axis=axis, begin=data.shape[axis], end=data.shape[axis]+buffer.shape[axis])
```
Useful for
* Encoding explicit re-use of computation in convolution ops operated on a sliding window input
* Implementing a FIFO queue to cache intermediate results, e.g. as in Fast WaveNet.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.FIFOBufferAttrs")
.set_num_inputs(2)
.add_argument("data", "Tensor", "Latest input")
.add_argument("buffer", "Tensor",
"Buffer storing latest [length_buffer] inputs")
.set_support_level(3)
.add_type_rel("FIFOBuffer", FIFOBufferRel);


// relay.nn.dense
TVM_REGISTER_NODE_TYPE(DenseAttrs);

Expand Down
1 change: 1 addition & 0 deletions topi/python/topi/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@
from .batch_matmul import *
from .sparse import *
from .pad import *
from .fifo_buffer import *
127 changes: 127 additions & 0 deletions topi/python/topi/nn/fifo_buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""FIFO buffer op"""
from __future__ import absolute_import as _abs
import tvm
from .. import tag
from ..transform import concatenate, strided_slice

@tvm.tag_scope(tag=tag.INJECTIVE+",fifo_buffer")
def fifo_buffer(data, buffer, axis):
"""
Implements the FIFO buffer
"""
assert len(data.shape) == len(buffer.shape), \
'buffer and data must have same number of dimensions, ' + \
'buffer.shape = {}, data.shape = {}'.format(buffer.shape, data.shape)
assert len(buffer.shape) >= 1, 'Zero-dimension tensor not supported'
assert 0 <= axis < len(buffer.shape), 'buffer axis out of range'
for i in range(len(data.shape)):
if i == axis:
assert int(str(data.shape[i])) <= int(str(buffer.shape[i]))
else:
assert int(str(data.shape[i])) == int(str(buffer.shape[i]))

buflen = buffer.shape[axis]
data_size = data.shape[axis]

# Explicitly write out formula up to 4D, and then use concat+slice combo for 5D and higher
if len(buffer.shape) == 1:
return tvm.compute(buffer.shape,
lambda i:
tvm.if_then_else(i < buflen - data_size,
buffer[i + data_size],
data[i - buflen + data_size]),
name='new_buffer')
elif len(buffer.shape) == 2:
if axis == 0:
return tvm.compute(buffer.shape,
lambda i, j:
tvm.if_then_else(i < buflen - data_size,
buffer[i + data_size, j],
data[i - buflen + data_size, j]),
name='new_buffer')
if axis == 1:
return tvm.compute(buffer.shape,
lambda i, j:
tvm.if_then_else(j < buflen - data_size,
buffer[i, j + data_size],
data[i, j - buflen + data_size]),
name='new_buffer')
assert False, 'Invalid value for axis; it should be at most {}'.format(len(buffer.shape))
elif len(buffer.shape) == 3:
if axis == 0:
return tvm.compute(buffer.shape,
lambda i, j, k:
tvm.if_then_else(i < buflen - data_size,
buffer[i + data_size, j, k],
data[i - buflen + data_size, j, k]),
name='new_buffer')
if axis == 1:
return tvm.compute(buffer.shape,
lambda i, j, k:
tvm.if_then_else(j < buflen - data_size,
buffer[i, j + data_size, k],
data[i, j - buflen + data_size, k]),
name='new_buffer')
if axis == 2:
return tvm.compute(buffer.shape,
lambda i, j, k:
tvm.if_then_else(k < buflen - data_size,
buffer[i, j, k + data_size],
data[i, j, k - buflen + data_size]),
name='new_buffer')
assert False, 'Invalid value for axis; it should be at most {}'.format(len(buffer.shape))
elif len(buffer.shape) == 4:
if axis == 0:
return tvm.compute(buffer.shape,
lambda i, j, k, l:
tvm.if_then_else(i < buflen - data_size,
buffer[i + data_size, j, k, l],
data[i - buflen + data_size, j, k, l]),
name='new_buffer')
if axis == 1:
return tvm.compute(buffer.shape,
lambda i, j, k, l:
tvm.if_then_else(j < buflen - data_size,
buffer[i, j + data_size, k, l],
data[i, j - buflen + data_size, k, l]),
name='new_buffer')
if axis == 2:
return tvm.compute(buffer.shape,
lambda i, j, k, l:
tvm.if_then_else(k < buflen - data_size,
buffer[i, j, k + data_size, l],
data[i, j, k - buflen + data_size, l]),
name='new_buffer')
if axis == 3:
return tvm.compute(buffer.shape,
lambda i, j, k, l:
tvm.if_then_else(l < buflen - data_size,
buffer[i, j, k, l + data_size],
data[i, j, k, l - buflen + data_size]),
name='new_buffer')
assert False, 'Invalid value for axis; it should be at most {}'.format(len(buffer.shape))
else:
# Implement FIFO buffer as combination of concat and slice
begin = [0] * len(buffer.shape)
begin[axis] = data.shape[axis]
end = list(buffer.shape[:])
end[axis] += data.shape[axis]
return strided_slice(concatenate((buffer, data), axis=axis), begin=begin, end=end)
return None
Loading

0 comments on commit 83eb7e9

Please sign in to comment.