Skip to content

Commit

Permalink
[Refactor] Buffer flatten (#340)
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao authored and jinhongyii committed Jul 29, 2021
1 parent e7b6a16 commit 66b989f
Show file tree
Hide file tree
Showing 6 changed files with 670 additions and 432 deletions.
19 changes: 2 additions & 17 deletions src/meta_schedule/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -338,23 +338,8 @@ inline std::ostream& StdCout(int verbose, int setting = 1) {

/**************** String Manipulation ****************/

/*!
* \brief Find all positions that the specific char occurs in the string
* \param str The string to be examined
* \param c The specific char
* \return A list of integers indicating the occurrence position
*/
inline std::vector<int> FindCharPos(const String& str, char c) {
std::vector<int> result;
const char* data = str.data();
int n = str.length();
for (int i = 0; i < n; ++i) {
if (data[i] == c) {
result.push_back(i);
}
}
return result;
}
using tir::FindCharPos;
using tir::StartsWith;

/**************** Target Hardware Concurrency ****************/

Expand Down
11 changes: 11 additions & 0 deletions src/tir/schedule/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,17 @@ bool ContainsVar(const ObjectRef& stmt_or_expr, const std::unordered_set<const V
return found;
}

std::unordered_set<const VarNode*> Vars(const ObjectRef& stmt_or_expr) {
std::unordered_set<const VarNode*> result;
auto f_visit = [&result](const ObjectRef& obj) -> void {
if (const auto* var = obj.as<VarNode>()) {
result.insert(var);
}
};
PostOrderVisit(stmt_or_expr, f_visit);
return result;
}

bool ValidateBlockBinding(const BlockRealize& realize, const Map<Var, Range>& loop_var_ranges) {
arith::Analyzer analyzer;
Array<arith::IterSumExpr> results = arith::DetectIterMap(
Expand Down
45 changes: 45 additions & 0 deletions src/tir/schedule/transform.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "./utils.h"

namespace tvm {
namespace tir {

Stmt RealizeInitBlock(const Stmt& init, const Array<IterVar>& iter_vars) {
std::vector<PrimExpr> conditions;
for (const IterVar& var : iter_vars) {
if (var->iter_type == IterVarType::kCommReduce) {
conditions.push_back(equal(var->var, var->dom->min));
}
}
int n = conditions.size();
// Handle the case where there is no condition
if (n == 0) {
return init;
}
// Concate the conditions with logical and (&&)
PrimExpr cond = conditions[0];
for (int i = 1; i < n; ++i) {
cond = logical_and(cond, conditions[i]);
}
return IfThenElse(cond, init);
}

} // namespace tir
} // namespace tvm
40 changes: 40 additions & 0 deletions src/tir/schedule/transform.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_TIR_SCHEDULE_TRANSFORM_H_
#define TVM_TIR_SCHEDULE_TRANSFORM_H_

#include <tvm/tir/stmt.h>

#include <unordered_set>

namespace tvm {
namespace tir {

/*!
* \brief Transform the init block into actual computation
* \param init The init block
* \param iter_vars The block variables
* \return The actual computation
*/
Stmt RealizeInitBlock(const Stmt& init, const Array<IterVar>& iter_vars);

} // namespace tir
} // namespace tvm

#endif // TVM_TIR_SCHEDULE_TRANSFORM_H_
42 changes: 42 additions & 0 deletions src/tir/schedule/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include "./primitive.h"

#include <tvm/arith/analyzer.h>
#include <tvm/arith/int_set.h>
#include <tvm/arith/iter_affine_map.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/op.h>
Expand All @@ -55,6 +56,7 @@
#include <vector>

#include "./analysis.h"
#include "./transform.h"

namespace tvm {
namespace tir {
Expand Down Expand Up @@ -598,6 +600,46 @@ static DefaultReducer default_reducers[4] = {

} // namespace default_reducer

/**************** String ****************/

/*!
* \brief Find all positions that the specific char occurs in the string
* \param str The string to be examined
* \param c The specific char
* \return A list of integers indicating the occurrence position
*/
inline std::vector<int> FindCharPos(const String& str, char c) {
std::vector<int> result;
const char* data = str.data();
int n = str.length();
for (int i = 0; i < n; ++i) {
if (data[i] == c) {
result.push_back(i);
}
}
return result;
}

inline bool StartsWith(const String& str, const String& prefix) {
int n = prefix.size();
if (static_cast<int>(str.size()) < n) {
return false;
}
const char* data = str.data();
return std::equal(data, data + n, prefix.data());
}

inline bool StartsWith(const String& str, const char* prefix) {
int n = strlen(prefix);
if (static_cast<int>(str.size()) < n) {
return false;
}
const char* data = str.data();
return std::equal(data, data + n, prefix);
}

/**************** Loop extents ****************/

inline int64_t GetLoopIntExtent(const ForNode* loop) {
const auto* int_extent = loop->extent.as<IntImmNode>();
return int_extent ? int_extent->value : -1;
Expand Down
Loading

0 comments on commit 66b989f

Please sign in to comment.