-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BeamSearchDecodeOp #5498
BeamSearchDecodeOp #5498
Changes from 10 commits
c59f569
4ef08e0
cf4287b
e0bdab9
067b30e
c2c5147
03ec48a
4b09874
79dcbd4
799d6db
beaf643
549d8fe
e802ebc
96959ff
31b871e
37bb50c
d4341dd
66ba330
c2079cc
cdb6f8a
3bf4c81
0705b54
89b8d32
1c038bd
fa1c74b
2f1cba2
1e0263f
90e93b8
08e0ef4
794be6e
1f76e94
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/operators/trieconcat_op.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
class TrieConcatOp : public framework::OperatorBase { | ||
public: | ||
TrieConcatOp(const std::string& type, | ||
const framework::VariableNameMap& inputs, | ||
const framework::VariableNameMap& outputs, | ||
const framework::AttributeMap& attrs) | ||
: OperatorBase(type, inputs, outputs, attrs) {} | ||
void Run(const framework::Scope& scope, | ||
const platform::DeviceContext& dev_ctx) const override { | ||
framework::ExecutionContext ctx(*this, scope, dev_ctx); | ||
const std::vector<LoDTensor>* ids = | ||
ctx.Input<std::vector<LoDTensor>>("Ids"); | ||
const std::vector<LoDTensor>* probs = | ||
ctx.Input<std::vector<LoDTensor>>("Scores"); | ||
const size_t step_num = ids->size(); | ||
PADDLE_ENFORCE_LT(step_num, 0, "beam search steps should be larger than 0"); | ||
const size_t source_num = ids->at(0).lod().at(0).size() - 1; | ||
PADDLE_ENFORCE_LT(source_num, 0UL, "source num should be larger than 0"); | ||
|
||
for (size_t i = 0; i < step_num; ++i) { | ||
PADDLE_ENFORCE_EQ(ids->at(i).lod().size(), 2UL, | ||
"Level of LodTensor should be 2"); | ||
} | ||
|
||
// prepare output | ||
LoDTensor* sentenceIds = ctx.Output<LoDTensor>("SentenceIds"); | ||
LoDTensor* sentenceProbs = ctx.Output<LoDTensor>("SentenceScores"); | ||
|
||
BeamHelpter beam_helper; | ||
beam_helper.PackAllSteps(*ids, *probs, sentenceIds, sentenceProbs); | ||
} | ||
}; | ||
|
||
class TrieConcatOpProtoMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
TrieConcatOpProtoMaker(framework::OpProto* proto, | ||
framework::OpAttrChecker* op_checker) | ||
: OpProtoAndCheckerMaker(proto, op_checker) { | ||
AddInput("Ids", | ||
"(vector<LodTensor>)" | ||
"score of the candidate words in each step"); | ||
AddInput("Scores", | ||
"(vector<LodTensor>)" | ||
"score of the candidate words in each step"); | ||
AddOutput("SentenceIds", | ||
"(LodTensor)" | ||
"All possible result sentences of word ids"); | ||
AddOutput("SentenceScores", | ||
"(LodTensor)" | ||
"All possible result sentences of word scores"); | ||
AddComment(R"DOC( | ||
Pack the result of Beam search op into SentenceIds and SentenceScores. | ||
)DOC"); | ||
} | ||
}; | ||
|
||
class TrieConcatInferShape : public framework::InferShapeBase { | ||
public: | ||
void operator()(framework::InferShapeContext* context) const override { | ||
PADDLE_ENFORCE(context->HasInput("Ids"), "TrieConcatOp must has input Ids"); | ||
PADDLE_ENFORCE(context->HasInput("Scores"), | ||
"TrieConcatOp must has input Scores"); | ||
PADDLE_ENFORCE(context->HasOutput("SentenceIds"), | ||
"TrieConcatOp must has output SentenceIds"); | ||
PADDLE_ENFORCE(context->HasOutput("SentenceScores"), | ||
"TrieConcatOp must has output SentenceScores"); | ||
} | ||
}; | ||
|
||
class TrieConcatInferVarType : public framework::VarTypeInference { | ||
public: | ||
void operator()(const framework::OpDescBind& op_desc, | ||
framework::BlockDescBind* block) const override { | ||
for (auto& o : op_desc.Output("SentenceIds")) { | ||
block->Var(o)->SetType(framework::VarDesc::LOD_TENSOR); | ||
} | ||
for (auto& o : op_desc.Output("SentenceScores")) { | ||
block->Var(o)->SetType(framework::VarDesc::LOD_TENSOR); | ||
} | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
REGISTER_OPERATOR(trie_concat, paddle::operators::TrieConcatOp, | ||
paddle::operators::TrieConcatOpProtoMaker, | ||
paddle::operators::TrieConcatInferShape, | ||
paddle::operators::TrieConcatInferVarType, | ||
paddle::framework::EmptyGradOpMaker); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,248 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
|
||
#include "paddle/framework/op_registry.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
using value_type = float; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. DO NOT ASSUME THE VALUE TYPE IS FLOAT There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is a problem to be fixed after the main process is done There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done, use a template to describe score. |
||
using LoDTensor = framework::LoDTensor; | ||
|
||
const int64_t kInitLength = 1024; | ||
const int64_t kEndId = 0; | ||
|
||
// all the lod have 2 level, the first it source level, | ||
// each source have multiple possible sentences in the second level | ||
const size_t kSourceLevel = 0; | ||
const size_t kSentenceLevel = 1; | ||
|
||
struct BeamNode { | ||
BeamNode(int64_t word_id, float prob) : word_id_(word_id), prob_(prob) {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You need a destructor here ~BeamNode() {
if (parent_) {
parent_->DropKid(this);
if (parent_->NumKid() == 0) {
delete parent_;
}
}
} There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. currently, I use a function RemoveFromEnd to do this, but it need to be optimized~ thank you for your suggestion. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
|
||
void AppendTo(BeamNode* father) { | ||
father_ = father; | ||
father->kids_.push_back(this); | ||
} | ||
|
||
BeamNode* father_ = nullptr; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. father --> parent There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
std::vector<BeamNode*> kids_; | ||
int64_t word_id_; | ||
float prob_; | ||
}; | ||
|
||
struct BeamHelpter { | ||
// remove this prefix from the beam tree | ||
void RemoveFromEnd(BeamNode* end) { | ||
PADDLE_ENFORCE_EQ(end->kids_.size(), 0UL, "end should not have any kids"); | ||
auto* father = end->father_; | ||
if (father != nullptr) { | ||
// should use reference | ||
auto& kids = father->kids_; | ||
kids.erase(std::remove(kids.begin(), kids.end(), end), kids.end()); | ||
VLOG(3) << "Delete BeamNode with word_id:" << end->word_id_; | ||
delete end; | ||
if (father->kids_.size() == 0) { | ||
RemoveFromEnd(father); | ||
} | ||
} else { | ||
VLOG(3) << "Delete BeamNode root with word_id:" << end->word_id_; | ||
delete end; | ||
} | ||
} | ||
|
||
void AppendBeamNodeToResult(size_t source_idx, BeamNode* node) { | ||
std::vector<int64_t> sequence_ids; | ||
std::vector<float> sequence_probs; | ||
|
||
BeamNode* tmp = node; | ||
while (tmp != nullptr) { | ||
sequence_ids.emplace_back(tmp->word_id_); | ||
sequence_probs.emplace_back(tmp->prob_); | ||
tmp = tmp->father_; | ||
} | ||
|
||
std::reverse(std::begin(sequence_ids), std::end(sequence_ids)); | ||
std::reverse(std::begin(sequence_probs), std::end(sequence_probs)); | ||
|
||
result_id[source_idx].emplace_back(sequence_ids); | ||
result_prob[source_idx].push_back(sequence_probs); | ||
} | ||
|
||
std::vector<BeamNode*> PackTwoBeamStepOut( | ||
size_t source_idx, const std::vector<BeamNode*>& prefixes, | ||
const LoDTensor& cur_ids, const LoDTensor& cur_probs) { | ||
std::vector<BeamNode*> result; | ||
|
||
size_t source_start = cur_ids.lod()[kSourceLevel][source_idx]; | ||
size_t source_end = cur_ids.lod()[kSourceLevel][source_idx + 1]; | ||
PADDLE_ENFORCE_EQ(source_end - source_start, prefixes.size(), | ||
"prefix and candidate set number should be the same"); | ||
std::vector<size_t> candidate_offset = cur_ids.lod()[kSentenceLevel]; | ||
for (size_t prefix_idx = 0; prefix_idx < prefixes.size(); ++prefix_idx) { | ||
size_t candidate_start = candidate_offset[source_start + prefix_idx]; | ||
size_t candidate_end = candidate_offset[source_start + prefix_idx + 1]; | ||
auto* prefix = prefixes[prefix_idx]; | ||
PADDLE_ENFORCE_NE(prefix->word_id_, kEndId, | ||
"prefix should not contain end id"); | ||
if (candidate_start == candidate_end) { | ||
VLOG(3) << "this sentence has no more candidate, prune it"; | ||
// remove this sentence from Beam Tree. | ||
RemoveFromEnd(prefix); | ||
} else { | ||
// two level lod | ||
// [0 2 6] source level | ||
// [0 1 1 2 3 4] sentence level | ||
PADDLE_ENFORCE_NE(prefix->word_id_, kEndId, | ||
"End id should not have candidate anymore"); | ||
for (size_t candidate_index = candidate_start; | ||
candidate_index < candidate_end; ++candidate_index) { | ||
int64_t word_id = cur_ids.data<int64_t>()[candidate_index]; | ||
float prob = cur_probs.data<float>()[candidate_index]; | ||
auto* candidate = new BeamNode(word_id, prob); | ||
candidate->AppendTo(prefix); | ||
// if candidate is end id, then put it into result and remove it from | ||
// beam tree. | ||
if (word_id == kEndId) { | ||
AppendBeamNodeToResult(source_idx, candidate); | ||
RemoveFromEnd(candidate); | ||
} else { | ||
result.push_back(candidate); | ||
} | ||
} | ||
} | ||
} | ||
return result; | ||
} | ||
|
||
void InitFirstStepBeamNodes( | ||
const LoDTensor& tensor_id, const LoDTensor& tensor_prob, | ||
std::unordered_map<size_t, std::vector<BeamNode*>>* batch_beam_nodes) { | ||
// init beam_nodes for each source sentence. | ||
// in the first level, each sentence should have be a prefix | ||
// [0 3 6] level 0 | ||
// [0 1 2 3 4 5 6] level 1 | ||
// [0 0 0 0 0 0] data | ||
PADDLE_ENFORCE_EQ(tensor_id.lod().at(kSourceLevel).back(), | ||
tensor_id.lod().at(kSentenceLevel).back()); | ||
|
||
const size_t source_num = tensor_id.lod().at(kSourceLevel).size() - 1; | ||
|
||
for (size_t source_idx = 0; source_idx < source_num; ++source_idx) { | ||
std::vector<BeamNode*> init_beam_nodes; | ||
size_t source_start = tensor_id.lod().at(kSourceLevel).at(source_idx); | ||
size_t source_end = tensor_id.lod().at(kSourceLevel).at(source_idx + 1); | ||
|
||
for (size_t word_id_idx = source_start; word_id_idx < source_end; | ||
++word_id_idx) { | ||
init_beam_nodes.push_back( | ||
new BeamNode(tensor_id.data<int64_t>()[word_id_idx], | ||
tensor_prob.data<float>()[word_id_idx])); | ||
} | ||
(*batch_beam_nodes)[source_idx] = init_beam_nodes; | ||
} | ||
} | ||
|
||
void ConvertMapToLodTensor( | ||
const std::unordered_map<size_t, std::vector<std::vector<int64_t>>>& | ||
result_id, | ||
const std::unordered_map<size_t, std::vector<std::vector<float>>>& | ||
result_prob, | ||
LoDTensor* id_tensor, LoDTensor* prob_tensor) const { | ||
size_t source_num = result_id.size(); | ||
|
||
std::vector<size_t> source_level_lod = {0}; | ||
std::vector<size_t> sentence_level_lod = {0}; | ||
std::vector<int64_t> id_data; | ||
std::vector<float> prob_data; | ||
for (size_t source_idx = 0; source_idx < source_num; ++source_idx) { | ||
auto& all_sentence_ids = result_id.at(source_idx); | ||
auto& all_sentence_probs = result_prob.at(source_idx); | ||
for (size_t sentence_idx = 0; sentence_idx < all_sentence_ids.size(); | ||
++sentence_idx) { | ||
auto& sentence_ids = all_sentence_ids.at(sentence_idx); | ||
id_data.insert(id_data.end(), sentence_ids.begin(), sentence_ids.end()); | ||
auto& sentence_probs = all_sentence_probs.at(sentence_idx); | ||
prob_data.insert(prob_data.end(), sentence_probs.begin(), | ||
sentence_probs.end()); | ||
sentence_level_lod.push_back(sentence_level_lod.back() + | ||
sentence_ids.size()); | ||
} | ||
source_level_lod.push_back(source_level_lod.back() + | ||
all_sentence_ids.size()); | ||
} | ||
|
||
auto cpu_place = new paddle::platform::CPUPlace(); | ||
paddle::platform::CPUDeviceContext cpu_ctx(*cpu_place); | ||
|
||
framework::LoD lod; | ||
lod.push_back(source_level_lod); | ||
lod.push_back(sentence_level_lod); | ||
|
||
id_tensor->set_lod(lod); | ||
id_tensor->Resize({static_cast<int64_t>(id_data.size())}); | ||
id_tensor->mutable_data<int64_t>(paddle::platform::CPUPlace()); | ||
id_tensor->CopyFromVector<int64_t>(id_data, cpu_ctx); | ||
|
||
prob_tensor->set_lod(lod); | ||
prob_tensor->Resize({static_cast<int64_t>(prob_data.size())}); | ||
prob_tensor->mutable_data<float>(paddle::platform::CPUPlace()); | ||
prob_tensor->CopyFromVector<float>(prob_data, cpu_ctx); | ||
} | ||
|
||
void PackAllSteps(const std::vector<LoDTensor>& step_ids, | ||
const std::vector<LoDTensor>& step_probs, | ||
LoDTensor* id_tensor, LoDTensor* prob_tensor) { | ||
PADDLE_ENFORCE_EQ(step_ids.size(), step_probs.size(), | ||
"step_ids and step_probs should be the same"); | ||
size_t step_num = step_ids.size(); | ||
size_t source_num = step_ids.at(0).lod().at(kSourceLevel).size() - 1; | ||
|
||
std::unordered_map<size_t, std::vector<BeamNode*>> batch_beam_nodes; | ||
InitFirstStepBeamNodes(step_ids.at(0), step_probs.at(0), &batch_beam_nodes); | ||
|
||
// pack all steps for one batch first, then another batch | ||
for (size_t source_idx = 0; source_idx < source_num; ++source_idx) { | ||
for (size_t step_id = 1; step_id < step_num; ++step_id) { | ||
auto prefixes = batch_beam_nodes.at(source_idx); | ||
if (prefixes.size() > 0UL) { | ||
std::vector<BeamNode*> result = | ||
PackTwoBeamStepOut(source_idx, prefixes, step_ids.at(step_id), | ||
step_probs.at(step_id)); | ||
batch_beam_nodes[source_idx] = result; | ||
} else { | ||
VLOG(3) << "source_idx: " << source_idx << " step_id: " << step_id | ||
<< " have no more candidate"; | ||
} | ||
} | ||
|
||
// append last beam_node to result | ||
for (auto* beam_node : batch_beam_nodes.at(source_idx)) { | ||
AppendBeamNodeToResult(source_idx, beam_node); | ||
RemoveFromEnd(beam_node); | ||
} | ||
} | ||
|
||
ConvertMapToLodTensor(result_id, result_prob, id_tensor, prob_tensor); | ||
} | ||
|
||
public: | ||
std::unordered_map<size_t, std::vector<std::vector<int64_t>>> result_id; | ||
std::unordered_map<size_t, std::vector<std::vector<float>>> result_prob; | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This operator does not need header file at all. You can put all code into .cc file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use a header file is easier for testing. because this op is a little complex, it need to be tested carefully in CPP.