Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
co63oc committed May 8, 2024
1 parent 11de092 commit 664a021
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 0 deletions.
58 changes: 58 additions & 0 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1101,6 +1101,64 @@ void ConcatInferMeta(const std::vector<const MetaTensor*>& x,
out->share_lod(*x.at(0));
}

void ChunkEvalInferMeta(const MetaTensor& inference,
const MetaTensor& label,
const MetaTensor& seq_length,
int num_chunk_types,
const std::string& chunk_scheme,
const std::vector<int>& excluded_chunk_types,
MetaTensor* precision,
MetaTensor* recall,
MetaTensor* f1_score,
MetaTensor* num_infer_chunks,
MetaTensor* num_label_chunks,
MetaTensor* num_correct_chunks) {
auto inference_dim = inference.dims();
auto label_dim = label.dims();

PADDLE_ENFORCE_EQ(
inference_dim,
label_dim,
phi::errors::InvalidArgument(
"Input(Inference)'s shape must be the same as Input(Label)'s "
"shape, but received [%s] (Inference) vs [%s] (Label).",
inference_dim,
label_dim));

bool use_padding = seq_length.initialized();
if (use_padding) {
PADDLE_ENFORCE_EQ((inference_dim.size() == 3 && inference_dim[2] == 1) ||
inference_dim.size() == 2,
true,
phi::errors::InvalidArgument(
"when Input(SeqLength) is provided, Input(Inference) "
"should be of dim 3 (batch_size, bucket, 1) or dim 2 "
"(batch_size, bucket), but received [%s].",
inference_dim));
auto seq_length_dim = seq_length.get().dims();
PADDLE_ENFORCE_LE(seq_length_dim.size(),
2,
phi::errors::InvalidArgument(
"Input(SeqLength)'s rank should not be greater "
"than 2, but received %d.",
seq_length_dim.size()));
}

precision->set_dims({1});
recall->set_dims({1});
f1_score->set_dims({1});
num_infer_chunks->set_dims({1});
num_label_chunks->set_dims({1});
num_correct_chunks->set_dims({1});

precision->set_dtype(phi::DataType::FLOAT32);
recall->set_dtype(phi::DataType::FLOAT32);
f1_score->set_dtype(phi::DataType::FLOAT32);
num_infer_chunks->set_dtype(phi::DataType::INT64);
num_label_chunks->set_dtype(phi::DataType::INT64);
num_correct_chunks->set_dtype(phi::DataType::INT64);
}

void CudnnLSTMInferMeta(
const MetaTensor& x,
const MetaTensor& init_h,
Expand Down
13 changes: 13 additions & 0 deletions paddle/phi/infermeta/multiary.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,19 @@ void ConcatInferMeta(const std::vector<const MetaTensor*>& x,
MetaTensor* out,
MetaConfig config = MetaConfig());

void ChunkEvalInferMeta(const MetaTensor& inference,
const MetaTensor& label,
const MetaTensor& seq_length,
int num_chunk_types,
const std::string& chunk_scheme,
const std::vector<int>& excluded_chunk_types,
MetaTensor* precision,
MetaTensor* recall,
MetaTensor* f1_score,
MetaTensor* num_infer_chunks,
MetaTensor* num_label_chunks,
MetaTensor* num_correct_chunks);

void CudnnLSTMInferMeta(
const MetaTensor& x,
const MetaTensor& init_h,
Expand Down

0 comments on commit 664a021

Please sign in to comment.