From 1b8ddc9ad9742d010d692e00e0ca2c4fd1d4802d Mon Sep 17 00:00:00 2001 From: Shibo Tao <62922815+T8T9@users.noreply.github.com> Date: Mon, 26 Apr 2021 15:07:00 +0800 Subject: [PATCH] CINNRT support load params from file and add predictor (#371) --- build.sh | 16 ++ cinn/frontend/paddle/model_parser.h | 2 + cinnrt/CMakeLists.txt | 1 + cinnrt/api/CMakeLists.txt | 7 + cinnrt/api/cinnrt_api.cc | 210 ++++++++++++++++++++++ cinnrt/api/cinnrt_api.h | 46 +++++ cinnrt/api/cinnrt_api_test.cc | 58 ++++++ cinnrt/common/global.h | 1 + cinnrt/dialect/basic_kernels.cc | 2 + cinnrt/dialect/basic_kernels.td | 24 +++ cinnrt/dialect/cinn_base.cc | 21 +++ cinnrt/dialect/cinn_base.td | 8 +- cinnrt/dialect/dense_tensor.cc | 8 + cinnrt/dialect/dense_tensor.h | 14 ++ cinnrt/dialect/dense_tensor.td | 32 ++++ cinnrt/dialect/mlir_tests/basic.mlir | 8 + cinnrt/dialect/mlir_tests/tensor_map.mlir | 31 ++++ cinnrt/host_context/value.cc | 2 + cinnrt/host_context/value.h | 3 + cinnrt/kernel/basic_kernels.cc | 13 ++ cinnrt/kernel/tensor_kernels.cc | 19 +- cinnrt/tensor/CMakeLists.txt | 10 ++ cinnrt/tensor/dense_host_tensor.cc | 9 + cinnrt/tensor/dense_host_tensor.h | 5 + cinnrt/tensor/tensor_map.cc | 79 ++++++++ cinnrt/tensor/tensor_map.h | 12 ++ 26 files changed, 636 insertions(+), 5 deletions(-) create mode 100644 cinnrt/api/CMakeLists.txt create mode 100644 cinnrt/api/cinnrt_api.cc create mode 100644 cinnrt/api/cinnrt_api.h create mode 100644 cinnrt/api/cinnrt_api_test.cc create mode 100644 cinnrt/dialect/mlir_tests/tensor_map.mlir create mode 100644 cinnrt/tensor/tensor_map.cc create mode 100644 cinnrt/tensor/tensor_map.h diff --git a/build.sh b/build.sh index 04788a733c2f0..9292d2c95e72e 100755 --- a/build.sh +++ b/build.sh @@ -94,6 +94,22 @@ function prepare_model { wget http://paddle-inference-dist.bj.bcebos.com/CINN/EfficientNet.tar tar -xvf EfficientNet.tar fi + mkdir -p $build_dir/paddle + cd $build_dir/paddle + if [[ ! -f "libexternal_kernels.so.tgz" ]]; then + wget https://github.com/T8T9/files/raw/main/libexternal_kernels.so.tgz + fi + tar -zxvf libexternal_kernels.so.tgz + if [[ ! -f "paddle_1.8_fc_model.tgz" ]]; then + wget https://github.com/T8T9/files/raw/main/paddle_1.8_fc_model.tgz + fi + tar -zxvf paddle_1.8_fc_model.tgz + if [[ ! -f "mkldnn.tgz" ]]; then + wget https://github.com/T8T9/files/raw/main/mkldnn.tgz + fi + tar -zxvf mkldnn.tgz + export LD_LIBRARY_PATH=$build_dir/paddle/mkldnn:$build_dir/thirds/install/mklml/lib:$LD_LIBRARY_PATH + cd - python3 $workspace/python/tests/fake_model/naive_mul.py python3 $workspace/python/tests/fake_model/naive_multi_fc.py python3 $workspace/python/tests/fake_model/resnet_model.py diff --git a/cinn/frontend/paddle/model_parser.h b/cinn/frontend/paddle/model_parser.h index 76ffb099df17f..185c78961b3e7 100644 --- a/cinn/frontend/paddle/model_parser.h +++ b/cinn/frontend/paddle/model_parser.h @@ -28,6 +28,8 @@ void LoadModelPb(const std::string& model_dir, // Read a __model__ file. std::unique_ptr LoadProgram(const std::string& path, bool program_from_memory = false); +void LoadLoDTensor(std::istream& is, hlir::framework::Variable* var, const common::Target& target); + // Read a single file containing all the parameters. void LoadParams(const std::string& path); diff --git a/cinnrt/CMakeLists.txt b/cinnrt/CMakeLists.txt index 8cc2aed7f2086..b97539c76bb77 100644 --- a/cinnrt/CMakeLists.txt +++ b/cinnrt/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(api) add_subdirectory(common) add_subdirectory(dialect) add_subdirectory(host_context) diff --git a/cinnrt/api/CMakeLists.txt b/cinnrt/api/CMakeLists.txt new file mode 100644 index 0000000000000..a089d99aa48b3 --- /dev/null +++ b/cinnrt/api/CMakeLists.txt @@ -0,0 +1,7 @@ +core_gather_headers() + +core_gather_srcs(SRCS + cinnrt_api.cc + ) + +cc_test(test_cinnrt_predictor SRCS cinnrt_api_test.cc DEPS cinncore ${MLIR_IR_LIBS}) diff --git a/cinnrt/api/cinnrt_api.cc b/cinnrt/api/cinnrt_api.cc new file mode 100644 index 0000000000000..124ee639b5670 --- /dev/null +++ b/cinnrt/api/cinnrt_api.cc @@ -0,0 +1,210 @@ +#include "cinnrt/api/cinnrt_api.h" + +#include + +#include "cinnrt/common/global.h" +#include "cinnrt/dialect/dense_tensor.h" +#include "cinnrt/dialect/mlir_loader.h" +#include "cinnrt/host_context/core_runtime.h" +#include "cinnrt/host_context/kernel_registry.h" +#include "cinnrt/host_context/mlir_function_executable.h" +#include "cinnrt/host_context/mlir_to_runtime_translate.h" +#include "cinnrt/host_context/op_executable.h" +#include "cinnrt/host_context/value.h" +#include "cinnrt/kernel/basic_kernels.h" +#include "cinnrt/kernel/control_flow_kernels.h" +#include "cinnrt/kernel/tensor_kernels.h" +#include "cinnrt/kernel/tensor_shape_kernels.h" +#include "cinnrt/kernel/test_kernels.h" +#include "cinnrt/tensor/tensor_map.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/DynamicLibrary.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Parser.h" + +using namespace cinnrt::host_context; +using namespace cinnrt::tensor; +using namespace cinnrt::tensor; +using cinnrt::dt::TensorMapType; +using cinnrt::dt::TensorType; + +namespace cinnrt { + +template +std::string DumpToString(T& op) { // NOLINT + std::string buffer; + llvm::raw_string_ostream os(buffer); + op.print(os); + os.flush(); + return buffer; +} + +struct MlirToRuntimeTranslator::Impl { + mlir::ModuleOp module; + // The runtime for a function call. + CoreRuntimeBuilder* runtime{}; + // The current working op, the translator process the ops one by one, each time it updates `cur_op` here to current op + // working on. + OpExecutableBuilder* cur_op{}; + + // record the current function name. + std::string cur_func_name; + + // Name to function definitions. + std::unordered_map func_defs; + + // Map from an operation to its results. + std::unordered_map> op_results; + llvm::DenseMap value_map; +}; + +/** + * Execute the mlir program in predict mode. + */ +class PredictExecutor : public MlirToRuntimeTranslator { + public: + CoreRuntimeBuilder core_runtime; + + PredictExecutor(mlir::ModuleOp module, KernelRegistry* registry, TensorMap* map) + : core_runtime(registry), MlirToRuntimeTranslator(module, &core_runtime), registry_(registry) { + CHECK(registry_); + Init(map); + } + + void Run() { + auto arguments = llvm::makeArrayRef(arguments_); + auto results = llvm::makeMutableArrayRef(results_.begin(), results_.size()); + function_executable_->Execute(arguments, results); + } + + int GetInputNum() { return inputs_.size(); } + + DenseHostTensor* GetInput(int i) { return inputs_[i]; } + + int GetOutputNum() { return outputs_.size(); } + + DenseHostTensor* GetOutput(int i) { return outputs_[i]; } + + private: + void Init(TensorMap* map) { + EmitFunctions(); + llvm::Optional predict_func_ = llvm::None; + for (auto func_op : impl_->module.getOps()) { + if (func_op.getName().str() != "predict") continue; + predict_func_ = func_op; + break; + } + if (!predict_func_) { + std::cout << "ERROR: init failed, no predict function found in mlir." << std::endl; + return; + } + auto& predict_func = predict_func_.getValue(); + function_executable_ = new MlirFunctionExecutable(predict_func, registry_, impl_->func_defs); + + // process parammeters + for (int i = 0; i < predict_func.getNumArguments(); ++i) { + auto arg = predict_func.getArgument(i); + auto type = arg.getType(); + // this param is TensorMap + if (type.isa()) { + auto* value = new host_context::Value(std::move(*map)); + arguments_.push_back(value); + AddValue(predict_func.getArgument(i), value); + } else { + // this param is an input Tensor + auto dht = DenseHostTensor(); + auto* value = new host_context::Value(std::move(dht)); + arguments_.push_back(value); + inputs_.push_back(&(value->get())); + } + } + + // process results + auto& last_op = predict_func.front().back(); + if (last_op.getName().getStringRef() == "cinn.return") { + for (int i = 0; i < last_op.getNumOperands(); ++i) { + auto* value = AddValue(mlir::Value(last_op.getOperand(i))); + results_.push_back(ValueRef(value)); + outputs_.push_back(&(value->get())); + } + } + } + + protected: + std::unordered_map func_def_table; + + void EmitFunction(mlir::FuncOp op) override { + auto it = impl_->func_defs.try_emplace(op.getName().str(), op); + CHECK(it.second) << "Duplicate function defition found for function [" << op.getName().str(); + } + + private: + KernelRegistry* registry_{}; + MlirFunctionExecutable* function_executable_; + llvm::SmallVector inputs_; + llvm::SmallVector arguments_; + llvm::SmallVector outputs_; + llvm::SmallVector results_; +}; + +std::shared_ptr CreateCinnRtPredictor(const CinnRtConfig& config) { + auto x = std::make_shared(); + x->Init(config); + return x; +} + +struct CinnRtPredictor::Impl { + mlir::OwningModuleRef module_ref; + PredictExecutor* executor; +}; + +CinnRtPredictor::CinnRtPredictor() : impl_(new Impl) {} +CinnRtPredictor::~CinnRtPredictor() {} + +void CinnRtPredictor::Run() { impl_->executor->Run(); } + +int CinnRtPredictor::Init(const CinnRtConfig& config) { + mlir::MLIRContext* context = cinnrt::Global::getMLIRContext(); + auto module_ref = dialect::LoadMlirFile(config.mlir_path(), context); + + KernelRegistry* registry = new KernelRegistry(); + + kernel::RegisterBasicKernels(registry); + kernel::RegisterTestKernels(registry); + kernel::RegisterTensorShapeKernels(registry); + kernel::RegisterTensorKernels(registry); + kernel::RegisterControlFlowKernels(registry); + + impl_->module_ref = std::move(module_ref); + + // load extra shared library + for (const std::string& lib_path : config.shared_libs()) { + std::string err; + llvm::sys::DynamicLibrary dynLib = llvm::sys::DynamicLibrary::getPermanentLibrary(lib_path.c_str(), &err); + if (!dynLib.isValid()) { + llvm::errs() << "Load shared library failed. Error: " << err << "\n"; + return 1; + } + if (auto reg_sym = dynLib.SearchForAddressOfSymbol("RegisterKernels")) { + auto reg_func = reinterpret_cast(reg_sym); + reg_func(registry); + } else { + llvm::outs() << "Symbol \"RegisterKernels\" not found in \"" << lib_path << "\". Skip.\n"; + } + } + // Load params + TensorMap* map = LoadParams(config.model_dir()); + // Create PredictExecutor + impl_->executor = new PredictExecutor(impl_->module_ref.get(), registry, map); + return 0; +} + +int CinnRtPredictor::GetInputNum() { return impl_->executor->GetInputNum(); } + +DenseHostTensor* CinnRtPredictor::GetInput(int i) { return impl_->executor->GetInput(i); } + +int CinnRtPredictor::GetOutputNum() { return impl_->executor->GetOutputNum(); } + +DenseHostTensor* CinnRtPredictor::GetOutput(int i) { return impl_->executor->GetOutput(i); } + +} // namespace cinnrt diff --git a/cinnrt/api/cinnrt_api.h b/cinnrt/api/cinnrt_api.h new file mode 100644 index 0000000000000..ee60e4b3c5a2e --- /dev/null +++ b/cinnrt/api/cinnrt_api.h @@ -0,0 +1,46 @@ +#include +#include +#include + +#include "cinnrt/tensor/dense_host_tensor.h" + +namespace cinnrt { + +class CinnRtConfig { + std::string model_dir_; + std::string mlir_path_; + std::vector shared_libs_; + + public: + CinnRtConfig() = default; + void set_model_dir(const std::string& model_dir) { model_dir_ = model_dir; }; + const std::string& model_dir() const { return model_dir_; } + + void set_mlir_path(const std::string& mlir_path) { mlir_path_ = mlir_path; }; + const std::string& mlir_path() const { return mlir_path_; } + + void set_shared_libs(const std::vector& shared_libs) { shared_libs_ = shared_libs; }; + const std::vector& shared_libs() const { return shared_libs_; } + + virtual ~CinnRtConfig() = default; +}; + +class CinnRtPredictor { + public: + CinnRtPredictor(); + ~CinnRtPredictor(); + void Run(); + int Init(const CinnRtConfig& config); + int GetInputNum(); + tensor::DenseHostTensor* GetInput(int i); + int GetOutputNum(); + tensor::DenseHostTensor* GetOutput(int i); + + protected: + struct Impl; + std::unique_ptr impl_; +}; + +std::shared_ptr CreateCinnRtPredictor(const CinnRtConfig& config); + +} // namespace cinnrt diff --git a/cinnrt/api/cinnrt_api_test.cc b/cinnrt/api/cinnrt_api_test.cc new file mode 100644 index 0000000000000..8682c103c55d2 --- /dev/null +++ b/cinnrt/api/cinnrt_api_test.cc @@ -0,0 +1,58 @@ +#include "cinnrt/api/cinnrt_api.h" + +#include + +#include +#include + +#include "cinn/hlir/framework/buffer.h" +#include "cinnrt/common/dtype.h" +#include "llvm/Support/raw_ostream.h" + +using cinnrt::CinnRtConfig; +using cinnrt::CinnRtPredictor; +using cinnrt::CreateCinnRtPredictor; + +namespace cinnrt { + +TEST(CinnRtPredictor, predictor) { + std::vector shared_libs; + shared_libs.push_back("../../paddle/libexternal_kernels.so"); + + CinnRtConfig config; + + // set external shared libraries that contain kernels. + config.set_shared_libs(shared_libs); + // set model dir + config.set_model_dir("../../paddle/paddle_1.8_fc_model"); + // set mlir path + config.set_mlir_path("../../../cinnrt/dialect/mlir_tests/tensor_map.mlir"); + + std::shared_ptr predictor = CreateCinnRtPredictor(config); + + // std::cout << "input num: " << predictor->GetInputNum() << std::endl; + // std::cout << "output num: " << predictor->GetOutputNum() << std::endl; + auto* input = predictor->GetInput(0); + std::vector shape = {3, 3}; + input->Init(shape, cinnrt::GetDType()); + llvm::outs() << input->shape() << "\n"; + + // init input tensor + auto* input_data = reinterpret_cast(input->buffer()->data()->memory); + for (int i = 0; i < input->shape().GetNumElements(); i++) input_data[i] = 1.0; + + predictor->Run(); + + // get and print output tensor + auto* output = predictor->GetOutput(0); + auto* output_data = reinterpret_cast(output->buffer()->data()->memory); + + std::vector ans = {0.428458, 0.244493, 0.572342, 0.572008, 0.509771, 0.495599, 0.651287, 0.326426, 0.404649}; + + ASSERT_EQ(output->shape().GetNumElements(), ans.size()); + for (int i = 0; i < output->shape().GetNumElements(); ++i) { + ASSERT_NEAR(output_data[i], ans[i], 0.000001); + } +} + +} // namespace cinnrt diff --git a/cinnrt/common/global.h b/cinnrt/common/global.h index e59eb9b20b5b9..8eeb66a22bc5d 100644 --- a/cinnrt/common/global.h +++ b/cinnrt/common/global.h @@ -1,5 +1,6 @@ #pragma once +#include "cinnrt/tensor/dense_host_tensor.h" #include "mlir/IR/MLIRContext.h" namespace cinnrt { diff --git a/cinnrt/dialect/basic_kernels.cc b/cinnrt/dialect/basic_kernels.cc index 3194894496a26..48ed3a99bd110 100644 --- a/cinnrt/dialect/basic_kernels.cc +++ b/cinnrt/dialect/basic_kernels.cc @@ -11,6 +11,8 @@ #include #include +#include "cinnrt/dialect/dense_tensor.h" + namespace cinnrt::dialect { using namespace mlir; diff --git a/cinnrt/dialect/basic_kernels.td b/cinnrt/dialect/basic_kernels.td index 03cf8f2aaffe2..3b34fbd7db8e1 100644 --- a/cinnrt/dialect/basic_kernels.td +++ b/cinnrt/dialect/basic_kernels.td @@ -112,4 +112,28 @@ class PrintOp : CINN_Op<"print." # suffix> { def PrintF32Op : PrintOp<"f32", F32>; //def PrintF64Op : PrintOp<"f64", F64>; +def GetStringOp : CINN_Op<"get_string"> { + let summary = "cinn.get_string"; + let description = [{ + Get a !cinn.string value from the given string attribute. + }]; + + let arguments = (ins StrAttr:$value); + let results = (outs StringType); + let assemblyFormat = "`(` $value `)` attr-dict"; + let verifier = ?; +} + +def PrintStringOp : CINN_Op<"print_string"> { + let summary = "cinn.print_string"; + let description = [{ + An operation that prints a string. + }]; + + let arguments = (ins StringType:$input); + let results = (outs); + let assemblyFormat = "`(` $input `)` attr-dict"; + let verifier = ?; +} + #endif // basic kernels diff --git a/cinnrt/dialect/cinn_base.cc b/cinnrt/dialect/cinn_base.cc index 4c036828d3a1c..6e686dc266f03 100644 --- a/cinnrt/dialect/cinn_base.cc +++ b/cinnrt/dialect/cinn_base.cc @@ -11,7 +11,9 @@ void CINNDialect::initialize() { allowUnknownTypes(); allowUnknownOperations(); + addTypes(); addTypes(); + addTypes(); addOperations< #define GET_OP_LIST @@ -67,6 +69,15 @@ mlir::Type CINNDialect::parseType(mlir::DialectAsmParser &parser) const { return cinnrt::dt::TensorType::get(*targetType, *layoutType, *precisionType); } + // parse TensorMapType, for example: !cinn.tensor_map + if (keyword == "tensor_map") { + return cinnrt::dt::TensorMapType::get(); + } + // parse StringType, for example: !cinn.string + if (keyword == "string") { + return cinnrt::dt::StringType::get(); + } + parser.emitError(parser.getCurrentLocation(), "unknown cinn type: ") << keyword; return mlir::Type(); } @@ -78,6 +89,16 @@ void CINNDialect::printType(mlir::Type type, mlir::DialectAsmPrinter &printer) c printer << "tensor<" << tensorType.target() << ", " << tensorType.layout() << ", " << tensorType.precision() << ">"; return; } + // print TensorMapType, for example: !cinn.tensor_map + if (type.isa()) { + printer << "tensor_map"; + return; + } + // print StringType, for example: !cinn.string + if (type.isa()) { + printer << "string"; + return; + } llvm_unreachable("unknown cinn type."); } diff --git a/cinnrt/dialect/cinn_base.td b/cinnrt/dialect/cinn_base.td index 38d76f8d53bef..70427fe67c403 100644 --- a/cinnrt/dialect/cinn_base.td +++ b/cinnrt/dialect/cinn_base.td @@ -15,11 +15,17 @@ def CINN_Dialect : Dialect { } // Type definitions -def StringType : OpaqueType<"cinn", "string", "!cinn.string type">; +def StringType : + Type()">, "!cinn.string type">, + BuildableType<"$_builder.getType<::cinnrt::dt::StringType>()">; def TensorType : Type()">, "!cinn.tensor type">; +def TensorMapType : + Type()">, "!cinn.tensor_map type">, + BuildableType<"$_builder.getType<::cinnrt::dt::TensorMapType>()">; + def BufferType : OpaqueType<"b", "buffer", "buffer">; #endif // CINN_BASE diff --git a/cinnrt/dialect/dense_tensor.cc b/cinnrt/dialect/dense_tensor.cc index ab2ff56f25648..2f6d6180a38ec 100644 --- a/cinnrt/dialect/dense_tensor.cc +++ b/cinnrt/dialect/dense_tensor.cc @@ -91,6 +91,14 @@ raw_ostream &operator<<(raw_ostream &os, TensorType tensorType) { return os; } +TensorMapType TensorMapType::get() { return Base::get(::cinnrt::Global::getMLIRContext()); } + +TensorMapType TensorMapType::get(mlir::MLIRContext *context) { return Base::get(context); } + +StringType StringType::get() { return Base::get(::cinnrt::Global::getMLIRContext()); } + +StringType StringType::get(mlir::MLIRContext *context) { return Base::get(context); } + raw_ostream &operator<<(raw_ostream &os, TargetType type) { switch (type) { case (TargetType::X86): diff --git a/cinnrt/dialect/dense_tensor.h b/cinnrt/dialect/dense_tensor.h index dd7f684c26f4f..89672da9357d5 100644 --- a/cinnrt/dialect/dense_tensor.h +++ b/cinnrt/dialect/dense_tensor.h @@ -36,6 +36,20 @@ class TensorType : public mlir::Type::TypeBase { + public: + using Base::Base; + static TensorMapType get(); + static TensorMapType get(mlir::MLIRContext *context); +}; + +class StringType : public mlir::Type::TypeBase { + public: + using Base::Base; + static StringType get(); + static StringType get(mlir::MLIRContext *context); +}; + #include "cinnrt/dialect/dense_tensor_dialect.hpp.inc" #define GET_OP_CLASSES diff --git a/cinnrt/dialect/dense_tensor.td b/cinnrt/dialect/dense_tensor.td index 20c3952d4faa1..572bdc93adc57 100644 --- a/cinnrt/dialect/dense_tensor.td +++ b/cinnrt/dialect/dense_tensor.td @@ -81,6 +81,38 @@ class SetTensorOp : let printer = [{ return cinnrt::dt::printSetTensorOp(p, *this); }]; } +def LoadParamsOp : DT_Op<"load_params", [NoSideEffect]> { + let summary = "dt.load_params operation"; + + let description = [{ + An operation that can load tensors to TensorMap. + }]; + + // input path of model params. + let arguments = (ins StringType:$path); + let results = (outs TensorMapType); + + let assemblyFormat = "`(` operands `)` attr-dict"; + let verifier = ?; +} + +def GetParamOp : DT_Op<"get_param", [NoSideEffect]> { + let summary = "dt.get_param operation"; + + let description = [{ + An operation that can get a tensor from TensorMap. + }]; + + // input path of model params. + let arguments = (ins + TensorMapType:$map, + StrAttr:$name + ); + let results = (outs TensorType:$output); + let assemblyFormat = "`(` $map `,` $name `)` attr-dict `->` type($output)"; + let verifier = ?; +} + def GetTensorShapeOp : DT_Op<"get_tensor_shape", [NoSideEffect]> { let summary = "dt.get_tensor_shape operation"; diff --git a/cinnrt/dialect/mlir_tests/basic.mlir b/cinnrt/dialect/mlir_tests/basic.mlir index 8dae3c4bbe1fe..2f6ae4783dbb4 100644 --- a/cinnrt/dialect/mlir_tests/basic.mlir +++ b/cinnrt/dialect/mlir_tests/basic.mlir @@ -30,3 +30,11 @@ func @caller.add.f32() -> f32 { cinn.return %z : f32 } /// <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< + +// CHECK-LABEL: @string_test +func @string_test() { + %path = cinn.get_string("this is get_string op.") + // CHECK-LABEL: string = this is get_string op. + cinn.print_string(%path) + cinn.return +} diff --git a/cinnrt/dialect/mlir_tests/tensor_map.mlir b/cinnrt/dialect/mlir_tests/tensor_map.mlir new file mode 100644 index 0000000000000..dbb163769eba9 --- /dev/null +++ b/cinnrt/dialect/mlir_tests/tensor_map.mlir @@ -0,0 +1,31 @@ +// CHECK-LABEL: @predict +func @predict(%input:!cinn.tensor, %map: !cinn.tensor_map) -> (!cinn.tensor) { + %w = dt.get_param(%map, "create_parameter_0.w_0") -> !cinn.tensor + %bias = dt.get_param(%map, "create_parameter_1.w_0") -> !cinn.tensor + + %out = dt.create_uninit_tensor.f32 [3, 3] -> !cinn.tensor + + // fc + "external.matmul"(%input, %w, %out) {}: (!cinn.tensor, !cinn.tensor, !cinn.tensor) -> () + "external.elementwise_add"(%out, %bias, %out) {axis = -1}: (!cinn.tensor, !cinn.tensor, !cinn.tensor) -> () + "external.sigmoid"(%out, %out) {}: (!cinn.tensor, !cinn.tensor) -> () + //dt.print_tensor (%out : !cinn.tensor) + + cinn.return %out : !cinn.tensor +} + +// CHECK-LABEL: @main +func @main() { + %input = dt.create_uninit_tensor.f32 [3, 3] -> !cinn.tensor + dt.fill_tensor_with_constant.f32 (%input : !cinn.tensor) {value=1.0:f32} + + %path = cinn.get_string("/cinn/build/paddle/paddle_1.8_fc_model") + // CHECK-LABEL: loading params + %map = dt.load_params(%path) + + %out = cinn.call @predict(%input, %map): (!cinn.tensor, !cinn.tensor_map) -> (!cinn.tensor) + dt.print_tensor (%out : !cinn.tensor) + + cinn.return +} + diff --git a/cinnrt/host_context/value.cc b/cinnrt/host_context/value.cc index ce00c1fd4d89f..13f038375ecef 100644 --- a/cinnrt/host_context/value.cc +++ b/cinnrt/host_context/value.cc @@ -43,6 +43,8 @@ void CopyTo(const Value& from, Value* to) { to->data = arg; else if constexpr (std::is_same_v>) to->data = arg; + else if constexpr (std::is_same_v) + to->data = arg; else LOG(FATAL) << "Not supported Value copy: " << typeid(T).name(); }, diff --git a/cinnrt/host_context/value.h b/cinnrt/host_context/value.h index 2d8629013d55b..56505142e7bbe 100644 --- a/cinnrt/host_context/value.h +++ b/cinnrt/host_context/value.h @@ -12,6 +12,7 @@ #include "cinnrt/support/variant.h" #include "cinnrt/tensor/dense_host_tensor.h" #include "cinnrt/tensor/dense_tensor_view.h" +#include "cinnrt/tensor/tensor_map.h" #include "cinnrt/tensor/tensor_shape.h" namespace cinnrt { @@ -29,6 +30,7 @@ using ValueVariantType = cinnrt::Variant, std::vector, std::vector, @@ -52,6 +54,7 @@ class Value : public cinnrt::common::Object { explicit Value(double x) : data(x) {} explicit Value(bool x) : data(x) {} explicit Value(std::string x) : data(x) {} + explicit Value(tensor::TensorMap&& x) : data(x) {} explicit Value(std::vector&& x) : data(x) {} explicit Value(std::vector&& x) : data(x) {} explicit Value(std::vector&& x) : data(x) {} diff --git a/cinnrt/kernel/basic_kernels.cc b/cinnrt/kernel/basic_kernels.cc index 30637bb993e15..b56e4c8a76d2f 100644 --- a/cinnrt/kernel/basic_kernels.cc +++ b/cinnrt/kernel/basic_kernels.cc @@ -1,9 +1,13 @@ #include "cinnrt/kernel/basic_kernels.h" #include +#include #include "cinnrt/host_context/kernel_registry.h" #include "cinnrt/host_context/kernel_utils.h" +#include "llvm/Support/raw_ostream.h" + +using cinnrt::host_context::Attribute; namespace cinnrt::kernel { @@ -32,9 +36,18 @@ void print(T a) { std::cout << a << std::endl; } +static std::string GetString(Attribute value) { return value.get(); } + +static void PrintString(const std::string &str) { + llvm::outs() << "string = " << str << '\n'; + llvm::outs().flush(); +} + void RegisterBasicKernels(host_context::KernelRegistry *registry) { RegisterIntBasicKernels(registry); RegisterFloatBasicKernels(registry); + registry->AddKernel("cinn.get_string", CINN_KERNEL(GetString)); + registry->AddKernel("cinn.print_string", CINN_KERNEL(PrintString)); } void RegisterIntBasicKernels(host_context::KernelRegistry *registry) { diff --git a/cinnrt/kernel/tensor_kernels.cc b/cinnrt/kernel/tensor_kernels.cc index 6c1867b5f7fe3..da68ebe0e06ef 100644 --- a/cinnrt/kernel/tensor_kernels.cc +++ b/cinnrt/kernel/tensor_kernels.cc @@ -3,10 +3,12 @@ #include #include +#include "cinnrt/common/global.h" #include "cinnrt/host_context/kernel_registry.h" #include "cinnrt/host_context/kernel_utils.h" #include "cinnrt/tensor/dense_host_tensor.h" #include "cinnrt/tensor/dense_tensor_view.h" +#include "cinnrt/tensor/tensor_map.h" #include "cinnrt/tensor/tensor_shape.h" namespace cinnrt::kernel { @@ -17,27 +19,36 @@ using namespace tensor; // NOLINT template DenseHostTensor CreateUninitTensor(Attribute> shape) { - const auto& shape_data = shape.get(); + const auto &shape_data = shape.get(); auto array = llvm::ArrayRef(shape_data.data(), shape_data.size()); auto type = GetDType(); return DenseHostTensor(TensorShape(array), type); } -void PrintTensor(const DenseHostTensor& tensor) { std::cout << tensor << std::endl; } +void PrintTensor(const DenseHostTensor &tensor) { std::cout << tensor << std::endl; } template -void FillTensorWithConstant(DenseHostTensor* tensor, Attribute v) { +void FillTensorWithConstant(DenseHostTensor *tensor, Attribute v) { MutableDTArrayView(tensor).Fill(v.get()); } +TensorMap LoadParams(const std::string &path) { return *(cinnrt::tensor::LoadParams(path)); } + +DenseHostTensor GetParam(TensorMap map, Attribute nameAttr) { + auto &name = nameAttr.get(); + return *(map[name]); +} + /// ===== Kernel end ==== -void RegisterTensorKernels(host_context::KernelRegistry* registry) { +void RegisterTensorKernels(host_context::KernelRegistry *registry) { registry->AddKernel("dt.create_uninit_tensor.f32", CINN_KERNEL(CreateUninitTensor)); registry->AddKernelAttrNameList("dt.create_uninit_tensor.f32", {"shape"}); registry->AddKernel("dt.print_tensor", CINN_KERNEL(PrintTensor)); registry->AddKernel("dt.fill_tensor_with_constant.f32", CINN_KERNEL(FillTensorWithConstant)); registry->AddKernel("dt.fill_tensor_with_constant.f64", CINN_KERNEL(FillTensorWithConstant)); + registry->AddKernel("dt.load_params", CINN_KERNEL(LoadParams)); + registry->AddKernel("dt.get_param", CINN_KERNEL(GetParam)); } } // namespace cinnrt::kernel diff --git a/cinnrt/tensor/CMakeLists.txt b/cinnrt/tensor/CMakeLists.txt index 88f54eada5588..9bcb4de3fb7a6 100644 --- a/cinnrt/tensor/CMakeLists.txt +++ b/cinnrt/tensor/CMakeLists.txt @@ -1,4 +1,5 @@ set(srcs + tensor_map.cc tensor_shape.cc tensor_metadata.cc dense_host_tensor.cc @@ -16,3 +17,12 @@ file(GLOB includes LIST_DIRECTORIES false RELATIVE ${CMAKE_SOURCE_DIR} *.h) foreach(header ${includes}) set(core_includes "${core_includes};${header}" CACHE INTERNAL "") endforeach() + +set(tensor_map_mlir "${CMAKE_SOURCE_DIR}/cinnrt/dialect/mlir_tests/tensor_map.mlir") +set(external_kernels_lib "${CMAKE_BINARY_DIR}/paddle/libexternal_kernels.so") +message(STATUS "tensor_map_mlir: ${tensor_map_mlir}") +message(STATUS "external_kernels_lib: ${external_kernels_lib}") +add_test( + NAME run_and_check_tensor_map + COMMAND sh -c "${CMAKE_BINARY_DIR}/cinnrt/host_context/cinn-exec -i ${tensor_map_mlir} --shared_libs=${external_kernels_lib} | FileCheck-10 ${tensor_map_mlir}" +) diff --git a/cinnrt/tensor/dense_host_tensor.cc b/cinnrt/tensor/dense_host_tensor.cc index 7fa7a088dcb71..c95c2403aca66 100644 --- a/cinnrt/tensor/dense_host_tensor.cc +++ b/cinnrt/tensor/dense_host_tensor.cc @@ -13,6 +13,15 @@ DenseHostTensor::DenseHostTensor(const TensorShape& shape, DType dtype) : HostTe } const TensorShape& DenseHostTensor::shape() const { return metadata().shape; } + +void DenseHostTensor::Init(const std::vector& shape, DType dtype) { + auto shape_array = llvm::ArrayRef(shape.data(), shape.size()); + auto metadata = TensorMetadata(dtype, shape_array); + setTensorMetadata(metadata); + buffer_.reset(new cinn::hlir::framework::Buffer(cinn::common::DefaultHostTarget())); + buffer_->ResizeLazy(dtype.GetHostSize() * metadata.shape.GetNumElements()); +} + const cinn::hlir::framework::Buffer* DenseHostTensor::buffer() const { return buffer_.get(); } template diff --git a/cinnrt/tensor/dense_host_tensor.h b/cinnrt/tensor/dense_host_tensor.h index bd1703b8f1dc1..3933764266deb 100644 --- a/cinnrt/tensor/dense_host_tensor.h +++ b/cinnrt/tensor/dense_host_tensor.h @@ -24,6 +24,8 @@ class Tensor { const TensorMetadata& metadata() const { return metadata_; } protected: + Tensor() = default; + void setTensorMetadata(TensorMetadata& metadata) { metadata_ = metadata; } explicit Tensor(const TensorMetadata& metadata) : metadata_(metadata) {} explicit Tensor(TensorMetadata&& metadata) : metadata_(std::move(metadata)) {} @@ -36,6 +38,7 @@ class HostTensor : public Tensor { bool IsHostTensor() const override { return true; } protected: + HostTensor() = default; explicit HostTensor(const TensorMetadata& metadata) : Tensor(metadata) {} explicit HostTensor(TensorMetadata&& metadata) : Tensor(std::move(metadata)) {} }; @@ -46,8 +49,10 @@ class HostTensor : public Tensor { */ class DenseHostTensor : public HostTensor { public: + DenseHostTensor() = default; DenseHostTensor(const TensorShape& shape, DType dtype); + void Init(const std::vector& shape, DType dtype); const TensorShape& shape() const; const cinn::hlir::framework::Buffer* buffer() const; diff --git a/cinnrt/tensor/tensor_map.cc b/cinnrt/tensor/tensor_map.cc new file mode 100644 index 0000000000000..02936411ac152 --- /dev/null +++ b/cinnrt/tensor/tensor_map.cc @@ -0,0 +1,79 @@ +#include "cinnrt/tensor/tensor_map.h" + +#include +#include + +#include "cinn/frontend/paddle/compatible_pb.h" +#include "cinn/frontend/paddle/model_parser.h" + +using Scope = cinn::hlir::framework::Scope; +using ProgramDesc = cinn::frontend::paddle::cpp::ProgramDesc; +using Target = cinn::common::Target; + +namespace cinnrt { +namespace tensor { + +cinnrt::DType CinnType2DType_(cinn::common::Type type) { + if (type.is_bool()) return GetDType(); + if (type.is_int(8)) return GetDType(); + if (type.is_int(16)) return GetDType(); + if (type.is_int(32)) return GetDType(); + if (type.is_int(64)) return GetDType(); + if (type.is_uint(8)) return GetDType(); + if (type.is_uint(16)) return GetDType(); + if (type.is_uint(32)) return GetDType(); + if (type.is_uint(64)) return GetDType(); + if (type.is_float(32)) return GetDType(); + if (type.is_float(64)) return GetDType(); + if (type.is_string()) return GetDType(); + return cinnrt::DType(cinnrt::DType::Kind::Unk); +} + +TensorMap *LoadParams(const std::string &path) { + std::cout << "loading params from: " << path << std::endl; + TensorMap *map = new TensorMap(); + Scope scope; + ProgramDesc cpp_prog; + const Target &target = cinn::common::DefaultHostTarget(); + + std::string model_path = path + "/__model__"; + paddle::framework::proto::ProgramDesc pb_proto_prog = *cinn::frontend::paddle::LoadProgram(model_path); + // cinn::frontend::paddle::pb::ProgramDesc pb_prog_desc(&pb_proto_prog); + // cinn::frontend::paddle::TransformProgramDescAnyToCpp(pb_prog_desc, cpp_prog); + auto main_block = pb_proto_prog.blocks(0); + for (auto &var : main_block.vars()) { + if (var.name() == "feed" || var.name() == "fetch" || !var.persistable()) continue; + std::string param_path = path + "/" + var.name(); + std::ifstream param_file(param_path, std::ios::binary); + switch (var.type().type()) { + case paddle::framework::proto::VarType_Type_LOD_TENSOR: { + using CinnTensor = cinn::hlir::framework::Tensor; + using namespace cinn::utils; + auto var_name = TransValidVarName(var.name()); + // std::cout << "var name: " << var.name() << " " << var_name << std::endl; + auto *_var = scope.Var(var_name); + cinn::frontend::paddle::LoadLoDTensor(param_file, _var, target); + auto tensor = scope.GetTensor(var_name); + auto *src_data = tensor->data(); + auto &cinn_type = tensor->type(); + std::vector shape; + for (int dim : tensor->shape().data()) shape.push_back(dim); + auto shape_array = llvm::ArrayRef(shape.data(), shape.size()); + auto dtype = CinnType2DType_(cinn_type); + auto *dht = new DenseHostTensor(TensorShape(shape_array), dtype); + int num_elements = dht->shape().GetNumElements(); + auto *dst_data = reinterpret_cast(dht->raw_data()); + for (int i = 0; i < num_elements; ++i) dst_data[i] = src_data[i]; + (*map)[var.name()] = dht; + break; + } + default: + std::cout << "unknown weight type" << std::endl; + break; + } + } + return map; +} + +} // namespace tensor +} // namespace cinnrt diff --git a/cinnrt/tensor/tensor_map.h b/cinnrt/tensor/tensor_map.h new file mode 100644 index 0000000000000..a18967eaa98ed --- /dev/null +++ b/cinnrt/tensor/tensor_map.h @@ -0,0 +1,12 @@ +#include +#include + +#include "cinnrt/tensor/dense_host_tensor.h" + +namespace cinnrt { +namespace tensor { // namespace tensor +using TensorMap = std::unordered_map; + +TensorMap* LoadParams(const std::string& path); +} // namespace tensor +} // namespace cinnrt