Skip to content

Commit

Permalink
Add unittest, backward of array read/write op (#5409)
Browse files Browse the repository at this point in the history
* Use stable_sort in lod_rank_table

It is easy to debug and test when use `stable_sort`and the time
complexity is not changed.

* Add LoDTensorArray

* Stash

* Better debug message for IsInitialized

* Stash

* Better debug message for IsInitialized

* Complete array read/write op unittests

* Add unittest, Gradient of array read/write

* Follow comments
  • Loading branch information
reyoung authored Nov 7, 2017
1 parent b25804c commit 6cde889
Show file tree
Hide file tree
Showing 11 changed files with 210 additions and 24 deletions.
11 changes: 10 additions & 1 deletion paddle/framework/op_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,11 @@ class CompileTimeInferShapeContext : public InferShapeContext {
out);
in_var->SetLoDLevel(out_var->GetLodLevel());
}
bool IsRuntime() const override;

protected:
VarDesc::VarType GetVarType(const std::string &name) const override;

private:
DDim GetDim(const std::string &name) const override;

void SetDim(const std::string &name, const DDim &dim) override;
Expand Down Expand Up @@ -451,6 +454,12 @@ void CompileTimeInferShapeContext::SetDim(const std::string &name,
const DDim &dim) {
block_.FindVarRecursive(name)->SetShape(framework::vectorize(dim));
}
bool CompileTimeInferShapeContext::IsRuntime() const { return false; }

VarDesc::VarType CompileTimeInferShapeContext::GetVarType(
const std::string &name) const {
return block_.FindVarRecursive(name)->GetType();
}

} // namespace framework
} // namespace paddle
12 changes: 11 additions & 1 deletion paddle/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ limitations under the License. */
#include "paddle/framework/operator.h"
#include <algorithm>
#include <atomic>
#include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/shape_inference.h"
#include "paddle/framework/var_type.h"

namespace paddle {
namespace framework {
Expand Down Expand Up @@ -365,7 +367,9 @@ class RuntimeInferShapeContext : public InferShapeContext {
out_tensor->set_lod(in_tensor.lod());
}

private:
bool IsRuntime() const override { return true; }

protected:
DDim GetDim(const std::string& name) const override {
Variable* var = scope_.FindVar(name);
if (var->IsType<LoDTensor>()) {
Expand All @@ -388,6 +392,12 @@ class RuntimeInferShapeContext : public InferShapeContext {
}
}

VarDesc::VarType GetVarType(const std::string& name) const override {
auto* var = scope_.FindVar(name);
return ToVarType(var->Type());
}

private:
const OperatorBase& op_;
const Scope& scope_;
};
Expand Down
17 changes: 17 additions & 0 deletions paddle/framework/shape_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,23 @@ void InferShapeContext::SetDims(const std::vector<std::string> &names,
SetDim(names[i], dims[i]);
}
}
std::vector<VarDesc::VarType> InferShapeContext::GetInputsVarType(
const std::string &name) const {
return GetVarTypes(Inputs(name));
}
std::vector<VarDesc::VarType> InferShapeContext::GetOutputsVarType(
const std::string &name) const {
return GetVarTypes(Outputs(name));
}
std::vector<VarDesc::VarType> InferShapeContext::GetVarTypes(
const std::vector<std::string> &names) const {
std::vector<VarDesc::VarType> retv;
retv.resize(names.size());
std::transform(names.begin(), names.end(), retv.begin(),
std::bind(std::mem_fn(&InferShapeContext::GetVarType), this,
std::placeholders::_1));
return retv;
}

} // namespace framework
} // namespace paddle
12 changes: 12 additions & 0 deletions paddle/framework/shape_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License. */

#include "paddle/framework/attribute.h"
#include "paddle/framework/ddim.h"
#include "paddle/framework/framework.pb.h"

namespace paddle {
namespace framework {
Expand All @@ -26,6 +27,10 @@ class InferShapeContext {
virtual bool HasInput(const std::string &name) const = 0;
virtual bool HasOutput(const std::string &name) const = 0;

std::vector<VarDesc::VarType> GetInputsVarType(const std::string &name) const;
std::vector<VarDesc::VarType> GetOutputsVarType(
const std::string &name) const;

virtual bool HasInputs(const std::string &name) const = 0;
virtual bool HasOutputs(const std::string &name) const = 0;

Expand All @@ -46,6 +51,8 @@ class InferShapeContext {
virtual void ShareLoD(const std::string &in, const std::string &out,
size_t i = 0, size_t j = 0) const = 0;

virtual bool IsRuntime() const = 0;

protected:
virtual framework::DDim GetDim(const std::string &name) const = 0;
virtual void SetDim(const std::string &name, const framework::DDim &dim) = 0;
Expand All @@ -55,6 +62,11 @@ class InferShapeContext {

void SetDims(const std::vector<std::string> &names,
const std::vector<framework::DDim> &dims);

std::vector<VarDesc::VarType> GetVarTypes(
const std::vector<std::string> &names) const;

virtual VarDesc::VarType GetVarType(const std::string &name) const = 0;
};

} // namespace framework
Expand Down
36 changes: 36 additions & 0 deletions paddle/framework/var_type.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/lod_rank_table.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/lod_tensor_array.h"

namespace paddle {
namespace framework {
inline VarDesc::VarType ToVarType(std::type_index type) {
if (type.hash_code() == typeid(LoDTensor).hash_code()) {
return VarDesc_VarType_LOD_TENSOR;
} else if (type.hash_code() == typeid(LoDRankTable).hash_code()) {
return VarDesc_VarType_LOD_RANK_TABLE;
} else if (type.hash_code() == typeid(LoDTensorArray).hash_code()) {
return VarDesc_VarType_LOD_TENSOR_ARRAY;
} else {
PADDLE_THROW("ToVarType:Unsupported type %s", type.name());
}
}

} // namespace framework
} // namespace paddle
5 changes: 5 additions & 0 deletions paddle/framework/variable.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ class Variable {

void Clear() { holder_.reset(); }

std::type_index Type() const {
PADDLE_ENFORCE(holder_ != nullptr, "Must hold memory");
return holder_->Type();
}

private:
struct Placeholder {
virtual ~Placeholder() {}
Expand Down
52 changes: 47 additions & 5 deletions paddle/operators/sum_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,16 @@ class SumOp : public framework::OperatorWithKernel {

void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInputs("X"), "Inputs(X) should not be null");
auto x_dims = ctx->GetInputsDim("X");

PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SumOp should not be null.");
if (ctx->IsRuntime() &&
ctx->GetOutputsVarType("Out")[0] ==
framework::VarDesc::LOD_TENSOR_ARRAY) {
return; // skip runtime infershape when is tensor array;
}

auto x_dims = ctx->GetInputsDim("X");
size_t N = x_dims.size();
PADDLE_ENFORCE_GT(N, 1, "Input tensors count should > 1.");

Expand All @@ -39,6 +45,28 @@ class SumOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Out", in_dim);
ctx->ShareLoD("X", /*->*/ "Out");
}

protected:
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
auto x_vars = ctx.MultiInputVar("X");
if (x_vars[0]->IsType<framework::LoDTensor>()) {
return framework::ToDataType(
x_vars[0]->Get<framework::LoDTensor>().type());
} else if (x_vars[0]->IsType<framework::SelectedRows>()) {
return framework::ToDataType(
x_vars[0]->Get<framework::SelectedRows>().value().type());
} else if (x_vars[0]->IsType<framework::LoDTensorArray>()) {
auto& array = x_vars[0]->Get<framework::LoDTensorArray>();
for (auto& each : array) {
if (each.numel() != 0) {
return framework::ToDataType(each.type());
}
}
}
PADDLE_THROW("Unexpected branch. Input type is %s",
x_vars[0]->Type().name());
}
};

class SumOpMaker : public framework::OpProtoAndCheckerMaker {
Expand All @@ -63,18 +91,32 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
void operator()(const framework::OpDescBind& op_desc,
framework::BlockDescBind* block) const override {
auto& inputs = op_desc.Input("X");
auto default_var_type = framework::VarDesc::SELECTED_ROWS;
auto var_type = framework::VarDesc::SELECTED_ROWS;

bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [block](const std::string& name) {
return block->Var(name)->GetType() == framework::VarDesc::LOD_TENSOR;
});
if (any_input_is_lod_tensor) {
default_var_type = framework::VarDesc::LOD_TENSOR;

auto is_tensor_array = [block](const std::string& name) {
return block->Var(name)->GetType() ==
framework::VarDesc::LOD_TENSOR_ARRAY;
};

bool any_input_is_tensor_array =
std::any_of(inputs.begin(), inputs.end(), is_tensor_array);
bool all_inputs_are_tensor_array =
std::all_of(inputs.begin(), inputs.end(), is_tensor_array);

if (any_input_is_tensor_array) {
PADDLE_ENFORCE(all_inputs_are_tensor_array);
var_type = framework::VarDesc::LOD_TENSOR_ARRAY;
} else if (any_input_is_lod_tensor) {
var_type = framework::VarDesc::LOD_TENSOR;
}

auto out_var_name = op_desc.Output("Out").front();
block->Var(out_var_name)->SetType(default_var_type);
block->Var(out_var_name)->SetType(var_type);
}
};

Expand Down
42 changes: 35 additions & 7 deletions paddle/operators/sum_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ limitations under the License. */

#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/selected_rows_functor.h"
Expand All @@ -28,15 +29,15 @@ using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T>
class SumKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
void Compute(const framework::ExecutionContext &context) const override {
auto in_vars = context.MultiInputVar("X");
int N = in_vars.size();
auto out_var = context.OutputVar("Out");

bool in_place = out_var == in_vars[0];

if (out_var->IsType<framework::LoDTensor>()) {
auto* out = context.Output<Tensor>("Out");
auto *out = context.Output<Tensor>("Out");
out->mutable_data<T>(context.GetPlace());

auto result = EigenVector<T>::Flatten(*out);
Expand All @@ -51,20 +52,20 @@ class SumKernel : public framework::OpKernel<T> {
// If in_place, just skip the first tensor
for (int i = in_place ? 1 : 0; i < N; i++) {
if (in_vars[i]->IsType<framework::LoDTensor>()) {
auto& in_t = in_vars[i]->Get<framework::LoDTensor>();
auto &in_t = in_vars[i]->Get<framework::LoDTensor>();
auto in = EigenVector<T>::Flatten(in_t);
result.device(place) = result + in;
} else if (in_vars[i]->IsType<framework::SelectedRows>()) {
auto& in_t = in_vars[i]->Get<framework::SelectedRows>();
auto &in_t = in_vars[i]->Get<framework::SelectedRows>();
functor(context.device_context(), in_t, out);
} else {
PADDLE_THROW("Variable type must be LoDTensor/SelectedRows.");
}
}
} else if (out_var->IsType<framework::SelectedRows>()) {
PADDLE_ENFORCE(!in_place, "SelectedRows not support inplace sum now");
auto* out = context.Output<SelectedRows>("Out");
auto* out_value = out->mutable_value();
auto *out = context.Output<SelectedRows>("Out");
auto *out_value = out->mutable_value();

// Runtime InferShape
size_t first_dim = 0;
Expand All @@ -88,9 +89,36 @@ class SumKernel : public framework::OpKernel<T> {
offset, out);
offset += in_vars[i]->Get<SelectedRows>().value().numel();
}
} else if (out_var->IsType<framework::LoDTensorArray>()) {
auto &out_array = *out_var->GetMutable<framework::LoDTensorArray>();
for (size_t i = in_place ? 1 : 0; i < in_vars.size(); ++i) {
PADDLE_ENFORCE(in_vars[i]->IsType<framework::LoDTensorArray>(),
"Only support all inputs are TensorArray");
auto &in_array = in_vars[i]->Get<framework::LoDTensorArray>();

for (size_t i = 0; i < in_array.size(); ++i) {
if (in_array[i].numel() != 0) {
if (i >= out_array.size()) {
out_array.resize(i + 1);
}
if (out_array[i].numel() == 0) {
out_array[i].CopyFrom(in_array[i], in_array[i].place(),
context.device_context());
out_array[i].set_lod(in_array[i].lod());
} else {
PADDLE_ENFORCE(out_array[i].lod() == in_array[i].lod());
auto in = EigenVector<T>::Flatten(in_array[i]);
auto result = EigenVector<T>::Flatten(out_array[i]);
result.device(context.GetEigenDevice<Place>()) = result + in;
}
}
}
}
} else {
PADDLE_THROW("Unexpected branch, output variable type is %s",
out_var->Type().name());
}
}
};

} // namespace operators
} // namespace paddle
1 change: 0 additions & 1 deletion paddle/operators/tensor_array_read_write_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ class WriteToArrayInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDescBind &op_desc,
framework::BlockDescBind *block) const override {
VLOG(10) << "I am here?";
for (auto &out_var : op_desc.OutputArgumentNames()) {
VLOG(10) << "Set Variable " << out_var << " as LOD_TENSOR_ARRAY";
block->Var(out_var)->SetType(framework::VarDesc::LOD_TENSOR_ARRAY);
Expand Down
5 changes: 3 additions & 2 deletions python/paddle/v2/framework/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,12 +801,13 @@ def zeros(shape, dtype, main_program=None):

def increment(x, value=1.0, main_program=None):
helper = LayerHelper("increment", **locals())
tmp = helper.create_tmp_variable(dtype=x.data_type)
helper.append_op(
type='increment',
inputs={'X': [x]},
outputs={'Out': [x]},
outputs={'Out': [tmp]},
attrs={'step': value})
return x
return tmp


def array_write(x, i, array=None, main_program=None):
Expand Down
Loading

0 comments on commit 6cde889

Please sign in to comment.