diff --git a/include/tvm/relay/feature.h b/include/tvm/relay/feature.h new file mode 100644 index 0000000000000..bcbea9affe2cc --- /dev/null +++ b/include/tvm/relay/feature.h @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relay/feature.h + * \brief Detect features used in Expr/Module. + */ +#ifndef TVM_RELAY_FEATURE_H_ +#define TVM_RELAY_FEATURE_H_ + +#include + +namespace tvm { +namespace relay { + +enum Feature : int { + fVar = 0, + fGlobalVar = 1, + fConstant = 2, + fTuple = 3, + fTupleGetItem = 4, + fFunction = 5, + fOp = 6, + fCall = 7, + fLet = 8, + fIf = 9, + fRefCreate = 10, + fRefRead = 11, + fRefWrite = 12, + fConstructor = 13, + fMatch = 14, + fGraph = 15, + fLetRec = 16 +}; + +constexpr size_t feature_count = 17; + +class FeatureSet { + public: + FeatureSet(const FeatureSet&) = default; + explicit FeatureSet(Feature ft) { + bs_.set(static_cast(ft)); + } + static FeatureSet AllFeature() { + FeatureSet fs; + return fs; + } + static FeatureSet NoFeature() { + FeatureSet fs; + fs.bs_.flip(); + return fs; + } + template + FeatureSet& operator+=(const T& rhs) { + bs_ |= FeatureSet(rhs).bs_; + return *this; + } + template + FeatureSet operator+(const T& rhs) const { + FeatureSet fs(*this); + fs += rhs; + return fs; + } + template + FeatureSet& operator-=(const T& rhs) { + bs_ &= ~(FeatureSet(rhs)).bs_; + return *this; + } + template + FeatureSet operator-(const T& rhs) const { + FeatureSet fs(*this); + fs -= rhs; + return fs; + } + bool is_subset_of(const FeatureSet& rhs) const { + return ((*this) - rhs).bs_.none(); + } + + private: + std::bitset bs_; + FeatureSet() = default; + explicit FeatureSet(const std::bitset& bs) : bs_(bs) { } +}; + +class Expr; +FeatureSet DetectFeature(const Expr& expr); +struct Module; +FeatureSet DetectFeature(const Module& mod); + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_FEATURE_H_ diff --git a/src/relay/pass/alter_op_layout.cc b/src/relay/pass/alter_op_layout.cc index f51c201d0b2a9..9df39c6d6165a 100644 --- a/src/relay/pass/alter_op_layout.cc +++ b/src/relay/pass/alter_op_layout.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,7 +18,7 @@ */ /*! - * Copyright (c) 2018 by Contributors + * Copyright (c) 2019 by Contributors * \file alter_op_layout.cc * \brief Alternate the layouts of operators or replace primitive operators with other expressions. This pass can be used for computing convolution in diff --git a/src/relay/pass/feature.cc b/src/relay/pass/feature.cc new file mode 100644 index 0000000000000..d4206e4a8f2be --- /dev/null +++ b/src/relay/pass/feature.cc @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file feature.cc + * \brief Detect features used in Expr/Module + */ +#include +#include +#include +#include +#include +#include "pass_util.h" + +namespace tvm { +namespace relay { + +FeatureSet DetectFeature(const Expr& expr) { + struct FeatureDetector : ExprVisitor { + std::unordered_set visited_; + FeatureSet fs = FeatureSet::NoFeature(); + void VisitExpr(const Expr& expr) final { + if (visited_.count(expr) == 0) { + ExprVisitor::VisitExpr(expr); + } else { + if (!IsAtomic(expr)) { + fs += fGraph; + } + } + } +#define DETECT_CONSTRUCT(CONSTRUCT_NAME, STMT) \ + void VisitExpr_(const CONSTRUCT_NAME##Node* op) final { \ + STMT \ + fs += f##CONSTRUCT_NAME; \ + ExprVisitor::VisitExpr_(op); \ + } +#define DETECT_DEFAULT_CONSTRUCT(CONSTRUCT_NAME) DETECT_CONSTRUCT(CONSTRUCT_NAME, {}) + DETECT_DEFAULT_CONSTRUCT(Var) + DETECT_DEFAULT_CONSTRUCT(GlobalVar) + DETECT_DEFAULT_CONSTRUCT(Constant) + DETECT_DEFAULT_CONSTRUCT(Tuple) + DETECT_DEFAULT_CONSTRUCT(TupleGetItem) + DETECT_DEFAULT_CONSTRUCT(Function) + DETECT_DEFAULT_CONSTRUCT(Op) + DETECT_DEFAULT_CONSTRUCT(Call) + DETECT_CONSTRUCT(Let, { + for (const Var& v : FreeVars(op->value)) { + if (op->var == v) { + fs += fLetRec; + } + } + }) + DETECT_DEFAULT_CONSTRUCT(If) + DETECT_DEFAULT_CONSTRUCT(RefCreate) + DETECT_DEFAULT_CONSTRUCT(RefRead) + DETECT_DEFAULT_CONSTRUCT(RefWrite) + DETECT_DEFAULT_CONSTRUCT(Constructor) + DETECT_DEFAULT_CONSTRUCT(Match) +#undef DETECT_DEFAULT_CONSTRUCT + } fd; + fd(expr); + return fd.fs; +} + +FeatureSet DetectFeature(const Module& mod) { + FeatureSet fs = FeatureSet::NoFeature(); + if (mod.defined()) { + for (const auto& f : mod->functions) { + fs += DetectFeature(f.second); + } + } + return fs; +} + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/pass_util.h b/src/relay/pass/pass_util.h index 38d8b0bd9040a..e2956e2343ef9 100644 --- a/src/relay/pass/pass_util.h +++ b/src/relay/pass/pass_util.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -97,6 +97,10 @@ inline Expr TransformF(const std::function& func, const Expr& } } +inline bool IsAtomic(const Expr& e) { + return e.as() || e.as() || e.as() || e.as(); +} + } // namespace relay } // namespace tvm #endif // TVM_RELAY_PASS_PASS_UTIL_H_ diff --git a/src/relay/pass/well_formed.cc b/src/relay/pass/well_formed.cc index 4eaaa934e78bb..dea9374812896 100644 --- a/src/relay/pass/well_formed.cc +++ b/src/relay/pass/well_formed.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY