-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add unittest, backward of array read/write op #5409
Changes from 12 commits
d6e5801
3057bb3
d03bdb9
7c8da20
85d18a5
37a29eb
e095ce4
5956f5f
3915e38
1ece992
30615c9
ff0bef2
3dbe2db
a82167b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,10 +24,16 @@ class SumOp : public framework::OperatorWithKernel { | |
|
||
void InferShape(framework::InferShapeContext* ctx) const override { | ||
PADDLE_ENFORCE(ctx->HasInputs("X"), "Inputs(X) should not be null"); | ||
auto x_dims = ctx->GetInputsDim("X"); | ||
|
||
PADDLE_ENFORCE(ctx->HasOutput("Out"), | ||
"Output(Out) of SumOp should not be null."); | ||
if (ctx->IsRuntime() && | ||
ctx->GetOutputsVarType("Out")[0] == | ||
framework::VarDesc::LOD_TENSOR_ARRAY) { | ||
return; // skip runtime infershape when is tensor array; | ||
} | ||
|
||
auto x_dims = ctx->GetInputsDim("X"); | ||
size_t N = x_dims.size(); | ||
PADDLE_ENFORCE_GT(N, 1, "Input tensors count should > 1."); | ||
|
||
|
@@ -39,6 +45,27 @@ class SumOp : public framework::OperatorWithKernel { | |
ctx->SetOutputDim("Out", in_dim); | ||
ctx->ShareLoD("X", /*->*/ "Out"); | ||
} | ||
|
||
protected: | ||
framework::DataType IndicateDataType( | ||
const framework::ExecutionContext& ctx) const override { | ||
auto x_vars = ctx.MultiInputVar("X"); | ||
if (x_vars[0]->IsType<framework::LoDTensor>()) { | ||
return framework::ToDataType( | ||
x_vars[0]->Get<framework::LoDTensor>().type()); | ||
} else if (x_vars[0]->IsType<framework::SelectedRows>()) { | ||
return framework::ToDataType( | ||
x_vars[0]->Get<framework::SelectedRows>().value().type()); | ||
} else if (x_vars[0]->IsType<framework::LoDTensorArray>()) { | ||
auto& array = x_vars[0]->Get<framework::LoDTensorArray>(); | ||
for (auto& each : array) { | ||
if (each.numel() != 0) { | ||
return framework::ToDataType(each.type()); | ||
} | ||
} | ||
} | ||
PADDLE_THROW("Unexpected branch"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. More helpful message, something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
} | ||
}; | ||
|
||
class SumOpMaker : public framework::OpProtoAndCheckerMaker { | ||
|
@@ -63,18 +90,32 @@ class SumOpVarTypeInference : public framework::VarTypeInference { | |
void operator()(const framework::OpDescBind& op_desc, | ||
framework::BlockDescBind* block) const override { | ||
auto& inputs = op_desc.Input("X"); | ||
auto default_var_type = framework::VarDesc::SELECTED_ROWS; | ||
auto var_type = framework::VarDesc::SELECTED_ROWS; | ||
|
||
bool any_input_is_lod_tensor = std::any_of( | ||
inputs.begin(), inputs.end(), [block](const std::string& name) { | ||
return block->Var(name)->GetType() == framework::VarDesc::LOD_TENSOR; | ||
}); | ||
if (any_input_is_lod_tensor) { | ||
default_var_type = framework::VarDesc::LOD_TENSOR; | ||
|
||
auto is_tensor_array = [block](const std::string& name) { | ||
return block->Var(name)->GetType() == | ||
framework::VarDesc::LOD_TENSOR_ARRAY; | ||
}; | ||
|
||
bool any_input_is_tensor_array = | ||
std::any_of(inputs.begin(), inputs.end(), is_tensor_array); | ||
bool all_inputs_are_tensor_array = | ||
std::all_of(inputs.begin(), inputs.end(), is_tensor_array); | ||
|
||
if (any_input_is_tensor_array) { | ||
PADDLE_ENFORCE(all_inputs_are_tensor_array); | ||
var_type = framework::VarDesc::LOD_TENSOR_ARRAY; | ||
} else if (any_input_is_lod_tensor) { | ||
var_type = framework::VarDesc::LOD_TENSOR; | ||
} | ||
|
||
auto out_var_name = op_desc.Output("Out").front(); | ||
block->Var(out_var_name)->SetType(default_var_type); | ||
block->Var(out_var_name)->SetType(var_type); | ||
} | ||
}; | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ limitations under the License. */ | |
|
||
#pragma once | ||
#include "paddle/framework/eigen.h" | ||
#include "paddle/framework/lod_tensor_array.h" | ||
#include "paddle/framework/op_registry.h" | ||
#include "paddle/operators/math/math_function.h" | ||
#include "paddle/operators/math/selected_rows_functor.h" | ||
|
@@ -88,6 +89,33 @@ class SumKernel : public framework::OpKernel<T> { | |
offset, out); | ||
offset += in_vars[i]->Get<SelectedRows>().value().numel(); | ||
} | ||
} else if (out_var->IsType<framework::LoDTensorArray>()) { | ||
auto& out_array = *out_var->GetMutable<framework::LoDTensorArray>(); | ||
for (size_t i = in_place ? 1 : 0; i < in_vars.size(); ++i) { | ||
PADDLE_ENFORCE(in_vars[i]->IsType<framework::LoDTensorArray>(), | ||
"Only support all inputs are TensorArray"); | ||
auto& in_array = in_vars[i]->Get<framework::LoDTensorArray>(); | ||
|
||
for (size_t i = 0; i < in_array.size(); ++i) { | ||
if (in_array[i].numel() != 0) { | ||
if (i >= out_array.size()) { | ||
out_array.resize(i + 1); | ||
} | ||
if (out_array[i].numel() == 0) { | ||
out_array[i].CopyFrom(in_array[i], in_array[i].place(), | ||
context.device_context()); | ||
out_array[i].set_lod(in_array[i].lod()); | ||
} else { | ||
PADDLE_ENFORCE(out_array[i].lod() == in_array[i].lod()); | ||
auto in = EigenVector<T>::Flatten(in_array[i]); | ||
auto result = EigenVector<T>::Flatten(out_array[i]); | ||
result.device(context.GetEigenDevice<Place>()) = result + in; | ||
} | ||
} | ||
} | ||
} | ||
} else { | ||
PADDLE_THROW("Unexpected branch"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. More helpful message, something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
} | ||
} | ||
}; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this logic should be under
Variable
, something likereturn var->GetType()
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.