From 45aefb8415b97eea1bc0b0954d480e0b95d13e68 Mon Sep 17 00:00:00 2001 From: co63oc Date: Tue, 7 May 2024 16:40:07 +0800 Subject: [PATCH] =?UTF-8?q?Revert=20"=E3=80=90Hackathon=206th=20Fundable?= =?UTF-8?q?=20Projects=203=20No.81=E3=80=91Remove=20fluid=20operators=20ct?= =?UTF-8?q?c=5Fa=E2=80=A6"=20(#64049)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit 2134ead2500320b0101169a4809fb9f50c76cc77. --- paddle/fluid/operators/ctc_align_op.cc | 133 ++++++++++++++ paddle/fluid/operators/ctc_align_op.cu | 171 ++++++++++++++++++ paddle/fluid/operators/ctc_align_op.h | 119 +++++++++++++ test/legacy_test/test_ctc_align.py | 232 +++++++++++++++++++++++++ 4 files changed, 655 insertions(+) create mode 100644 paddle/fluid/operators/ctc_align_op.cc create mode 100644 paddle/fluid/operators/ctc_align_op.cu create mode 100644 paddle/fluid/operators/ctc_align_op.h create mode 100644 test/legacy_test/test_ctc_align.py diff --git a/paddle/fluid/operators/ctc_align_op.cc b/paddle/fluid/operators/ctc_align_op.cc new file mode 100644 index 00000000000000..a40ba846102935 --- /dev/null +++ b/paddle/fluid/operators/ctc_align_op.cc @@ -0,0 +1,133 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/ctc_align_op.h" + +namespace paddle { +namespace operators { + +class CTCAlignOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "ctc_align"); + OP_INOUT_CHECK(ctx->HasOutput("Output"), "Output", "Output", "ctc_align"); + + auto input_dims = ctx->GetInputDim("Input"); + + // TODO(wanghaoshuang): it is tricky to set the wrong dimension here. + ctx->SetOutputDim("Output", input_dims); + if (ctx->HasInput("InputLength")) { + ctx->SetOutputDim("OutputLength", {input_dims[0], 1}); + } + } + + protected: + phi::KernelKey GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context().GetPlace()); + } +}; + +class CTCAlignOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Input", + "2-D Tensor or LodTensor with shape " + "[Lp, 1], where Lp is the sum of all input sequences' length."); + AddInput("InputLength", + "2-D Tensor with shape [batch_size, 1], " + " When Input is padding mode, InputLength is length of every " + "sequence in Input.") + .AsDispensable(); + AddOutput("Output", "(Tensor, default: Tensor), The align result."); + AddOutput("OutputLength", + "2-D Tensor with shape [batch_size, 1], " + "When Input is padding mode, OutputLength is length of every " + "sequence in Output.") + .AsDispensable(); + AddAttr("blank", + "(int, default: 0), the blank label set in Connectionist " + "Temporal Classification (CTC) op.") + .SetDefault(0); + AddAttr("merge_repeated", + "(bool, default: true), whether to " + "merge repeated elements between two blanks. ") + .SetDefault(true); + // add attr padding number for tensor input + AddAttr("padding_value", + "(int, default: 0), padding number " + "use to padding tensor. ") + .SetDefault(0); + AddComment(R"DOC( +CTCAlign op is used to merge repeated elements between two blanks +and then delete all blanks in sequence. + +Given: + Input.data = [0, 1, 2, 2, 0, 4, 0, 4, 5, 0, 6, + 6, 0, 0, 7, 7, 7, 0] + Input.dims = {18, 1} + Input.LoD = [[0, 11, 18]] + +And: + blank = 0 + merge_repeated = True + +Then: + Output.data = [1, 2, 4, 4, 5, 6, + 6, 7] + Output.dims = {8, 1} + Output.LoD = [[0, 6, 8]] +or Given: + Input.data = [[0, 1, 2, 2, 0, 4], + [0, 4, 5, 0, 6, 0], + [0, 7, 7, 7, 0, 0]] + InputLength.data = [[6], + [5], + [4]], + Input.dims = {3, 6}, + Input.Lod = [] +And: + blank = 0 + merge_repeated = True + padding_value = 0 + +Then: + Output.data = [[1, 2, 4, 0, 0, 0], + [4, 5, 6, 0, 0, 0], + [7, 0, 0, 0, 0, 0]], + OutputLength.data = [[3], + [3], + [1]], + Output.dims = {3, 6}, + Output.Lod = [] +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR( + ctc_align, + ops::CTCAlignOp, + ops::CTCAlignOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); + +PD_REGISTER_STRUCT_KERNEL( + ctc_align, CPU, ALL_LAYOUT, ops::CTCAlignKernel, int, int64_t) {} diff --git a/paddle/fluid/operators/ctc_align_op.cu b/paddle/fluid/operators/ctc_align_op.cu new file mode 100644 index 00000000000000..76466ed12ab88f --- /dev/null +++ b/paddle/fluid/operators/ctc_align_op.cu @@ -0,0 +1,171 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include + +#include + +#include "paddle/fluid/operators/ctc_align_op.h" + +namespace paddle { +namespace operators { + +template +__global__ void MergeAndDelCudaKernel(const int64_t num_token, + const T* tokens, + const size_t num_seq, + size_t* lod0, + const int blank, + const int merge_repeated, + size_t* out_lod0, + T* output) { + int output_idx = 0; + out_lod0[0] = 0; + + for (int i = 0; i < num_seq; ++i) { + T pre_token = -1; + for (int j = lod0[i]; j < lod0[i + 1]; ++j) { + if (tokens[j] != blank && !(merge_repeated && tokens[j] == pre_token)) { + output[output_idx] = tokens[j]; + ++output_idx; + } + pre_token = tokens[j]; + } + out_lod0[i + 1] = output_idx; + } +} + +template +__global__ void PaddingMergeAndDelCudaKernel(const int64_t num_token, + const T* tokens, + const T* tokens_length, + const int blank, + const int merge_repeated, + const int padding_value, + const int64_t batch_size, + T* output, + T* output_length) { + int ind = blockIdx.x * blockDim.x + threadIdx.x; + if (ind >= batch_size) return; + int output_idx = ind * num_token; + T prev_token = -1; + for (int i = ind * num_token; i < ind * num_token + tokens_length[ind]; i++) { + if ((unsigned)tokens[i] != blank && + !(merge_repeated && tokens[i] == prev_token)) { + output[output_idx] = tokens[i]; + ++output_idx; + } + prev_token = tokens[i]; + } + output_length[ind] = output_idx - ind * num_token; + for (int i = output_idx; i < ind * num_token + num_token; i++) { + output[i] = padding_value; + } +} + +template +class CTCAlignOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), + true, + phi::errors::InvalidArgument( + "CTCAlign operator CUDA kernel must use CUDAPlace " + "rather than CPUPlace.")); + auto* input = ctx.Input("Input"); + auto* output = ctx.Output("Output"); + const int blank = ctx.Attr("blank"); + const int merge_repeated = + static_cast(ctx.Attr("merge_repeated")); + const T* tokens = input->data(); + auto stream = ctx.cuda_device_context().stream(); + + // tensor input which has no lod + if (input->lod().empty()) { + const int padding_value = ctx.Attr("padding_value"); + auto input_dims = input->dims(); + T* output_data = output->mutable_data({input_dims[0], input_dims[1]}, + ctx.GetPlace()); + auto* input_length = ctx.Input("InputLength"); + const T* input_length_data = input_length->data(); + auto* output_length = ctx.Output("OutputLength"); + T* output_length_data = + output_length->mutable_data({input_dims[0], 1}, ctx.GetPlace()); + PaddingMergeAndDelCudaKernel + <<<32, (input_dims[0] + 32 - 1) / 32, 0, stream>>>( + input_dims[1], + tokens, + input_length_data, + blank, + merge_repeated, + padding_value, + input_dims[0], + output_data, + output_length_data); + } else { + const size_t level = 0; + auto input_lod = framework::ToAbsOffset(input->lod()); + + const int64_t num_tokens = input->dims()[0]; + const size_t num_seq = input_lod[level].size() - 1; + + // prepare a lod to record lod information while merging elements + thrust::device_vector dev_out_lod0(input_lod[level].size()); + size_t* dev_out_lod0_ptr = thrust::raw_pointer_cast(dev_out_lod0.data()); + + // merge elements and delete blank + T* output_data = output->mutable_data({num_tokens, 1}, ctx.GetPlace()); + + phi::MixVector mixv_input_lod(&input_lod[level]); + MergeAndDelCudaKernel + <<<1, 1, 0, stream>>>(num_tokens, + tokens, + num_seq, + mixv_input_lod.CUDAMutableData(ctx.GetPlace()), + blank, + merge_repeated, + dev_out_lod0_ptr, + output_data); + mixv_input_lod.CopyToCPU(); + + // set output lod + std::vector host_out_lod0(dev_out_lod0.begin(), + dev_out_lod0.end()); + framework::LoD out_lod; + out_lod.push_back(host_out_lod0); + output->set_lod(out_lod); + + // resize output dims + output->Resize({static_cast(host_out_lod0.back()), 1}); + + if (host_out_lod0.back() == 0) { + output->Resize({1, 1}); + output->mutable_data(ctx.GetPlace()); + phi::funcs::SetConstant set_constant; + set_constant( + ctx.template device_context(), output, -1); + } + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +PD_REGISTER_STRUCT_KERNEL( + ctc_align, GPU, ALL_LAYOUT, ops::CTCAlignOpCUDAKernel, int, int64_t) {} diff --git a/paddle/fluid/operators/ctc_align_op.h b/paddle/fluid/operators/ctc_align_op.h new file mode 100644 index 00000000000000..9ebfa7196ecc56 --- /dev/null +++ b/paddle/fluid/operators/ctc_align_op.h @@ -0,0 +1,119 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace paddle { +namespace operators { + +template +class CTCAlignKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("Input"); + auto* output = ctx.Output("Output"); + size_t blank = static_cast(ctx.Attr("blank")); + bool merge_repeated = ctx.Attr("merge_repeated"); + T* output_data = output->mutable_data(ctx.GetPlace()); + auto input_dims = common::vectorize(input->dims()); + const T* input_data = input->data(); + + // support tensor input, no lod information + if (input->lod().empty()) { + size_t padding_value = + static_cast(ctx.Attr("padding_value")); + auto* input_length = ctx.Input("InputLength"); + const T* input_length_data = input_length->data(); + + auto* output_length = ctx.Output("OutputLength"); + T* output_length_data = output_length->mutable_data(ctx.GetPlace()); + + for (size_t batch_id = 0; batch_id < (unsigned)input_dims[0]; + batch_id++) { + T prev_token = -1; + size_t output_idx = 0; + for (size_t i = 0; i < (unsigned)input_length_data[batch_id]; i++) { + size_t input_ind = batch_id * input_dims[1] + i; + if ((unsigned)input_data[input_ind] != blank && + !(merge_repeated && input_data[input_ind] == prev_token)) { + output_data[batch_id * input_dims[1] + output_idx] = + input_data[input_ind]; + ++output_idx; + } + prev_token = input_data[input_ind]; + } + output_length_data[batch_id] = output_idx; + for (size_t j = output_idx; j < (unsigned)input_dims[1]; j++) + output_data[batch_id * input_dims[1] + j] = padding_value; + } + } else { + const size_t level = 0; + auto input_lod = framework::ToAbsOffset(input->lod()); + + // check input dims and lod + PADDLE_ENFORCE_EQ( + input_dims[0], + static_cast(input_lod[level].back()), + phi::errors::InvalidArgument( + "The first dimension %d of CTCAlign operator Input(Input) should " + "be equal to " + "the sum of all sequences' lengths %d.", + input_dims[0], + static_cast(input_lod[level].back()))); + + const size_t num_sequences = input_lod[level].size() - 1; + + // merge repeated tokens and delete blank + size_t output_idx = 0; + std::vector output_lod0(1, 0); + for (size_t seq_idx = 0; seq_idx < num_sequences; ++seq_idx) { + T prev_token = -1; + for (size_t i = input_lod[level][seq_idx]; + i < input_lod[level][seq_idx + 1]; + ++i) { + if ((unsigned)input_data[i] != blank && + !(merge_repeated && input_data[i] == prev_token)) { + output_data[output_idx] = input_data[i]; + ++output_idx; + } + prev_token = input_data[i]; + } + output_lod0.push_back(output_idx); + } + + // set output lod + framework::LoD output_lod; + output_lod.push_back(output_lod0); + output->set_lod(output_lod); + // resize output dims + output->Resize({static_cast(output_lod0.back()), 1}); + // for empty sequence + if (output_lod0.back() == 0) { + output->Resize({1, 1}); + output_data = output->mutable_data(ctx.GetPlace()); + output_data[0] = -1; + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/test/legacy_test/test_ctc_align.py b/test/legacy_test/test_ctc_align.py new file mode 100644 index 00000000000000..699b176518be18 --- /dev/null +++ b/test/legacy_test/test_ctc_align.py @@ -0,0 +1,232 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from op_test import OpTest + +import paddle + + +def CTCAlign(input, lod, blank, merge_repeated, padding=0, input_length=None): + if input_length is None: + lod0 = lod[0] + result = [] + cur_offset = 0 + for i in range(len(lod0)): + prev_token = -1 + for j in range(cur_offset, cur_offset + lod0[i]): + token = input[j][0] + if (token != blank) and not ( + merge_repeated and token == prev_token + ): + result.append(token) + prev_token = token + cur_offset += lod0[i] + result = np.array(result).reshape([len(result), 1]).astype("int32") + if len(result) == 0: + result = np.array([[-1]]) + return result + else: + result = [[] for i in range(len(input))] + output_length = [] + for i in range(len(input)): + prev_token = -1 + for j in range(input_length[i][0]): + token = input[i][j] + if (token != blank) and not ( + merge_repeated and token == prev_token + ): + result[i].append(token) + prev_token = token + start = len(result[i]) + output_length.append([start]) + for j in range(start, len(input[i])): + result[i].append(padding) + result = ( + np.array(result) + .reshape([len(input), len(input[0])]) + .astype("int32") + ) + output_length = ( + np.array(output_length).reshape([len(input), 1]).astype("int32") + ) + + return result, output_length + + +class TestCTCAlignOp(OpTest): + def config(self): + self.op_type = "ctc_align" + self.input_lod = [[11, 7]] + self.blank = 0 + self.merge_repeated = False + self.input = ( + np.array([0, 1, 2, 2, 0, 4, 0, 4, 5, 0, 6, 6, 0, 0, 7, 7, 7, 0]) + .reshape([18, 1]) + .astype("int32") + ) + + def setUp(self): + self.config() + output = CTCAlign( + self.input, self.input_lod, self.blank, self.merge_repeated + ) + + self.inputs = { + "Input": (self.input, self.input_lod), + } + self.outputs = {"Output": output} + self.attrs = { + "blank": self.blank, + "merge_repeated": self.merge_repeated, + } + + def test_check_output(self): + # NODE(yjjiang11): This op will be deprecated. + self.check_output(check_dygraph=False) + + +class TestCTCAlignOpCase1(TestCTCAlignOp): + def config(self): + self.op_type = "ctc_align" + self.input_lod = [[11, 8]] + self.blank = 0 + self.merge_repeated = True + self.input = ( + np.array([0, 1, 2, 2, 0, 4, 0, 4, 5, 0, 6, 6, 0, 0, 7, 7, 7, 0, 0]) + .reshape([19, 1]) + .astype("int32") + ) + + +class TestCTCAlignOpCase2(TestCTCAlignOp): + def config(self): + self.op_type = "ctc_align" + self.input_lod = [[4]] + self.blank = 0 + self.merge_repeated = True + self.input = np.array([0, 0, 0, 0]).reshape([4, 1]).astype("int32") + + +class TestCTCAlignPaddingOp(OpTest): + def config(self): + self.op_type = "ctc_align" + self.input_lod = [] + self.blank = 0 + self.padding_value = 0 + self.merge_repeated = True + self.input = ( + np.array( + [ + [0, 2, 4, 4, 0, 6, 3, 6, 6, 0, 0], + [1, 1, 3, 0, 0, 4, 5, 6, 0, 0, 0], + ] + ) + .reshape([2, 11]) + .astype("int32") + ) + self.input_length = np.array([[9], [8]]).reshape([2, 1]).astype("int32") + + def setUp(self): + self.config() + output, output_length = CTCAlign( + self.input, + self.input_lod, + self.blank, + self.merge_repeated, + self.padding_value, + self.input_length, + ) + self.inputs = { + "Input": (self.input, self.input_lod), + "InputLength": self.input_length, + } + self.outputs = {"Output": output, "OutputLength": output_length} + self.attrs = { + "blank": self.blank, + "merge_repeated": self.merge_repeated, + "padding_value": self.padding_value, + } + + def test_check_output(self): + # NODE(yjjiang11): This op will be deprecated. + self.check_output(check_dygraph=False) + + +class TestCTCAlignOpCase3(TestCTCAlignPaddingOp): + def config(self): + self.op_type = "ctc_align" + self.blank = 0 + self.input_lod = [] + self.merge_repeated = True + self.padding_value = 0 + self.input = ( + np.array( + [[0, 1, 2, 2, 0, 4], [0, 4, 5, 0, 6, 0], [0, 7, 7, 7, 0, 0]] + ) + .reshape([3, 6]) + .astype("int32") + ) + self.input_length = ( + np.array([[6], [5], [4]]).reshape([3, 1]).astype("int32") + ) + + +class TestCTCAlignOpCase4(TestCTCAlignPaddingOp): + ''' + # test tensor input which has attr input padding_value + ''' + + def config(self): + self.op_type = "ctc_align" + self.blank = 0 + self.input_lod = [] + self.merge_repeated = False + self.padding_value = 0 + self.input = ( + np.array( + [[0, 1, 2, 2, 0, 4], [0, 4, 5, 0, 6, 0], [0, 7, 7, 7, 0, 0]] + ) + .reshape([3, 6]) + .astype("int32") + ) + self.input_length = ( + np.array([[6], [5], [4]]).reshape([3, 1]).astype("int32") + ) + + +class TestCTCAlignOpCase5(TestCTCAlignPaddingOp): + def config(self): + self.op_type = "ctc_align" + self.blank = 0 + self.input_lod = [] + self.merge_repeated = False + self.padding_value = 1 + self.input = ( + np.array( + [[0, 1, 2, 2, 0, 4], [0, 4, 5, 0, 6, 0], [0, 7, 1, 7, 0, 0]] + ) + .reshape([3, 6]) + .astype("int32") + ) + self.input_length = ( + np.array([[6], [5], [4]]).reshape([3, 1]).astype("int32") + ) + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main()