Skip to content

Commit

Permalink
Merge pull request #1 from Superjomn/fea/add-ir-description
Browse files Browse the repository at this point in the history
add IrNodeTy ostream support
  • Loading branch information
Superjomn authored Jan 13, 2020
2 parents 4df5863 + bb41534 commit a97a458
Show file tree
Hide file tree
Showing 11 changed files with 87 additions and 13 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ include(cmake/external/gtest.cmake)
include(cmake/external/protobuf.cmake)
include(cmake/external/mklml.cmake)

find_package(Threads REQUIRED)
add_subdirectory(cinn)
1 change: 1 addition & 0 deletions cinn/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
cc_library(cinn_gtest_main SRCS gtest_main.cc DEPS gtest)

add_subdirectory(utils)
add_subdirectory(ir)
add_subdirectory(dsl)
add_subdirectory(optim)
4 changes: 3 additions & 1 deletion cinn/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
cc_library(ir SRCS
type.cc node.cc ir.cc
type.cc node.cc ir.cc node.cc
ir_visitor.cc
)

cc_test(test_ir SRCS ir_test.cc DEPS ir)
2 changes: 1 addition & 1 deletion cinn/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ struct Block : public ExprNode<Block> {
struct Builder {
template <typename IRType, typename... Args>
Expr make(Args... args) {
return std::shared_ptr<IRType>(args...);
return Expr(std::make_shared<IRType>(args...));
}
};

Expand Down
22 changes: 22 additions & 0 deletions cinn/ir/ir_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#include "cinn/ir/ir.h"
#include <gtest/gtest.h>

#include "cinn/utils/string.h"

namespace cinn {
namespace ir {

TEST(ir, Add) {
Builder builder;

auto one = builder.make<IntImm>(Int(32), 1);
auto two = builder.make<IntImm>(Int(32), 2);

auto add = builder.make<Add>(one, two);

auto cnt = utils::GetStreamCnt(add.node_type());
ASSERT_EQ(cnt, "<node: Add>");
}

} // namespace ir
} // namespace cinn
28 changes: 26 additions & 2 deletions cinn/ir/node.cc
Original file line number Diff line number Diff line change
@@ -1,10 +1,34 @@
#include "cinn/ir/node.h"
#include "cinn/ir/ir.h"

namespace cinn {
namespace ir {

template <>
void ExprNode<IntImm>::Accept(cinn::ir::IRVisitor *v) const {}
//! Implementations for Ir nodes.
// @{
#define __m(t__) \
template <> \
void ExprNode<t__>::Accept(cinn::ir::IRVisitor *v) const {}
NODETY_FORALL(__m)
#undef __m
// @}

std::ostream &operator<<(std::ostream &os, IrNodeTy type) {
switch (type) {
#define __m(t__) \
case IrNodeTy::t__: \
os << "<node: " << #t__ << ">"; \
break;

NODETY_FORALL(__m)
#undef __m

default:
LOG(FATAL) << "unknown IrNodeTy found";
}

return os;
}

} // namespace ir
} // namespace cinn
17 changes: 9 additions & 8 deletions cinn/ir/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,8 @@ class IRVisitor;

#define NODETY_FORALL(macro__) \
NODETY_PRIMITIVE_TYPE_FOR_EACH(__m) \
NODETY_OP_FOR_EACH(__m) \
NODETY_CONTROL_OP_FOR_EACH(__m)

NODETY_OP_FOR_EACH(__m) \
NODETY_CONTROL_OP_FOR_EACH(__m)
// clang-format on

//! Define IrNodeTy
Expand All @@ -56,12 +55,14 @@ enum class IrNodeTy { NODETY_FORALL(__m) };
#undef __m
// @}

std::ostream& operator<<(std::ostream& os, IrNodeTy type);

class IRNode : public std::enable_shared_from_this<IRNode> {
public:
IRNode() = default;
virtual ~IRNode() = default;

virtual void Accept(IRVisitor* v) const = 0;
virtual void Accept(IRVisitor* v) const {}
virtual IrNodeTy node_type() = 0;
virtual const Type& type() = 0;

Expand All @@ -88,7 +89,7 @@ class IRHandle : public std::enable_shared_from_this<IRHandle> {
}
template <typename T>
T* As() {
if (node_type() == T::_type_info_) return static_cast<T*>(ptr_.get());
if (node_type() == T::_node_type_) return static_cast<T*>(ptr_.get());
return nullptr;
}

Expand All @@ -110,11 +111,9 @@ struct ExprNode : public IRNode {
T* self() { return static_cast<T*>(this); }
const T* const_self() const { return static_cast<const T*>(this); }

IrNodeTy node_type() { return _node_type_; }
IrNodeTy node_type() { return T::_node_type_; }
const Type& type() { return type_; }

static IrNodeTy _node_type_;

private:
Type type_;
};
Expand Down Expand Up @@ -158,6 +157,8 @@ struct Expr : public IRHandle {
public:
Expr() = default;
Expr(const Expr& other) : IRHandle(other.ptr()) {}
Expr(const std::shared_ptr<IRNode>& p) : IRHandle(p) {}
void operator=(const std::shared_ptr<IRNode>& p) { ptr_ = p; }

//! Helper function to construct numeric constants of various types.
// @{
Expand Down
1 change: 1 addition & 0 deletions cinn/utils/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
cc_library(utils SRCS string.cc)
5 changes: 5 additions & 0 deletions cinn/utils/string.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include "cinn/utils/string.h"

namespace cinn {
namespace utils {} // namespace utils
} // namespace cinn
17 changes: 17 additions & 0 deletions cinn/utils/string.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#pragma once
#include <sstream>
#include <string>

namespace cinn {
namespace utils {

//! Get the content of a stream.
template <typename T>
std::string GetStreamCnt(const T& x) {
std::stringstream os;
os << x;
return os.str();
}

} // namespace utils
} // namespace cinn
2 changes: 1 addition & 1 deletion cmake/core.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ function(cc_library TARGET_NAME)
(NOT ("${TARGET_NAME}" STREQUAL "cinn_lib"))
)
message(STATUS "xxxxx target:${TARGET_NAME}")
target_link_libraries(${TARGET_NAME} ${isl_lib})
target_link_libraries(${TARGET_NAME} ${isl_lib} Threads::Threads)

endif (
(NOT ("${TARGET_NAME}" STREQUAL "cinn_gtest_main")) AND
Expand Down

0 comments on commit a97a458

Please sign in to comment.