Skip to content

Commit

Permalink
shape check fix (PaddlePaddle#39)
Browse files Browse the repository at this point in the history
  • Loading branch information
xcpher authored Jul 6, 2022
1 parent cdd3beb commit 7d7fb8e
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 87 deletions.
61 changes: 0 additions & 61 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1456,18 +1456,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
platform::RecordEvent record_event("compute",
platform::TracerEventType::OperatorInner,
1, platform::EventRole::kInnerOp);


// infershape check
// RuntimeInferShapeContext infer_shape_ctx(*this, *runtime_ctx);
// std::vector<std::vector<DDim>> pre_dims;
// std::vector<std::vector<LoD>> pre_lod;
// auto outnames = Outputs();
// for (auto& var_name_item : outnames) {
// pre_dims.push_back(infer_shape_ctx.GetOutputsDim(var_name_item.first));
// pre_lod.push_back(infer_shape_ctx.GetOutputsLod(var_name_item.first));
// }

if (run_phi_kernel_) {
phi::KernelContext pt_kernel_context;
// Do data transform before building KernelContext
Expand All @@ -1480,55 +1468,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
(*kernel_func_)(
ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx));
}

// if (all_kernels_must_compute_runtime_shape_) {
// std::vector<std::vector<DDim>> after_dims;
// std::vector<std::vector<LoD>> after_lod;
// for (auto& var_name_item : outnames) {
// after_dims.push_back(infer_shape_ctx.GetOutputsDim(var_name_item.first));
// after_lod.push_back(infer_shape_ctx.GetOutputsLod(var_name_item.first));
// }
// if (pre_dims.size() != after_dims.size()) {
// CHECK(false) << "dims error: " << Info().Proto().type();
// }
// for (size_t i = 0; i < pre_dims.size(); i++) {
// if (pre_dims[i].size() != after_dims[i].size()) {
// CHECK(false) << "dims error: " << Info().Proto().type();
// }
// for (size_t j = 0; j < pre_dims[i].size(); j++) {
// if (pre_dims[i][j] != after_dims[i][j]) {
// CHECK(false) << "dims error: " << Info().Proto().type();
// }
// }
// }
// if (pre_lod.size() != after_lod.size()) {
// CHECK(false) << "lods error: " << Info().Proto().type();
// }
// for (size_t i = 0; i < pre_lod.size(); i++) {
// if (pre_lod[i].size() != after_lod[i].size()) {
// CHECK(false) << "lods error: " << Info().Proto().type();
// }
// for (size_t j = 0; j < pre_lod[i].size(); j++) {
// auto& a = pre_lod[i][j];
// auto& b = after_lod[i][j];
// if (a.size() != b.size()) {
// CHECK(false) << "lods error: " << Info().Proto().type();
// }
// for (size_t i = 0; i < a.size(); i++) {
// const auto &a_level = a[i];
// const auto &b_level = b[i];
// if (a_level.size() != b_level.size()) {
// CHECK(false) << "lods error: " << Info().Proto().type();
// }
// for (size_t j = 0; j < a_level.size(); j++) {
// if (a_level[j] != b_level[j]) {
// CHECK(false) << "lods error: " << Info().Proto().type();
// }
// }
// }
// }
// }
// }
}

if (!transfered_inplace_vars.empty()) {
Expand Down
50 changes: 24 additions & 26 deletions paddle/fluid/framework/ps_gpu_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include <mutex>

#include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/platform/cpu_helper.h"
Expand All @@ -35,7 +36,7 @@ namespace paddle {
namespace framework {

std::atomic<int> PSGPUWorker::shape_check_count_(16);
std::atomic<bool> PSGPUWorker::shape_check_flag_(false);
std::atomic<bool> PSGPUWorker::shape_check_flag_(true);

void PSGPUWorker::CreateDeviceResource(const ProgramDesc& main_prog) {
this->HogwildWorker::CreateDeviceResource(main_prog);
Expand Down Expand Up @@ -256,7 +257,7 @@ int PSGPUWorker::OpRunAndShapeCheck(OperatorBase& op,
VLOG(0) << "Begin OpRunAndShapeCheck... "
<< shape_check_count_.load();
if (shape_check_count_.fetch_sub(1) <= 0) {
// shape_check_flag_ = false;
shape_check_flag_ = false;
}
// before op run
InferShapeCheckData check_data;
Expand All @@ -280,40 +281,42 @@ int PSGPUWorker::OpRunAndShapeCheck(OperatorBase& op,
after_dims.push_back(infer_shape_ctx.GetOutputsDim(var_name_item.first));
after_lods.push_back(infer_shape_ctx.GetOutputsLod(var_name_item.first));
}
// auto& op_name = op.Info().Proto().type();
CHECK(pre_dims.size() == after_dims.size())
<< "dims error, op name:" << op.Info().Proto().type();

std::string op_name = "unknow_op";
if (op.Info().HasOpProtoAndChecker()) {
op_name = op.Info().Proto().type();
}

#define SHAPE_CHECK_EQ(__VAL0, __VAL1) \
PADDLE_ENFORCE_EQ(__VAL0, __VAL1, platform::errors::Fatal( \
"Shape check dims/lods error, op name: %s .", op_name))

SHAPE_CHECK_EQ(pre_dims.size(), after_dims.size());
for (size_t i = 0; i < pre_dims.size(); i++) {
CHECK(pre_dims[i].size() == after_dims[i].size())
<< "dims error, op name:" << op.Info().Proto().type();
SHAPE_CHECK_EQ(pre_dims[i].size(), after_dims[i].size());
for (size_t j = 0; j < pre_dims[i].size(); j++) {
CHECK(pre_dims[i][j] == after_dims[i][j])
<< "dims error, op name:" << op.Info().Proto().type();
SHAPE_CHECK_EQ(pre_dims[i][j], after_dims[i][j]);
}
}

CHECK(pre_lods.size() == after_lods.size())
<< "lods error, op name:" << op.Info().Proto().type();
SHAPE_CHECK_EQ(pre_lods.size(), after_lods.size());
for (size_t i = 0; i < pre_lods.size(); i++) {
CHECK(pre_lods[i].size() == after_lods[i].size())
<< "lods error, op name:" << op.Info().Proto().type();
SHAPE_CHECK_EQ(pre_lods[i].size(), after_lods[i].size());
for (size_t j = 0; j < pre_lods[i].size(); j++) {
auto& x = pre_lods[i][j];
auto& y = after_lods[i][j];
CHECK(x.size() == y.size())
<< "lods error, op name:" << op.Info().Proto().type();
SHAPE_CHECK_EQ(x.size(), y.size());
for (size_t i = 0; i < x.size(); i++) {
const auto &x_level = x[i];
const auto &y_level = y[i];
CHECK(x_level.size() == y_level.size())
<< "lods error, op name:" << op.Info().Proto().type();
SHAPE_CHECK_EQ(x_level.size(), y_level.size());
for (size_t j = 0; j < x_level.size(); j++) {
CHECK(x_level[j] == y_level[j])
<< "lods error, op name:" << op.Info().Proto().type();
SHAPE_CHECK_EQ(x_level[j], y_level[j]);
}
}
}
}
#undef SHAPE_CHECK_EQ
} else {
op.Run(scope, place);
}
Expand Down Expand Up @@ -346,7 +349,6 @@ void PSGPUWorker::TrainFiles() {
task.scope = thread_scope_vec_[i];
free_task_queue_.Push(task);
}
// std::atomic<int>* thread_run = new std::atomic<int>(task_threads_);
thread_count_.store(task_threads_num_);
task_threads_.reserve(task_threads_num_);
for (int i = 0; i < task_threads_num_; i++) {
Expand All @@ -364,8 +366,8 @@ void PSGPUWorker::TrainFiles() {
task.pack = pack;
task.ins_num = pack->ins_num();
device_reader_->PackToScope(task.pack, task.scope);
for (size_t ii = 0; ii < ops_.size(); ii++) {
auto& op = ops_[ii];
for (size_t i = 0; i < ops_.size(); i++) {
auto& op = ops_[i];
bool need_skip = false;
for (auto t = 0u; t < skip_ops_.size(); ++t) {
if (op->Type().find(skip_ops_[t]) != std::string::npos) {
Expand Down Expand Up @@ -426,7 +428,6 @@ void PSGPUWorker::TrainFiles() {
}
if (!need_skip) {
OpRunAndShapeCheck(*op, *thread_scope, place_);
// op->Run(*thread_scope, place_);
}
}
graph_batch_size = cur_batch;
Expand All @@ -443,7 +444,6 @@ void PSGPUWorker::TrainFiles() {
}
if (!need_skip) {
OpRunAndShapeCheck(*op, *thread_scope, place_);
// op->Run(*thread_scope, place_);
}
}
} else {
Expand All @@ -456,7 +456,6 @@ void PSGPUWorker::TrainFiles() {
platform::BeginCUDAGraphCapture(place_, cudaStreamCaptureModeThreadLocal);
for (auto& op : op_or_cuda_graph.ops) {
OpRunAndShapeCheck(*op, *thread_scope, place_);
// op->Run(*thread_scope, place_);
}
op_or_cuda_graph.cudagraph = platform::EndCUDAGraphCapture();
}
Expand All @@ -467,7 +466,6 @@ void PSGPUWorker::TrainFiles() {
} else {
for (auto& op : op_or_cuda_graph.ops) {
OpRunAndShapeCheck(*op, *thread_scope, place_);
// op->Run(*thread_scope, place_);
}
}
}
Expand Down

0 comments on commit 7d7fb8e

Please sign in to comment.