-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RELAY][PASS] Make Fusor to handle split gracefully #2890
Comments
The problem has something to do with the tuple handling, in particular, wrt to the split, which we do not do too well so far, I agree that this should be handled more gracefully. |
Thanks @tqchen. For reference, something like this: def test_dense_split():
def before():
rnn_dim = 10
X = relay.var("X", shape=(1, rnn_dim))
W = relay.var("y", shape=(3 * rnn_dim, rnn_dim))
matmul = relay.nn.dense(X, W)
splitted = relay.split(matmul, indices_or_sections=3, axis=1)
out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2])
return relay.Function(relay.ir_pass.free_vars(out), out)
def expected():
rnn_dim = 10
X = relay.var("X", shape=(1, rnn_dim))
W = relay.var("W", shape=(3 * rnn_dim, rnn_dim))
matmul = relay.nn.dense(X, W)
out = matmul * relay.const(1.0, "float32") + relay.const(1.0, "float32")
splitted = relay.split(matmul, indices_or_sections=3, axis=1)
out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2])
outf = relay.Function([X, W], out)
X = relay.var("X", shape=(1, rnn_dim))
W = relay.var("W", shape=(3 * rnn_dim, rnn_dim))
y = relay.Call(outf, [X, W])
return relay.Function([X, W], y)
z = before()
z = relay.ir_pass.infer_type(z)
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
zz = relay.ir_pass.infer_type(zz)
assert not relay.ir_pass.free_vars(zz)
after = relay.ir_pass.infer_type(expected())
print(zz.astext())
print(after.astext())
assert relay.ir_pass.alpha_equal(zz, after) Expected IR:
Actual IR:
is a test case that I believe should pass if this improvement is implemented. |
For reference, previously concat was not properly handled and was fixed by this PR #2187 likely we need similar thing for split |
Ok, this is my favorite topic. I can work on this |
We might also want to confirm if the current low-level topi layer already allows the generation of the corresponding fused kernel. |
I dont expect we can fuse dense with split in our implementation. |
The split can at least be fused into the later stages of element-wise ops. To fuse split into dense, we might want to think a bit deeper about changing some of the fusion rule, in the meanwhile, we should work to enable that in topi level |
@tqchen I'm puzzled by the usage of TupleGetItemNode. In src/relay/pass/fuse_ops.cc, we have a visitor which takes TupleGetItemNode, and continues visiting the supposedly inner tuple node.
But it looks like the visitor does not visit the tuple node after visiting TupleGetItemNode above (this function is never called). In fact, op->tuple in the above snippet seems to be a CallNode, the output of split op in this line. Is this behavior expected? I'm wondering why we are applying TupleGetItem on a CallNode and how it is supposed to work without creating a tuple value explicitly. |
FWIW I realized this also introduces some interesting new ideas for scheduling. Normally in our existing frameworks we do stuff like batching the entire gate computation into a single X * k * D gemm (for k = 3 for GRU, 4 for LSTM, etc). This has the problem that (once you can fuse with split), that to compute the gate values you inspect at D-strided locations in the output, and so in practice you end up fully realizing that intermediate computation. It may be probably more natural to instead reorder the single GEMM into vector-width blocked variants (i.e. instead of a (kD, D) matrix you'd reorder it into a ((D // V) x k x V, D) and then realize it at (k, V)-sized locations. Most existing frameworks I know of don't take advantage of this, but conceptually it has the potential to be quite useful. |
@masahi in the case of TupleGetItem, the Call will return a value with TupleTyype, in our current case, it updates tuple with opaque which cut off the graph(this was the old behavior). We might be able to generalize this like the way we support concat |
@masahi thanks for taking point on this 👍 |
Hi folks,
A common pattern in LSTM/GRU-style cells is a structure like (for simplicity):
Normally when implementing this in Relay, we'd expect that graph_fuse would fuse this entire sequence (matmul + split + sigmoid/tanh/exp/add/mul) into a single function, as that's entirely reasonable expectation and generates the highest performance code. That is, we expect:
Instead, Relay generates something like:
While of course it would be possible to implement a "GateComputation" op or similar which is internally just (split + pointwise functions), but it would be quite elegant to avoid that if possible.
I'm not fluent in the Relay GraphFuser code, but I was hoping someone (@jroesch?) knows off the top of their head what needs to be modified inside the fuser, and I or someone else can do the implementation work.
cc @jroesch, @yidawang, @tqchen
The text was updated successfully, but these errors were encountered: