Skip to content

Commit

Permalink
【Bug fix】Fix dygraph double grad dtype error (#36125)
Browse files Browse the repository at this point in the history
* fix dygraph double grad dtype error when calling for high differential senario

* reinvoke ci

* add test for partial_engine.cc
  • Loading branch information
JiabinYang authored Sep 28, 2021
1 parent 0e07f20 commit af4f018
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 13 deletions.
17 changes: 9 additions & 8 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1589,14 +1589,15 @@ void OperatorWithKernel::ParseInputDataType(
"not initialized.",
Type(), name, ctx.InputNames(name).at(i)));
proto::VarType::Type tmp = t->type();
PADDLE_ENFORCE(
tmp == *data_type || *data_type == default_data_type,
platform::errors::InvalidArgument(
"The DataType of %s Op's duplicable Variable %s must be "
"consistent. The current variable type is (%s), but the "
"previous variable type is (%s).",
Type(), name, DataTypeToString(tmp),
DataTypeToString(*data_type)));
PADDLE_ENFORCE(tmp == *data_type || *data_type == default_data_type,
platform::errors::InvalidArgument(
"The DataType of %s Op's duplicable or different "
"slot Variable %s must be "
"consistent or reigster GetExpectedKernelType. The "
"current variable type is (%s), but the "
"previous variable type is (%s).",
Type(), name, DataTypeToString(tmp),
DataTypeToString(*data_type)));
*data_type = tmp;
}
}
Expand Down
10 changes: 9 additions & 1 deletion paddle/fluid/imperative/partial_grad_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,15 @@ static void FillConstantLike(const VariableWrapper &ref_var,
auto *dst_tensor = dst_var->MutableVar()->GetMutable<framework::LoDTensor>();
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place);
dst_tensor->Resize(ref_tensor.dims());
dst_tensor->mutable_data(place, ref_var.DataType());
// TOOD(jiabin): Ugly fix here we have fwd_data_type_ and data_type, since in
// grad mission
// we can't get data_type_ directly. We need to check if we can only use
// default data_type for now.
if (ref_var.ForwardDataType() != -1) {
dst_tensor->mutable_data(place, ref_var.ForwardDataType());
} else {
dst_tensor->mutable_data(place, ref_var.DataType());
}
operators::math::set_constant(*dev_ctx, dst_tensor, value);
}

Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/imperative/variable_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ class VariableWrapper {
return tensor->type();
} else {
VLOG(6) << "The tensor of variable " << name_ << " is not initialized";

return data_type_;
}
}
Expand Down
4 changes: 0 additions & 4 deletions python/paddle/fluid/tests/unittests/autograd/test_jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,6 @@ def setUpClass(self):
self.x = paddle.rand(shape=self.shape, dtype=self.dtype)
self.y = paddle.rand(shape=self.shape, dtype=self.dtype)

# NOTE(levi): skip this test case temporaryly.
def test_create_graph_true(self):
pass


if __name__ == "__main__":
unittest.main()

0 comments on commit af4f018

Please sign in to comment.