Skip to content

Commit

Permalink
[Refactor] Buffer flatten (#340)
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao authored and Hzfengsy committed Mar 26, 2021
1 parent e46b917 commit 9f0b3a1
Show file tree
Hide file tree
Showing 8 changed files with 677 additions and 432 deletions.
1 change: 1 addition & 0 deletions include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,7 @@ class BufferLoad : public PrimExpr {
public:
TVM_DLL explicit BufferLoad(Buffer buffer, Array<PrimExpr> indices, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(BufferLoad, PrimExpr, BufferLoadNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferLoadNode);
};

/*!
Expand Down
19 changes: 2 additions & 17 deletions src/meta_schedule/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -338,23 +338,8 @@ inline std::ostream& StdCout(int verbose, int setting = 1) {

/**************** String Manipulation ****************/

/*!
* \brief Find all positions that the specific char occurs in the string
* \param str The string to be examined
* \param c The specific char
* \return A list of integers indicating the occurrence position
*/
inline std::vector<int> FindCharPos(const String& str, char c) {
std::vector<int> result;
const char* data = str.data();
int n = str.length();
for (int i = 0; i < n; ++i) {
if (data[i] == c) {
result.push_back(i);
}
}
return result;
}
using tir::FindCharPos;
using tir::StartsWith;

/**************** Target Hardware Concurrency ****************/

Expand Down
11 changes: 11 additions & 0 deletions src/tir/schedule/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,17 @@ bool ContainsVar(const ObjectRef& stmt_or_expr, const std::unordered_set<const V
return found;
}

std::unordered_set<const VarNode*> Vars(const ObjectRef& stmt_or_expr) {
std::unordered_set<const VarNode*> result;
auto f_visit = [&result](const ObjectRef& obj) -> void {
if (const auto* var = obj.as<VarNode>()) {
result.insert(var);
}
};
PostOrderVisit(stmt_or_expr, f_visit);
return result;
}

bool ValidateBlockBinding(const BlockRealize& realize, const Map<Var, Range>& loop_var_ranges) {
arith::Analyzer analyzer;
Array<arith::IterSumExpr> results = arith::DetectIterMap(
Expand Down
6 changes: 6 additions & 0 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ bool ContainsVar(const ObjectRef& stmt_or_expr, const Var& var);
* \return A boolean indicating if any var in the list is found in stmt/expr
*/
bool ContainsVar(const ObjectRef& stmt_or_expr, const std::unordered_set<const VarNode*>& var);
/*!
* \brief Collect the variables that appear in the specific Stmt or Expr
* \param stmt_or_expr The Stmt or Expr
* \return All variables that appear
*/
std::unordered_set<const VarNode*> Vars(const ObjectRef& stmt_or_expr);

/******** Verification ********/
/*!
Expand Down
45 changes: 45 additions & 0 deletions src/tir/schedule/transform.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "./utils.h"

namespace tvm {
namespace tir {

Stmt RealizeInitBlock(const Stmt& init, const Array<IterVar>& iter_vars) {
std::vector<PrimExpr> conditions;
for (const IterVar& var : iter_vars) {
if (var->iter_type == IterVarType::kCommReduce) {
conditions.push_back(equal(var->var, var->dom->min));
}
}
int n = conditions.size();
// Handle the case where there is no condition
if (n == 0) {
return init;
}
// Concate the conditions with logical and (&&)
PrimExpr cond = conditions[0];
for (int i = 1; i < n; ++i) {
cond = logical_and(cond, conditions[i]);
}
return IfThenElse(cond, init);
}

} // namespace tir
} // namespace tvm
40 changes: 40 additions & 0 deletions src/tir/schedule/transform.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_TIR_SCHEDULE_TRANSFORM_H_
#define TVM_TIR_SCHEDULE_TRANSFORM_H_

#include <tvm/tir/stmt.h>

#include <unordered_set>

namespace tvm {
namespace tir {

/*!
* \brief Transform the init block into actual computation
* \param init The init block
* \param iter_vars The block variables
* \return The actual computation
*/
Stmt RealizeInitBlock(const Stmt& init, const Array<IterVar>& iter_vars);

} // namespace tir
} // namespace tvm

#endif // TVM_TIR_SCHEDULE_TRANSFORM_H_
42 changes: 42 additions & 0 deletions src/tir/schedule/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#define TVM_TIR_SCHEDULE_SCHEDULE_COMMON_H_

#include <tvm/arith/analyzer.h>
#include <tvm/arith/int_set.h>
#include <tvm/arith/iter_affine_map.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/op.h>
Expand All @@ -35,6 +36,7 @@
#include <vector>

#include "./analysis.h"
#include "./transform.h"

namespace tvm {
namespace tir {
Expand Down Expand Up @@ -455,6 +457,46 @@ static DefaultReducer default_reducers[4] = {

} // namespace default_reducer

/**************** String ****************/

/*!
* \brief Find all positions that the specific char occurs in the string
* \param str The string to be examined
* \param c The specific char
* \return A list of integers indicating the occurrence position
*/
inline std::vector<int> FindCharPos(const String& str, char c) {
std::vector<int> result;
const char* data = str.data();
int n = str.length();
for (int i = 0; i < n; ++i) {
if (data[i] == c) {
result.push_back(i);
}
}
return result;
}

inline bool StartsWith(const String& str, const String& prefix) {
int n = prefix.size();
if (static_cast<int>(str.size()) < n) {
return false;
}
const char* data = str.data();
return std::equal(data, data + n, prefix.data());
}

inline bool StartsWith(const String& str, const char* prefix) {
int n = strlen(prefix);
if (static_cast<int>(str.size()) < n) {
return false;
}
const char* data = str.data();
return std::equal(data, data + n, prefix);
}

/**************** Loop extents ****************/

inline int64_t GetLoopIntExtent(const ForNode* loop) {
const auto* int_extent = loop->extent.as<IntImmNode>();
return int_extent ? int_extent->value : -1;
Expand Down
Loading

0 comments on commit 9f0b3a1

Please sign in to comment.