Skip to content

Commit

Permalink
add autogen code support for warpctc op (#52610)
Browse files Browse the repository at this point in the history
  • Loading branch information
ccsuzzh authored Apr 7, 2023
1 parent fa949b1 commit a62de41
Show file tree
Hide file tree
Showing 8 changed files with 33 additions and 242 deletions.
12 changes: 1 addition & 11 deletions paddle/fluid/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ endif()

set(OP_HEADER_DEPS ${OP_HEADER_DEPS} phi phi_utils backward_infermeta sparse_backward_infermeta static_prim_api get_expected_kernel_func)

register_operators(EXCLUDES py_func_op warpctc_op dgc_op generated_op1 generated_op2 generated_op3 generated_op4 load_combine_op lstm_op run_program_op quantize_linear_op
register_operators(EXCLUDES py_func_op dgc_op generated_op1 generated_op2 generated_op3 generated_op4 load_combine_op lstm_op run_program_op quantize_linear_op
recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op activation_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS})

op_library(generated_op UNITY SRCS generated_op1.cc generated_op2.cc generated_op3.cc generated_op4.cc DEPS ${OP_HEADER_DEPS})
Expand All @@ -111,20 +111,10 @@ else()
endif()

if (WITH_GPU OR WITH_ROCM)
if(WITH_ROCM)
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale SRCS warpctc_op.cc)
# warpctc_op needs cudnn 7 above
elseif(${CUDNN_MAJOR_VERSION} VERSION_LESS 7)
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale SRCS warpctc_op.cc)
else()
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
endif()
op_library(sync_batch_norm_op)
if ((NOT WIN32) AND (NOT WITH_ROCM) AND (NOT PADDLE_WITH_ARM) AND (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_LESS 11.3) )
op_library(sparse_attention_op)
endif()
else()
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
endif()

if (WITH_ASCEND_CL)
Expand Down
170 changes: 0 additions & 170 deletions paddle/fluid/operators/warpctc_op.cc

This file was deleted.

13 changes: 13 additions & 0 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1869,6 +1869,19 @@
kernel :
func : unstack_grad

- backward_op : warpctc_grad
forward : warpctc (Tensor logits, Tensor label, Tensor logits_length, Tensor labels_length, int blank = 0, bool norm_by_times = false) -> Tensor(loss), Tensor(warpctcgrad)
args : (Tensor logits, Tensor logits_length, Tensor warpctcgrad, Tensor loss_grad, int blank, bool norm_by_times)
output : Tensor(logits_grad)
infer_meta :
func : UnchangedInferMeta
param : [logits]
kernel :
func : warpctc_grad
data_type : loss_grad
optional : logits_length
no_need_buffer : logits

- backward_op : warprnnt_grad
forward : warprnnt (Tensor input, Tensor label, Tensor input_lengths, Tensor label_lengths, int blank = 0, float fastemit_lambda = 0.0) -> Tensor(loss), Tensor(warprnntgrad)
args : (Tensor input, Tensor input_lengths, Tensor warprnntgrad, Tensor loss_grad, int blank = 0, float fastemit_lambda = 0.0)
Expand Down
12 changes: 0 additions & 12 deletions paddle/phi/api/yaml/legacy_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1224,18 +1224,6 @@
func : uniform_inplace_grad
inplace : (out_grad -> x_grad)

- backward_op : warpctc_grad
forward : warpctc (Tensor logits, Tensor label, Tensor logits_length, Tensor labels_length, int blank, bool norm_by_times) -> Tensor(loss), Tensor(warpctcgrad)
args : (Tensor logits, Tensor logits_length, Tensor warpctcgrad, Tensor loss_grad, int blank, bool norm_by_times)
output : Tensor(logits_grad)
infer_meta :
func : UnchangedInferMeta
param : [logits]
kernel :
func : warpctc_grad
optional : logits_length
no_need_buffer : logits

- backward_op : yolo_loss_grad
forward : yolo_loss(Tensor x, Tensor gt_box, Tensor gt_label, Tensor gt_score, int[] anchors, int[] anchor_mask, int class_num, float ignore_thresh, int downsample_ratio, bool use_label_smooth=true, float scale_x_y=1.0) -> Tensor(loss), Tensor(objectness_mask), Tensor(gt_match_mask)
args : (Tensor x, Tensor gt_box, Tensor gt_label, Tensor gt_score, Tensor objectness_mask, Tensor gt_match_mask, Tensor loss_grad, int[] anchors, int[] anchor_mask, int class_num, float ignore_thresh, int downsample_ratio, bool use_label_smooth=true, float scale_x_y=1.0)
Expand Down
12 changes: 0 additions & 12 deletions paddle/phi/api/yaml/legacy_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1547,18 +1547,6 @@
data_type: x
backward: unpool3d_grad

- op : warpctc
args : (Tensor logits, Tensor label, Tensor logits_length, Tensor labels_length, int blank, bool norm_by_times)
output : Tensor(loss), Tensor(warpctcgrad)
infer_meta :
func : WarpctcInferMeta
kernel :
func : warpctc
data_type: logits
optional: logits_length, labels_length
intermediate: warpctcgrad
backward : warpctc_grad

- op : yolo_box
args : (Tensor x, Tensor img_size, int[] anchors, int class_num, float conf_thresh, int downsample_ratio, bool clip_bbox, float scale_x_y=1.0, bool iou_aware=false, float iou_aware_factor=0.5)
output : Tensor(boxes), Tensor(scores)
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2193,6 +2193,13 @@
outputs :
{scores : Scores, path : Path}

- op : warpctc
backward : warpctc_grad
inputs :
{logits : Logits, label : Label, logits_length : LogitsLength, labels_length : LabelLength}
outputs :
{warpctcgrad : WarpCTCGrad, loss : Loss}

- op : where
backward : where_grad
inputs :
Expand Down
12 changes: 12 additions & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1827,6 +1827,18 @@
func : viterbi_decode
data_type : potentials

- op : warpctc
args : (Tensor logits, Tensor label, Tensor logits_length, Tensor labels_length, int blank = 0, bool norm_by_times = false)
output : Tensor(loss), Tensor(warpctcgrad)
infer_meta :
func : WarpctcInferMeta
kernel :
func : warpctc
data_type: logits
optional: logits_length, labels_length
intermediate: warpctcgrad
backward : warpctc_grad

- op : warprnnt
args : (Tensor input, Tensor label, Tensor input_lengths, Tensor label_lengths, int blank = 0, float fastemit_lambda = 0.0)
output : Tensor(loss), Tensor(warprnntgrad)
Expand Down
37 changes: 0 additions & 37 deletions paddle/phi/ops/compat/warpctc_sig.cc

This file was deleted.

0 comments on commit a62de41

Please sign in to comment.