Skip to content

Commit

Permalink
address comment
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Jul 23, 2020
1 parent 1d4778f commit 99e9ab3
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 31 deletions.
6 changes: 3 additions & 3 deletions include/tvm/auto_scheduler/compute_dag.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class AccessAnalyzer : public ObjectRef {
* \brief Get all consumers of on operation
* \param state The current loop state
* \param op The operation
* \return The return consumer set
* \return The set of consumers
* \note This function propagates the relation for inlined ops
*/
TVM_DLL std::unordered_set<te::Operation, ObjectHash, ObjectEqual> GetConsumers(
Expand All @@ -125,7 +125,7 @@ class AccessAnalyzer : public ObjectRef {
* \brief Get all producers of on operation
* \param state The current loop state
* \param op The operation
* \param producers The return producer set
* \return The set of producers
* \note This function propagates the relation for inlined ops
*/
TVM_DLL std::unordered_set<te::Operation, ObjectHash, ObjectEqual> GetProducers(
Expand All @@ -134,7 +134,7 @@ class AccessAnalyzer : public ObjectRef {
/*!
* \brief Get all direct producers of on operation
* \param op The operation
* \param producers The return producer set
* \return The set of direct producers
* \note This function DOES NOT propagate the relation for inlined ops
*/
TVM_DLL std::unordered_set<te::Operation, ObjectHash, ObjectEqual> GetDirectProducers(
Expand Down
54 changes: 26 additions & 28 deletions src/auto_scheduler/compute_dag.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include <vector>

#include "utils.h"
#include "../arith/pattern_match.h"

namespace tvm {
namespace auto_scheduler {
Expand Down Expand Up @@ -131,8 +132,8 @@ class ReadAccessExtractor : public StmtExprVisitor {
}

void VisitExpr_(const ProducerLoadNode* op) final {
buf_accesses[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(),
op->indices.end());
read_access[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(),
op->indices.end());
StmtExprVisitor::VisitExpr_(op);
}

Expand All @@ -149,28 +150,25 @@ class ReadAccessExtractor : public StmtExprVisitor {
// All read accesses to all operations
// The innermost vector stores mulit-dimentional indices.
// The middle vector stores possible multiple accesses
OperationMap<std::vector<std::vector<PrimExpr>>> buf_accesses;
OperationMap<std::vector<std::vector<PrimExpr>>> read_access;
// Whether this expression has branch
bool has_branch{false};
};

// Returns whether the expr equals to the var with an optional const shift
bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) {
if (auto pv = expr.as<VarNode>()) {
return pv == var.get();
} else if (auto padd = expr.as<AddNode>()) {
return ((padd->a.get() == var.get() && padd->b->IsInstance<IntImmNode>()) ||
(padd->b.get() == var.get() && padd->a->IsInstance<IntImmNode>()));
} else if (auto psub = expr.as<SubNode>()) {
return ((psub->a.get() == var.get() && psub->b->IsInstance<IntImmNode>()) ||
(psub->b.get() == var.get() && psub->a->IsInstance<IntImmNode>()));
} else {
return false;
arith::PVar<PrimExpr> x;
arith::PVar<IntImm> c;

if (((x + c).Match(expr) || (x - c).Match(expr) || (c + x).Match(expr) || x.Match(expr)) &&
x.Eval().same_as(var)) {
return true;
}
return false;
}

// Return whether the access is injective
bool IsInjective(const te::Operation& op, const std::vector<PrimExpr>& index, bool* axis_missing,
bool IsInjective(const te::Operation& op, const std::vector<PrimExpr>& indices, bool* axis_missing,
bool* axis_duplicated, bool* same_order) {
auto cop = op.as<te::ComputeOpNode>();
if (cop == nullptr) {
Expand All @@ -180,7 +178,7 @@ bool IsInjective(const te::Operation& op, const std::vector<PrimExpr>& index, bo
std::vector<int> index_to_var_idx;
std::vector<int> var_idx_ct(cop->axis.size(), 0);

for (const auto& expr : index) {
for (const auto& expr : indices) {
if (!is_const_int(expr)) {
bool found = false;
for (size_t i = 0; i < cop->axis.size(); ++i) {
Expand Down Expand Up @@ -248,7 +246,7 @@ AccessAnalyzer::AccessAnalyzer(const Array<te::Tensor>& tensors) {

arith::Analyzer analyzer;

// build read & write access map
// Build read & write access map
for (const auto& op : node->ops_topo_order) {
if (op->IsInstance<te::PlaceholderOpNode>()) {
node->read_from[op] = OperationMap<std::vector<std::vector<PrimExpr>>>();
Expand All @@ -259,12 +257,12 @@ AccessAnalyzer::AccessAnalyzer(const Array<te::Tensor>& tensors) {
}

// read_by and read_from map
for (const auto& iter : extractor.buf_accesses) {
for (const auto& iter : extractor.read_access) {
std::vector<std::vector<PrimExpr>>& accesses = node->read_by[iter.first][op];
accesses.insert(accesses.begin(), iter.second.begin(), iter.second.end());
}

node->read_from[op] = std::move(extractor.buf_accesses);
node->read_from[op] = std::move(extractor.read_access);
has_branch[op] = extractor.has_branch;

// compute number of common outer iterators
Expand All @@ -282,15 +280,15 @@ AccessAnalyzer::AccessAnalyzer(const Array<te::Tensor>& tensors) {
break;
}

bool direct_access = true;
bool injective = true;
for (const auto& access : access_list) {
if (!IsConstShiftEqual(cop->axis[n_common]->var, access[n_common])) {
direct_access = false;
injective = false;
break;
}
}

if (!direct_access) {
if (!injective) {
break;
}
}
Expand All @@ -303,7 +301,7 @@ AccessAnalyzer::AccessAnalyzer(const Array<te::Tensor>& tensors) {
}
}

// do some static analysis
// Do some static analysis on ComputeOps
for (const auto& op : node->ops_topo_order) {
if (op->IsInstance<te::PlaceholderOpNode>()) {
node->is_injective[op] = true;
Expand All @@ -317,9 +315,9 @@ AccessAnalyzer::AccessAnalyzer(const Array<te::Tensor>& tensors) {

bool axis_missing, axis_duplicated, same_order;
for (const auto& pair : node->read_from[op]) {
const std::vector<std::vector<PrimExpr>>& access = pair.second;
for (const auto& index : access) {
if (!auto_scheduler::IsInjective(op, index, &axis_missing, &axis_duplicated,
const std::vector<std::vector<PrimExpr>>& access_list = pair.second;
for (const auto& access : access_list) {
if (!auto_scheduler::IsInjective(op, access, &axis_missing, &axis_duplicated,
&same_order)) {
is_injective = false;
is_strict_inlineable = false;
Expand Down Expand Up @@ -352,10 +350,10 @@ AccessAnalyzer::AccessAnalyzer(const Array<te::Tensor>& tensors) {
int n_missing = 0;

for (const auto& pair : node->read_from[op]) {
const std::vector<std::vector<PrimExpr>>& access = pair.second;
const std::vector<std::vector<PrimExpr>>& access_list = pair.second;
std::unordered_set<const VarNode*> vars;
for (const std::vector<PrimExpr>& indices : access) {
for (const PrimExpr& expr : indices) {
for (const std::vector<PrimExpr>& access : access_list) {
for (const PrimExpr& expr : access) {
GatherVars(expr, &vars);
}
}
Expand Down

0 comments on commit 99e9ab3

Please sign in to comment.