From 9378769d762c353c2721779b6af4e6ccde9a47b3 Mon Sep 17 00:00:00 2001 From: JiabinYang <360788950@qq.com> Date: Sun, 26 Sep 2021 16:06:15 +0000 Subject: [PATCH 1/3] fix dygraph double grad dtype error when calling for high differential senario --- paddle/fluid/framework/operator.cc | 17 +++++++++-------- paddle/fluid/imperative/partial_grad_engine.cc | 10 +++++++++- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 670cb36dcc3ab..2a543d48791a3 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1589,14 +1589,15 @@ void OperatorWithKernel::ParseInputDataType( "not initialized.", Type(), name, ctx.InputNames(name).at(i))); proto::VarType::Type tmp = t->type(); - PADDLE_ENFORCE( - tmp == *data_type || *data_type == default_data_type, - platform::errors::InvalidArgument( - "The DataType of %s Op's duplicable Variable %s must be " - "consistent. The current variable type is (%s), but the " - "previous variable type is (%s).", - Type(), name, DataTypeToString(tmp), - DataTypeToString(*data_type))); + PADDLE_ENFORCE(tmp == *data_type || *data_type == default_data_type, + platform::errors::InvalidArgument( + "The DataType of %s Op's duplicable or different " + "slot Variable %s must be " + "consistent or reigster GetExpectedKernelType. The " + "current variable type is (%s), but the " + "previous variable type is (%s).", + Type(), name, DataTypeToString(tmp), + DataTypeToString(*data_type))); *data_type = tmp; } } diff --git a/paddle/fluid/imperative/partial_grad_engine.cc b/paddle/fluid/imperative/partial_grad_engine.cc index c1ec675a55707..45756083c9047 100644 --- a/paddle/fluid/imperative/partial_grad_engine.cc +++ b/paddle/fluid/imperative/partial_grad_engine.cc @@ -307,7 +307,15 @@ static void FillConstantLike(const VariableWrapper &ref_var, auto *dst_tensor = dst_var->MutableVar()->GetMutable(); auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place); dst_tensor->Resize(ref_tensor.dims()); - dst_tensor->mutable_data(place, ref_var.DataType()); + // TOOD(jiabin): Ugly fix here we have fwd_data_type_ and data_type, since in + // grad mission + // we can't get data_type_ directly. We need to check if we can only use + // default data_type for now. + if (ref_var.ForwardDataType() != -1) { + dst_tensor->mutable_data(place, ref_var.ForwardDataType()); + } else { + dst_tensor->mutable_data(place, ref_var.DataType()); + } operators::math::set_constant(*dev_ctx, dst_tensor, value); } From 034011d5ba8d03749b0544f99ef4f18fbdccf968 Mon Sep 17 00:00:00 2001 From: JiabinYang <360788950@qq.com> Date: Sun, 26 Sep 2021 16:16:29 +0000 Subject: [PATCH 2/3] reinvoke ci --- paddle/fluid/imperative/variable_wrapper.h | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/fluid/imperative/variable_wrapper.h b/paddle/fluid/imperative/variable_wrapper.h index 5fa8b89a396d9..758e8e62718e7 100644 --- a/paddle/fluid/imperative/variable_wrapper.h +++ b/paddle/fluid/imperative/variable_wrapper.h @@ -162,6 +162,7 @@ class VariableWrapper { return tensor->type(); } else { VLOG(6) << "The tensor of variable " << name_ << " is not initialized"; + return data_type_; } } From f478f18ce061877def4526f14ce4363eb9bbfe0f Mon Sep 17 00:00:00 2001 From: JiabinYang <360788950@qq.com> Date: Mon, 27 Sep 2021 06:35:05 +0000 Subject: [PATCH 3/3] add test for partial_engine.cc --- python/paddle/fluid/tests/unittests/autograd/test_jacobian.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/autograd/test_jacobian.py b/python/paddle/fluid/tests/unittests/autograd/test_jacobian.py index 640292a47114a..2722d2c83b130 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_jacobian.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_jacobian.py @@ -215,10 +215,6 @@ def setUpClass(self): self.x = paddle.rand(shape=self.shape, dtype=self.dtype) self.y = paddle.rand(shape=self.shape, dtype=self.dtype) - # NOTE(levi): skip this test case temporaryly. - def test_create_graph_true(self): - pass - if __name__ == "__main__": unittest.main()