-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Design of if else op #3828
Design of if else op #3828
Conversation
doc/design/if_else_op.md
Outdated
In an if_op, only inputs with condition satisfied will be run. The op could have multiple inputs and multiple outputs. | ||
We should have the following design: | ||
|
||
```python |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Basing on #3827, Python bindings of operators and layers should return Vars. So I think that this usage could be
x1 = Var()
x2 = Var()
y1 = Var()
y2 = Var()
cond = Var()
with paddle.block() as left:
x1 = if_input(x1)
x2 = if_input(x2)
y1 = mul(x1, x2)
y2 = add(x1, x2)
with paddle.block() as right:
x1 = if_input(x1)
x2 = if_input(x2)
y1 = add(x1, x2)
y2 = mul(x1, x2)
y1, y2 = layer.if(cond, left, right, output=[y1, y2])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import paddle as pd
x1 = Var()
x2 = Var()
cond = Var()
c1 = pd.IfElseOp(inputs=[x1, x2], output_num=2)
with c1.true_block():
x1, x2 = c1.inputs()
y1, y2 = c1.outputs()
y1 = mul(x1, x2)
y2 = add(x1, x2)
with c1.false_block():
x1, x2 = c1.inputs()
y1, y2 = c1.outputs()
y1 = add(x1, x2)
y2 = mul(x1, x2)
out1, out2 = c1(cond)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe @Superjom 's proposal is better than mine because
- @Superjom 's approach defines the input and output of the branch block well, and
- it doesn't introduce variable duplication/copying -- the x1 and x2 in the block are exactly x1 and x2 defined outside.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Above example breaks the convention that Python operator bindings must return Vars.
However, mine has a severe problem that the two branches would overwrite y1 and y2. A right solution should keep two copies of y1 and y2 for the left and the right branches respectively and then merge these two copies into the global variables y1 and y2.
A correct solution can be derived easily from @Superjom 's solution -- to rename paddle.IfElseOp
into paddle.create_ifelseop_builder
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import paddle as pd
x1 = var()
x2 = var()
c = pd.create_ifelseop_builder(inputs=[x1, x2], output_num=2)
with c.true_block() as b:
x1, x2 = b.inputs()
b.set_output(0, mul(x1, x2))
b.set_output(1, add(x1, x2))
with c.false_block() as b:
x1, x2 = c1.inputs()
b.set_output(0, add(x1, x2))
b.set_output(1, mul(x1, x2))
cond = var()
out1, out2 = c(cond)
doc/design/if_else_op.md
Outdated
|
||
In an if_op, only inputs with condition satisfied will be run. | ||
We should have the following design: | ||
```python |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IfOp should have only one branch. An IfOp operator takes a cond
variable whose value must be a vector of N boolean elements. Its return value has M (M<=N) instances, each corresponds to a true element in cond
.
import paddle as pd
x = var()
y = var()
cond = var()
b = pd.create_ifop_builder(inputs=[x], output_num=1)
with b.true_block():
x = b.inputs(0)
z = operator.add(x, y)
b.set_output(0, operator.softmax(z))
out = b(cond)
If we want the output still has N instances, we can use IfElseOp with a default value, whose minibatch size must be N:
import paddle as pd
x = var()
y = var()
cond = var()
default_value = var()
b = pd.create_ifelseop_builder(inputs=[x], output_num=1, default_value)
with b.true_block():
x = b.inputs(0)
z = operator.add(x, y)
b.set_output(0, operator.softmax(z))
out = b(cond)
If the IfElseOp has multiple return values, the default_value must be a list of variables with corresponding shapes.
@@ -0,0 +1,59 @@ | |||
IfOp should have only one branch. An IfOp operator takes a `cond` variable whose value must be a vector of N boolean elements. Its return value has M (M<=N) instances, each corresponds to a true element in `cond`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we add a static IfElseOp just like TF, just run one branch.
This design describe a DynamicIfElseOp ? am I right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make this IfElseOp. If you are going to add a static conditional branching structure later, we can name it StaticIfElseOp.
doc/design/if_else_op.md
Outdated
y = var() | ||
cond = var() | ||
|
||
b = pd.create_ifop_builder(inputs=[x], output_num=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
create_ifop_builder ==> create_ifop, as suggested by @Superjom
doc/design/if_else_op.md
Outdated
y = var() | ||
cond = var() | ||
default_value = var() | ||
b = pd.create_ifelseop_builder(inputs=[x], output_num=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
create_ifelseop_builder => create_ifelseop
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…Paddle#3844) * [dev] rbox update2 (PaddlePaddle#3828) * set lr for 4 card as default, and update * fix error
* set lr for 4 card as default, and update
No description provided.