diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_api.cc b/paddle/fluid/pir/dialect/operator/ir/manual_api.cc index 9228c85c13011c..7fb835dd01c908 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_api.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_api.cc @@ -56,9 +56,29 @@ pir::Value parameter(const std::string& name) { } void set_parameter(const pir::Value& parameter, const std::string& name) { - std::unique_ptr param( + pir::Parameter* param = ApiBuilder::Instance().GetParameter(name); + if (param) { + PADDLE_ENFORCE_EQ(param->type(), + parameter.type(), + phi::errors::InvalidArgument( + "Duplicate parameter %s with different type.", name)); + } else { + std::unique_ptr param_new( + new pir::Parameter(nullptr, 0, parameter.type())); + ApiBuilder::Instance().SetParameter(name, std::move(param_new)); + ApiBuilder::Instance().GetBuilder()->Build(parameter, + name); + } +} + +void updata_parameter(const pir::Value& parameter, const std::string& name) { + pir::Parameter* param = ApiBuilder::Instance().GetParameter(name); + PADDLE_ENFORCE_NOT_NULL(param, + phi::errors::InvalidArgument( + "Parameter %s not exist, can not updata.", name)); + std::unique_ptr param_new( new pir::Parameter(nullptr, 0, parameter.type())); - ApiBuilder::Instance().SetParameter(name, std::move(param)); + ApiBuilder::Instance().SetParameter(name, std::move(param_new)); ApiBuilder::Instance().GetBuilder()->Build(parameter, name); } diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_api.h b/paddle/fluid/pir/dialect/operator/ir/manual_api.h index b3812215de5c86..86d9b9a8245cc1 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_api.h +++ b/paddle/fluid/pir/dialect/operator/ir/manual_api.h @@ -36,6 +36,8 @@ pir::Value parameter(const std::string& name); void set_parameter(const pir::Value& parameter, const std::string& name); +void updata_parameter(const pir::Value& parameter, const std::string& name); + void shadow_output(const pir::Value& persist_value, const std::string& name); pir::Value embedding_grad(const pir::Value& x, diff --git a/paddle/fluid/pybind/manual_static_op_function.h b/paddle/fluid/pybind/manual_static_op_function.h index c1fbd7f6ff49aa..872be599d9a76e 100644 --- a/paddle/fluid/pybind/manual_static_op_function.h +++ b/paddle/fluid/pybind/manual_static_op_function.h @@ -81,6 +81,32 @@ static PyObject *static_api_set_parameter(PyObject *self, } } +static PyObject *static_api_updata_parameter(PyObject *self, + PyObject *args, + PyObject *kwargs) { + try { + VLOG(6) << "Add uodata_parameter op into program"; + VLOG(8) << "args count: " << (PyTuple_Size(args) / 2); + + // Get Value from args + PyObject *parameter_obj = PyTuple_GET_ITEM(args, 0); + auto parameter = CastPyArg2Value(parameter_obj, "parameter", 0); + + // Parse Attributes + PyObject *name_obj = PyTuple_GET_ITEM(args, 1); + std::string name = CastPyArg2String(name_obj, "name", 1); + // Call ir static api + CallStackRecorder callstack_recoder("uodata_parameter"); + callstack_recoder.Record(); + paddle::dialect::updata_parameter(parameter, name); + callstack_recoder.AttachToOps(); + Py_RETURN_NONE; + } catch (...) { + ThrowExceptionToPython(std::current_exception()); + return nullptr; + } +} + static PyObject *static_api_set_persistable_value(PyObject *self, PyObject *args, PyObject *kwargs) { @@ -949,6 +975,10 @@ static PyMethodDef ManualOpsAPI[] = { (PyCFunction)(void (*)(void))static_api_set_parameter, METH_VARARGS | METH_KEYWORDS, "C++ interface function for set_parameter."}, + {"updata_parameter", + (PyCFunction)(void (*)(void))static_api_updata_parameter, + METH_VARARGS | METH_KEYWORDS, + "C++ interface function for updata_parameter."}, {"set_persistable_value", (PyCFunction)(void (*)(void))static_api_set_persistable_value, METH_VARARGS | METH_KEYWORDS, diff --git a/python/paddle/amp/auto_cast.py b/python/paddle/amp/auto_cast.py index afb5a916db58ce..9ae60e5185ee0c 100644 --- a/python/paddle/amp/auto_cast.py +++ b/python/paddle/amp/auto_cast.py @@ -251,7 +251,7 @@ def _pir_transform(t, dtype): param = op.operand(0).source() cast_param = paddle.cast(param, dtype) cast_param.persistable = True - paddle._pir_ops.set_parameter(cast_param, t.name) + paddle._pir_ops.updata_parameter(cast_param, t.name) block.remove_op(op) break main.set_parameters_from(startup) diff --git a/python/paddle/base/layer_helper_base.py b/python/paddle/base/layer_helper_base.py index 197782813ad608..66379ac8bc21d7 100644 --- a/python/paddle/base/layer_helper_base.py +++ b/python/paddle/base/layer_helper_base.py @@ -26,6 +26,7 @@ default_main_program, default_startup_program, in_dygraph_mode, + in_dynamic_or_pir_mode, in_pir_mode, ) from .initializer import _global_bias_initializer, _global_weight_initializer @@ -377,7 +378,7 @@ def create_parameter( else default_initializer ) if attr.name is None: - if in_dygraph_mode(): + if in_dynamic_or_pir_mode(): attr.name = unique_name.generate(".".join([self.name, suffix])) else: attr.name = self.main_program._name_generator.generate( diff --git a/test/auto_parallel/pir/test_ir_dist_attr.py b/test/auto_parallel/pir/test_ir_dist_attr.py index b43c2bee091d44..d47602da2b9f4b 100644 --- a/test/auto_parallel/pir/test_ir_dist_attr.py +++ b/test/auto_parallel/pir/test_ir_dist_attr.py @@ -30,7 +30,8 @@ class TestBuildFakeProgram(unittest.TestCase): def test_build_api(self): with paddle.pir_utils.IrGuard(): main_program = paddle.base.Program() - with paddle.base.program_guard(main_program): + start_program = paddle.base.Program() + with paddle.base.program_guard(main_program, start_program): mesh = dist.ProcessMesh([0, 1], dim_names=['mp']) input = paddle.static.data( name='input', shape=[BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE] @@ -55,7 +56,8 @@ def test_build_api(self): def test_build_replicated_program(self): with paddle.pir_utils.IrGuard(): main_program = paddle.base.Program() - with paddle.base.program_guard(main_program): + start_program = paddle.base.Program() + with paddle.base.program_guard(main_program, start_program): mesh = dist.ProcessMesh([0, 1], dim_names=['mp']) input = paddle.static.data( name='input', shape=[BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE] @@ -122,7 +124,8 @@ def test_build_replicated_program(self): def test_build_col_parallel_program(self): with paddle.pir_utils.IrGuard(): main_program = paddle.base.Program() - with paddle.base.program_guard(main_program): + start_program = paddle.base.Program() + with paddle.base.program_guard(main_program, start_program): mesh = dist.ProcessMesh([0, 1], dim_names=['mp']) input = paddle.static.data( name='input', shape=[BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE] @@ -171,7 +174,8 @@ def test_build_col_parallel_program(self): def test_build_row_parallel_program(self): with paddle.pir_utils.IrGuard(): main_program = paddle.base.Program() - with paddle.base.program_guard(main_program): + start_program = paddle.base.Program() + with paddle.base.program_guard(main_program, start_program): mesh = dist.ProcessMesh([0, 1], dim_names=['mp']) input = paddle.static.data( name='input', @@ -223,7 +227,8 @@ def test_build_row_parallel_program(self): def test_build_with_shard_tensor(self): with paddle.pir_utils.IrGuard(): main_program = paddle.base.Program() - with paddle.base.program_guard(main_program): + start_program = paddle.base.Program() + with paddle.base.program_guard(main_program, start_program): mesh = dist.ProcessMesh([0, 1], dim_names=['mp']) input = paddle.static.data( name='input', diff --git a/test/deprecated/legacy_test/test_imperative_gan.py b/test/legacy_test/test_imperative_gan.py similarity index 99% rename from test/deprecated/legacy_test/test_imperative_gan.py rename to test/legacy_test/test_imperative_gan.py index 5b294e93e105e5..08bcad82b9d1d5 100644 --- a/test/deprecated/legacy_test/test_imperative_gan.py +++ b/test/legacy_test/test_imperative_gan.py @@ -270,4 +270,5 @@ def test_gan_float32(self): if __name__ == '__main__': + paddle.enable_static() unittest.main()