diff --git a/CMakeLists.txt b/CMakeLists.txt index 3921fa12835e3..80d1e5638b54b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -17,4 +17,5 @@ include(cmake/external/gtest.cmake) include(cmake/external/protobuf.cmake) include(cmake/external/mklml.cmake) +find_package(Threads REQUIRED) add_subdirectory(cinn) diff --git a/cinn/CMakeLists.txt b/cinn/CMakeLists.txt index 0123b99c73744..e90962fce60f9 100644 --- a/cinn/CMakeLists.txt +++ b/cinn/CMakeLists.txt @@ -1,5 +1,6 @@ cc_library(cinn_gtest_main SRCS gtest_main.cc DEPS gtest) +add_subdirectory(utils) add_subdirectory(ir) add_subdirectory(dsl) add_subdirectory(optim) diff --git a/cinn/ir/CMakeLists.txt b/cinn/ir/CMakeLists.txt index 7ea43510c17e2..ec7fdec2c99d3 100644 --- a/cinn/ir/CMakeLists.txt +++ b/cinn/ir/CMakeLists.txt @@ -1,4 +1,6 @@ cc_library(ir SRCS - type.cc node.cc ir.cc + type.cc node.cc ir.cc node.cc ir_visitor.cc ) + +cc_test(test_ir SRCS ir_test.cc DEPS ir) diff --git a/cinn/ir/ir.h b/cinn/ir/ir.h index cac0b6ca045ce..5b218d2cf7eea 100644 --- a/cinn/ir/ir.h +++ b/cinn/ir/ir.h @@ -200,7 +200,7 @@ struct Block : public ExprNode { struct Builder { template Expr make(Args... args) { - return std::shared_ptr(args...); + return Expr(std::make_shared(args...)); } }; diff --git a/cinn/ir/ir_test.cc b/cinn/ir/ir_test.cc new file mode 100644 index 0000000000000..b98f614c6e72d --- /dev/null +++ b/cinn/ir/ir_test.cc @@ -0,0 +1,22 @@ +#include "cinn/ir/ir.h" +#include + +#include "cinn/utils/string.h" + +namespace cinn { +namespace ir { + +TEST(ir, Add) { + Builder builder; + + auto one = builder.make(Int(32), 1); + auto two = builder.make(Int(32), 2); + + auto add = builder.make(one, two); + + auto cnt = utils::GetStreamCnt(add.node_type()); + ASSERT_EQ(cnt, ""); +} + +} // namespace ir +} // namespace cinn diff --git a/cinn/ir/node.cc b/cinn/ir/node.cc index 665c9e3d79510..adee28a05eb6d 100644 --- a/cinn/ir/node.cc +++ b/cinn/ir/node.cc @@ -1,10 +1,34 @@ #include "cinn/ir/node.h" +#include "cinn/ir/ir.h" namespace cinn { namespace ir { -template <> -void ExprNode::Accept(cinn::ir::IRVisitor *v) const {} +//! Implementations for Ir nodes. +// @{ +#define __m(t__) \ + template <> \ + void ExprNode::Accept(cinn::ir::IRVisitor *v) const {} +NODETY_FORALL(__m) +#undef __m +// @} + +std::ostream &operator<<(std::ostream &os, IrNodeTy type) { + switch (type) { +#define __m(t__) \ + case IrNodeTy::t__: \ + os << ""; \ + break; + + NODETY_FORALL(__m) +#undef __m + + default: + LOG(FATAL) << "unknown IrNodeTy found"; + } + + return os; +} } // namespace ir } // namespace cinn diff --git a/cinn/ir/node.h b/cinn/ir/node.h index 4b9e5b6c8471f..dbeea4f5b903e 100644 --- a/cinn/ir/node.h +++ b/cinn/ir/node.h @@ -44,9 +44,8 @@ class IRVisitor; #define NODETY_FORALL(macro__) \ NODETY_PRIMITIVE_TYPE_FOR_EACH(__m) \ -NODETY_OP_FOR_EACH(__m) \ -NODETY_CONTROL_OP_FOR_EACH(__m) - + NODETY_OP_FOR_EACH(__m) \ + NODETY_CONTROL_OP_FOR_EACH(__m) // clang-format on //! Define IrNodeTy @@ -56,12 +55,14 @@ enum class IrNodeTy { NODETY_FORALL(__m) }; #undef __m // @} +std::ostream& operator<<(std::ostream& os, IrNodeTy type); + class IRNode : public std::enable_shared_from_this { public: IRNode() = default; virtual ~IRNode() = default; - virtual void Accept(IRVisitor* v) const = 0; + virtual void Accept(IRVisitor* v) const {} virtual IrNodeTy node_type() = 0; virtual const Type& type() = 0; @@ -88,7 +89,7 @@ class IRHandle : public std::enable_shared_from_this { } template T* As() { - if (node_type() == T::_type_info_) return static_cast(ptr_.get()); + if (node_type() == T::_node_type_) return static_cast(ptr_.get()); return nullptr; } @@ -110,11 +111,9 @@ struct ExprNode : public IRNode { T* self() { return static_cast(this); } const T* const_self() const { return static_cast(this); } - IrNodeTy node_type() { return _node_type_; } + IrNodeTy node_type() { return T::_node_type_; } const Type& type() { return type_; } - static IrNodeTy _node_type_; - private: Type type_; }; @@ -158,6 +157,8 @@ struct Expr : public IRHandle { public: Expr() = default; Expr(const Expr& other) : IRHandle(other.ptr()) {} + Expr(const std::shared_ptr& p) : IRHandle(p) {} + void operator=(const std::shared_ptr& p) { ptr_ = p; } //! Helper function to construct numeric constants of various types. // @{ diff --git a/cinn/utils/CMakeLists.txt b/cinn/utils/CMakeLists.txt new file mode 100644 index 0000000000000..49a0cb46668ef --- /dev/null +++ b/cinn/utils/CMakeLists.txt @@ -0,0 +1 @@ +cc_library(utils SRCS string.cc) diff --git a/cinn/utils/string.cc b/cinn/utils/string.cc new file mode 100644 index 0000000000000..c6a02896da1b5 --- /dev/null +++ b/cinn/utils/string.cc @@ -0,0 +1,5 @@ +#include "cinn/utils/string.h" + +namespace cinn { +namespace utils {} // namespace utils +} // namespace cinn diff --git a/cinn/utils/string.h b/cinn/utils/string.h new file mode 100644 index 0000000000000..35eb0eee0b0bd --- /dev/null +++ b/cinn/utils/string.h @@ -0,0 +1,17 @@ +#pragma once +#include +#include + +namespace cinn { +namespace utils { + +//! Get the content of a stream. +template +std::string GetStreamCnt(const T& x) { + std::stringstream os; + os << x; + return os.str(); +} + +} // namespace utils +} // namespace cinn diff --git a/cmake/core.cmake b/cmake/core.cmake index 0830cd1888def..31157d28ca3ac 100644 --- a/cmake/core.cmake +++ b/cmake/core.cmake @@ -39,7 +39,7 @@ function(cc_library TARGET_NAME) (NOT ("${TARGET_NAME}" STREQUAL "cinn_lib")) ) message(STATUS "xxxxx target:${TARGET_NAME}") - target_link_libraries(${TARGET_NAME} ${isl_lib}) + target_link_libraries(${TARGET_NAME} ${isl_lib} Threads::Threads) endif ( (NOT ("${TARGET_NAME}" STREQUAL "cinn_gtest_main")) AND