From af4f018ade3d39f76233456ed2a8abb386afac51 Mon Sep 17 00:00:00 2001 From: Jiabin Yang Date: Tue, 28 Sep 2021 13:03:36 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Bug=20fix=E3=80=91Fix=20dygraph=20doub?= =?UTF-8?q?le=20grad=20dtype=20error=20(#36125)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix dygraph double grad dtype error when calling for high differential senario * reinvoke ci * add test for partial_engine.cc --- paddle/fluid/framework/operator.cc | 17 +++++++++-------- paddle/fluid/imperative/partial_grad_engine.cc | 10 +++++++++- paddle/fluid/imperative/variable_wrapper.h | 1 + .../tests/unittests/autograd/test_jacobian.py | 4 ---- 4 files changed, 19 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 670cb36dcc3ab..2a543d48791a3 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1589,14 +1589,15 @@ void OperatorWithKernel::ParseInputDataType( "not initialized.", Type(), name, ctx.InputNames(name).at(i))); proto::VarType::Type tmp = t->type(); - PADDLE_ENFORCE( - tmp == *data_type || *data_type == default_data_type, - platform::errors::InvalidArgument( - "The DataType of %s Op's duplicable Variable %s must be " - "consistent. The current variable type is (%s), but the " - "previous variable type is (%s).", - Type(), name, DataTypeToString(tmp), - DataTypeToString(*data_type))); + PADDLE_ENFORCE(tmp == *data_type || *data_type == default_data_type, + platform::errors::InvalidArgument( + "The DataType of %s Op's duplicable or different " + "slot Variable %s must be " + "consistent or reigster GetExpectedKernelType. The " + "current variable type is (%s), but the " + "previous variable type is (%s).", + Type(), name, DataTypeToString(tmp), + DataTypeToString(*data_type))); *data_type = tmp; } } diff --git a/paddle/fluid/imperative/partial_grad_engine.cc b/paddle/fluid/imperative/partial_grad_engine.cc index c1ec675a55707..45756083c9047 100644 --- a/paddle/fluid/imperative/partial_grad_engine.cc +++ b/paddle/fluid/imperative/partial_grad_engine.cc @@ -307,7 +307,15 @@ static void FillConstantLike(const VariableWrapper &ref_var, auto *dst_tensor = dst_var->MutableVar()->GetMutable(); auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place); dst_tensor->Resize(ref_tensor.dims()); - dst_tensor->mutable_data(place, ref_var.DataType()); + // TOOD(jiabin): Ugly fix here we have fwd_data_type_ and data_type, since in + // grad mission + // we can't get data_type_ directly. We need to check if we can only use + // default data_type for now. + if (ref_var.ForwardDataType() != -1) { + dst_tensor->mutable_data(place, ref_var.ForwardDataType()); + } else { + dst_tensor->mutable_data(place, ref_var.DataType()); + } operators::math::set_constant(*dev_ctx, dst_tensor, value); } diff --git a/paddle/fluid/imperative/variable_wrapper.h b/paddle/fluid/imperative/variable_wrapper.h index 5fa8b89a396d9..758e8e62718e7 100644 --- a/paddle/fluid/imperative/variable_wrapper.h +++ b/paddle/fluid/imperative/variable_wrapper.h @@ -162,6 +162,7 @@ class VariableWrapper { return tensor->type(); } else { VLOG(6) << "The tensor of variable " << name_ << " is not initialized"; + return data_type_; } } diff --git a/python/paddle/fluid/tests/unittests/autograd/test_jacobian.py b/python/paddle/fluid/tests/unittests/autograd/test_jacobian.py index 640292a47114a..2722d2c83b130 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_jacobian.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_jacobian.py @@ -215,10 +215,6 @@ def setUpClass(self): self.x = paddle.rand(shape=self.shape, dtype=self.dtype) self.y = paddle.rand(shape=self.shape, dtype=self.dtype) - # NOTE(levi): skip this test case temporaryly. - def test_create_graph_true(self): - pass - if __name__ == "__main__": unittest.main()