Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
co63oc committed Jun 4, 2024
1 parent c8add67 commit e6c606b
Showing 1 changed file with 78 additions and 3 deletions.
81 changes: 78 additions & 3 deletions paddle/phi/kernels/fusion/cpu/fusion_lstm_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,32 @@ void BatchCompute(const Context &dev_ctx,
INIT_BASE_DEFINES;
if (x->lod()[0].size() == 2) {
xx->Resize({x_dims[0], D4});
SeqCompute(ctx);
SeqCompute(dev_ctx,
x_in,
weight_x_in,
weight_h_in,
bias_in,
h0_in,
c0_in,
use_peepholes,
is_reverse,
use_seq,
gate_activation,
cell_activation,
candidate_activation,
scale_data,
shift_data,
scale_weights,
force_fp32_output,
hidden,
cell,
xx,
batched_input,
batched_hidden,
batched_cell,
reordered_h0,
reordered_c0,
checked_cell);
return;
}
INIT_OTHER_DEFINES;
Expand Down Expand Up @@ -352,9 +377,59 @@ void FusionLSTMKernel(const Context &dev_ctx,
DenseTensor *reordered_c0,
DenseTensor *checked_cell) const override {
if (use_seq) {
SeqCompute<T, Context>(dev_ctx);
SeqCompute<T, Context>(dev_ctx,
x_in,
weight_x_in,
weight_h_in,
bias_in,
h0_in,
c0_in,
use_peepholes,
is_reverse,
use_seq,
gate_activation,
cell_activation,
candidate_activation,
scale_data,
shift_data,
scale_weights,
force_fp32_output,
hidden,
cell,
xx,
batched_input,
batched_hidden,
batched_cell,
reordered_h0,
reordered_c0,
checked_cell);
} else {
BatchCompute<T, Context>(dev_ctx);
BatchCompute<T, Context>(dev_ctx,
x_in,
weight_x_in,
weight_h_in,
bias_in,
h0_in,
c0_in,
use_peepholes,
is_reverse,
use_seq,
gate_activation,
cell_activation,
candidate_activation,
scale_data,
shift_data,
scale_weights,
force_fp32_output,
hidden,
cell,
xx,
batched_input,
batched_hidden,
batched_cell,
reordered_h0,
reordered_c0,
checked_cell);
}
}

Expand Down

0 comments on commit e6c606b

Please sign in to comment.