Skip to content

Commit

Permalink
[TIR][REFACTOR] std::string -> String Migration in data_layout.h
Browse files Browse the repository at this point in the history
  • Loading branch information
cchung100m committed May 14, 2020
1 parent 3af9ab8 commit d4c5e64
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 23 deletions.
22 changes: 12 additions & 10 deletions include/tvm/tir/data_layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ class LayoutAxis {
static const LayoutAxis& Get(const tir::IterVar& itvar);

// Get the singleton LayoutAxis using name[0] (size of name must be 1).
static const LayoutAxis& make(const std::string& name);
static const LayoutAxis& make(const String& name);

inline bool IsPrimal() const { return name_ >= 'A' && name_ <= 'Z'; }
inline std::string name() const { return std::string(1, name_); }
inline String name() const { return String(std::string(1, name_)); }

// if current axis is primal, switch the axis to its subordinate one,
// else switch to the primal.
Expand Down Expand Up @@ -88,7 +88,7 @@ class Layout;
class LayoutNode : public Object {
public:
/*! \brief string representation of layout, "" for scalar. */
std::string name;
String name;
/*! \brief specify each axis of the layout,
* in which the variable name is the name of the axis.
* The IterVar's extent indicates the size of the axis,
Expand All @@ -102,7 +102,7 @@ class LayoutNode : public Object {
v->Visit("axes", &axes);
}

TVM_DLL static Layout make(const std::string& layout);
TVM_DLL static Layout make(const String& layout);

static constexpr const char* _type_key = "Layout";
TVM_DECLARE_FINAL_OBJECT_INFO(LayoutNode, Object);
Expand All @@ -128,7 +128,8 @@ class Layout : public ObjectRef {
explicit Layout(const Array<tir::IterVar>& axes);

/*! \brief construct from a string */
Layout(const char* name) : Layout(std::string(name)) {} // NOLINT(*)
Layout(const char* name) : Layout(String(name)) {} // NOLINT(*)
Layout(const std::string& name) : Layout(String(name)) {} // NOLINT(*)

/*!
* \brief construct from a string.
Expand All @@ -138,7 +139,7 @@ class Layout : public ObjectRef {
* indicates the split dimension.
* return undefined layout if "__undef__" is passed.
*/
Layout(const std::string& name); // NOLINT(*)
Layout(const String& name); // NOLINT(*)

/*!
* \brief access the internal node container
Expand Down Expand Up @@ -206,16 +207,17 @@ class Layout : public ObjectRef {
inline Layout ExpandPrimal(const Layout& dst_layout) {
Layout new_src_layout;
// 1) Find the axis which are missing in the current layout. Make them the prefix.
std::string new_src_layout_str = "";
String new_src_layout_str = "";
for (auto dst_axis : dst_layout->axes) {
if (LayoutAxis::Get(dst_axis).IsPrimal()) {
if (!this->Contains(LayoutAxis::Get(dst_axis))) {
new_src_layout_str += dst_axis->var->name_hint;
new_src_layout_str.operator std::string() +=
dst_axis->var->name_hint.operator std::string();
}
}
}
// 2) Now, add the primal axis of the current layout.
new_src_layout_str += this->name();
new_src_layout_str.operator std::string() += this->name().operator std::string();
new_src_layout = Layout(new_src_layout_str);
return new_src_layout;
}
Expand Down Expand Up @@ -269,7 +271,7 @@ class Layout : public ObjectRef {
}

/*! \return the string description of the layout */
inline std::string name() const {
inline String name() const {
if (!defined()) return "__undef__";
return operator->()->name;
}
Expand Down
26 changes: 13 additions & 13 deletions src/tir/ir/data_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,14 @@ const LayoutAxis& LayoutAxis::Get(const char name) {
}

const LayoutAxis& LayoutAxis::Get(const IterVar& itvar) {
const std::string axis = itvar->var.get()->name_hint;
const String axis = itvar->var.get()->name_hint;
CHECK_EQ(axis.size(), 1) << "Invalid layout axis " << axis;
return LayoutAxis::Get(axis[0]);
return LayoutAxis::Get(axis.operator std::string()[0]);
}

const LayoutAxis& LayoutAxis::make(const std::string& name) {
const LayoutAxis& LayoutAxis::make(const String& name) {
CHECK_EQ(name.length(), 1) << "Invalid axis " << name;
return LayoutAxis::Get(name[0]);
return LayoutAxis::Get(name.operator std::string()[0]);
}

Layout::Layout(const Array<IterVar>& axes) {
Expand All @@ -90,7 +90,7 @@ Layout::Layout(const Array<IterVar>& axes) {
data_ = std::move(node);
}

Layout::Layout(const std::string& name) { // NOLINT(*)
Layout::Layout(const String& name) { // NOLINT(*)
if (name == "__undef__") return;

auto node = make_object<LayoutNode>();
Expand All @@ -100,12 +100,12 @@ Layout::Layout(const std::string& name) { // NOLINT(*)

// parse layout string
int32_t factor = 0;
for (char c : name) {
for (char c : name.operator std::string()) {
if (c >= 'A' && c <= 'Z') {
CHECK_EQ(factor, 0) << "Invalid layout " << name << ": invalid factor size " << factor
<< " before dimension " << c;
std::string shape_name("_shape");
shape_name.insert(0, 1, c);
String shape_name("_shape");
shape_name.operator std::string().insert(0, 1, c);
IterVar axis = IterVarNode::make(Range(PrimExpr(0), Var(shape_name)), Var(std::string(1, c)),
tir::kDataPar);
node->axes.push_back(axis);
Expand Down Expand Up @@ -144,7 +144,7 @@ Layout::Layout(const std::string& name) { // NOLINT(*)
data_ = std::move(node);
}

Layout LayoutNode::make(const std::string& layout) { return Layout(layout); }
Layout LayoutNode::make(const String& layout) { return Layout(layout); }

Layout Layout::SubLayout(size_t pos, size_t len) const {
if (!defined() || pos > ndim()) return Layout::Undef();
Expand All @@ -160,7 +160,7 @@ Layout Layout::SubLayout(size_t pos, size_t len) const {

Layout Layout::Split(const LayoutAxis& axis, size_t target_pos, int32_t factor) const {
if (!defined()) return Layout::Undef();
const std::string& name = operator->()->name;
const String& name = operator->()->name;
const auto axes = operator->()->axes;
CHECK(target_pos <= this->ndim())
<< "Invalid split position " << target_pos << " for layout " << name;
Expand Down Expand Up @@ -367,20 +367,20 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)

TVM_REGISTER_GLOBAL("tir.Layout").set_body_typed(LayoutNode::make);

TVM_REGISTER_GLOBAL("tir.LayoutIndexOf").set_body_typed([](Layout layout, std::string axis) -> int {
TVM_REGISTER_GLOBAL("tir.LayoutIndexOf").set_body_typed([](Layout layout, String axis) -> int {
return layout.IndexOf(LayoutAxis::make(axis));
});

TVM_REGISTER_GLOBAL("tir.LayoutFactorOf")
.set_body_typed([](Layout layout, std::string axis) -> int {
.set_body_typed([](Layout layout, String axis) -> int {
return layout.FactorOf(LayoutAxis::make(axis));
});

TVM_REGISTER_GLOBAL("tir.LayoutNdim").set_body_typed([](Layout layout) -> int {
return layout.ndim();
});

TVM_REGISTER_GLOBAL("tir.LayoutGetItem").set_body_typed([](Layout layout, int idx) -> std::string {
TVM_REGISTER_GLOBAL("tir.LayoutGetItem").set_body_typed([](Layout layout, int idx) -> String {
const LayoutAxis& axis = layout[idx];
return axis.name();
});
Expand Down

0 comments on commit d4c5e64

Please sign in to comment.