Skip to content

Commit

Permalink
[Runtime][Relay][Cleanup] Clean up for memory pass to enable heteroge…
Browse files Browse the repository at this point in the history
…nous execution support. (apache#5324)

* Cleanup type pack and unpack for tuples.

* Clean up the memory_pass using common helpers

* Clean up memory.cc

* Refactor pass

* Add doc strings

* Fix CPPlint

* Fix PyLint

* Fix

* Apply suggestions from code review

Co-Authored-By: Zhi <[email protected]>

* Fix typo

Co-authored-by: Zhi <[email protected]>
  • Loading branch information
2 people authored and dhruvaray committed Apr 28, 2020
1 parent d7c977c commit ab5afbc
Show file tree
Hide file tree
Showing 6 changed files with 269 additions and 172 deletions.
4 changes: 2 additions & 2 deletions include/tvm/relay/attrs/device_copy.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ struct DeviceCopyAttrs : public tvm::AttrsNode<DeviceCopyAttrs> {
TVM_DECLARE_ATTRS(DeviceCopyAttrs, "relay.attrs.DeviceCopyAttrs") {
TVM_ATTR_FIELD(src_dev_type)
.describe(
"The virutal device/context type where the op copies data from.")
"The virtual device/context type where the op copies data from.")
.set_default(0);
TVM_ATTR_FIELD(dst_dev_type)
.describe(
"The virutal device/context type where the op copies data to.")
"The virtual device/context type where the op copies data to.")
.set_default(0);
}
};
Expand Down
27 changes: 27 additions & 0 deletions include/tvm/relay/attrs/memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,37 @@
#include <tvm/ir/attrs.h>
#include <tvm/relay/expr.h>
#include <string>
#include <vector>

namespace tvm {
namespace relay {

std::vector<TensorType> FlattenTupleType(const Type& type);
std::vector<Expr> FromTupleType(const Type& type, const Expr& expr);
Expr ToTupleType(const Type& t, const Array<Expr>& exprs);

/*!
* \brief Options for allocating storage.
*/
struct AllocStorageAttrs : public tvm::AttrsNode<AllocStorageAttrs> {
DataType dtype;
int device_id;
int device_type;

TVM_DECLARE_ATTRS(AllocStorageAttrs, "relay.attrs.AllocStorageAttrs") {
TVM_ATTR_FIELD(dtype)
.describe(
"The dtype of the tensor to allocate.")
.set_default(DataType::Float(32, 1));
TVM_ATTR_FIELD(device_id)
.describe(
"The device id on which to allocate memory.");
TVM_ATTR_FIELD(device_type)
.describe(
"The device type on which to allocate memory.");
}
};

/*!
* \brief Options for allocating tensors.
*/
Expand Down
61 changes: 59 additions & 2 deletions python/tvm/relay/op/memory/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return,invalid-name,len-as-condition,too-many-nested-blocks
"""Operators for manipulating low-level memory."""
from __future__ import absolute_import as _abs
from . import _make
Expand All @@ -23,6 +24,9 @@ def invoke_tvm_op(func, inputs, outputs):
Parameters
----------
func : tvm.relay.Expr
The input expr.
inputs : tvm.relay.Expr
A tuple of the inputs to pass to the TVM function.
Expand Down Expand Up @@ -59,7 +63,7 @@ def alloc_tensor(storage, shape, dtype='float32', assert_shape=None):
"""
return _make.alloc_tensor(storage, shape, dtype, assert_shape)

def alloc_storage(size, alignment, dtype_hint='float32'):
def alloc_storage(size, alignment, ctx, dtype_hint='float32'):
"""Allocate a piece of tensor storage.
Parameters
Expand All @@ -76,7 +80,7 @@ def alloc_storage(size, alignment, dtype_hint='float32'):
result : tvm.relay.Expr
The alloc_storage expression.
"""
return _make.alloc_storage(size, alignment, dtype_hint)
return _make.alloc_storage(size, alignment, ctx, dtype_hint)

def shape_func(func, inputs, outputs, dependent=False):
"""Invoke the shape function of the passed function.
Expand All @@ -96,3 +100,56 @@ def shape_func(func, inputs, outputs, dependent=False):
The shape function expression.
"""
return _make.shape_func(func, inputs, outputs, dependent)

def flatten_tuple_type(ty):
"""Return a sequence of the types contained in the tuple type in order.
Parameters
----------
ty: tvm.Type
The type to flatten.
Returns
-------
result: List[tvm.Type]
The types in their linear order.
"""
return _make.FlattenTupleType(ty)

def from_tuple_type(ty, expr):
"""Convert an expression with the given type into a sequence of expressions.
Each expression maps to a field of the tuple or nested tuples in linear
order.
Parameters
----------
ty: tvm.Type
The type to unpack.
expr: tvm.relay.Expr
The expression from which to extract each sub-field.
Returns
-------
result: List[tvm.relay.Expr]
The list of sub-expressions.
"""
return _make.FromTupleType(ty, expr)

def to_tuple_type(ty, exprs):
"""Pack the sequence of expressions into the nested tuple type.
Parameters
----------
ty: tvm.Type
The type to pack with.
exprs: tvm.relay.Expr
The expressions to pack back into the nested tuple type.
Returns
-------
result: List[tvm.relay.Expr]
The packed tuple expression.
"""
return _make.ToTupleType(ty, exprs)
Loading

0 comments on commit ab5afbc

Please sign in to comment.