From 91893efbe0525e7f1379b01b191d65a85d7bbefa Mon Sep 17 00:00:00 2001 From: zhupengyang Date: Mon, 20 Mar 2023 12:13:31 +0000 Subject: [PATCH] subgraph support device param copy --- .../ir_params_sync_among_devices_pass.cc | 30 +++++++++++-------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc b/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc index 687e3581c5e47..503ea531f171a 100644 --- a/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc +++ b/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc @@ -227,19 +227,23 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToXpu(Argument *argument) { platform::CPUPlace cpu_place; platform::Place xpu_place = platform::XPUPlace(argument->xpu_device_id()); auto *scope = argument->scope_ptr(); - framework::ir::Graph &graph = argument->main_graph(); - - for (auto *node : graph.Nodes()) { - if (!node->IsVar() || !node->Var()->Persistable()) continue; - auto *var = scope->FindVar(node->Name()); - if (!var->IsType()) continue; - auto *tensor = var->GetMutable(); - - phi::DenseTensor temp_tensor; - temp_tensor.Resize(tensor->dims()); - paddle::framework::TensorCopySync(*tensor, cpu_place, &temp_tensor); - tensor->clear(); - paddle::framework::TensorCopySync(temp_tensor, xpu_place, tensor); + framework::ir::Graph &main_graph = argument->main_graph(); + + for (size_t i = 0; i < main_graph.SubGraphsSize(); i++) { + auto *graph = main_graph.GetSubGraph(i); + for (auto *node : graph->Nodes()) { + if (!node->IsVar() || !node->Var()->Persistable()) continue; + auto *var = scope->FindVar(node->Name()); + if (!var->IsType()) continue; + auto *tensor = var->GetMutable(); + if (tensor->place().GetType() == phi::AllocationType::XPU) continue; + + phi::DenseTensor temp_tensor; + temp_tensor.Resize(tensor->dims()); + paddle::framework::TensorCopySync(*tensor, cpu_place, &temp_tensor); + tensor->clear(); + paddle::framework::TensorCopySync(temp_tensor, xpu_place, tensor); + } } } #endif