Skip to content

Commit

Permalink
Change some op with xpu control (#52067)
Browse files Browse the repository at this point in the history
* change op with xpu

* change range yaml

* fix bug in generate_op.py
  • Loading branch information
heavyrain-lzy authored Mar 30, 2023
1 parent 8baf33a commit 1faa06f
Show file tree
Hide file tree
Showing 13 changed files with 101 additions and 385 deletions.
140 changes: 0 additions & 140 deletions paddle/fluid/operators/amp/update_loss_scaling_op.cc

This file was deleted.

22 changes: 18 additions & 4 deletions paddle/fluid/operators/generator/generate_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,15 @@ def parse_get_expected_kerneltype(
for op_comp_map in op_fluid_list:
if 'get_expected_kernel_type' in op_comp_map:
fw_name = op_comp_map['op'].split('(')[0].strip()
# deal the last underline of function name in op_comp_map['get_expected_kernel_type']
new_get_expected_kernel_type_func_map = {}
for (key, value) in op_comp_map['get_expected_kernel_type'].items():
new_get_expected_kernel_type_func_map[
delete_last_underline(key)
] = value
op_comp_map[
'get_expected_kernel_type'
] = new_get_expected_kernel_type_func_map
if fw_name in op_comp_map['get_expected_kernel_type']:
# static_ops.yaml and ops.yaml use the common op_compat.yaml
if fw_name in fw_op_dict:
Expand Down Expand Up @@ -507,10 +516,15 @@ def parse_keep_signature(
for op_comp_map in op_fluid_list:
if 'manual_signature' in op_comp_map:
for op_name in op_comp_map['manual_signature']:
if op_name in fw_op_dict:
fw_op_dict[op_name]["manual_signature"] = True
elif op_name in bw_op_dict:
bw_op_dict[op_name]["manual_signature"] = True
op_name_without_last_underline = delete_last_underline(op_name)
if op_name_without_last_underline in fw_op_dict:
fw_op_dict[op_name_without_last_underline][
"manual_signature"
] = True
elif op_name_without_last_underline in bw_op_dict:
bw_op_dict[op_name_without_last_underline][
"manual_signature"
] = True


def split_ops_list(ops, backward_op_dict, split_num):
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/operators/generator/get_expected_kernel_func.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,5 +141,15 @@ phi::KernelKey GetSgdExpectedKernelType(
return phi::KernelKey(data_type, ctx.GetPlace());
}

phi::KernelKey GetUpdateLossScalingExpectedKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr) {
auto dtype = framework::proto::VarType::FP32;
if (ctx.MultiInputVar("X").size() >= 1) {
dtype = op_ptr->IndicateVarDataType(ctx, "X");
}
return phi::KernelKey(dtype, ctx.GetPlace());
}

} // namespace operators
} // namespace paddle
4 changes: 4 additions & 0 deletions paddle/fluid/operators/generator/get_expected_kernel_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,9 @@ phi::KernelKey GetSgdExpectedKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr);

phi::KernelKey GetUpdateLossScalingExpectedKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr);

} // namespace operators
} // namespace paddle
91 changes: 0 additions & 91 deletions paddle/fluid/operators/linspace_op.cc

This file was deleted.

74 changes: 0 additions & 74 deletions paddle/fluid/operators/range_op.cc

This file was deleted.

11 changes: 0 additions & 11 deletions paddle/phi/api/yaml/legacy_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1658,17 +1658,6 @@
data_type: x
backward: unpool3d_grad

- op : update_loss_scaling_
args : (Tensor[] x, Tensor found_infinite, Tensor prev_loss_scaling, Tensor in_good_steps, Tensor in_bad_steps, int incr_every_n_steps, int decr_every_n_nan_or_inf, float incr_ratio, float decr_ratio, Scalar stop_update)
output : Tensor[](out){x.size()}, Tensor(loss_scaling), Tensor(out_good_steps), Tensor(out_bad_steps)
infer_meta :
func : UpdateLossScalingInferMeta
param : [x, found_infinite, prev_loss_scaling, in_good_steps, in_bad_steps]
kernel :
func : update_loss_scaling
data_type : x
inplace : (x -> out), (prev_loss_scaling -> loss_scaling), (in_good_steps -> out_good_steps), (in_bad_steps -> out_bad_steps)

- op : warpctc
args : (Tensor logits, Tensor label, Tensor logits_length, Tensor labels_length, int blank, bool norm_by_times)
output : Tensor(loss), Tensor(warpctcgrad)
Expand Down
Loading

0 comments on commit 1faa06f

Please sign in to comment.