diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index d6f9bee112a5..793b43ad2bb3 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -376,6 +376,15 @@ struct SparseTransposeAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(SparseTransposeAttrs, "relay.attrs.SparseTransposeAttrs") {} }; +/*! \brief Attributes for FIFO buffer operator */ +struct FIFOBufferAttrs : public tvm::AttrsNode { + int axis; + + TVM_DECLARE_ATTRS(FIFOBufferAttrs, "relay.attrs.FIFOBufferAttrs") { + TVM_ATTR_FIELD(axis).set_default(0); + } +}; + /*! \brief Attributes for upsampling operator */ struct UpSamplingAttrs : public tvm::AttrsNode { int scale; diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 0febfdd85c4a..05dd691f1c60 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -1026,6 +1026,12 @@ def _mx_one_hot(inputs, attrs): return _op.one_hot(indices, on_value, off_value, depth, -1, dtype) +def _mx_contrib_fifo_buffer(inputs, attrs): + new_attrs = {} + new_attrs['axis'] = attrs.get_int('axis') + return _op.nn.fifo_buffer(*inputs, **new_attrs) + + # Note: due to attribute conversion constraint # ops in the identity set must be attribute free _identity_list = [ @@ -1198,6 +1204,7 @@ def _mx_one_hot(inputs, attrs): # TODO(tvm-tvm): support all operators. # # "broadcast_to", + "contrib_fifo_buffer" : _mx_contrib_fifo_buffer, } # set identity list diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 8c09390b4deb..b8572349fb9d 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -69,6 +69,20 @@ def schedule_dense(attrs, outputs, target): reg.register_pattern("nn.dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE) +@reg.register_compute('nn.fifo_buffer') +def compute_fifo_buffer(attrs, inputs, out_type, target): + return [topi.nn.fifo_buffer(inputs[0], inputs[1], axis=attrs.get_int('axis'))] + + +@reg.register_schedule('nn.fifo_buffer') +def schedule_fifo_buffer(attrs, outputs, target): + with target: + return topi.generic.schedule_injective(outputs) + + +reg.register_pattern("nn.fifo_buffer", OpPattern.OPAQUE) + + # batch_matmul @reg.register_compute("nn.batch_matmul") def compute_batch_matmul(attrs, inputs, out_type, target): diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 31c1006e8f4d..9ddb3ece4ce2 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -601,6 +601,36 @@ def dense(data, weight, units=None, out_dtype=""): return _make.dense(data, weight, units, out_dtype) +def fifo_buffer(data, buffer, axis): + """FIFO buffer + + Compute equivalent of + ``` + concat(buffer, data, axis=axis) \ + .slice_axis(axis=axis, begin=data.shape[axis], end=data.shape[axis]+buffer.shape[axis]) + ``` + + Useful for + * Encoding explicit re-use of computation in convolution ops operated on a sliding window input + * Implementing a FIFO queue to cache intermediate results, e.g. as in Fast WaveNet. + + Parameters + ---------- + data : tvm.relay.Expr + The input data + buffer : tvm.relay.Expr + Previous value of the FIFO buffer + axis : int + Specify which axis should be used for buffering + + Returns + ------- + result : tvm.relay.Expr + Updated value for the buffer + """ + return _make.fifo_buffer(data, buffer, axis) + + def relu(data): """Rectified linear unit. diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 11f8ad1611cd..2de0257aa841 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -54,6 +54,11 @@ class DenseAttrs(Attrs): """Attributes for nn.dense""" +@register_relay_attr_node +class FIFOBufferAttrs(Attrs): + """Attributes for nn.fifo_buffer""" + + @register_relay_attr_node class UpSamplingAttrs(Attrs): """Attributes for nn.upsampling""" diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index a875ffc6293b..6d8c3acf1e00 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -100,6 +100,73 @@ RELAY_REGISTER_OP("nn.bias_add") }); +// relay.nn.fifo_buffer +TVM_REGISTER_NODE_TYPE(FIFOBufferAttrs); + +Expr MakeFIFOBuffer(Expr input, Expr buffer, int axis) { + auto attrs = make_node(); + attrs->axis = axis; + static const Op& op = Op::Get("nn.fifo_buffer"); + return CallNode::make(op, {input, buffer}, Attrs(attrs), {}); +} + +bool FIFOBufferRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 3); + const auto* input = types[0].as(); + const auto* buffer = types[1].as(); + const FIFOBufferAttrs* param = attrs.as(); + if (input == nullptr || buffer == nullptr) { + return false; + } + CHECK(param != nullptr); + CHECK_EQ(input->shape.size(), buffer->shape.size()); + + const size_t buffer_axis + = static_cast(param->axis < 0 ? static_cast(buffer->shape.size()) + param->axis + : param->axis); + + reporter->Assert(buffer_axis < buffer->shape.size()); + for (size_t i = 0; i < buffer->shape.size(); ++i) { + if (i != buffer_axis) { + reporter->AssertEQ(input->shape[i], buffer->shape[i]); + } + } + reporter->Assert(input->shape[buffer_axis] < buffer->shape[buffer_axis]); + + Array oshape = buffer->shape; + + reporter->Assign(types[2], TensorTypeNode::make(oshape, buffer->dtype)); + return true; +} + +TVM_REGISTER_API("relay.op.nn._make.fifo_buffer") +.set_body_typed(MakeFIFOBuffer); + +RELAY_REGISTER_OP("nn.fifo_buffer") +.describe(R"code(FIFO buffer +Compute equivalent of + +``` +concat(buffer, data, axis=axis) \ +.slice_axis(axis=axis, begin=data.shape[axis], end=data.shape[axis]+buffer.shape[axis]) +``` + +Useful for +* Encoding explicit re-use of computation in convolution ops operated on a sliding window input +* Implementing a FIFO queue to cache intermediate results, e.g. as in Fast WaveNet. +)code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.FIFOBufferAttrs") +.set_num_inputs(2) +.add_argument("data", "Tensor", "Latest input") +.add_argument("buffer", "Tensor", + "Buffer storing latest [length_buffer] inputs") +.set_support_level(3) +.add_type_rel("FIFOBuffer", FIFOBufferRel); + + // relay.nn.dense TVM_REGISTER_NODE_TYPE(DenseAttrs); diff --git a/topi/python/topi/nn/__init__.py b/topi/python/topi/nn/__init__.py index 41d9de2da7d0..5362436652af 100644 --- a/topi/python/topi/nn/__init__.py +++ b/topi/python/topi/nn/__init__.py @@ -22,3 +22,4 @@ from .batch_matmul import * from .sparse import * from .pad import * +from .fifo_buffer import * diff --git a/topi/python/topi/nn/fifo_buffer.py b/topi/python/topi/nn/fifo_buffer.py new file mode 100644 index 000000000000..5467a90fbaae --- /dev/null +++ b/topi/python/topi/nn/fifo_buffer.py @@ -0,0 +1,127 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""FIFO buffer op""" +from __future__ import absolute_import as _abs +import tvm +from .. import tag +from ..transform import concatenate, strided_slice + +@tvm.tag_scope(tag=tag.INJECTIVE+",fifo_buffer") +def fifo_buffer(data, buffer, axis): + """ + Implements the FIFO buffer + """ + assert len(data.shape) == len(buffer.shape), \ + 'buffer and data must have same number of dimensions, ' + \ + 'buffer.shape = {}, data.shape = {}'.format(buffer.shape, data.shape) + assert len(buffer.shape) >= 1, 'Zero-dimension tensor not supported' + assert 0 <= axis < len(buffer.shape), 'buffer axis out of range' + for i in range(len(data.shape)): + if i == axis: + assert int(str(data.shape[i])) <= int(str(buffer.shape[i])) + else: + assert int(str(data.shape[i])) == int(str(buffer.shape[i])) + + buflen = buffer.shape[axis] + data_size = data.shape[axis] + + # Explicitly write out formula up to 4D, and then use concat+slice combo for 5D and higher + if len(buffer.shape) == 1: + return tvm.compute(buffer.shape, + lambda i: + tvm.if_then_else(i < buflen - data_size, + buffer[i + data_size], + data[i - buflen + data_size]), + name='new_buffer') + elif len(buffer.shape) == 2: + if axis == 0: + return tvm.compute(buffer.shape, + lambda i, j: + tvm.if_then_else(i < buflen - data_size, + buffer[i + data_size, j], + data[i - buflen + data_size, j]), + name='new_buffer') + if axis == 1: + return tvm.compute(buffer.shape, + lambda i, j: + tvm.if_then_else(j < buflen - data_size, + buffer[i, j + data_size], + data[i, j - buflen + data_size]), + name='new_buffer') + assert False, 'Invalid value for axis; it should be at most {}'.format(len(buffer.shape)) + elif len(buffer.shape) == 3: + if axis == 0: + return tvm.compute(buffer.shape, + lambda i, j, k: + tvm.if_then_else(i < buflen - data_size, + buffer[i + data_size, j, k], + data[i - buflen + data_size, j, k]), + name='new_buffer') + if axis == 1: + return tvm.compute(buffer.shape, + lambda i, j, k: + tvm.if_then_else(j < buflen - data_size, + buffer[i, j + data_size, k], + data[i, j - buflen + data_size, k]), + name='new_buffer') + if axis == 2: + return tvm.compute(buffer.shape, + lambda i, j, k: + tvm.if_then_else(k < buflen - data_size, + buffer[i, j, k + data_size], + data[i, j, k - buflen + data_size]), + name='new_buffer') + assert False, 'Invalid value for axis; it should be at most {}'.format(len(buffer.shape)) + elif len(buffer.shape) == 4: + if axis == 0: + return tvm.compute(buffer.shape, + lambda i, j, k, l: + tvm.if_then_else(i < buflen - data_size, + buffer[i + data_size, j, k, l], + data[i - buflen + data_size, j, k, l]), + name='new_buffer') + if axis == 1: + return tvm.compute(buffer.shape, + lambda i, j, k, l: + tvm.if_then_else(j < buflen - data_size, + buffer[i, j + data_size, k, l], + data[i, j - buflen + data_size, k, l]), + name='new_buffer') + if axis == 2: + return tvm.compute(buffer.shape, + lambda i, j, k, l: + tvm.if_then_else(k < buflen - data_size, + buffer[i, j, k + data_size, l], + data[i, j, k - buflen + data_size, l]), + name='new_buffer') + if axis == 3: + return tvm.compute(buffer.shape, + lambda i, j, k, l: + tvm.if_then_else(l < buflen - data_size, + buffer[i, j, k, l + data_size], + data[i, j, k, l - buflen + data_size]), + name='new_buffer') + assert False, 'Invalid value for axis; it should be at most {}'.format(len(buffer.shape)) + else: + # Implement FIFO buffer as combination of concat and slice + begin = [0] * len(buffer.shape) + begin[axis] = data.shape[axis] + end = list(buffer.shape[:]) + end[axis] += data.shape[axis] + return strided_slice(concatenate((buffer, data), axis=axis), begin=begin, end=end) + return None diff --git a/topi/tests/python/test_fifo_buffer.py b/topi/tests/python/test_fifo_buffer.py new file mode 100644 index 000000000000..022272f6c4da --- /dev/null +++ b/topi/tests/python/test_fifo_buffer.py @@ -0,0 +1,202 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test code for FIFO buffer""" + +import tvm +import topi +import numpy as np +from common import get_all_backend +from tvm.contrib.pickle_memoize import memoize + +def verify_fifo_buffer(buffer_shape, data_shape, axis, dtype='float32'): + buffer = tvm.placeholder(buffer_shape, name='buffer', dtype=dtype) + data = tvm.placeholder(data_shape, name='data', dtype=dtype) + + # Use memoize, pickle the test data for next time use + @memoize('topi.tests.test_fifo_buffer') + def get_ref_data(): + buffer_np = np.random.uniform(size=buffer_shape).astype(dtype) + data_np = np.random.uniform(size=data_shape).astype(dtype) + + # Reference implementation of FIFO queue + begin = data_np.shape[axis] + end = buffer_np.shape[axis] + data_np.shape[axis] + ndim = len(buffer_np.shape) + ss = tuple((slice(begin, end, 1) if x == axis else slice(None)) for x in range(ndim)) + out_np = np.concatenate((buffer_np, data_np), axis=axis)[ss] + return (buffer_np, data_np, out_np) + + # Get the test data + buffer_np, data_np, out_np = get_ref_data() + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print(' Skip because %s is not enabled' % device) + return + print(' Running on target: {}'.format(device)) + + with tvm.target.create(device): + out = topi.nn.fifo_buffer(data, buffer, axis=axis) + s = topi.generic.schedule_injective([out]) + + buffer_tvm = tvm.nd.array(buffer_np, ctx=ctx) + data_tvm = tvm.nd.array(data_np, ctx=ctx) + out_tvm = tvm.nd.empty(shape=buffer_shape, ctx=ctx, dtype=dtype) + f = tvm.build(s, [data, buffer, out], device, name='fifo') + f(data_tvm, buffer_tvm, out_tvm) + tvm.testing.assert_allclose(out_tvm.asnumpy(), out_np) + + for device in get_all_backend(): + check_device(device) + +def verify_conv1d_integration(): + batch_size = 1 + num_channel = 1 + num_filter = 1 + + # Note: TVM doesn't have a separate op for 1D convolution, so we use conv2d instead. + # We set height=1 to indicate that convolution is really 1D. + stride = (1, 1) + dilate = (1, 1) + padding = (0, 0) + + kernel_size = (1, 3) + input_window_size = (1, 10) + inc_input_size = (1, 2) + context_size = (1, 4) + inc_output_size = (1, 2) + output_window_size = (1, 8) + + num_iteration = 20 + buffer_axis = 3 + + kernel_shape = (num_filter, num_channel, kernel_size[0], kernel_size[1]) + input_window_shape = (batch_size, num_channel, input_window_size[0], input_window_size[1]) + inc_input_shape = (batch_size, num_channel, inc_input_size[0], inc_input_size[1]) + inc_output_shape = (batch_size, num_filter, inc_output_size[0], inc_output_size[1]) + context_shape = (batch_size, num_channel, context_size[0], context_size[1]) + output_window_shape = (batch_size, num_filter, output_window_size[0], output_window_size[1]) + # Rule: Convolution of Tensor[context_shape] and Tensor[kernel_shape] + # produces Tensor[inc_input_shape] + + dtype = 'float32' + + inc_input = tvm.placeholder(inc_input_shape, name='inc_input', dtype=dtype) + input_window = tvm.placeholder(input_window_shape, name='input_window', dtype=dtype) + context = tvm.placeholder(context_shape, name='context', dtype=dtype) + kernel = tvm.placeholder(kernel_shape, name='kernel', dtype=dtype) + inc_output = tvm.placeholder(inc_input_shape, name='inc_output', dtype=dtype) + output_window = tvm.placeholder(output_window_shape, name='output_window', dtype=dtype) + + # Use memoize, pickle the test data for next time use + @memoize('topi.tests.test_fifo_buffer_conv1d_integration') + def get_data(): + # Generate [num_iteration] slices of input + inc_input_np = np.random.uniform(size=tuple([num_iteration] + list(inc_input_shape)))\ + .astype(dtype) + input_window_np = np.zeros(input_window_shape, dtype=dtype) + kernel_np = np.random.uniform(size=kernel_shape).astype(dtype) + context_np = np.zeros(context_shape, dtype=dtype) + output_window_np = np.zeros(output_window_shape, dtype=dtype) + + return (inc_input_np, input_window_np, kernel_np, context_np, output_window_np) + + # Get the test data + inc_input_np, input_window_np, kernel_np, context_np, output_window_np = get_data() + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print(' Skip because %s is not enabled' % device) + return + print(' Running on target: {}'.format(device)) + + with tvm.target.create(device): + out = topi.nn.fifo_buffer(inc_input, context, axis=buffer_axis) + s = topi.generic.schedule_injective([out]) + update_context = tvm.build(s, [inc_input, context, out], device, name='update_context') + + out = topi.nn.conv2d(context, kernel, strides=stride, padding=padding, dilation=dilate, + layout='NCHW', out_dtype=dtype) + s = topi.generic.schedule_conv2d_nchw([out]) + conv2d_inc = tvm.build(s, [context, kernel, out], device, name='conv2d_inc') + + out = topi.nn.fifo_buffer(inc_output, output_window, axis=buffer_axis) + s = topi.generic.schedule_injective([out]) + update_output_window = tvm.build(s, [inc_output, output_window, out], device, + name='update_output_window') + + out = topi.nn.fifo_buffer(inc_input, input_window, axis=buffer_axis) + s = topi.generic.schedule_injective([out]) + update_input_window = tvm.build(s, [inc_input, input_window, out], device, + name='update_input_window') + + out = topi.nn.conv2d(input_window, kernel, strides=stride, padding=padding, + dilation=dilate, layout='NCHW', out_dtype=dtype) + s = topi.generic.schedule_conv2d_nchw([out]) + conv2d = tvm.build(s, [input_window, kernel, out], device, name='conv2d') + + input_window_tvm = tvm.nd.array(input_window_np, ctx=ctx) + new_input_window_tvm = tvm.nd.empty(shape=input_window_shape, ctx=ctx, dtype=dtype) + kernel_tvm = tvm.nd.array(kernel_np, ctx=ctx) + context_tvm = tvm.nd.array(context_np, ctx=ctx) + new_context_tvm = tvm.nd.empty(shape=context_shape, ctx=ctx, dtype=dtype) + inc_output_tvm = tvm.nd.empty(shape=inc_output_shape, ctx=ctx, dtype=dtype) + output_window_tvm = tvm.nd.array(output_window_np, ctx=ctx) + new_output_window_tvm = tvm.nd.empty(shape=output_window_shape, ctx=ctx, dtype=dtype) + output_window_ref_tvm = tvm.nd.empty(shape=output_window_shape, ctx=ctx, dtype=dtype) + + for i in range(num_iteration): + # Take i-th slice of inc_input_np + inc_input_tvm = tvm.nd.array(inc_input_np[i], ctx=ctx) + + # Compute new output window incrementally, using the FIFO buffer op + update_context(inc_input_tvm, context_tvm, new_context_tvm) + conv2d_inc(new_context_tvm, kernel_tvm, inc_output_tvm) + update_output_window(inc_output_tvm, output_window_tvm, new_output_window_tvm) + context_tvm = new_context_tvm + output_window_tvm = new_output_window_tvm + + # Compute full input window, so that we have a baseline + update_input_window(inc_input_tvm, input_window_tvm, new_input_window_tvm) + input_window_tvm = new_input_window_tvm + conv2d(input_window_tvm, kernel_tvm, output_window_ref_tvm) + # Incrementally updating the output window should be equivalent to computing it from + # scratch using the input window + tvm.testing.assert_allclose(output_window_tvm.asnumpy(), + output_window_ref_tvm.asnumpy()) + + for device in get_all_backend(): + check_device(device) + +def test_fifo_buffer(): + for ndim in [1, 2, 3, 4, 5, 6]: + for axis in range(ndim): + buffer_shape = tuple(7 for _ in range(ndim)) + data_shape = tuple((2 if i == axis else 7) for i in range(ndim)) + print('Testing FIFO buffer op: buffer_shape = {}, data_shape = {}, axis = {}' + .format(buffer_shape, data_shape, axis)) + verify_fifo_buffer(buffer_shape, data_shape, axis) + +def test_conv1d_integration(): + print('Testing FIFO buffer with 1D convolution') + verify_conv1d_integration() + +if __name__ == '__main__': + test_fifo_buffer() + test_conv1d_integration()