Skip to content

Commit

Permalink
fix runtime error
Browse files Browse the repository at this point in the history
  • Loading branch information
gglin001 committed Jan 10, 2022
1 parent cbcc966 commit 06c8e26
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 7 deletions.
7 changes: 2 additions & 5 deletions paddle/fluid/platform/device/ipu/ipu_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,18 +84,14 @@ void Executor::Run(const std::vector<const Tensor *> &inputs,
std::map<popart::TensorId, PaddleIArray> input_wrappers;
for (size_t i = 0; i < inputs.size(); i++) {
auto tensor_id = one_builder_->inputs[i];
framework::Tensor *tensor = nullptr;
tensor->ShareDataWith(*inputs[i]);
input_wrappers.emplace(tensor_id, PaddleIArray(tensor));
input_wrappers.emplace(tensor_id, PaddleIArray(inputs[i]));
popart_inputs.emplace(tensor_id, input_wrappers.at(tensor_id));
}
// anchors
std::map<popart::TensorId, popart::IArray &> popart_anchors;
std::map<popart::TensorId, PaddleIArray> anchor_wrappers;
for (size_t i = 0; i < outputs.size(); i++) {
auto tensor_id = one_builder_->outputs[i];
framework::Tensor *tensor = nullptr;
tensor->ShareDataWith(*outputs[i]);
// get dims & dtype from session
auto fetch_info = session_->getInfo(tensor_id);
auto output_shape = fetch_info.shape();
Expand All @@ -112,6 +108,7 @@ void Executor::Run(const std::vector<const Tensor *> &inputs,
ipu_strategy_->popart_options.replicatedGraphCount);
}

auto *tensor = outputs[i];
tensor->Resize(framework::make_ddim(output_shape));
auto fetch_dtype = fetch_info.dataType();
auto paddle_type = PopartType2VarType(fetch_dtype);
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/platform/device/ipu/ipu_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ enum ONNXDataType : int {

class PaddleIArray final : public popart::IArray {
public:
explicit PaddleIArray(Tensor* tensor) : tensor_(tensor) {
explicit PaddleIArray(const Tensor* tensor) : tensor_(tensor) {
for (int i = 0; i < tensor->dims().size(); ++i) {
shape_.push_back(tensor->dims().at(i));
}
Expand Down Expand Up @@ -96,7 +96,7 @@ std::unique_ptr<popart::NDArrayWrapper<T>> Tensor2IArray(const Tensor& tensor) {
popart::TensorInfo tensor_info(dtype, shape);

return std::make_unique<popart::NDArrayWrapper<T>>(
reinterpret_cast<T *>(tensor.data()), tensor_info);
reinterpret_cast<T*>(tensor.data()), tensor_info);
}

template <typename T>
Expand Down

0 comments on commit 06c8e26

Please sign in to comment.