Skip to content

Commit

Permalink
[CINN] Add constraint for while op (PaddlePaddle#64385)
Browse files Browse the repository at this point in the history
* add constraint for while of

* fix bug
  • Loading branch information
zyfncg authored and co63oc committed May 18, 2024
1 parent 70535a9 commit 1a13d37
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -712,11 +712,22 @@ bool WhileOp::InferSymbolicShape(
input_arg_shape[j]);
continue;
}
if (original_input_shape.size() == yield_value_shape.size() &&
original_input_shape[j] == yield_value_shape[j]) {
infer_context->AddEqualCstr(original_input_shape[j],
input_arg_shape[j]);
continue;
if (original_input_shape.size() == yield_value_shape.size()) {
if (original_input_shape[j] == yield_value_shape[j]) {
infer_context->AddEqualCstr(original_input_shape[j],
input_arg_shape[j]);
continue;
}
symbol::DimExprBuilder builder;
if (yield_value_shape[j] ==
builder.Broadcast(input_arg_shape[j],
original_input_shape[j]) ||
yield_value_shape[j] == builder.Broadcast(original_input_shape[j],
input_arg_shape[j])) {
infer_context->AddEqualCstr(original_input_shape[j],
input_arg_shape[j]);
continue;
}
}
}
}
Expand Down

0 comments on commit 1a13d37

Please sign in to comment.