diff --git a/CMakeLists.txt b/CMakeLists.txt
index 4cb661faf..cad0bb5bc 100755
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -30,7 +30,7 @@ find_package(Threads REQUIRED)
find_package(CUDA QUIET)
include(simd)
-
+# SET(CMAKE_BUILD_TYPE "Debug")
# CMAKE_BUILD_TYPE
if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE "RelWithDebInfo" CACHE STRING
diff --git a/README.md b/README.md
index ab6b1c014..6c6d0924b 100644
--- a/README.md
+++ b/README.md
@@ -175,9 +175,12 @@ python3 -m paddle_serving_server.serve --model uci_housing_model --thread 10 --p
| Argument | Type | Default | Description |
| ---------------------------------------------- | ---- | ------- | ----------------------------------------------------- |
-| `thread` | int | `4` | Concurrency of current service |
+| `thread` | int | `2` | Number of brpc service thread |
+| `op_num` | int[]| `0` | Thread Number for each model in asynchronous mode |
+| `op_max_batch` | int[]| `0` | Batch Number for each model in asynchronous mode |
+| `gpu_ids` | str[]| `"-1"` | Gpu card id for each model |
| `port` | int | `9292` | Exposed port of current service to users |
-| `model` | str | `""` | Path of paddle model directory to be served |
+| `model` | str[]| `""` | Path of paddle model directory to be served |
| `mem_optim_off` | - | - | Disable memory / graphic memory optimization |
| `ir_optim` | bool | False | Enable analysis and optimization of calculation graph |
| `use_mkl` (Only for cpu version) | - | - | Run inference with MKL |
@@ -186,7 +189,24 @@ python3 -m paddle_serving_server.serve --model uci_housing_model --thread 10 --p
| `use_xpu` | - | - | Run PaddleLite inference with Baidu Kunlun XPU |
| `precision` | str | FP32 | Precision Mode, support FP32, FP16, INT8 |
| `use_calib` | bool | False | Only for deployment with TensorRT |
-
+| `gpu_multi_stream` | bool | False | EnableGpuMultiStream to get larger QPS |
+
+#### Description of asynchronous model
+ Asynchronous mode is suitable for 1. When the number of requests is very large, 2. When multiple models are concatenated and you want to specify the concurrency number of each model.
+ Asynchronous mode helps to improve the throughput (QPS) of service, but for a single request, the delay will increase slightly.
+ In asynchronous mode, each model will start n threads of the number you specify, and each thread contains a model instance. In other words, each model is equivalent to a thread pool containing N threads, and the task is taken from the task queue of the thread pool to execute.
+ In asynchronous mode, each RPC server thread is only responsible for putting the request into the task queue of the model thread pool. After the task is executed, the completed task is removed from the task queue.
+ In the above table, the number of RPC server threads is specified by --thread, and the default value is 2.
+ --op_num specifies the number of threads in the thread pool of each model. The default value is 0, indicating that asynchronous mode is not used.
+ --op_max_batch specifies the number of batches for each model. The default value is 32. It takes effect when --op_num is not 0.
+#### When you want a model to use multiple GPU cards.
+python3 -m paddle_serving_server.serve --model uci_housing_model --thread 10 --port 9292 --gpu_ids 0,1,2
+#### When you want 2 models.
+python3 -m paddle_serving_server.serve --model uci_housing_model_1 uci_housing_model_2 --thread 10 --port 9292
+#### When you want 2 models, and want each of them use multiple GPU cards.
+python3 -m paddle_serving_server.serve --model uci_housing_model_1 uci_housing_model_2 --thread 10 --port 9292 --gpu_ids 0,1 1,2
+#### When a service contains two models, and each model needs to specify multiple GPU cards, and needs asynchronous mode, each model specifies different concurrency number.
+python3 -m paddle_serving_server.serve --model uci_housing_model_1 uci_housing_model_2 --thread 10 --port 9292 --gpu_ids 0,1 1,2 --op_num 4 8
```python
diff --git a/README_CN.md b/README_CN.md
index d728071db..a1bb9f9e7 100644
--- a/README_CN.md
+++ b/README_CN.md
@@ -172,19 +172,40 @@ python3 -m paddle_serving_server.serve --model uci_housing_model --thread 10 --p
```
-| Argument | Type | Default | Description |
-| ---------------------------------------------- | ---- | ------- | ------------------------------------------------------ |
-| `thread` | int | `4` | Concurrency of current service |
-| `port` | int | `9292` | Exposed port of current service to users |
-| `name` | str | `""` | Service name, can be used to generate HTTP request url |
-| `model` | str | `""` | Path of paddle model directory to be served |
-| `mem_optim_off` | - | - | Disable memory optimization |
-| `ir_optim` | bool | False | Enable analysis and optimization of calculation graph |
-| `use_mkl` (Only for cpu version) | - | - | Run inference with MKL |
-| `use_trt` (Only for Cuda>=10.1 version) | - | - | Run inference with TensorRT |
-| `use_lite` (Only for Intel x86 CPU or ARM CPU) | - | - | Run PaddleLite inference |
-| `use_xpu` | - | - | Run PaddleLite inference with Baidu Kunlun XPU |
-| `precision` | str | FP32 | Precision Mode, support FP32, FP16, INT8 |
+| Argument | Type | Default | Description |
+| ---------------------------------------------- | ---- | ------- | ----------------------------------------------------- |
+| `thread` | int | `2` | Number of brpc service thread |
+| `op_num` | int[]| `0` | Thread Number for each model in asynchronous mode |
+| `op_max_batch` | int[]| `32` | Batch Number for each model in asynchronous mode |
+| `gpu_ids` | str[]| `"-1"` | Gpu card id for each model |
+| `port` | int | `9292` | Exposed port of current service to users |
+| `model` | str[]| `""` | Path of paddle model directory to be served |
+| `mem_optim_off` | - | - | Disable memory / graphic memory optimization |
+| `ir_optim` | bool | False | Enable analysis and optimization of calculation graph |
+| `use_mkl` (Only for cpu version) | - | - | Run inference with MKL |
+| `use_trt` (Only for trt version) | - | - | Run inference with TensorRT |
+| `use_lite` (Only for Intel x86 CPU or ARM CPU) | - | - | Run PaddleLite inference |
+| `use_xpu` | - | - | Run PaddleLite inference with Baidu Kunlun XPU |
+| `precision` | str | FP32 | Precision Mode, support FP32, FP16, INT8 |
+| `use_calib` | bool | False | Only for deployment with TensorRT |
+| `gpu_multi_stream` | bool | False | EnableGpuMultiStream to get larger QPS |
+
+#### 异步模型的说明
+ 异步模式适用于1、请求数量非常大的情况,2、多模型串联,想要分别指定每个模型的并发数的情况。
+ 异步模式有助于提高Service服务的吞吐(QPS),但对于单次请求而言,时延会有少量增加。
+ 异步模式中,每个模型会启动您指定个数的N个线程,每个线程中包含一个模型实例,换句话说每个模型相当于包含N个线程的线程池,从线程池的任务队列中取任务来执行。
+ 异步模式中,各个RPC Server的线程只负责将Request请求放入模型线程池的任务队列中,等任务被执行完毕后,再从任务队列中取出已完成的任务。
+ 上表中通过 --thread 10 指定的是RPC Server的线程数量,默认值为2,--op_num 指定的是各个模型的线程池中线程数N,默认值为0,表示不使用异步模式。
+ --op_max_batch 指定的各个模型的batch数量,默认值为32,该参数只有当--op_num不为0时才生效。
+
+#### 当您的某个模型想使用多张GPU卡部署时.
+python3 -m paddle_serving_server.serve --model uci_housing_model --thread 10 --port 9292 --gpu_ids 0,1,2
+#### 当您的一个服务包含两个模型部署时.
+python3 -m paddle_serving_server.serve --model uci_housing_model_1 uci_housing_model_2 --thread 10 --port 9292
+#### 当您的一个服务包含两个模型,且每个模型都需要指定多张GPU卡部署时.
+python3 -m paddle_serving_server.serve --model uci_housing_model_1 uci_housing_model_2 --thread 10 --port 9292 --gpu_ids 0,1 1,2
+#### 当您的一个服务包含两个模型,且每个模型都需要指定多张GPU卡,且需要异步模式每个模型指定不同的并发数时.
+python3 -m paddle_serving_server.serve --model uci_housing_model_1 uci_housing_model_2 --thread 10 --port 9292 --gpu_ids 0,1 1,2 --op_num 4 8
diff --git a/core/configure/proto/server_configure.proto b/core/configure/proto/server_configure.proto
index 24fb62806..5cace0642 100755
--- a/core/configure/proto/server_configure.proto
+++ b/core/configure/proto/server_configure.proto
@@ -21,11 +21,12 @@ message EngineDesc {
required string reloadable_meta = 3;
required string reloadable_type = 4;
required string model_dir = 5;
- required int32 runtime_thread_num = 6;
- required int32 batch_infer_size = 7;
- required int32 enable_batch_align = 8;
- optional string version_file = 9;
- optional string version_type = 10;
+ repeated int32 gpu_ids = 6;
+ required int32 runtime_thread_num = 7;
+ required int32 batch_infer_size = 8;
+ required int32 enable_batch_align = 9;
+ optional string version_file = 10;
+ optional string version_type = 11;
/*
* Sparse Parameter Service type. Valid types are:
@@ -38,16 +39,17 @@ message EngineDesc {
LOCAL = 1;
REMOTE = 2;
}
- optional SparseParamServiceType sparse_param_service_type = 11;
- optional string sparse_param_service_table_name = 12;
- optional bool enable_memory_optimization = 13;
- optional bool enable_ir_optimization = 14;
- optional bool use_trt = 15;
- optional bool use_lite = 16;
- optional bool use_xpu = 17;
- optional bool use_gpu = 18;
- optional bool combined_model = 19;
- optional bool encrypted_model = 20;
+ optional SparseParamServiceType sparse_param_service_type = 12;
+ optional string sparse_param_service_table_name = 13;
+ optional bool enable_memory_optimization = 14;
+ optional bool enable_ir_optimization = 15;
+ optional bool use_trt = 16;
+ optional bool use_lite = 17;
+ optional bool use_xpu = 18;
+ optional bool use_gpu = 19;
+ optional bool combined_model = 20;
+ optional bool encrypted_model = 21;
+ optional bool gpu_multi_stream = 22;
};
// model_toolkit conf
diff --git a/core/general-client/src/general_model.cpp b/core/general-client/src/general_model.cpp
old mode 100644
new mode 100755
index 0ade573de..d3dd5d9f7
--- a/core/general-client/src/general_model.cpp
+++ b/core/general-client/src/general_model.cpp
@@ -166,6 +166,8 @@ int PredictorClient::numpy_predict(
batch_size = batch_size > string_feed_batch.size() ? batch_size
: string_feed_batch.size();
VLOG(2) << "batch size: " << batch_size;
+ // batch_size must be 1, cause batch is already in Tensor.
+ // I suggest to remove the outside vector<>.
predict_res_batch.clear();
Timer timeline;
int64_t preprocess_start = timeline.TimeStampUS();
@@ -188,6 +190,8 @@ int PredictorClient::numpy_predict(
}
int vec_idx = 0;
+ // batch_size can only be 1, cause batch is already in Tensor.
+ // if batch_size is not 1, error will occur in C++ part.
for (int bi = 0; bi < batch_size; bi++) {
VLOG(2) << "prepare batch " << bi;
std::vector tensor_vec;
diff --git a/core/general-server/op/general_reader_op.cpp b/core/general-server/op/general_reader_op.cpp
index 3e1091dd8..8e1904f81 100644
--- a/core/general-server/op/general_reader_op.cpp
+++ b/core/general-server/op/general_reader_op.cpp
@@ -93,6 +93,9 @@ int GeneralReaderOp::inference() {
res->SetLogId(log_id);
Timer timeline;
int64_t start = timeline.TimeStampUS();
+ // only get insts(0), cause batch is already in Tensor.
+ // req can only include 1 inst.
+ // var_num means the number of feed_var.
int var_num = req->insts(0).tensor_array_size();
VLOG(2) << "(logid=" << log_id << ") var num: " << var_num
@@ -178,7 +181,10 @@ int GeneralReaderOp::inference() {
VLOG(2) << "(logid=" << log_id << ") tensor size for var[" << i
<< "]: " << data_len;
databuf_size = data_len * elem_size;
- out->at(i).data.Resize(databuf_size);
+ void *databuf_char = MempoolWrapper::instance().malloc(databuf_size);
+ paddle::PaddleBuf paddleBuf(databuf_char, databuf_size);
+ out->at(i).data = paddleBuf;
+ // out->at(i).data.Resize(databuf_size);
if (out->at(i).lod.size() > 0) {
VLOG(2) << "(logid=" << log_id << ") var[" << i
<< "] has lod_tensor and len=" << out->at(i).lod[0].back();
diff --git a/core/predictor/framework/bsf-inl-tensor.h b/core/predictor/framework/bsf-inl-tensor.h
deleted file mode 100644
index b7c725b44..000000000
--- a/core/predictor/framework/bsf-inl-tensor.h
+++ /dev/null
@@ -1,373 +0,0 @@
-// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#pragma once
-
-#ifdef BCLOUD
-#include
-#else
-#include
-#endif
-
-#include
-#include
-#include
-#include
-#include "core/predictor/common/inner_common.h"
-#include "core/predictor/framework/infer_data.h"
-#include "core/predictor/framework/memory.h"
-
-#include
-
-namespace im {
-namespace bsf {
-
-template <>
-struct Task {
- typedef Task
- TaskT;
- typedef baidu::paddle_serving::predictor::Tensor Tensor;
- typedef baidu::paddle_serving::predictor::Tensor InType;
- typedef baidu::paddle_serving::predictor::Tensor OutType;
- typedef baidu::paddle_serving::predictor::BatchTensor BatchTensor;
- typedef baidu::paddle_serving::predictor::BatchTensor InArrayT;
- typedef baidu::paddle_serving::predictor::BatchTensor OutArrayT;
-
- struct Segment {
- Segment(void* p, size_t b, size_t s) : ptr(p), begin(b), size(s) {}
- void* ptr;
- size_t begin;
- size_t size;
- };
-
- int read_fd;
- int write_fd;
-
- pid_t owner_tid;
-
- const InArrayT* in;
- OutArrayT* out;
-
- size_t rem;
- size_t size;
-
- butil::atomic index;
-
- const BatchTensor* get(bool is_in) const {
- if (is_in) {
- return in;
- } else {
- return out;
- }
- }
-
- BatchTensor* get(bool is_in) {
- if (is_in) {
- return const_cast(in);
- } else {
- return out;
- }
- }
-
- Task() {
- read_fd = -1;
- write_fd = -1;
- owner_tid = -1;
- in = NULL;
- out = NULL;
- rem = -1;
- size = -1;
- index.store(0, butil::memory_order_relaxed);
- }
-};
-
-template <>
-class BatchTasks> {
- public:
- typedef baidu::paddle_serving::predictor::Tensor Tensor;
- typedef baidu::paddle_serving::predictor::Tensor InType;
- typedef baidu::paddle_serving::predictor::Tensor OutType;
- typedef baidu::paddle_serving::predictor::DataBuf DataBuf;
- typedef baidu::paddle_serving::predictor::MempoolWrapper MempoolWrapper;
-
- typedef Task
- TaskT;
- typedef TaskMeta TaskMetaT;
- typedef TaskT::InArrayT InArrayT;
- typedef TaskT::OutArrayT OutArrayT;
-
- explicit BatchTasks(size_t batch_size, bool batch_align = false)
- : _batch_size(batch_size),
- _rem_size(batch_size),
- _batch_align(batch_align) {
- _batch_in.clear();
- _batch_out.clear();
- _tasks.clear();
- }
-
- ~BatchTasks() {
- _batch_in.clear();
- _batch_out.clear();
- _tasks.clear();
- }
-
- static bool check_valid(const InArrayT& in,
- OutArrayT& out, // NOLINT
- bool align) { // NOLINT
- if (align) {
- if (out.count() <= 0 || out.size() <= 0) {
- LOG(ERROR) << "Out tensor is empty, when aligned";
- return false;
- }
-
- if (out.size() != in.size()) {
- LOG(ERROR) << "In/Out tensor size not eq: " << out.size()
- << "!=" << in.size();
- return false;
- }
-
- for (size_t fi = 0, shape0 = 0; fi < out.count(); ++fi) {
- if (!out[fi].valid()) {
- LOG(ERROR) << "Out[" << fi << "] tensor not valid";
- return false;
- }
-
- if (out.size() != out[fi].shape0()) {
- LOG(ERROR) << "Shape0 not consistency, " << out.size()
- << "!=" << out[fi].shape0() << ", " << fi;
- return false;
- }
- }
- }
-
- return true;
- }
-
- size_t append_task(TaskT* task) {
- size_t add = std::min(task->rem, _rem_size);
- if (!_batch_align) {
- add = task->rem;
- }
- TaskMetaT tm(task, task->in->size() - task->rem, add);
- _tasks.push_back(tm);
-
- task->rem -= add;
- _rem_size -= add;
- return _rem_size;
- }
-
- void merge_tasks() {
- merge_input();
- merge_output();
- }
-
- void merge_input() {
- if (_tasks.size() <= 0 || _tasks[0].task->in->count() <= 0) {
- return;
- }
-
- if (_tasks.size() == 1 && !_batch_align) {
- TaskMetaT& tm = _tasks[0];
- _batch_in = *(tm.task->in);
- return;
- }
-
- merge_tensor(true);
- }
-
- void merge_output() {
- if (_batch_align) {
- if (_tasks.size() <= 0 || _tasks[0].task->out->count() <= 0) {
- return;
- }
- }
-
- if (_tasks.size() <= 0 || _tasks[0].task->out->count() <= 0) {
- return;
- }
-
- TaskMetaT& tm = _tasks[0];
- if (_tasks.size() == 1 && !_batch_align) {
- _batch_out = *(tm.task->out);
- return;
- }
-
- if (tm.task->out->size() <= 0) {
- // shape is empty
- _batch_out = *(tm.task->out);
- return;
- }
-
- if ((*tm.task->out)[0].data.data() == 0 ||
- (*tm.task->out)[0].data.size() == 0) {
- _batch_out = *(tm.task->out);
- return;
- }
-
- merge_tensor(false);
- }
-
- void merge_tensor(bool is_in) {
- // accumulate batch size from fetched tasks
- size_t batch_size = 0;
- for (size_t ti = 0; ti < _tasks.size(); ++ti) {
- TaskMetaT& tm = _tasks[ti];
- size_t add = tm.end - tm.begin;
- batch_size += add;
- }
-
- // merge all instanses in each tensor data
- size_t tensor_count = _tasks[0].task->get(is_in)->count();
- for (size_t fi = 0; fi < tensor_count; ++fi) {
- const Tensor& head = (*(_tasks[0].task->get(is_in)))[fi];
- Tensor batch_tensor;
- batch_tensor.name = head.name;
- batch_tensor.type = head.type;
- batch_tensor.shape.push_back(batch_size);
-
- size_t ins_ele_count = 1;
- for (size_t si = 1; si < head.shape.size(); ++si) {
- batch_tensor.shape.push_back(head.shape[si]);
- ins_ele_count *= head.shape[si];
- }
-
- size_t tensor_ele_count = ins_ele_count * batch_size;
- size_t ins_byte = ins_ele_count * head.ele_byte();
-
- size_t tensor_byte = tensor_ele_count * head.ele_byte();
- void* data_buf = MempoolWrapper::instance().malloc(tensor_byte);
- if (!data_buf) {
- LOG(ERROR) << "Malloc failed, size: " << tensor_byte;
- return;
- }
-
- size_t data_byte = 0;
- for (size_t ti = 0; ti < _tasks.size(); ++ti) {
- TaskMetaT& tm = _tasks[ti];
- size_t acc_byte = ins_byte * (tm.end - tm.begin);
- if (data_byte + acc_byte > tensor_byte) {
- LOG(ERROR) << "Invalid bytes: " << data_byte << " + " << acc_byte
- << " >= " << tensor_byte;
- return;
- }
-
- const Tensor& tensor = (*(tm.task->get(is_in)))[fi];
- memcpy(
- reinterpret_cast(data_buf) + data_byte,
- reinterpret_cast(tensor.data.data()) + tm.begin * ins_byte,
- acc_byte);
- data_byte += acc_byte;
- }
-
- if (data_byte != tensor_byte) {
- LOG(ERROR) << "Invalid tensor byte: " << data_byte
- << " != " << tensor_byte;
- return;
- }
-
- batch_tensor.data =
- DataBuf(reinterpret_cast(data_buf), tensor_byte);
- if (is_in) {
- _batch_in.push_back(batch_tensor);
- } else {
- _batch_out.push_back(batch_tensor);
- }
- }
-
- LOG(INFO) << "merge input(" << is_in << ") samples: " << batch_size
- << " from " << _tasks.size() << " pvs";
- }
-
- void notify_tasks() {
- if (_batch_out.size() != _batch_in.size()) {
- LOG(ERROR) << "batch size not consistency: " << _batch_out.size()
- << " != " << _batch_in.size();
- return;
- }
-
- size_t tensor_count = _batch_out.count();
- size_t batch_size = _batch_out.size();
- for (size_t fi = 0; fi < tensor_count; ++fi) {
- const Tensor& tensor = _batch_out[fi];
- size_t ins_byte = tensor.ele_byte();
- for (size_t si = 1; si < tensor.shape.size(); ++si) {
- ins_byte *= tensor.shape[si];
- }
-
- for (size_t ti = 0, bi = 0, add = 0; ti < _tasks.size();
- ++ti, bi += add) {
- OutArrayT* dst = _tasks[ti].task->out;
- add = _tasks[ti].end - _tasks[ti].begin;
- size_t offset_src = ins_byte * bi;
- size_t add_byte = add * ins_byte;
-
- if (_batch_align) { // merge all batchs
- size_t offset_dst = ins_byte * _tasks[ti].begin;
- void* ptr = const_cast((*dst)[fi].data.data());
- memcpy(
- reinterpret_cast(ptr) + offset_dst,
- reinterpret_cast(_batch_out[fi].data.data()) + offset_src,
- add_byte);
- } else { // overwrite
- if (dst->count() <= 0) {
- dst->push_back(_batch_out[fi]);
- } else {
- (*dst)[fi] = _batch_out[fi];
- }
-
- (*dst)[fi].shape[0] = add;
- (*dst)[fi].data = DataBuf(
- reinterpret_cast(_batch_out[fi].data.data()) + offset_src,
- add_byte);
- }
- }
- }
-
- for (size_t ti = 0; ti < _tasks.size(); ++ti) {
- TaskT* task = _tasks[ti].task;
- size_t begin = _tasks[ti].begin;
- size_t end = _tasks[ti].end;
- size_t add = end - begin;
-
- size_t index = task->index.fetch_add(add);
- if ((index + add) >= task->in->size()) {
- char c = 0;
- while (write(task->write_fd, &c, 1) != 1 && errno == EINTR) {
- }
- butil::return_object(task);
- }
- }
- }
-
- const typename TaskT::InArrayT& in() const { return _batch_in; }
-
- typename TaskT::OutArrayT& out() { return _batch_out; }
-
- size_t task_size() { return _tasks.size(); }
-
- private:
- std::vector _tasks;
- InArrayT _batch_in;
- OutArrayT _batch_out;
- size_t _batch_size;
- size_t _rem_size;
- bool _batch_align;
-};
-
-} // namespace bsf
-} // namespace im
diff --git a/core/predictor/framework/bsf-inl.h b/core/predictor/framework/bsf-inl.h
index 1193ce486..4dc0baa3a 100644
--- a/core/predictor/framework/bsf-inl.h
+++ b/core/predictor/framework/bsf-inl.h
@@ -24,6 +24,7 @@
#include
#include "core/predictor/common/inner_common.h"
+#include "core/predictor/framework/memory.h"
namespace im {
namespace bsf {
@@ -35,7 +36,7 @@ void* TaskExecutor::thread_entry(void* args) {
static_cast*>(context->executor);
executor->work(context);
- return NULL;
+ return nullptr;
}
template
@@ -125,18 +126,21 @@ void TaskExecutor::stop() {
}
template
-TaskHandler TaskExecutor::schedule(const InArrayT& in,
- OutArrayT& out) { // NOLINT
+TaskHandler TaskExecutor::schedule(
+ const void* inVectorT_ptr,
+ void* outVectorT_ptr) { // NOLINT
TaskT* task = butil::get_object();
if (!task) {
LOG(ERROR) << "Failed get TaskT from object pool";
return TaskHandler::valid_handle();
}
+ /*
if (!BatchTasks::check_valid(in, out, _batch_align)) {
LOG(ERROR) << "Invalid input & output";
return TaskHandler::valid_handle();
}
+ */
int fds[2];
int rc = pipe(fds);
@@ -150,10 +154,9 @@ TaskHandler TaskExecutor::schedule(const InArrayT& in,
task->write_fd = fds[1];
task->owner_tid = ::syscall(SYS_gettid);
- task->in = ∈
- task->out = &out;
- task->rem = in.size();
- task->size = in.size();
+ task->inVectorT_ptr = (const InVectorT*)inVectorT_ptr;
+ task->outVectorT_ptr = (OutVectorT*)outVectorT_ptr;
+ task->rem = task->batch_size();
task->index.store(0, butil::memory_order_relaxed);
AutoMutex lock(_mut);
@@ -163,8 +166,13 @@ TaskHandler TaskExecutor::schedule(const InArrayT& in,
return TaskHandler(*task);
}
+// this function is accessed by multi thread.
+// so AutoMutex at first.
+// so batch.append_task is thread safe.
+// you dont need to add extra lock in append_task()
template
-bool TaskExecutor::fetch_batch(BatchTasks& batch) { // NOLINT
+bool TaskExecutor::move_task_to_batch(
+ BatchTasks& batch) { // NOLINT
AutoMutex lock(_mut);
while (_task_queue.empty()) {
THREAD_COND_WAIT(&_cond, &_mut);
@@ -187,8 +195,30 @@ bool TaskExecutor::fetch_batch(BatchTasks& batch) { // NOLINT
return true;
}
+// this function is accessed by multi thread.
+// move_task_to_batch have add lock inside the function.
+// Packaging 1 TaskT as 1 or Several TaskMeta.
+// TaskT is from the SingleTon TaskExecutor`s _task_queue
+// although TaskMeta is a local variable, but several TaskMeta may points to
+// the same TaskT which is get from the SingleTon TaskExecutor`s _task_queue.
+// put TaskMeta to the local variable BatchTasks batch.
+
+// batch.merge_tasks() and batch.notify_tasks() has no lock.
+// BatchTasks batch itself is a local variable, it`s thread safe.
+// If batch.merge_tasks() and batch.notify_tasks() do something to TaskMeta
+// you need to pay attention to that.
+// Multi-Thread deal with different TaskMeta(cause it`s created as local
+// variable)
+// But different TaskMeta may points to the same TaskT
+// which is get from the SingleTon TaskExecutor`s _task_queue.
+
template
int TaskExecutor::work(ThreadContext* context) {
+ if (MempoolWrapper::instance().thread_initialize() != 0) {
+ LOG(ERROR) << "Failed thread initialize mempool";
+ return -1;
+ }
+
if (_thread_init_fn != NULL) {
if (_thread_init_fn(context->user_thread_context) != 0) {
LOG(ERROR) << "execute thread init thunk failed, BSF thread will exit";
@@ -207,10 +237,15 @@ int TaskExecutor::work(ThreadContext* context) {
}
}
+ if (MempoolWrapper::instance().thread_clear() != 0) {
+ LOG(ERROR) << "Failed thread clear mempool";
+ return -1;
+ }
+
BatchTasks batch(_batch_size, _batch_align);
- if (fetch_batch(batch)) {
+ if (move_task_to_batch(batch)) {
batch.merge_tasks();
- _fn(batch.in(), batch.out());
+ _fn(&batch.in(), &batch.out());
batch.notify_tasks();
}
}
@@ -219,9 +254,10 @@ int TaskExecutor::work(ThreadContext* context) {
}
template
-bool TaskManager::schedule(const InArrayT& in,
- OutArrayT& out) { // NOLINT
- TaskHandler handler = _executor.schedule(in, out);
+bool TaskManager::schedule(const void* in,
+ void* out) { // NOLINT
+ TaskHandler handler =
+ TaskExecutorVector::instance()[_model_index].schedule(in, out);
if (handler.valid()) {
_task_owned = handler;
diff --git a/core/predictor/framework/bsf.h b/core/predictor/framework/bsf.h
index 36a00c381..75cce3002 100644
--- a/core/predictor/framework/bsf.h
+++ b/core/predictor/framework/bsf.h
@@ -16,7 +16,7 @@
#include
#include
-#include
+#include
#include
#ifdef BCLOUD
@@ -29,46 +29,186 @@
#include "boost/function.hpp"
+#include "core/predictor/framework/memory.h"
+#include "paddle_inference_api.h"
+
namespace im {
namespace bsf {
static const size_t DEFAULT_BATCH_SIZE = 100;
+// InItemT is paddle::PaddleTensor
+// InVectorT std::vector
+// InVectorT means different feedvar, but not batch.
+// Batch is already inside the paddle::PaddleTensor.
+
+// size_t `rem` records how many batch have not been put in BatchTasks.
+// `rem` don`t need to be atomic, cause the operation `put` is synchronous.
+// actually, the reason is that lock have been added outside the operation
+// `put`.
+
+// size_t `index` records how many batch have been processing completed.
+// `index` need to be atomic, cause the operation 'notify' is asynchronous.
template
struct Task {
- typedef std::vector InArrayT;
- typedef std::vector OutArrayT;
+ typedef std::vector InVectorT;
+ typedef std::vector OutVectorT;
typedef InItemT InType;
typedef OutItemT OutType;
typedef Task TaskT;
+ typedef std::vector ShapeVector;
+ typedef std::vector VectorOfShapeVector;
int read_fd;
int write_fd;
-
pid_t owner_tid;
-
- const InArrayT* in;
- OutArrayT* out;
-
+ const InVectorT* inVectorT_ptr;
+ OutVectorT* outVectorT_ptr;
size_t rem;
- size_t size;
-
- size_t batch_size() { return in->size(); }
-
butil::atomic index;
Task() {
read_fd = -1;
write_fd = -1;
owner_tid = -1;
- in = NULL;
- out = NULL;
+ inVectorT_ptr = NULL;
+ outVectorT_ptr = NULL;
rem = -1;
- size = -1;
index.store(0, butil::memory_order_relaxed);
}
+
+ bool check_feedvar_valid(int feedvar_index) {
+ if (feedvar_index < 0 || inVectorT_ptr->size() <= feedvar_index) {
+ LOG(ERROR) << "feedvar doesnt exsit or feedvar_index error";
+ return 0;
+ }
+
+ if ((*inVectorT_ptr)[feedvar_index].shape.size() <= 0) {
+ LOG(ERROR) << "feedvar[" << feedvar_index << "].shape.size()<=0,error";
+ return 0;
+ }
+
+ return 1;
+ }
+
+ // Now, it simply assume that the first dimension of data is batch.
+ // so the batch is PaddleTensor.shape[0]
+
+ // If batch information is added into feedvar.prototxt.
+ // we can get the information from the feedvar.prototxt instead of assume.
+ size_t feedvar_batch_size(int feedvar_index) {
+ if (!check_feedvar_valid(feedvar_index)) {
+ return 0;
+ }
+
+ return (*inVectorT_ptr)[feedvar_index].shape[0];
+ }
+
+ size_t feedvar_element_bytesize(int feedvar_index) {
+ if (!check_feedvar_valid(feedvar_index)) {
+ return 0;
+ }
+ int dtype = (*inVectorT_ptr)[feedvar_index].dtype;
+ if (dtype == paddle::PaddleDType::INT64) {
+ return sizeof(int64_t);
+ }
+ if (dtype == paddle::PaddleDType::FLOAT32) {
+ return sizeof(float);
+ }
+ if (dtype == paddle::PaddleDType::INT32) {
+ return sizeof(int32_t);
+ }
+ if (dtype == paddle::PaddleDType::UINT8) {
+ return sizeof(char);
+ }
+ return 0;
+ }
+
+ // Now, the implementation of this function is based on assumption
+ // that shape [0] = batch_size.
+ size_t feedvar_element_num(int feedvar_index) {
+ if (!check_feedvar_valid(feedvar_index)) {
+ return 0;
+ }
+ int element_num = 1;
+ if ((*inVectorT_ptr)[feedvar_index].shape.size() == 1) {
+ // cause shape[0] is batch_size.
+ // [10,1] = [10], so if shape[1] doesn`t exist.
+ // should return 1.
+ return 1;
+ }
+ // start from shape[1], cause shape[0] = batch_size.
+ for (int i = 1; i < (*inVectorT_ptr)[feedvar_index].shape.size(); ++i) {
+ element_num *= (*inVectorT_ptr)[feedvar_index].shape[i];
+ }
+ return element_num;
+ }
+
+ size_t feedvar_bytesize(int feedvar_index) {
+ return feedvar_element_num(feedvar_index) *
+ feedvar_element_bytesize(feedvar_index);
+ }
+
+ ShapeVector feedvar_shape_nobatch(int feedvar_index) {
+ if (!check_feedvar_valid(feedvar_index)) {
+ return ShapeVector();
+ }
+ return ShapeVector{(*inVectorT_ptr)[feedvar_index].shape.begin() + 1,
+ (*inVectorT_ptr)[feedvar_index].shape.end()};
+ }
+
+ VectorOfShapeVector feedvar_shape_nobatch() {
+ VectorOfShapeVector vector_of_feedvar_shape_nobatch(inVectorT_ptr->size());
+ for (int index = 0; index < inVectorT_ptr->size(); ++index) {
+ vector_of_feedvar_shape_nobatch.push_back(feedvar_shape_nobatch(index));
+ }
+ return vector_of_feedvar_shape_nobatch;
+ }
+
+ // At present, it is considered that the batch of all feedvar is consistent.
+ // so for each feedvar, PaddleTensor.shape[0] should be the same.
+ bool check_batch_align() {
+ int batch_size_align = feedvar_batch_size(0);
+ for (int feedvar_index = 0; feedvar_index < inVectorT_ptr->size();
+ ++feedvar_index) {
+ if (feedvar_batch_size(feedvar_index) != batch_size_align) {
+ return 0;
+ }
+ }
+ /*
+ for(int fetchvar_index = 0; fetchvar_index < outVectorT_ptr->size();
+ ++fetchvar_index) {
+ if(fetchvar_batch_size(fetchvar_index) != batch_size_align) {
+ return 0;
+ }
+ }
+ */
+ return 1;
+ }
+
+ size_t batch_size() {
+ if (check_batch_align()) {
+ return feedvar_batch_size(0);
+ }
+ return 0;
+ }
};
+// `Several Task` or `part of batch in Task` can be a TaskMeta.
+// Task is the original Request from User.
+// For example, the batch of Task is 30. There are 4 Requests.
+// The batch of BatchTasks is 100, which means we can deal 100 batch 1 time.
+// TaskMeta-1:{task-1,0,30} TaskMeta-2:{task-2,0,30} TaskMeta-3:{task-3,0,30}
+// but the last Task will be divided to 2 TaskMeta.
+// TaskMeta-4:{task-4,0,10} TaskMeta-5:{task-4,10,30}.
+// TaskMeta-1 ~ TaskMeta-4 will be inside BatchTasks-1.
+// TaskMeta-5 will be inside BatchTasks-2.
+
+// TaskMeta is necessary.
+// cause we need know the the corresponding relationship between
+// `batch_out`(which is in BatchTasks) and `outVectorT_ptr`(which is in Task).
+// especially when 1 Task be divided into several TaskMeta and be put into
+// several different BatchTasks.
template
struct TaskMeta {
TaskMeta(TaskT* ptr, size_t start, size_t add)
@@ -79,6 +219,11 @@ struct TaskMeta {
size_t end;
};
+// each TaskT is already include batch in itself
+// BatchTasks need to combine several `small TaskMeta` into a new `big TaskT`.
+// The only difference between the `big TaskT` and `small TaskT` is that
+// the TaskT.inVectorT_ptr->[feedvar_index].shape[0]
+// which is actually batch_size is different.
template
class BatchTasks {
public:
@@ -91,33 +236,38 @@ class BatchTasks {
_rem_size(batch_size),
_batch_align(batch_align) {
_batch_in.clear();
+ _batch_in_offset.clear();
_batch_out.clear();
- _tasks.clear();
+ _batch_out_offset.clear();
+ _taskmeta_vector.clear();
}
~BatchTasks() {
_batch_in.clear();
+ _batch_in_offset.clear();
_batch_out.clear();
- _tasks.clear();
+ _batch_out_offset.clear();
+ _taskmeta_vector.clear();
}
// synchronized operation
+ // because Upper level callers of this function have already locked.
size_t append_task(TaskT* task) {
size_t add = std::min(task->rem, _rem_size);
if (!_batch_align) {
add = task->rem;
}
-
- TaskMetaT tm(task, task->in->size() - task->rem, add);
- _tasks.push_back(tm);
+ int start_index = task->batch_size() - task->rem;
+ TaskMetaT tm(task, start_index, add);
+ _taskmeta_vector.push_back(tm);
task->rem -= add;
_rem_size -= add;
return _rem_size;
}
- static bool check_valid(const typename TaskT::InArrayT& in,
- const typename TaskT::OutArrayT& out,
+ static bool check_valid(const typename TaskT::InVectorT& in,
+ const typename TaskT::OutVectorT& out,
bool align) {
(void)in;
(void)out;
@@ -125,40 +275,221 @@ class BatchTasks {
return true;
}
+ // this should be modified totally.
+ // maybe we don`t need to do this inside the BatchTasks.
+ // we can do the copy work outside the BatchTasks.
+ // cause maybe next time we don`t need to do the extra copy.
+ // directly copy the every Task into the Predictor.
+
+ // lod is not supported.
+ // if lod is set, we should not allow to use the bsf task.
+
+ // batch.merge_tasks() is thread-safe function
+ // cause batch is a local variable and Task is just read, not written.
void merge_tasks() {
- for (size_t ti = 0; ti < _tasks.size(); ++ti) {
- TaskMetaT& tm = _tasks[ti];
- for (size_t vi = tm.begin; vi < tm.end; ++vi) {
- _batch_in.push_back((*tm.task->in)[vi]);
- _batch_out.push_back((*tm.task->out)[vi]);
+ if (_taskmeta_vector.size() <= 0) {
+ return;
+ }
+
+ // Temporarily, the batch of each feedvar is consistent
+ // If not consistent, use feedvar_batch_size instead of task->batch_size().
+ int temp_batch = 0;
+ for (size_t ti = 0; ti < _taskmeta_vector.size(); ++ti) {
+ TaskMetaT& tm = _taskmeta_vector[ti];
+ temp_batch += tm.task->batch_size();
+ }
+ if (temp_batch > _batch_size) {
+ LOG(ERROR) << "_realNumber_batch_in >_batch_size, error.";
+ return;
+ }
+
+ int feedvar_num = _taskmeta_vector[0].task->inVectorT_ptr->size();
+ if (_batch_in_offset.size() == 0) {
+ _batch_in_offset.resize(feedvar_num, 0);
+ _realNumber_batch_in.resize(feedvar_num, temp_batch);
+ }
+
+ for (size_t ti = 0; ti < _taskmeta_vector.size(); ++ti) {
+ TaskMetaT& tm = _taskmeta_vector[ti];
+
+ for (int index = 0; index < feedvar_num; ++index) {
+ const paddle::PaddleTensor& feedVarTensor =
+ (*tm.task->inVectorT_ptr)[index];
+ int feedvar_bytesize = tm.task->feedvar_bytesize(index);
+
+ if (ti == 0) {
+ if (feedVarTensor.lod.size() > 0 && feedVarTensor.lod[0].size() > 0) {
+ LOG(ERROR) << "lod Tensor is not supported now.";
+ return;
+ }
+ // for now, we assume that every task feedvar_bytesize is the same.
+ // which means we dont support auto embedding.
+ // but for different feedvar, it is different.
+ paddle::PaddleTensor paddleTensor;
+ paddleTensor.dtype = feedVarTensor.dtype;
+ paddleTensor.name = feedVarTensor.name;
+ paddleTensor.lod = feedVarTensor.lod;
+ paddleTensor.shape = feedVarTensor.shape;
+ paddleTensor.shape[0] = _realNumber_batch_in[index];
+ paddleTensor.data.Resize(feedvar_bytesize *
+ _realNumber_batch_in[index]);
+ _batch_in.push_back(paddleTensor);
+ }
+
+ void* dst_ptr = _batch_in[index].data.data() +
+ feedvar_bytesize * _batch_in_offset[index];
+ void* source_ptr =
+ feedVarTensor.data.data() + feedvar_bytesize * tm.begin;
+ int length = feedvar_bytesize * (tm.end - tm.begin);
+ memcpy(dst_ptr, source_ptr, length);
+ _batch_in_offset[index] += length;
}
}
}
+ bool check_fetchvar_valid(int fetchvar_index) {
+ if (fetchvar_index < 0 || _batch_out.size() <= fetchvar_index) {
+ LOG(ERROR) << "fetchvar doesnt exsit or fetchvar_index error";
+ return 0;
+ }
+
+ if (_batch_out[fetchvar_index].shape.size() <= 0) {
+ LOG(ERROR) << "fetchvar[" << fetchvar_index << "].shape.size()<=0,error";
+ return 0;
+ }
+
+ return 1;
+ }
+
+ size_t fetchvar_batch_size(int fetchvar_index) {
+ if (!check_fetchvar_valid(fetchvar_index)) {
+ return 0;
+ }
+
+ return _batch_out[fetchvar_index].shape[0];
+ }
+
+ size_t fetchvar_element_bytesize(int fetchvar_index) {
+ if (!check_fetchvar_valid(fetchvar_index)) {
+ return 0;
+ }
+ int dtype = _batch_out[fetchvar_index].dtype;
+ if (dtype == paddle::PaddleDType::INT64) {
+ return sizeof(int64_t);
+ }
+ if (dtype == paddle::PaddleDType::FLOAT32) {
+ return sizeof(float);
+ }
+ if (dtype == paddle::PaddleDType::INT32) {
+ return sizeof(int32_t);
+ }
+ if (dtype == paddle::PaddleDType::UINT8) {
+ return sizeof(char);
+ }
+ return 0;
+ }
+
+ // Now, the implementation of this function is based on assumption
+ // that shape [0] = batch_size.
+ size_t fetchvar_element_num(int fetchvar_index) {
+ if (!check_fetchvar_valid(fetchvar_index)) {
+ return 0;
+ }
+ int element_num = 1;
+ if (_batch_out[fetchvar_index].shape.size() == 1) {
+ // cause shape[0] is batch_size.
+ return 1;
+ }
+ // start from shape[1], cause shape[0] = batch_size.
+ for (int i = 1; i < _batch_out[fetchvar_index].shape.size(); ++i) {
+ element_num *= _batch_out[fetchvar_index].shape[i];
+ }
+ return element_num;
+ }
+
+ size_t fetchvar_bytesize(int fetchvar_index) {
+ return fetchvar_element_num(fetchvar_index) *
+ fetchvar_element_bytesize(fetchvar_index);
+ }
+
+ bool check_fetchvar_batch_align() {
+ int batch_size_align = fetchvar_batch_size(0);
+
+ for (int fetchvar_index = 0; fetchvar_index < _batch_out.size();
+ ++fetchvar_index) {
+ if (fetchvar_batch_size(fetchvar_index) != batch_size_align) {
+ return 0;
+ }
+ }
+
+ return 1;
+ }
+
+ size_t fetchvar_batch_size() {
+ if (check_fetchvar_batch_align()) {
+ return fetchvar_batch_size(0);
+ }
+ return 0;
+ }
+
void notify_tasks() {
- if (_batch_out.size() != _batch_in.size()) {
- LOG(ERROR) << "batch size not consistency: " << _batch_out.size()
- << " != " << _batch_in.size();
+ if (_taskmeta_vector.size() <= 0) {
+ LOG(ERROR) << "_taskmeta_vector.size() <=0, error.";
+ return;
+ }
+ if (_realNumber_batch_in[0] != fetchvar_batch_size()) {
+ LOG(ERROR) << "_batch_out`s batch != _batch_in`s batch, error.";
return;
}
- for (size_t ti = 0, bi = 0; ti < _tasks.size(); ++ti) {
- TaskT* task = _tasks[ti].task;
- size_t begin = _tasks[ti].begin;
- size_t end = _tasks[ti].end;
+ int fetchvar_num = _batch_out.size();
+ if (_batch_out_offset.size() == 0) {
+ _batch_out_offset.resize(fetchvar_num, 0);
+ }
+
+ for (size_t ti = 0; ti < _taskmeta_vector.size(); ++ti) {
+ TaskT* task = _taskmeta_vector[ti].task;
+ size_t begin = _taskmeta_vector[ti].begin;
+ size_t end = _taskmeta_vector[ti].end;
size_t add = end - begin;
- for (size_t oi = begin; oi < end; ++oi, ++bi) {
- if (bi >= _batch_in.size()) {
- LOG(ERROR) << "batch index overflow: " << bi << " > "
- << _batch_in.size();
+ for (int index = 0; index < fetchvar_num; ++index) {
+ // the task->outVectorT_ptr is null before core->run().
+ // first time we should copy from _batch_out
+ // so we need init.
+ int fetchvar_bytesize_index = fetchvar_bytesize(index);
+ if (task->outVectorT_ptr->size() <= index) {
+ paddle::PaddleTensor tensor_out;
+ tensor_out.name = _batch_out[index].name;
+ tensor_out.dtype = paddle::PaddleDType(_batch_out[index].dtype);
+ tensor_out.shape = _batch_out[index].shape;
+ tensor_out.shape[0] = task->batch_size();
+ tensor_out.lod = _batch_out[index].lod;
+ // resize all batch memory at one time
+ size_t databuf_size = task->batch_size() * fetchvar_bytesize_index;
+ tensor_out.data.Resize(databuf_size);
+ task->outVectorT_ptr->push_back(tensor_out);
+ }
+
+ paddle::PaddleTensor& fetchVarTensor = (*task->outVectorT_ptr)[index];
+
+ void* dst_ptr =
+ fetchVarTensor.data.data() + fetchvar_bytesize_index * begin;
+ int length = fetchvar_bytesize_index * add;
+ if (_batch_out_offset[index] + length >
+ fetchvar_batch_size() * fetchvar_bytesize(index)) {
+ LOG(ERROR) << "_batch_out is less than taskmeta, error.";
return;
}
- (*task->out)[oi] = _batch_out[bi];
+ void* source_ptr =
+ _batch_out[index].data.data() + _batch_out_offset[index];
+
+ memcpy(dst_ptr, source_ptr, length);
+ _batch_out_offset[index] += length;
}
size_t index = task->index.fetch_add(add);
- if ((index + add) >= task->in->size()) {
+ if ((index + add) >= task->batch_size()) {
char c = 0;
while (write(task->write_fd, &c, 1) != 1 && errno == EINTR) {
}
@@ -167,22 +498,33 @@ class BatchTasks {
}
}
- const typename TaskT::InArrayT& in() const { return _batch_in; }
+ const typename TaskT::InVectorT& in() const { return _batch_in; }
- typename TaskT::OutArrayT& out() { return _batch_out; }
+ typename TaskT::OutVectorT& out() { return _batch_out; }
- size_t task_size() { return _tasks.size(); }
+ size_t task_size() { return _taskmeta_vector.size(); }
private:
- std::vector _tasks;
- typename TaskT::InArrayT _batch_in;
- typename TaskT::OutArrayT _batch_out;
+ std::vector _taskmeta_vector;
+ typename TaskT::InVectorT _batch_in;
+ std::vector _batch_in_offset;
+ std::vector _realNumber_batch_in;
+ typename TaskT::OutVectorT _batch_out;
+ std::vector _batch_out_offset;
+ std::vector _realNumber_batch_out;
size_t _rem_size;
size_t _batch_size;
bool _batch_align;
};
// BSF task handle
+// TaskHandler is the handle of Task.
+// `read_fd` is used for receive signal in brpc Thread.
+// 'write_fd' is used for write signal in bsf Thread.
+// when TaskMeta is done, bsf Thread will write to 'write_fd'.
+// brpc Thread is keeping reading 'read_fd' in a while loop.
+// brpc Thread will receive signal when TaskMeta is done.
+// so `read_fd` and 'write_fd' is used for communicate in different Thread.
template
struct TaskHandler {
int read_fd;
@@ -205,12 +547,11 @@ struct TaskHandler {
}
};
+// TaskExecutor is a Thread pool.
template
class TaskExecutor;
-template
-class TaskManager;
-
+// ThreadContext is used for start a bsf Thread.
template
struct ThreadContext {
TaskExecutor* executor;
@@ -231,14 +572,24 @@ struct ThreadContext {
}
};
+// TaskExecutor is a Thread pool.
+// Each Model corresponding to a Model.
+// TaskT is actually a Request preprocessed by ReaderOp.
+// TaskT will be divided as TaskMeta which will be
+// put into _task_queue in brpc-Thread by schedule().
+// TaskHander will be returned to brpc-Thread.
+// start() function will create `thread_num` bsf Threads.
+// every bsf Thread check the _task_queue and take TaskMeta from it.
+// when a Task`s all TaskMeta is done, TaskHander will be noticed.
template
class TaskExecutor {
public:
typedef typename TaskT::InType InType;
typedef typename TaskT::OutType OutType;
- typedef typename TaskT::InArrayT InArrayT;
- typedef typename TaskT::OutArrayT OutArrayT;
+ typedef typename TaskT::InVectorT InVectorT;
+ typedef typename TaskT::OutVectorT OutVectorT;
typedef std::vector TaskArrayT;
+ typedef baidu::paddle_serving::predictor::MempoolWrapper MempoolWrapper;
TaskExecutor()
: _stop(false),
@@ -258,9 +609,11 @@ class TaskExecutor {
THREAD_COND_DESTROY(&_cond);
}
- static TaskExecutor* instance() {
- static TaskExecutor singleton;
- return &singleton;
+ // cause vector.resize will use copy or move construct.
+ TaskExecutor(TaskExecutor&& other) noexcept {
+ if (this != &other) {
+ TaskExecutor();
+ }
}
void set_batch_size(size_t batch_size) { _batch_size = batch_size; }
@@ -277,8 +630,7 @@ class TaskExecutor {
_thread_reset_fn = reset_fn;
}
- void set_thread_callback_fn(
- boost::function cb) {
+ void set_thread_callback_fn(boost::function cb) {
_fn = cb;
}
@@ -287,15 +639,21 @@ class TaskExecutor {
static void* thread_entry(void* args);
- private:
- TaskExecutor(TaskExecutor const& other);
- TaskExecutor* operator=(TaskExecutor const& other);
-
int work(ThreadContext* context);
- TaskHandler schedule(const InArrayT&, OutArrayT&);
+ TaskHandler schedule(const void*, void*);
- bool fetch_batch(BatchTasks& batch); // NOLINT
+ bool move_task_to_batch(BatchTasks& batch); // NOLINT
+
+ private:
+ TaskExecutor(TaskExecutor const& other) = delete;
+
+ TaskExecutor& operator=(TaskExecutor const& other) = delete;
+ /*
+ TaskExecutor(TaskExecutor && other) = delete;
+
+ TaskExecutor& operator=(TaskExecutor && other) = delete;
+ */
bool _stop;
@@ -303,43 +661,76 @@ class TaskExecutor {
THREAD_MUTEX_T _mut;
THREAD_COND_T _cond;
- std::deque _task_queue;
+ std::list _task_queue;
boost::function _thread_init_fn;
boost::function _thread_reset_fn;
void** _user_thread_contexts;
std::vector*> _thread_contexts;
- friend class TaskManager;
size_t _batch_size;
bool _batch_align;
- boost::function _fn;
+ boost::function _fn;
};
+// TaskExecutorVector is a SingleTon class.
+// Each Model corresponding to a TaskExecutor.
+// So we need several TaskExecutor when there are more than 1 Model.
+template
+class TaskExecutorVector {
+ public:
+ static TaskExecutorVector& instance() {
+ static TaskExecutorVector singleton;
+ return singleton;
+ }
+
+ void resize(int size) { _vector_executor.resize(size); }
+
+ TaskExecutor& operator[](int index) {
+ if (_vector_executor.size() <= index || index <= -1) {
+ LOG(ERROR) << "_vector_executor.size() <= index or <= -1";
+ throw "_vector_executor.size() <= index or <= -1";
+ }
+ return _vector_executor[index];
+ }
+
+ private:
+ TaskExecutorVector() = default;
+ TaskExecutorVector(const TaskExecutorVector& other) = delete;
+ TaskExecutorVector& operator=(const TaskExecutorVector& other) =
+ delete;
+ TaskExecutorVector(TaskExecutorVector&& other) = delete;
+ TaskExecutorVector& operator=(TaskExecutorVector&& other) = delete;
+ std::vector> _vector_executor;
+};
+
+// TaskManager is actually a wrapper of Request in bsf.
+// TaskManager`s schedule() change Request to be TaskT.
+// and divided TaskT into several TaskMeta to put into the TaskExecutor`s
+// task_queue.
+// wait() is a while loop to receive signal when a whole Task is done.
template
class TaskManager {
public:
typedef Task TaskT;
- typedef typename TaskT::InArrayT InArrayT;
- typedef typename TaskT::OutArrayT OutArrayT;
-
- explicit TaskManager(TaskExecutor& exe, size_t batch_size) // NOLINT
- : _executor(exe) {}
+ typedef typename TaskT::InVectorT InVectorT;
+ typedef typename TaskT::OutVectorT OutVectorT;
- TaskManager() : _executor(*TaskExecutor::instance()) {}
+ explicit TaskManager(uint32_t index) // NOLINT
+ : _model_index(index) {}
~TaskManager() { wait(); }
- bool schedule(const InArrayT& in, OutArrayT& out); // NOLINT
+ bool schedule(const void* in, void* out); // NOLINT
void wait();
inline void clear() { wait(); }
private:
- TaskExecutor& _executor;
TaskHandler _task_owned;
+ uint32_t _model_index;
}; // class TaskManager
class AutoMutex {
@@ -357,5 +748,5 @@ class AutoMutex {
} // namespace bsf
} // namespace im
-#include "core/predictor/framework/bsf-inl-tensor.h"
+// #include "core/predictor/framework/bsf-inl-tensor.h"
#include "core/predictor/framework/bsf-inl.h"
diff --git a/core/predictor/framework/infer.cpp b/core/predictor/framework/infer.cpp
index e11861426..fd80ed639 100644
--- a/core/predictor/framework/infer.cpp
+++ b/core/predictor/framework/infer.cpp
@@ -56,15 +56,23 @@ int ReloadableInferEngine::proc_initialize(const configure::EngineDesc& conf,
}
// init bsf framework
- im::bsf::TaskExecutor::instance()->set_thread_init_fn(
- boost::bind(&InferEngine::thrd_initialize_impl, this));
- im::bsf::TaskExecutor::instance()->set_thread_reset_fn(
- boost::bind(&InferEngine::thrd_clear_impl, this));
- im::bsf::TaskExecutor::instance()->set_thread_callback_fn(
- boost::bind(&InferEngine::task_infer_impl, this, _1, _2));
- im::bsf::TaskExecutor::instance()->set_batch_size(_infer_batch_size);
- im::bsf::TaskExecutor::instance()->set_batch_align(_infer_batch_align);
- if (im::bsf::TaskExecutor::instance()->start(_infer_thread_num) != 0) {
+ im::bsf::TaskExecutorVector::instance()[_model_index]
+ .set_thread_init_fn(
+ boost::bind(&InferEngine::thrd_initialize_impl, this));
+ im::bsf::TaskExecutorVector::instance()[_model_index]
+ .set_thread_init_fn(
+ boost::bind(&InferEngine::thrd_initialize_impl, this));
+ im::bsf::TaskExecutorVector::instance()[_model_index]
+ .set_thread_reset_fn(boost::bind(&InferEngine::thrd_clear_impl, this));
+ im::bsf::TaskExecutorVector::instance()[_model_index]
+ .set_thread_callback_fn(
+ boost::bind(&InferEngine::task_infer_impl, this, _1, _2));
+ im::bsf::TaskExecutorVector::instance()[_model_index].set_batch_size(
+ _infer_batch_size);
+ im::bsf::TaskExecutorVector::instance()[_model_index].set_batch_align(
+ _infer_batch_align);
+ if (im::bsf::TaskExecutorVector::instance()[_model_index].start(
+ _infer_thread_num) != 0) {
LOG(ERROR) << "Failed start bsf executor, threads:" << _infer_thread_num;
return -1;
}
@@ -75,6 +83,11 @@ int ReloadableInferEngine::proc_initialize(const configure::EngineDesc& conf,
return 0;
}
+// Multiple threads will enter this method of the same object
+// One Model corresponds to One ReloadableInferEngine object.
+// ReloadableInferEngine object is Process object.
+// One ReloadableInferEngine object can have several ModelData
+// ModelData is Thread object.
int ReloadableInferEngine::infer(const void* in,
void* out,
uint32_t batch_size) {
@@ -82,9 +95,10 @@ int ReloadableInferEngine::infer(const void* in,
return infer_impl(in, out, batch_size);
}
- im::bsf::TaskManager task_manager;
- task_manager.schedule(*(reinterpret_cast(in)),
- *(reinterpret_cast(out)));
+ im::bsf::TaskManager task_manager(
+ _model_index);
+
+ task_manager.schedule(in, out);
task_manager.wait();
return 0;
}
@@ -110,7 +124,7 @@ int ReloadableInferEngine::proc_finalize() {
}
if (_infer_thread_num > 0) {
- im::bsf::TaskExecutor::instance()->stop();
+ im::bsf::TaskExecutorVector::instance()[_model_index].stop();
}
return 0;
}
@@ -191,6 +205,7 @@ int VersionedInferEngine::proc_initialize(const configure::EngineDesc& conf,
std::string engine_type = conf.type();
InferEngine* engine =
StaticInferFactory::instance().generate_object(engine_type);
+ engine->set_model_index(_model_index);
if (!engine) {
LOG(ERROR) << "Failed generate engine with type:" << engine_type;
return -1;
@@ -362,8 +377,8 @@ int VersionedInferEngine::infer_impl(const void* in,
uint32_t batch_size) {
return -1;
}
-int VersionedInferEngine::task_infer_impl(const BatchTensor& in,
- BatchTensor& out) { // NOLINT
+int VersionedInferEngine::task_infer_impl(const void* in,
+ void* out) { // NOLINT
return -1;
}
@@ -373,12 +388,14 @@ int InferManager::proc_initialize(const char* path, const char* file) {
LOG(ERROR) << "failed load infer config, path: " << path << "/" << file;
return -1;
}
- size_t engine_num = model_toolkit_conf.engines_size();
- for (size_t ei = 0; ei < engine_num; ++ei) {
+ uint32_t engine_num = model_toolkit_conf.engines_size();
+ im::bsf::TaskExecutorVector::instance().resize(engine_num);
+ for (uint32_t ei = 0; ei < engine_num; ++ei) {
LOG(INFO) << "model_toolkit_conf.engines(" << ei
<< ").name: " << model_toolkit_conf.engines(ei).name();
std::string engine_name = model_toolkit_conf.engines(ei).name();
VersionedInferEngine* engine = new (std::nothrow) VersionedInferEngine();
+ engine->set_model_index(ei);
if (!engine) {
LOG(ERROR) << "Failed generate versioned engine: " << engine_name;
return -1;
diff --git a/core/predictor/framework/infer.h b/core/predictor/framework/infer.h
old mode 100755
new mode 100644
index 6113dc8ef..3cdef9dc9
--- a/core/predictor/framework/infer.h
+++ b/core/predictor/framework/infer.h
@@ -17,6 +17,7 @@
#include
#include
#include
+#include
#include
#include
#include
@@ -25,6 +26,7 @@
#include "core/predictor/framework/bsf.h"
#include "core/predictor/framework/factory.h"
#include "core/predictor/framework/infer_data.h"
+#include "core/predictor/framework/memory.h"
#include "paddle_inference_api.h" // NOLINT
namespace baidu {
namespace paddle_serving {
@@ -71,7 +73,7 @@ class InferEngine {
virtual int infer(const void* in, void* out, uint32_t batch_size = -1) {
return infer_impl(in, out, batch_size);
}
-
+ virtual void set_model_index(uint32_t index) { _model_index = index; }
virtual int reload() = 0;
virtual uint64_t version() const = 0;
@@ -86,12 +88,13 @@ class InferEngine {
virtual int infer_impl(const void* in,
void* out,
uint32_t batch_size = -1) = 0;
- virtual int task_infer_impl(const BatchTensor& in,
- BatchTensor& out) = 0; // NOLINT
+ virtual int task_infer_impl(const void* in, void* out) = 0; // NOLINT
+ protected:
+ uint32_t _model_index;
// end: framework inner call
};
-
+typedef im::bsf::Task TaskT;
class ReloadableInferEngine : public InferEngine {
public:
virtual ~ReloadableInferEngine() {}
@@ -104,7 +107,6 @@ class ReloadableInferEngine : public InferEngine {
};
virtual int load(const configure::EngineDesc& conf) = 0;
- typedef im::bsf::Task TaskT;
int proc_initialize_impl(const configure::EngineDesc& conf, bool version);
@@ -179,6 +181,8 @@ struct ModelData {
delete cores[1];
}
+ void* get() { return cores[current_idx]->get(); }
+
EngineCore* cores[2];
uint32_t current_idx;
};
@@ -191,14 +195,20 @@ class DBReloadableInferEngine : public ReloadableInferEngine {
int proc_initialize(const configure::EngineDesc& conf, bool version) {
THREAD_KEY_CREATE(&_skey, NULL);
THREAD_MUTEX_INIT(&_mutex, NULL);
+ gpu_index = 0;
return ReloadableInferEngine::proc_initialize(conf, version);
}
+ // 进程初始化会调用load,但由于未执行线程初始化,所以_reload_vec为空,不再继续执行。
+ // 热加载的话会调用load,由于线程已经初始化,_reload_vec不为空,所以继续执行load_data操作加载数据。
+ // 线程初始化会执行load_data操作加载数据,然后将engine加入_reload_vec中。
+ // 每个模型只有一个CloneDBReloadableInferEngine对象。
+ // 但一个CloneDBReloadableInferEngine对象,可以包含N个EngineCore。
virtual int load(const configure::EngineDesc& conf) {
if (_reload_vec.empty()) {
return 0;
}
-
+ gpu_index = 0;
for (uint32_t ti = 0; ti < _reload_vec.size(); ++ti) {
if (load_data(_reload_vec[ti], conf) != 0) {
LOG(ERROR) << "Failed reload engine model: " << ti;
@@ -210,7 +220,8 @@ class DBReloadableInferEngine : public ReloadableInferEngine {
return 0;
}
- int load_data(ModelData* md, const configure::EngineDesc& conf) {
+ virtual int load_data(ModelData* md,
+ const configure::EngineDesc& conf) {
uint32_t next_idx = (md->current_idx + 1) % 2;
if (md->cores[next_idx]) {
delete md->cores[next_idx];
@@ -219,28 +230,29 @@ class DBReloadableInferEngine : public ReloadableInferEngine {
md->cores[next_idx] = new (std::nothrow) EngineCore;
// params.dump();
- if (!md->cores[next_idx] || md->cores[next_idx]->create(conf) != 0) {
+ size_t gpu_ids_num = conf.gpu_ids_size();
+ im::bsf::AutoMutex lock(_mutex);
+ int gpu_id = -1;
+ if (gpu_ids_num > 0) {
+ gpu_id = conf.gpu_ids(gpu_index % gpu_ids_num);
+ }
+ if (!md->cores[next_idx] ||
+ md->cores[next_idx]->create(conf, gpu_id) != 0) {
LOG(ERROR) << "Failed create model, path: " << conf.model_dir();
return -1;
}
+ gpu_index++;
md->current_idx = next_idx;
return 0;
}
virtual int thrd_initialize_impl() {
- // memory pool to be inited in non-serving-threads
- if (MempoolWrapper::instance().thread_initialize() != 0) {
- LOG(ERROR) << "Failed thread initialize mempool";
- return -1;
- }
-
ModelData* md = new (std::nothrow) ModelData;
if (!md || load_data(md, _conf) != 0) {
LOG(ERROR) << "Failed create thread data from " << _conf.model_dir();
return -1;
}
- LOG(ERROR) << "THREAD_SETSPECIFIC _skey = md";
THREAD_SETSPECIFIC(_skey, md);
im::bsf::AutoMutex lock(_mutex);
_reload_vec.push_back(md);
@@ -248,11 +260,33 @@ class DBReloadableInferEngine : public ReloadableInferEngine {
}
int thrd_clear_impl() {
- // for non-serving-threads
- if (MempoolWrapper::instance().thread_clear() != 0) {
- LOG(ERROR) << "Failed thread clear mempool";
- return -1;
- }
+ // actually, there are 2 kinds of multi-thread.
+ // 1. brpc thread 2. bsf Task thread
+ // each request is in 1-single brpc thread.
+ // IF (bsf Task thread is not used)
+ // every single brpc thread corresponds to all the DBReloadableInferEngines.
+ // each request runs all models in 1-single brpc thread.
+ // every single brpc thread will create or clone N predictor.
+ // N = the number of Model.
+ // so if there are 2 models, and --thread 10.
+ // each brpc thread will create predictor of Model-1 and Model-2.
+ // there are totally 10 predictors of Model-1 and 10 predictors of Model-2
+ // cause there are 10 brpc threads.
+
+ // IF bsf Task thread is used。
+ // there will be a ThreadPool called bsf TaskExecutor.
+ // TaskExecutorVector is the vector of TaskExecutor.
+ // the number of TaskExecutor equals to the number of Model.
+ // 1 TaskExecutor corresponding to 1 Model.
+ // 1 TaskExecutor have N bsf threads.
+ // 1 bsf thread corresponds to 1 predictor of
+ // the Model corresponding to the TaskExecutor.
+ // brpc thread only put the data into the task_queue(which is in
+ // TaskExecutor)
+ // EngineCore->infer() is running in bsf Task thread.
+
+ // MempoolWrapper::instance() is actually a Thread-Local Mempool.
+ // so it belongs to a single Thread.
return 0;
}
@@ -278,6 +312,7 @@ class DBReloadableInferEngine : public ReloadableInferEngine {
THREAD_KEY_T _skey;
THREAD_MUTEX_T _mutex;
std::vector*> _reload_vec;
+ int gpu_index = 0;
};
// 多个EngineCore共用同一份模型数据
@@ -287,88 +322,72 @@ class CloneDBReloadableInferEngine
public:
virtual ~CloneDBReloadableInferEngine() {}
- virtual int proc_initialize(const configure::EngineDesc& conf, bool version) {
- _pd = new (std::nothrow) ModelData;
- if (!_pd) {
- LOG(ERROR) << "Failed to allocate for ProcData";
- return -1;
- }
- return DBReloadableInferEngine::proc_initialize(conf, version);
- }
+ // 进程初始化会调用load,但由于未执行线程初始化,所以_reload_vec为空,不再继续执行。
+ // 热加载的话会调用load,由于线程已经初始化,_reload_vec不为空,所以继续执行load_data操作加载数据。
+ // 线程初始化会执行load_data操作加载数据,然后将engine加入_reload_vec中。
+ // 每个模型只有一个CloneDBReloadableInferEngine对象。
+ // 但一个CloneDBReloadableInferEngine对象,可以包含N个EngineCore。
- virtual int load(const configure::EngineDesc& conf) {
- // 加载进程级模型数据
- if (!_pd ||
- DBReloadableInferEngine::load_data(_pd, conf) != 0) {
- LOG(ERROR) << "Failed to create common model from [" << conf.model_dir()
- << "].";
- return -1;
+ virtual int load_data(ModelData* md,
+ const configure::EngineDesc& conf) {
+ uint32_t next_idx = (md->current_idx + 1) % 2;
+ if (md->cores[next_idx]) {
+ delete md->cores[next_idx];
}
- LOG(WARNING) << "Succ load common model[" << _pd->cores[_pd->current_idx]
- << "], path[" << conf.model_dir() << "].";
+ md->cores[next_idx] = new (std::nothrow) EngineCore;
- if (DBReloadableInferEngine::_reload_vec.empty()) {
- return 0;
+ // params.dump();
+ size_t gpu_ids_num = conf.gpu_ids_size();
+ im::bsf::AutoMutex lock(DBReloadableInferEngine::_mutex);
+ int gpu_id = -1;
+ if (gpu_ids_num > 0) {
+ gpu_id = conf.gpu_ids(DBReloadableInferEngine::gpu_index %
+ gpu_ids_num);
}
-
- for (uint32_t ti = 0;
- ti < DBReloadableInferEngine::_reload_vec.size();
- ++ti) {
- if (load_data(DBReloadableInferEngine::_reload_vec[ti],
- _pd->cores[_pd->current_idx]) != 0) {
- LOG(ERROR) << "Failed reload engine model: " << ti;
+ // gpu_index will be set to be 0, when load() or proc_initial() is called.
+ // gpu_index < gpu_ids_num, means there are predictors still not create
+ // on some GPU card.
+ // so we need to create the predictor.
+ // gpu_index >= gpu_ids_num, means each GPU card has already create one.
+ // so we need to clone the predictor.
+ if (DBReloadableInferEngine::gpu_index < gpu_ids_num) {
+ if (!md->cores[next_idx] ||
+ md->cores[next_idx]->create(conf, gpu_id) != 0) {
+ LOG(ERROR) << "Failed create model, path: " << conf.model_dir();
return -1;
}
+ DBReloadableInferEngine::gpu_index++;
+ md->current_idx = next_idx;
+ if (_cloneTemplate.size() <
+ DBReloadableInferEngine::gpu_index) {
+ _cloneTemplate.push_back(md);
+ } else {
+ _cloneTemplate[DBReloadableInferEngine::gpu_index - 1] = md;
+ }
+ } else {
+ // when gpu_id = -1, means we use cpu, but the index should be 0.
+ // _cloneTemplate[-1] will occur error.
+ // actually, when gpu_id = -1, there is only 1 predictor in
+ // _cloneTemplate.
+ // so the index should always be 0 when gpu_id = -1.
+ if (gpu_id == -1) gpu_id = 0;
+ if (!md->cores[next_idx] ||
+ md->cores[next_idx]->clone(_cloneTemplate[gpu_id]->get()) != 0) {
+ LOG(ERROR) << "Failed clone model from core";
+ return -1;
+ }
+ DBReloadableInferEngine::gpu_index++;
+ md->current_idx = next_idx;
+ LOG(WARNING) << "core clone model succ, cur_idx[" << md->current_idx
+ << "].";
}
- LOG(WARNING) << "Succ load clone model, path[" << conf.model_dir() << "]";
- return 0;
- }
-
- // 加载线程级对象,多个线程级对象共用pd_core的模型数据
- int load_data(ModelData* td, EngineCore* pd_core) {
- uint32_t next_idx = (td->current_idx + 1) % 2;
- if (td->cores[next_idx]) {
- delete td->cores[next_idx];
- }
-
- td->cores[next_idx] = new (std::nothrow) EngineCore;
- if (!td->cores[next_idx] ||
- td->cores[next_idx]->clone(pd_core->get()) != 0) {
- LOG(ERROR) << "Failed clone model from pd_core[ " << pd_core << "], idx["
- << next_idx << "]";
- return -1;
- }
- td->current_idx = next_idx;
- LOG(WARNING) << "td_core[" << td->cores[td->current_idx]
- << "] clone model from pd_core[" << pd_core
- << "] succ, cur_idx[" << td->current_idx << "].";
- return 0;
- }
-
- virtual int thrd_initialize_impl() {
- // memory pool to be inited in non-serving-threads
- if (MempoolWrapper::instance().thread_initialize() != 0) {
- LOG(ERROR) << "Failed thread initialize mempool";
- return -1;
- }
-
- ModelData* md = new (std::nothrow) ModelData;
- if (!md || load_data(md, _pd->cores[_pd->current_idx]) != 0) {
- LOG(ERROR) << "Failed clone thread data, origin_core["
- << _pd->cores[_pd->current_idx] << "].";
- return -1;
- }
-
- THREAD_SETSPECIFIC(DBReloadableInferEngine::_skey, md);
- im::bsf::AutoMutex lock(DBReloadableInferEngine::_mutex);
- DBReloadableInferEngine::_reload_vec.push_back(md);
return 0;
}
protected:
- ModelData*
- _pd; // 进程级EngineCore,多个线程级EngineCore共用该对象的模型数据
+ // 模板EngineCore,如果已创建,则多个线程级EngineCore共用该对象的模型数据
+ std::vector*> _cloneTemplate;
};
template
@@ -505,8 +524,8 @@ class FluidInferEngine : public CloneDBReloadableInferEngine {
return 0;
}
- int task_infer_impl(const BatchTensor& in, BatchTensor& out) { // NOLINT
- return infer_impl(&in, &out);
+ int task_infer_impl(const void* in, void* out) { // NOLINT
+ return infer_impl(in, out);
}
};
@@ -559,7 +578,7 @@ class VersionedInferEngine : public InferEngine {
int infer_impl(const void* in, void* out, uint32_t batch_size = -1);
- int task_infer_impl(const BatchTensor& in, BatchTensor& out);
+ int task_infer_impl(const void* in, void* out);
private:
boost::unordered_map _versions;
diff --git a/core/predictor/framework/server.cpp b/core/predictor/framework/server.cpp
old mode 100644
new mode 100755
index 25b407950..8ced6f1e9
--- a/core/predictor/framework/server.cpp
+++ b/core/predictor/framework/server.cpp
@@ -91,6 +91,7 @@ int ServerManager::start_and_wait() {
}
}
+ // rpc multi-thread start from here.
if (_server.Start(FLAGS_port, &_options) != 0) {
LOG(ERROR) << "Failed to start Paddle Inference Server";
return -1;
diff --git a/core/predictor/mempool/mempool.cpp b/core/predictor/mempool/mempool.cpp
old mode 100644
new mode 100755
index 88936687e..0deab0226
--- a/core/predictor/mempool/mempool.cpp
+++ b/core/predictor/mempool/mempool.cpp
@@ -24,7 +24,7 @@ namespace fugue {
namespace memory {
void Region::init() {
- _big_mem_capacity = 64 * 1024 * 1024; // 64MB
+ _big_mem_capacity = 128 * 1024 * 1024; // 128MB
_big_mem_start = new char[_big_mem_capacity];
}
diff --git a/core/predictor/mempool/mempool.h b/core/predictor/mempool/mempool.h
index a10e8f97a..a4143d4b5 100644
--- a/core/predictor/mempool/mempool.h
+++ b/core/predictor/mempool/mempool.h
@@ -129,7 +129,7 @@ class FreeList {
to get the class Pointer
for example
T is the member of class Node, T data, 'data' is the name.
- T* value is the member(pointer type) class Node
+ T* value is the member(pointer type) of class Node
so we can get the Node* by calling container_of(value, Node, data)
*/
Node* node = container_of(value, Node, data);
@@ -261,7 +261,11 @@ struct BlockReference {
// because BlockFreeList is a threal-safe Singleton.
// so we don`t release Block, it is global memory.
-// total number is 32*1024
+// total number is 256*1024.
+// the MAX_BLOCK_COUNT of Region(one thread one Region) is 1024.
+// so BlockFreeList allow 256 Region(means 256 thread).
+// the memory used by BlockFreeListType is sizeof(void*)*256*1024.
+// Block(2MB) memory is created only when get() is called.
class BlockFreeList {
public:
static const int MAX_BLOCK_COUNT = 256 * 1024;
@@ -341,9 +345,10 @@ class Region {
2 * 1024 *
1024; // 2MB,means when you need less than 2M, get memory from Block.
- // 64MB,means when you need less than 64MB, get memory from BigMemory instead
+ // 128MB,means when you need less than 128MB, get memory from BigMemory
+ // instead
// of BigNode
- static const int BIGNODE_MEM_THRESHOLD = (64 * 1024 * 1024 + 1);
+ static const int BIGNODE_MEM_THRESHOLD = (128 * 1024 * 1024 + 1);
static const int COUNTER_SIZE =
BIGNODE_MEM_THRESHOLD / BIG_MEM_THRESHOLD + 1; // this is not used
@@ -374,7 +379,8 @@ class Mempool {
void* malloc(size_t size) {
size = _align(size);
// It does not enter the if statement the first time.
- // Because the block has not been used up, it will enter.
+ // The if statement may enter after the block is created.
+ // If the block has not been used up, it will enter.
if (size <= _free_size) {
void* p = _free_cursor;
_free_size -= size;
@@ -392,7 +398,7 @@ class Mempool {
return;
}
- // memory in Block,update the pointer.
+ // memory in _block,update the pointer.
if (_free_cursor - size == static_cast(p)) {
// for example, you need to release -(8+1)bytes
// you can only release -8bytes,cause -(8+2)byte is used by other.
@@ -424,9 +430,8 @@ class Mempool {
}
// 可能返回的是单独Region中malloc的内存。
- // 也可能是Block,例如new_size=1M, old_data原本的指针头就在1.2M处,old_size
- // =
- // 0.5M
+ // 也可能是Block,例如new_size=1M, old_data原本的指针头就在1.2M处
+ // old_size = 0.5M
// 此时,_free_size = 0.3M,new_size<2M,但是required = 1-0.5 >0.3
// 分配出来的就是Block,但是该Block没有并很完美的利用完全。
void* p = this->malloc_from_region(new_size);
diff --git a/core/predictor/src/pdserving.cpp b/core/predictor/src/pdserving.cpp
old mode 100755
new mode 100644
index e88d9b3b2..6fbf01c8b
--- a/core/predictor/src/pdserving.cpp
+++ b/core/predictor/src/pdserving.cpp
@@ -68,13 +68,14 @@ static bvar::PassiveStatus s_predictor_revision(
DEFINE_bool(V, false, "print version, bool");
DEFINE_bool(g, false, "user defined gflag path");
DECLARE_string(flagfile);
-
+/*
namespace bthread {
extern pthread_mutex_t g_task_control_mutex;
}
pthread_mutex_t g_worker_start_fn_mutex = PTHREAD_MUTEX_INITIALIZER;
-
+*/
void pthread_worker_start_fn() {
+ /*
while (pthread_mutex_lock(&g_worker_start_fn_mutex) != 0) {
}
@@ -83,15 +84,18 @@ void pthread_worker_start_fn() {
if (lock_status == EBUSY || lock_status == EAGAIN) {
pthread_mutex_unlock(&bthread::g_task_control_mutex);
}
+ */
Resource::instance().thread_initialize();
// Try to avoid deadlock in bthread
+ /*
if (lock_status == EBUSY || lock_status == EAGAIN) {
while (pthread_mutex_lock(&bthread::g_task_control_mutex) != 0) {
}
}
pthread_mutex_unlock(&g_worker_start_fn_mutex);
+ */
}
static void g_change_server_port() {
@@ -126,7 +130,7 @@ int main(int argc, char** argv) {
return 0;
}
- //google::ParseCommandLineFlags(&argc, &argv, true);
+ // google::ParseCommandLineFlags(&argc, &argv, true);
g_change_server_port();
@@ -202,7 +206,7 @@ int main(int argc, char** argv) {
}
VLOG(2) << "Succ call pthread worker start function";
- //this is not used by any code segment,which can be cancelled.
+ // this is not used by any code segment,which can be cancelled.
if (Resource::instance().general_model_initialize(FLAGS_resource_path,
FLAGS_resource_file) != 0) {
LOG(ERROR) << "Failed to initialize general model conf: "
diff --git a/paddle_inference/paddle/include/paddle_engine.h b/paddle_inference/paddle/include/paddle_engine.h
index 262a0378b..d2027ed28 100755
--- a/paddle_inference/paddle/include/paddle_engine.h
+++ b/paddle_inference/paddle/include/paddle_engine.h
@@ -19,6 +19,7 @@
#include