Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "【Hackathon 6th Fundable Projects 3 No.81】Remove fluid operators ctc_align" #64049

Merged
merged 2 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions paddle/fluid/operators/ctc_align_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/ctc_align_op.h"

namespace paddle {
namespace operators {

class CTCAlignOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "ctc_align");
OP_INOUT_CHECK(ctx->HasOutput("Output"), "Output", "Output", "ctc_align");

auto input_dims = ctx->GetInputDim("Input");

// TODO(wanghaoshuang): it is tricky to set the wrong dimension here.
ctx->SetOutputDim("Output", input_dims);
if (ctx->HasInput("InputLength")) {
ctx->SetOutputDim("OutputLength", {input_dims[0], 1});
}
}

protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context().GetPlace());
}
};

class CTCAlignOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Input",
"2-D Tensor or LodTensor with shape "
"[Lp, 1], where Lp is the sum of all input sequences' length.");
AddInput("InputLength",
"2-D Tensor with shape [batch_size, 1], "
" When Input is padding mode, InputLength is length of every "
"sequence in Input.")
.AsDispensable();
AddOutput("Output", "(Tensor, default: Tensor<int>), The align result.");
AddOutput("OutputLength",
"2-D Tensor with shape [batch_size, 1], "
"When Input is padding mode, OutputLength is length of every "
"sequence in Output.")
.AsDispensable();
AddAttr<int>("blank",
"(int, default: 0), the blank label set in Connectionist "
"Temporal Classification (CTC) op.")
.SetDefault(0);
AddAttr<bool>("merge_repeated",
"(bool, default: true), whether to "
"merge repeated elements between two blanks. ")
.SetDefault(true);
// add attr padding number for tensor input
AddAttr<int>("padding_value",
"(int, default: 0), padding number "
"use to padding tensor. ")
.SetDefault(0);
AddComment(R"DOC(
CTCAlign op is used to merge repeated elements between two blanks
and then delete all blanks in sequence.
Given:
Input.data = [0, 1, 2, 2, 0, 4, 0, 4, 5, 0, 6,
6, 0, 0, 7, 7, 7, 0]
Input.dims = {18, 1}
Input.LoD = [[0, 11, 18]]
And:
blank = 0
merge_repeated = True
Then:
Output.data = [1, 2, 4, 4, 5, 6,
6, 7]
Output.dims = {8, 1}
Output.LoD = [[0, 6, 8]]
or Given:
Input.data = [[0, 1, 2, 2, 0, 4],
[0, 4, 5, 0, 6, 0],
[0, 7, 7, 7, 0, 0]]
InputLength.data = [[6],
[5],
[4]],
Input.dims = {3, 6},
Input.Lod = []
And:
blank = 0
merge_repeated = True
padding_value = 0
Then:
Output.data = [[1, 2, 4, 0, 0, 0],
[4, 5, 6, 0, 0, 0],
[7, 0, 0, 0, 0, 0]],
OutputLength.data = [[3],
[3],
[1]],
Output.dims = {3, 6},
Output.Lod = []
)DOC");
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OPERATOR(
ctc_align,
ops::CTCAlignOp,
ops::CTCAlignOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);

PD_REGISTER_STRUCT_KERNEL(
ctc_align, CPU, ALL_LAYOUT, ops::CTCAlignKernel, int, int64_t) {}
171 changes: 171 additions & 0 deletions paddle/fluid/operators/ctc_align_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <stdio.h>
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>

#include <vector>

#include "paddle/fluid/operators/ctc_align_op.h"

namespace paddle {
namespace operators {

template <typename T>
__global__ void MergeAndDelCudaKernel(const int64_t num_token,
const T* tokens,
const size_t num_seq,
size_t* lod0,
const int blank,
const int merge_repeated,
size_t* out_lod0,
T* output) {
int output_idx = 0;
out_lod0[0] = 0;

for (int i = 0; i < num_seq; ++i) {
T pre_token = -1;
for (int j = lod0[i]; j < lod0[i + 1]; ++j) {
if (tokens[j] != blank && !(merge_repeated && tokens[j] == pre_token)) {
output[output_idx] = tokens[j];
++output_idx;
}
pre_token = tokens[j];
}
out_lod0[i + 1] = output_idx;
}
}

template <typename T>
__global__ void PaddingMergeAndDelCudaKernel(const int64_t num_token,
const T* tokens,
const T* tokens_length,
const int blank,
const int merge_repeated,
const int padding_value,
const int64_t batch_size,
T* output,
T* output_length) {
int ind = blockIdx.x * blockDim.x + threadIdx.x;
if (ind >= batch_size) return;
int output_idx = ind * num_token;
T prev_token = -1;
for (int i = ind * num_token; i < ind * num_token + tokens_length[ind]; i++) {
if ((unsigned)tokens[i] != blank &&
!(merge_repeated && tokens[i] == prev_token)) {
output[output_idx] = tokens[i];
++output_idx;
}
prev_token = tokens[i];
}
output_length[ind] = output_idx - ind * num_token;
for (int i = output_idx; i < ind * num_token + num_token; i++) {
output[i] = padding_value;
}
}

template <typename T, typename DeviceContext>
class CTCAlignOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()),
true,
phi::errors::InvalidArgument(
"CTCAlign operator CUDA kernel must use CUDAPlace "
"rather than CPUPlace."));
auto* input = ctx.Input<phi::DenseTensor>("Input");
auto* output = ctx.Output<phi::DenseTensor>("Output");
const int blank = ctx.Attr<int>("blank");
const int merge_repeated =
static_cast<int>(ctx.Attr<bool>("merge_repeated"));
const T* tokens = input->data<T>();
auto stream = ctx.cuda_device_context().stream();

// tensor input which has no lod
if (input->lod().empty()) {
const int padding_value = ctx.Attr<int>("padding_value");
auto input_dims = input->dims();
T* output_data = output->mutable_data<T>({input_dims[0], input_dims[1]},
ctx.GetPlace());
auto* input_length = ctx.Input<phi::DenseTensor>("InputLength");
const T* input_length_data = input_length->data<T>();
auto* output_length = ctx.Output<phi::DenseTensor>("OutputLength");
T* output_length_data =
output_length->mutable_data<T>({input_dims[0], 1}, ctx.GetPlace());
PaddingMergeAndDelCudaKernel<T>
<<<32, (input_dims[0] + 32 - 1) / 32, 0, stream>>>(
input_dims[1],
tokens,
input_length_data,
blank,
merge_repeated,
padding_value,
input_dims[0],
output_data,
output_length_data);
} else {
const size_t level = 0;
auto input_lod = framework::ToAbsOffset(input->lod());

const int64_t num_tokens = input->dims()[0];
const size_t num_seq = input_lod[level].size() - 1;

// prepare a lod to record lod information while merging elements
thrust::device_vector<size_t> dev_out_lod0(input_lod[level].size());
size_t* dev_out_lod0_ptr = thrust::raw_pointer_cast(dev_out_lod0.data());

// merge elements and delete blank
T* output_data = output->mutable_data<T>({num_tokens, 1}, ctx.GetPlace());

phi::MixVector<size_t> mixv_input_lod(&input_lod[level]);
MergeAndDelCudaKernel<T>
<<<1, 1, 0, stream>>>(num_tokens,
tokens,
num_seq,
mixv_input_lod.CUDAMutableData(ctx.GetPlace()),
blank,
merge_repeated,
dev_out_lod0_ptr,
output_data);
mixv_input_lod.CopyToCPU();

// set output lod
std::vector<size_t> host_out_lod0(dev_out_lod0.begin(),
dev_out_lod0.end());
framework::LoD out_lod;
out_lod.push_back(host_out_lod0);
output->set_lod(out_lod);

// resize output dims
output->Resize({static_cast<int64_t>(host_out_lod0.back()), 1});

if (host_out_lod0.back() == 0) {
output->Resize({1, 1});
output->mutable_data<T>(ctx.GetPlace());
phi::funcs::SetConstant<phi::GPUContext, T> set_constant;
set_constant(
ctx.template device_context<phi::GPUContext>(), output, -1);
}
}
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;

PD_REGISTER_STRUCT_KERNEL(
ctc_align, GPU, ALL_LAYOUT, ops::CTCAlignOpCUDAKernel, int, int64_t) {}
Loading