Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Buffer flatten #340

Merged
1 change: 1 addition & 0 deletions include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,7 @@ class BufferLoad : public PrimExpr {
public:
TVM_DLL explicit BufferLoad(Buffer buffer, Array<PrimExpr> indices, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(BufferLoad, PrimExpr, BufferLoadNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferLoadNode);
};

/*!
Expand Down
5 changes: 3 additions & 2 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -942,6 +942,7 @@ class BufferRegion : public ObjectRef {
TVM_DLL explicit BufferRegion(Buffer buffer);
TVM_DLL explicit BufferRegion(Buffer buffer, Array<Range> region);
TVM_DEFINE_OBJECT_REF_METHODS(BufferRegion, ObjectRef, BufferRegionNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferRegionNode);
};

/*!
Expand Down Expand Up @@ -1058,8 +1059,8 @@ class Block : public Stmt {
TVM_DLL explicit Block(Array<IterVar> iter_vars, Array<BufferRegion> reads,
Array<BufferRegion> writes, Array<Buffer> alloc_buffers,
Map<String, ObjectRef> annotations, Array<MatchBufferRegion> match_buffers,
String exec_scope, String name_hint,
Stmt body, Optional<Stmt> init, Span = Span());
String exec_scope, String name_hint, Stmt body, Optional<Stmt> init,
Span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(Block, Stmt, BlockNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(BlockNode);
Expand Down
19 changes: 2 additions & 17 deletions src/meta_schedule/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -337,23 +337,8 @@ inline std::ostream& StdCout(int verbose, int setting = 1) {

/**************** String Manipulation ****************/

/*!
* \brief Find all positions that the specific char occurs in the string
* \param str The string to be examined
* \param c The specific char
* \return A list of integers indicating the occurrence position
*/
inline std::vector<int> FindCharPos(const String& str, char c) {
std::vector<int> result;
const char* data = str.data();
int n = str.length();
for (int i = 0; i < n; ++i) {
if (data[i] == c) {
result.push_back(i);
}
}
return result;
}
using tir::FindCharPos;
using tir::StartsWith;

/**************** Target Hardware Concurrency ****************/

Expand Down
11 changes: 11 additions & 0 deletions src/tir/schedule/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,17 @@ bool ContainsVar(const ObjectRef& stmt_or_expr, const std::unordered_set<const V
return found;
}

std::unordered_set<const VarNode*> Vars(const ObjectRef& stmt_or_expr) {
std::unordered_set<const VarNode*> result;
auto f_visit = [&result](const ObjectRef& obj) -> void {
if (const auto* var = obj.as<VarNode>()) {
result.insert(var);
}
};
PostOrderVisit(stmt_or_expr, f_visit);
return result;
}

bool ValidateBlockBinding(const BlockRealize& realize, const Map<Var, Range>& loop_var_ranges) {
arith::Analyzer analyzer;
Array<arith::IterSumExpr> results = arith::DetectIterMap(
Expand Down
6 changes: 6 additions & 0 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ bool ContainsVar(const ObjectRef& stmt_or_expr, const Var& var);
* \return A boolean indicating if any var in the list is found in stmt/expr
*/
bool ContainsVar(const ObjectRef& stmt_or_expr, const std::unordered_set<const VarNode*>& var);
/*!
* \brief Collect the variables that appear in the specific Stmt or Expr
* \param stmt_or_expr The Stmt or Expr
* \return All variables that appear
*/
std::unordered_set<const VarNode*> Vars(const ObjectRef& stmt_or_expr);

/******** Verification ********/
/*!
Expand Down
45 changes: 45 additions & 0 deletions src/tir/schedule/transform.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "./utils.h"

namespace tvm {
namespace tir {

Stmt RealizeInitBlock(const Stmt& init, const Array<IterVar>& iter_vars) {
std::vector<PrimExpr> conditions;
for (const IterVar& var : iter_vars) {
if (var->iter_type == IterVarType::kCommReduce) {
conditions.push_back(equal(var->var, var->dom->min));
}
}
int n = conditions.size();
// Handle the case where there is no condition
if (n == 0) {
return init;
}
// Concate the conditions with logical and (&&)
PrimExpr cond = conditions[0];
for (int i = 1; i < n; ++i) {
cond = logical_and(cond, conditions[i]);
}
return IfThenElse(cond, init);
}

} // namespace tir
} // namespace tvm
40 changes: 40 additions & 0 deletions src/tir/schedule/transform.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_TIR_SCHEDULE_TRANSFORM_H_
#define TVM_TIR_SCHEDULE_TRANSFORM_H_

#include <tvm/tir/stmt.h>

#include <unordered_set>

namespace tvm {
namespace tir {

/*!
* \brief Transform the init block into actual computation
* \param init The init block
* \param iter_vars The block variables
* \return The actual computation
*/
Stmt RealizeInitBlock(const Stmt& init, const Array<IterVar>& iter_vars);

} // namespace tir
} // namespace tvm

#endif // TVM_TIR_SCHEDULE_TRANSFORM_H_
42 changes: 42 additions & 0 deletions src/tir/schedule/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#define TVM_TIR_SCHEDULE_SCHEDULE_COMMON_H_

#include <tvm/arith/analyzer.h>
#include <tvm/arith/int_set.h>
#include <tvm/arith/iter_affine_map.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/op.h>
Expand All @@ -35,6 +36,7 @@
#include <vector>

#include "./analysis.h"
#include "./transform.h"

namespace tvm {
namespace tir {
Expand Down Expand Up @@ -445,6 +447,46 @@ static DefaultReducer default_reducers[4] = {

} // namespace default_reducer

/**************** String ****************/

/*!
* \brief Find all positions that the specific char occurs in the string
* \param str The string to be examined
* \param c The specific char
* \return A list of integers indicating the occurrence position
*/
inline std::vector<int> FindCharPos(const String& str, char c) {
std::vector<int> result;
const char* data = str.data();
int n = str.length();
for (int i = 0; i < n; ++i) {
if (data[i] == c) {
result.push_back(i);
}
}
return result;
}

inline bool StartsWith(const String& str, const String& prefix) {
int n = prefix.size();
if (static_cast<int>(str.size()) < n) {
return false;
}
const char* data = str.data();
return std::equal(data, data + n, prefix.data());
}

inline bool StartsWith(const String& str, const char* prefix) {
int n = strlen(prefix);
if (static_cast<int>(str.size()) < n) {
return false;
}
const char* data = str.data();
return std::equal(data, data + n, prefix);
}

/**************** Loop extents ****************/

inline int64_t GetLoopIntExtent(const ForNode* loop) {
const auto* int_extent = loop->extent.as<IntImmNode>();
return int_extent ? int_extent->value : -1;
Expand Down
Loading