Skip to content

Commit

Permalink
[Relay] Feature Detection (#3238)
Browse files Browse the repository at this point in the history
* init

init

lint

rename

ci

fix

add

add some doc

save

add some test

add some test

lint

lint

lint

* fix build
  • Loading branch information
MarisaKirisame authored and vinx13 committed Jun 28, 2019
1 parent 329378c commit 813a3d5
Show file tree
Hide file tree
Showing 17 changed files with 448 additions and 30 deletions.
2 changes: 1 addition & 1 deletion include/tvm/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ class Integer : public Expr {
*/
operator int64_t() const {
CHECK(node_ != nullptr)
<< " Trying get reference a null Integer";
<< " Trying to reference a null Integer";
return (*this)->value;
}
/*! \brief type indicate the container type */
Expand Down
170 changes: 170 additions & 0 deletions include/tvm/relay/feature.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/relay/feature.h
* \brief Detect features used in Expr/Module.
*/
#ifndef TVM_RELAY_FEATURE_H_
#define TVM_RELAY_FEATURE_H_

#include <tvm/node/container.h>
#include <tvm/expr.h>
#include <bitset>

namespace tvm {
namespace relay {

/*! \brief Different kinds of relay feature a program might use. */
enum Feature : int {
fVar = 0,
fGlobalVar = 1,
fConstant = 2,
fTuple = 3,
fTupleGetItem = 4,
fFunction = 5,
fOp = 6,
fCall = 7,
fLet = 8,
fIf = 9,
fRefCreate = 10,
fRefRead = 11,
fRefWrite = 12,
fConstructor = 13,
fMatch = 14,
/*! \brief Whether any non-atom fragment of the program is shared, making the program a graph. */
fGraph = 15,
/*! \brief Whether there is local fixpoint in the program. */
fLetRec = 16
};

constexpr size_t feature_count = 17;

/*!
* \brief A finite set of Feature.
*/
class FeatureSet {
public:
FeatureSet(const FeatureSet&) = default;
/*! \brief A singleton set containing a single Feature. */
explicit FeatureSet(Feature ft) {
bs_.set(static_cast<size_t>(ft));
}
explicit FeatureSet(const tvm::Array<tvm::Integer>& ft) {
for (Integer i : ft) {
(*this) += Feature(static_cast<int>(i));
}
}
explicit operator Array<Integer>() const {
Array<Integer> ret;
for (size_t i = 0; i < feature_count; ++i) {
if (bs_[i]) {
ret.push_back(Integer(i));
}
}
return ret;
}
/*! \brief A set that contain all the Feature. */
static FeatureSet AllFeature() {
FeatureSet fs;
fs.bs_.flip();
return fs;
}
/*! \brief The empty set. Contain no Feature. */
static FeatureSet NoFeature() {
FeatureSet fs;
return fs;
}
template<typename T>
FeatureSet& operator+=(const T& rhs) {
bs_ |= FeatureSet(rhs).bs_;
return *this;
}
/*! \brief Set union. */
template<typename T>
FeatureSet operator+(const T& rhs) const {
FeatureSet fs(*this);
fs += rhs;
return fs;
}
template<typename T>
FeatureSet& operator-=(const T& rhs) {
bs_ &= ~(FeatureSet(rhs)).bs_;
return *this;
}
/*! \brief Set difference. */
template<typename T>
FeatureSet operator-(const T& rhs) const {
FeatureSet fs(*this);
fs -= rhs;
return fs;
}
/*!
* \brief Is this a subset of rhs?
*
* \param rhs another FeatureSet.
*
* \return true only if this is a subset of rhs.
*/
bool is_subset_of(const FeatureSet& rhs) const {
return ((*this) - rhs).bs_.none();
}

private:
std::bitset<feature_count> bs_;
FeatureSet() = default;
explicit FeatureSet(const std::bitset<feature_count>& bs) : bs_(bs) { }
};

class Expr;
/*!
* \brief Calculate the feature of the program.
*
* \param expr The expression.
*
* \return The FeatureSet.
*/
FeatureSet DetectFeature(const Expr& expr);

struct Module;
/*!
* \brief Calculate the feature of the program.
*
* \param mod The module.
*
* \return The FeatureSet.
*/
FeatureSet DetectFeature(const Module& mod);

/*!
* \brief Calculate the feature of the program.
*
* \param expr The expression.
* \param mod The module.
*
* \return The FeatureSet.
*/
inline FeatureSet DetectFeature(const Expr& expr, const Module& mod) {
return DetectFeature(expr) + DetectFeature(mod);
}

} // namespace relay
} // namespace tvm

#endif // TVM_RELAY_FEATURE_H_
4 changes: 2 additions & 2 deletions include/tvm/relay/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,10 @@ class TensorTypeNode : public BaseTensorTypeNode {

RELAY_DEFINE_NODE_REF(TensorType, TensorTypeNode, Type);

/*! \brief possible kinds of Type */
/*! \brief Possible kinds of Type. */
enum Kind : int {
/*! \brief template variable in shape expression */
kType = 0,
/*! \brief Template variable in shape expression. */
kShapeVar = 1,
kBaseType = 2,
kShape = 3,
Expand Down
41 changes: 41 additions & 0 deletions python/tvm/relay/feature.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""The type nodes of the Relay language."""
from enum import IntEnum

class Feature(IntEnum):
""" The features a program might contain. """
fVar = 0
fGlobalVar = 1
fConstant = 2
fTuple = 3
fTupleGetItem = 4
fFunction = 5
fOp = 6
fCall = 7
fLet = 8
fIf = 9
fRefCreate = 10
fRefRead = 11
fRefWrite = 12
fConstructor = 13
fMatch = 14
""" Whether any non-atom fragment of the program is shared, making the program a graph. """
fGraph = 15
""" Whether there is local fixpoint in the program. """
fLetRec = 16
27 changes: 26 additions & 1 deletion python/tvm/relay/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .expr import Expr
from .ty import Type
from .module import Module
from .feature import Feature


def post_order_visit(expr, fvisit):
Expand Down Expand Up @@ -604,7 +605,6 @@ def gradient(expr, mod=None, mode='higher_order'):
raise Exception('unknown mode')



def get_total_mac_number(expr):
"""
Count the number of MACs (multiply-accumulate) of a model
Expand Down Expand Up @@ -641,6 +641,7 @@ def eliminate_common_subexpr(expr, fskip=None):
"""
return _ir_pass.eliminate_common_subexpr(expr, fskip)


def partial_evaluate(expr, mod=None):
"""
Evaluate the static fragment of the code.
Expand All @@ -660,6 +661,7 @@ def partial_evaluate(expr, mod=None):
"""
return _ir_pass.partial_evaluate(expr, mod)


def unmatched_cases(match, mod=None):
"""
Finds cases that the match expression does not catch, if any.
Expand All @@ -677,3 +679,26 @@ def unmatched_cases(match, mod=None):
Patterns that the match expression does not catch.
"""
return _ir_pass.unmatched_cases(match, mod)


def detect_feature(a, b=None):
"""
Detect the feature used in a relay program.
Parameters
----------
a : Union[tvm.relay.Expr, tvm.relay.Module]
The input expression or module.
b : Optional[Union[tvm.relay.Expr, tvm.relay.Module]]
The input expression or module.
The two arguments cannot both be expression or module.
Returns
-------
features : Set[Feature]
Features used in the program.
"""
if isinstance(a, Module):
a, b = b, a
return set([Feature(int(x)) for x in _ir_pass.detect_feature(a, b)])
6 changes: 4 additions & 2 deletions python/tvm/relay/prelude.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from .adt import Constructor, TypeData, Clause, Match
from .adt import PatternConstructor, PatternVar, PatternWildcard
from .parser import fromtext

__PRELUDE_PATH__ = os.path.dirname(os.path.realpath(__file__))
from .module import Module

class Prelude:
"""Contains standard definitions."""
Expand Down Expand Up @@ -486,7 +486,9 @@ def load_prelude(self):
self.compose = self.mod.get_global_var("compose")


def __init__(self, mod):
def __init__(self, mod=None):
if mod is None:
mod = Module()
self.mod = mod
self.load_prelude()
self.define_list_adt()
Expand Down
4 changes: 2 additions & 2 deletions src/relay/ir/type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand Down
6 changes: 3 additions & 3 deletions src/relay/pass/alter_op_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand All @@ -18,7 +18,7 @@
*/

/*!
* Copyright (c) 2018 by Contributors
* Copyright (c) 2019 by Contributors
* \file alter_op_layout.cc
* \brief Alternate the layouts of operators or replace primitive operators with
other expressions. This pass can be used for computing convolution in
Expand Down
Loading

0 comments on commit 813a3d5

Please sign in to comment.