Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELAY][PASS] Make Fusor to handle split gracefully #2890

Closed
ajtulloch opened this issue Mar 25, 2019 · 13 comments
Closed

[RELAY][PASS] Make Fusor to handle split gracefully #2890

ajtulloch opened this issue Mar 25, 2019 · 13 comments
Assignees

Comments

@ajtulloch
Copy link
Contributor

ajtulloch commented Mar 25, 2019

Hi folks,

A common pattern in LSTM/GRU-style cells is a structure like (for simplicity):

        rnn_dim = 10
        X = relay.var("X", shape=(1, rnn_dim))
        W = relay.var("y", shape=(3 * rnn_dim, rnn_dim))
        matmul = relay.nn.dense(X, W)
        splitted = relay.split(matmul, indices_or_sections=3, axis=1)
        out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2])

Normally when implementing this in Relay, we'd expect that graph_fuse would fuse this entire sequence (matmul + split + sigmoid/tanh/exp/add/mul) into a single function, as that's entirely reasonable expectation and generates the highest performance code. That is, we expect:

fn (%X: Tensor[(1, 10), float32],
    %y: Tensor[(30, 10), float32])
    -> Tensor[(1, 10), float32] {
  %0 = nn.dense(%X, %y, units=None)
  %1 = split(%0, indices_or_sections=int64(3), axis=1)
  %2 = %1.0
  %3 = sigmoid(%2)
  %4 = %1.1
  %5 = tanh(%4)
  %6 = %1.2
  %7 = exp(%6)
  %8 = multiply(%5, %7)
  %9 = add(%3, %8)
  %9
}

Instead, Relay generates something like:

fn (%X: Tensor[(1, 10), float32],
    %y: Tensor[(30, 10), float32])
    -> Tensor[(1, 10), float32] {
  %0 = fn(%p0: Tensor[(1, 10), float32],
          %p1: Tensor[(30, 10), float32])
          -> Tensor[(1, 30), float32] {
    %1 = nn.dense(%p0, %p1, units=None) # ty=Tensor[(1, 30), float32]
    %1
  }
  %2 = %0(%X, %y) # ty=Tensor[(1, 30), float32]
  %3 = fn(%p01: Tensor[(1, 30), float32])
          -> Tuple[Tensor[(1, 10), float32], Tensor[(1, 10), float32], Tensor[(1, 10), float32]] {
    %4 = split(%p01, indices_or_sections=int64(3), axis=1) # ty=Tuple[Tensor[(1, 10), float32], Tensor[(1, 10), float32], Tensor[(1, 10), float32]]
    %4
  }
  %5 = %3(%2) # ty=Tuple[Tensor[(1, 10), float32], Tensor[(1, 10), float32], Tensor[(1, 10), float32]]
  %6 = %5.0
  %7 = %5.1
  %8 = %5.2
  %9 = fn(%p02: Tensor[(1, 10), float32],
          %p11: Tensor[(1, 10), float32],
          %p2: Tensor[(1, 10), float32])
          -> Tensor[(1, 10), float32] {
    %10 = sigmoid(%p02) # ty=Tensor[(1, 10), float32]
    %11 = tanh(%p11) # ty=Tensor[(1, 10), float32]
    %12 = exp(%p2) # ty=Tensor[(1, 10), float32]
    %13 = multiply(%11, %12) # ty=Tensor[(1, 10), float32]
    %14 = add(%10, %13) # ty=Tensor[(1, 10), float32]
    %14
  }
  %15 = %9(%6, %7, %8) # ty=Tensor[(1, 10), float32]
  %15
}

While of course it would be possible to implement a "GateComputation" op or similar which is internally just (split + pointwise functions), but it would be quite elegant to avoid that if possible.

I'm not fluent in the Relay GraphFuser code, but I was hoping someone (@jroesch?) knows off the top of their head what needs to be modified inside the fuser, and I or someone else can do the implementation work.

cc @jroesch, @yidawang, @tqchen

@tqchen
Copy link
Member

tqchen commented Mar 25, 2019

The problem has something to do with the tuple handling, in particular, wrt to the split, which we do not do too well so far, I agree that this should be handled more gracefully.

@tqchen tqchen changed the title Generalize fusion rules for RNN gate cases (Dense + Split + Pointwise)? [RELAY][PASS] Make Fusor Handle Split Gracefully Mar 25, 2019
@tqchen
Copy link
Member

tqchen commented Mar 25, 2019

cc @masahi @kazum who might be interested in this as well.

@ajtulloch
Copy link
Contributor Author

Thanks @tqchen. For reference, something like this:

def test_dense_split():
    def before():
        rnn_dim = 10
        X = relay.var("X", shape=(1, rnn_dim))
        W = relay.var("y", shape=(3 * rnn_dim, rnn_dim))
        matmul = relay.nn.dense(X, W)
        splitted = relay.split(matmul, indices_or_sections=3, axis=1)
        out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2])
        return relay.Function(relay.ir_pass.free_vars(out), out)

    def expected():
        rnn_dim = 10
        X = relay.var("X", shape=(1, rnn_dim))
        W = relay.var("W", shape=(3 * rnn_dim, rnn_dim))
        matmul = relay.nn.dense(X, W)
        out = matmul * relay.const(1.0, "float32") + relay.const(1.0, "float32")
        splitted = relay.split(matmul, indices_or_sections=3, axis=1)
        out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2])
        outf = relay.Function([X, W], out)

        X = relay.var("X", shape=(1, rnn_dim))
        W = relay.var("W", shape=(3 * rnn_dim, rnn_dim))
        y = relay.Call(outf, [X, W])
        return relay.Function([X, W], y)

    z = before()
    z = relay.ir_pass.infer_type(z)
    zz = relay.ir_pass.fuse_ops(z, opt_level=2)
    zz = relay.ir_pass.infer_type(zz)
    assert not relay.ir_pass.free_vars(zz)
    after = relay.ir_pass.infer_type(expected())
    print(zz.astext())
    print(after.astext())
    assert relay.ir_pass.alpha_equal(zz, after)

Expected IR:


fn (%X: Tensor[(1, 10), float32],
    %W: Tensor[(30, 10), float32])
    -> Tensor[(1, 10), float32] {
  %0 = fn(%X1: Tensor[(1, 10), float32],
          %W1: Tensor[(30, 10), float32])
          -> Tensor[(1, 10), float32] {
    %1 = nn.dense(%X1, %W1, units=None) # ty=Tensor[(1, 30), float32]
    %2 = split(%1, indices_or_sections=int64(3), axis=1) # ty=Tuple[Tensor[(1, 10), float32], Tensor[(1, 10), float32], Tensor[(1, 10), float32]]
    %3 = %2.0
    %4 = sigmoid(%3) # ty=Tensor[(1, 10), float32]
    %5 = %2.1
    %6 = tanh(%5) # ty=Tensor[(1, 10), float32]
    %7 = %2.2
    %8 = exp(%7) # ty=Tensor[(1, 10), float32]
    %9 = multiply(%6, %8) # ty=Tensor[(1, 10), float32]
    %10 = add(%4, %9) # ty=Tensor[(1, 10), float32]
    %10
  }
  %11 = %0(%X, %W) # ty=Tensor[(1, 10), float32]
  %11

Actual IR:

fn (%X: Tensor[(1, 10), float32],
    %y: Tensor[(30, 10), float32])
    -> Tensor[(1, 10), float32] {
  %0 = fn(%p0: Tensor[(1, 10), float32],
          %p1: Tensor[(30, 10), float32])
          -> Tensor[(1, 30), float32] {
    %1 = nn.dense(%p0, %p1, units=None) # ty=Tensor[(1, 30), float32]
    %1
  }
  %2 = %0(%X, %y) # ty=Tensor[(1, 30), float32]
  %3 = fn(%p01: Tensor[(1, 30), float32])
          -> Tuple[Tensor[(1, 10), float32], Tensor[(1, 10), float32], Tensor[(1, 10), float32]] {
    %4 = split(%p01, indices_or_sections=int64(3), axis=1) # ty=Tuple[Tensor[(1, 10), float32], Tensor[(1, 10), float32], Tensor[(1, 10), float32]]
    %4
  }
  %5 = %3(%2) # ty=Tuple[Tensor[(1, 10), float32], Tensor[(1, 10), float32], Tensor[(1, 10), float32]]
  %6 = %5.0
  %7 = %5.1
  %8 = %5.2
  %9 = fn(%p02: Tensor[(1, 10), float32],
          %p11: Tensor[(1, 10), float32],
          %p2: Tensor[(1, 10), float32])
          -> Tensor[(1, 10), float32] {
    %10 = sigmoid(%p02) # ty=Tensor[(1, 10), float32]
    %11 = tanh(%p11) # ty=Tensor[(1, 10), float32]
    %12 = exp(%p2) # ty=Tensor[(1, 10), float32]
    %13 = multiply(%11, %12) # ty=Tensor[(1, 10), float32]
    %14 = add(%10, %13) # ty=Tensor[(1, 10), float32]
    %14
  }
  %15 = %9(%6, %7, %8) # ty=Tensor[(1, 10), float32]
  %15
}

is a test case that I believe should pass if this improvement is implemented.

@tqchen
Copy link
Member

tqchen commented Mar 25, 2019

For reference, previously concat was not properly handled and was fixed by this PR #2187 likely we need similar thing for split

@tqchen tqchen changed the title [RELAY][PASS] Make Fusor Handle Split Gracefully [RELAY][PASS] Make Fusor to handle split gracefully Mar 25, 2019
@masahi
Copy link
Member

masahi commented Mar 25, 2019

Ok, this is my favorite topic. I can work on this

@masahi masahi self-assigned this Mar 25, 2019
@tqchen
Copy link
Member

tqchen commented Mar 25, 2019

We might also want to confirm if the current low-level topi layer already allows the generation of the corresponding fused kernel.

@masahi
Copy link
Member

masahi commented Mar 25, 2019

I dont expect we can fuse dense with split in our implementation.

@tqchen
Copy link
Member

tqchen commented Mar 25, 2019

The split can at least be fused into the later stages of element-wise ops. To fuse split into dense, we might want to think a bit deeper about changing some of the fusion rule, in the meanwhile, we should work to enable that in topi level

@masahi
Copy link
Member

masahi commented Mar 25, 2019

@tqchen I'm puzzled by the usage of TupleGetItemNode. In src/relay/pass/fuse_ops.cc, we have a visitor which takes TupleGetItemNode, and continues visiting the supposedly inner tuple node.

void VisitExpr_(const TupleGetItemNode* op) final {
    CHECK(graph_.node_map.count(op));
    Node* node = graph_.node_map.at(op);
    this->Update(op->tuple, node, kOpaque);
    ExprVisitor::VisitExpr_(op);
    this->AddNode(op);
  }

But it looks like the visitor does not visit the tuple node after visiting TupleGetItemNode above (this function is never called). In fact, op->tuple in the above snippet seems to be a CallNode, the output of split op in this line.

Is this behavior expected? I'm wondering why we are applying TupleGetItem on a CallNode and how it is supposed to work without creating a tuple value explicitly.

@ajtulloch
Copy link
Contributor Author

ajtulloch commented Mar 25, 2019

FWIW I realized this also introduces some interesting new ideas for scheduling. Normally in our existing frameworks we do stuff like batching the entire gate computation into a single X * k * D gemm (for k = 3 for GRU, 4 for LSTM, etc).

This has the problem that (once you can fuse with split), that to compute the gate values you inspect at D-strided locations in the output, and so in practice you end up fully realizing that intermediate computation. It may be probably more natural to instead reorder the single GEMM into vector-width blocked variants (i.e. instead of a (kD, D) matrix you'd reorder it into a ((D // V) x k x V, D) and then realize it at (k, V)-sized locations. Most existing frameworks I know of don't take advantage of this, but conceptually it has the potential to be quite useful.

@tqchen
Copy link
Member

tqchen commented Mar 26, 2019

@masahi in the case of TupleGetItem, the Call will return a value with TupleTyype, in our current case, it updates tuple with opaque which cut off the graph(this was the old behavior). We might be able to generalize this like the way we support concat

@jroesch
Copy link
Member

jroesch commented Mar 26, 2019

@masahi thanks for taking point on this 👍

@tqchen
Copy link
Member

tqchen commented Apr 17, 2019

#3039

@tqchen tqchen closed this as completed Apr 17, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants