Skip to content

Commit

Permalink
Move at::chunk into the graph fuser (pytorch#10178)
Browse files Browse the repository at this point in the history
Summary:
... to avoid slow at::chunk (it is slow due to tensor initialization). Picking up from pytorch#10026

This is done through the following:

1) Absorb starting chunks into FusionGroup as a part of the graph fuser
pass.
2) When compiling a kernel, emit a `std::vector<ConcatDesc>` that describes if an input (of the original graph) will be chunked.
3) When launching a kernel, `use std::vector<ConcatDesc>` to chunk an
input tensor on the CPU. This chunk directly takes in an at::Tensor and creates
four TensorInfo structs in-place in the argument list, bypassing the creation of intermediate Tensors.

- Expect test and correctness test to see if a single chunk is fused
  by the graph fuser
- Correctness test for a variety of chunks (dimension = beginning,
  middle, end) and tensors (contiguous, non-contiguous, edge case
  (splitSize = 1) for both CPU/CUDA
- Expect test for multiple chunks fused into the same kernel and
  correctness test.

cc zdevito apaszke

LSTM forward pass, 1 layer, 512 hidden size and input size, 100 seq length, requires_grad=False on all inputs and weights.

After changes:
```
thnn    cudnn   jit
8.8468  6.5797  9.3470
```

Before changes:
```
thnn    cudnn   jit
9.9221  6.6539  11.2550
```
Pull Request resolved: pytorch#10178

Differential Revision: D9382661

Pulled By: zou3519

fbshipit-source-id: 1f8a749208fbdd45559775ce98cf4eb9558448f8
  • Loading branch information
zou3519 authored and facebook-github-bot committed Aug 18, 2018
1 parent d87b4e9 commit f1420ad
Show file tree
Hide file tree
Showing 12 changed files with 517 additions and 184 deletions.
20 changes: 8 additions & 12 deletions test/expect/TestJit.test_fusion_distribute.expect
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
graph(%0 : Float(4, 4)
%1 : Float(4, 4)) {
%2 : int = prim::Constant[value=1]()
%3 : int = prim::Constant[value=2]()
%4 : Float(4!, 2), %5 : Float(4!, 2) = aten::chunk(%0, %3, %2)
%6 : Float(4!, 2), %7 : Float(4!, 2) = aten::chunk(%1, %3, %2)
%8 : Float(4, 2) = prim::FusionGroup_0[device=0](%4, %6, %5, %7)
return (%8);
%2 : Float(4, 2) = prim::FusionGroup_0[device=0](%0, %1)
return (%2);
}
with prim::FusionGroup_0 = graph(%3 : Float(4!, 2)
%4 : Float(4!, 2)
%7 : Float(4!, 2)
%8 : Float(4!, 2)) {
with prim::FusionGroup_0 = graph(%11 : Float(4, 4)
%14 : Float(4, 4)) {
%15 : Dynamic, %16 : Dynamic = prim::FusedChunk[chunks=2, dim=1](%14)
%12 : Dynamic, %13 : Dynamic = prim::FusedChunk[chunks=2, dim=1](%11)
%9 : int = prim::Constant[value=1]()
%10 : Float(4, 2) = aten::add(%7, %8, %9)
%10 : Float(4, 2) = aten::add(%13, %16, %9)
%5 : int = prim::Constant[value=1]()
%6 : Float(4, 2) = aten::add(%3, %4, %5)
%6 : Float(4, 2) = aten::add(%12, %15, %5)
%2 : Float(4, 2) = aten::mul(%6, %10)
return (%2);
}
27 changes: 10 additions & 17 deletions test/expect/TestJit.test_lstm_fusion_concat.expect
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,22 @@ graph(%0 : Float(3, 10)
%9 : Float(3, 80) = aten::addmm(%5, %0, %7, %8, %8)
%10 : Float(20!, 80!) = aten::t(%4)
%11 : Float(3, 80) = aten::addmm(%6, %1, %10, %8, %8)
%12 : int = prim::Constant[value=4]()
%13 : Float(3!, 20), %14 : Float(3!, 20), %15 : Float(3!, 20), %16 : Float(3!, 20) = aten::chunk(%9, %12, %8)
%17 : Float(3!, 20), %18 : Float(3!, 20), %19 : Float(3!, 20), %20 : Float(3!, 20) = aten::chunk(%11, %12, %8)
%21 : Float(6, 20) = prim::FusionGroup_0[device=0](%2, %16, %20, %15, %19, %14, %18, %13, %17)
return (%21);
%12 : Float(6, 20) = prim::FusionGroup_0[device=0](%2, %9, %11)
return (%12);
}
with prim::FusionGroup_0 = graph(%15 : Float(3, 20)
%25 : Float(3!, 20)
%26 : Float(3!, 20)
%29 : Float(3!, 20)
%30 : Float(3!, 20)
%33 : Float(3!, 20)
%34 : Float(3!, 20)
%37 : Float(3!, 20)
%38 : Float(3!, 20)) {
%41 : Float(3, 80)
%46 : Float(3, 80)) {
%47 : Dynamic, %48 : Dynamic, %49 : Dynamic, %50 : Dynamic = prim::FusedChunk[chunks=4, dim=1](%46)
%42 : Dynamic, %43 : Dynamic, %44 : Dynamic, %45 : Dynamic = prim::FusedChunk[chunks=4, dim=1](%41)
%39 : int = prim::Constant[value=1]()
%40 : Float(3, 20) = aten::add(%37, %38, %39)
%40 : Float(3, 20) = aten::add(%42, %47, %39)
%35 : int = prim::Constant[value=1]()
%36 : Float(3, 20) = aten::add(%33, %34, %35)
%36 : Float(3, 20) = aten::add(%43, %48, %35)
%31 : int = prim::Constant[value=1]()
%32 : Float(3, 20) = aten::add(%29, %30, %31)
%32 : Float(3, 20) = aten::add(%44, %49, %31)
%27 : int = prim::Constant[value=1]()
%28 : Float(3, 20) = aten::add(%25, %26, %27)
%28 : Float(3, 20) = aten::add(%45, %50, %27)
%24 : Float(3, 20) = aten::sigmoid(%40)
%22 : Float(3, 20) = aten::sigmoid(%36)
%20 : Float(3, 20) = aten::tanh(%32)
Expand Down
27 changes: 10 additions & 17 deletions test/expect/TestJit.test_lstm_fusion_cuda.expect
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,22 @@ graph(%0 : Float(3, 10)
%9 : Float(3, 80) = aten::addmm(%5, %0, %7, %8, %8)
%10 : Float(20!, 80!) = aten::t(%4)
%11 : Float(3, 80) = aten::addmm(%6, %1, %10, %8, %8)
%12 : int = prim::Constant[value=4]()
%13 : Float(3!, 20), %14 : Float(3!, 20), %15 : Float(3!, 20), %16 : Float(3!, 20) = aten::chunk(%9, %12, %8)
%17 : Float(3!, 20), %18 : Float(3!, 20), %19 : Float(3!, 20), %20 : Float(3!, 20) = aten::chunk(%11, %12, %8)
%21 : Float(3, 20), %22 : Float(3, 20) = prim::FusionGroup_0[device=0](%2, %16, %20, %15, %19, %14, %18, %13, %17)
return (%21, %22);
%12 : Float(3, 20), %13 : Float(3, 20) = prim::FusionGroup_0[device=0](%2, %9, %11)
return (%12, %13);
}
with prim::FusionGroup_0 = graph(%13 : Float(3, 20)
%23 : Float(3!, 20)
%24 : Float(3!, 20)
%27 : Float(3!, 20)
%28 : Float(3!, 20)
%31 : Float(3!, 20)
%32 : Float(3!, 20)
%35 : Float(3!, 20)
%36 : Float(3!, 20)) {
%39 : Float(3, 80)
%44 : Float(3, 80)) {
%45 : Dynamic, %46 : Dynamic, %47 : Dynamic, %48 : Dynamic = prim::FusedChunk[chunks=4, dim=1](%44)
%40 : Dynamic, %41 : Dynamic, %42 : Dynamic, %43 : Dynamic = prim::FusedChunk[chunks=4, dim=1](%39)
%37 : int = prim::Constant[value=1]()
%38 : Float(3, 20) = aten::add(%35, %36, %37)
%38 : Float(3, 20) = aten::add(%40, %45, %37)
%33 : int = prim::Constant[value=1]()
%34 : Float(3, 20) = aten::add(%31, %32, %33)
%34 : Float(3, 20) = aten::add(%41, %46, %33)
%29 : int = prim::Constant[value=1]()
%30 : Float(3, 20) = aten::add(%27, %28, %29)
%30 : Float(3, 20) = aten::add(%42, %47, %29)
%25 : int = prim::Constant[value=1]()
%26 : Float(3, 20) = aten::add(%23, %24, %25)
%26 : Float(3, 20) = aten::add(%43, %48, %25)
%22 : Float(3, 20) = aten::sigmoid(%38)
%20 : Float(3, 20) = aten::sigmoid(%34)
%18 : Float(3, 20) = aten::tanh(%30)
Expand Down
11 changes: 11 additions & 0 deletions test/expect/TestScript.test_chunk_fusion_cuda.expect
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
graph(%x : Float(10, 6)) {
%1 : Float(10, 2) = prim::FusionGroup_0[device=0](%x)
return (%1);
}
with prim::FusionGroup_0 = graph(%7 : Float(10, 6)) {
%8 : Dynamic, %9 : Dynamic, %10 : Dynamic = prim::FusedChunk[chunks=3, dim=1](%7)
%6 : Float(10, 2) = aten::mul(%8, %9)
%2 : int = prim::Constant[value=1]()
%3 : Float(10, 2) = aten::add(%6, %10, %2)
return (%3);
}
30 changes: 30 additions & 0 deletions test/expect/TestScript.test_chunk_multiple_fusion_cuda.expect
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
graph(%s : Float(5, 2, 3)
%x : Float(5, 6, 3)
%y : Float(10, 2, 3)
%z : Float(5, 2, 6)) {
%4 : Float(5, 2, 3) = prim::FusionGroup_0[device=0](%s, %y, %x, %z)
return (%4);
}
with prim::FusionGroup_0 = graph(%24 : Float(5, 2, 3)
%28 : Float(10, 2, 3)
%31 : Float(5, 6, 3)
%35 : Float(5, 2, 6)) {
%36 : Dynamic, %37 : Dynamic = prim::FusedChunk[chunks=2, dim=2](%35)
%32 : Dynamic, %33 : Dynamic, %34 : Dynamic = prim::FusedChunk[chunks=3, dim=1](%31)
%29 : Dynamic, %30 : Dynamic = prim::FusedChunk[chunks=2, dim=0](%28)
%26 : int = prim::Constant[value=1]()
%27 : Float(5, 2, 3) = aten::add(%24, %32, %26)
%22 : int = prim::Constant[value=1]()
%23 : Float(5, 2, 3) = aten::add(%27, %33, %22)
%18 : int = prim::Constant[value=1]()
%19 : Float(5, 2, 3) = aten::add(%23, %34, %18)
%14 : int = prim::Constant[value=1]()
%15 : Float(5, 2, 3) = aten::add(%19, %29, %14)
%10 : int = prim::Constant[value=1]()
%11 : Float(5, 2, 3) = aten::add(%15, %30, %10)
%6 : int = prim::Constant[value=1]()
%7 : Float(5, 2, 3) = aten::add(%11, %36, %6)
%2 : int = prim::Constant[value=1]()
%3 : Float(5, 2, 3) = aten::add(%7, %37, %2)
return (%3);
}
55 changes: 22 additions & 33 deletions test/expect/TestScript.test_lstm_fusion_cuda-forward.expect
Original file line number Diff line number Diff line change
Expand Up @@ -8,49 +8,38 @@ graph(%x.1 : Float(3, 10)
%7 : Float(10!, 80!) = aten::t(%w_ih)
%8 : Float(20!, 80!) = aten::t(%w_hh)
%9 : Float(3, 80) = aten::mm(%hx.1, %8)
%10 : int = prim::Constant[value=1]()
%11 : float = prim::Constant[value=1]()
%12 : Float(3, 80) = aten::addmm(%9, %x.1, %7, %11, %11)
%13 : int[] = prim::Constant[value=[3, 80]]()
%14 : int = prim::Constant[value=0]()
%15 : Float(3!, 80) = aten::expand(%b_ih, %13, %14)
%16 : Float(3!, 80) = aten::expand(%b_hh, %13, %14)
%17 : int = prim::Constant[value=4]()
%18 : Float(3!, 20), %19 : Float(3!, 20), %20 : Float(3!, 20), %21 : Float(3!, 20) = aten::chunk(%12, %17, %10)
%22 : Float(3!, 20), %23 : Float(3!, 20), %24 : Float(3!, 20), %25 : Float(3!, 20) = aten::chunk(%15, %17, %10)
%26 : Float(3!, 20), %27 : Float(3!, 20), %28 : Float(3!, 20), %29 : Float(3!, 20) = aten::chunk(%16, %17, %10)
%hy : Float(3, 20), %31 : Float(3, 20), %cy : Float(3, 20), %outgate.2 : Float(3, 20), %cellgate.2 : Float(3, 20), %forgetgate.2 : Float(3, 20), %ingate.2 : Float(3, 20) = prim::FusionGroup_0[device=0](%cx.1, %29, %28, %27, %26, %21, %25, %20, %24, %19, %23, %18, %22)
return (%hy, %cy, %7, %8, %ingate.2, %forgetgate.2, %cellgate.2, %outgate.2, %31);
%10 : float = prim::Constant[value=1]()
%11 : Float(3, 80) = aten::addmm(%9, %x.1, %7, %10, %10)
%12 : int[] = prim::Constant[value=[3, 80]]()
%13 : int = prim::Constant[value=0]()
%14 : Float(3!, 80) = aten::expand(%b_ih, %12, %13)
%15 : Float(3!, 80) = aten::expand(%b_hh, %12, %13)
%hy : Float(3, 20), %17 : Float(3, 20), %cy : Float(3, 20), %outgate.2 : Float(3, 20), %cellgate.2 : Float(3, 20), %forgetgate.2 : Float(3, 20), %ingate.2 : Float(3, 20) = prim::FusionGroup_0[device=0](%cx.1, %15, %11, %14)
return (%hy, %cy, %7, %8, %ingate.2, %forgetgate.2, %cellgate.2, %outgate.2, %17);
}
with prim::FusionGroup_0 = graph(%13 : Float(3, 20)
%24 : Float(3!, 20)
%28 : Float(3!, 20)
%32 : Float(3!, 20)
%36 : Float(3!, 20)
%39 : Float(3!, 20)
%40 : Float(3!, 20)
%43 : Float(3!, 20)
%44 : Float(3!, 20)
%47 : Float(3!, 20)
%48 : Float(3!, 20)
%51 : Float(3!, 20)
%52 : Float(3!, 20)) {
%55 : Float(3!, 80)
%60 : Float(3, 80)
%65 : Float(3!, 80)) {
%66 : Dynamic, %67 : Dynamic, %68 : Dynamic, %69 : Dynamic = prim::FusedChunk[chunks=4, dim=1](%65)
%61 : Dynamic, %62 : Dynamic, %63 : Dynamic, %64 : Dynamic = prim::FusedChunk[chunks=4, dim=1](%60)
%56 : Dynamic, %57 : Dynamic, %58 : Dynamic, %59 : Dynamic = prim::FusedChunk[chunks=4, dim=1](%55)
%53 : int = prim::Constant[value=1]()
%54 : Float(3, 20) = aten::add(%51, %52, %53)
%54 : Float(3, 20) = aten::add(%61, %66, %53)
%49 : int = prim::Constant[value=1]()
%50 : Float(3, 20) = aten::add(%47, %48, %49)
%50 : Float(3, 20) = aten::add(%62, %67, %49)
%45 : int = prim::Constant[value=1]()
%46 : Float(3, 20) = aten::add(%43, %44, %45)
%46 : Float(3, 20) = aten::add(%63, %68, %45)
%41 : int = prim::Constant[value=1]()
%42 : Float(3, 20) = aten::add(%39, %40, %41)
%42 : Float(3, 20) = aten::add(%64, %69, %41)
%37 : int = prim::Constant[value=1]()
%38 : Float(3, 20) = aten::add(%54, %36, %37)
%38 : Float(3, 20) = aten::add(%54, %56, %37)
%33 : int = prim::Constant[value=1]()
%34 : Float(3, 20) = aten::add(%50, %32, %33)
%34 : Float(3, 20) = aten::add(%50, %57, %33)
%29 : int = prim::Constant[value=1]()
%30 : Float(3, 20) = aten::add(%46, %28, %29)
%30 : Float(3, 20) = aten::add(%46, %58, %29)
%25 : int = prim::Constant[value=1]()
%26 : Float(3, 20) = aten::add(%42, %24, %25)
%26 : Float(3, 20) = aten::add(%42, %59, %25)
%ingate.2 : Float(3, 20) = aten::sigmoid(%38)
%forgetgate.2 : Float(3, 20) = aten::sigmoid(%34)
%cellgate.2 : Float(3, 20) = aten::tanh(%30)
Expand Down
88 changes: 32 additions & 56 deletions test/expect/TestScript.test_milstm_fusion_cuda-forward.expect
Original file line number Diff line number Diff line change
Expand Up @@ -15,66 +15,42 @@ graph(%x.1 : Float(3, 10)
%14 : int[] = prim::Constant[value=[3, 80]]()
%15 : int = prim::Constant[value=0]()
%16 : Float(3!, 80) = aten::expand(%beta_i.1, %14, %15)
%17 : int = prim::Constant[value=1]()
%18 : Float(3!, 80) = aten::expand(%beta_h.1, %14, %15)
%19 : Float(3!, 80) = aten::expand(%bias, %14, %15)
%20 : int = prim::Constant[value=4]()
%21 : Float(3!, 20), %22 : Float(3!, 20), %23 : Float(3!, 20), %24 : Float(3!, 20) = aten::chunk(%13, %20, %17)
%25 : Float(3!, 20), %26 : Float(3!, 20), %27 : Float(3!, 20), %28 : Float(3!, 20) = aten::chunk(%Uz.1, %20, %17)
%29 : Float(3!, 20), %30 : Float(3!, 20), %31 : Float(3!, 20), %32 : Float(3!, 20) = aten::chunk(%16, %20, %17)
%33 : Float(3!, 20), %34 : Float(3!, 20), %35 : Float(3!, 20), %36 : Float(3!, 20) = aten::chunk(%Wx.1, %20, %17)
%37 : Float(3!, 20), %38 : Float(3!, 20), %39 : Float(3!, 20), %40 : Float(3!, 20) = aten::chunk(%18, %20, %17)
%41 : Float(3!, 20), %42 : Float(3!, 20), %43 : Float(3!, 20), %44 : Float(3!, 20) = aten::chunk(%19, %20, %17)
%hy : Float(3, 20), %46 : Float(3, 20), %cy : Float(3, 20), %outgate.2 : Float(3, 20), %cellgate.2 : Float(3, 20), %forgetgate.2 : Float(3, 20), %ingate.2 : Float(3, 20) = prim::FusionGroup_0[device=0](%cx.1, %44, %43, %42, %41, %40, %28, %39, %27, %37, %25, %38, %26, %31, %35, %30, %34, %22, %26, %29, %33, %21, %25, %23, %27, %32, %36, %24, %28)
return (%hy, %cy, %9, %Wx.1, %11, %Uz.1, %13, %ingate.2, %forgetgate.2, %cellgate.2, %outgate.2, %46);
%17 : Float(3!, 80) = aten::expand(%beta_h.1, %14, %15)
%18 : Float(3!, 80) = aten::expand(%bias, %14, %15)
%hy : Float(3, 20), %20 : Float(3, 20), %cy : Float(3, 20), %outgate.2 : Float(3, 20), %cellgate.2 : Float(3, 20), %forgetgate.2 : Float(3, 20), %ingate.2 : Float(3, 20) = prim::FusionGroup_0[device=0](%cx.1, %Wx.1, %18, %17, %Uz.1, %16, %13)
return (%hy, %cy, %9, %Wx.1, %11, %Uz.1, %13, %ingate.2, %forgetgate.2, %cellgate.2, %outgate.2, %20);
}
with prim::FusionGroup_0 = graph(%13 : Float(3, 20)
%24 : Float(3!, 20)
%28 : Float(3!, 20)
%32 : Float(3!, 20)
%36 : Float(3!, 20)
%59 : Float(3!, 20)
%60 : Float(3!, 20)
%66 : Float(3!, 20)
%67 : Float(3!, 20)
%69 : Float(3!, 20)
%70 : Float(3!, 20)
%76 : Float(3!, 20)
%77 : Float(3!, 20)
%83 : Float(3!, 20)
%84 : Float(3!, 20)
%86 : Float(3!, 20)
%87 : Float(3!, 20)
%89 : Float(3!, 20)
%90 : Float(3!, 20)
%92 : Float(3!, 20)
%93 : Float(3!, 20)
%95 : Float(3!, 20)
%96 : Float(3!, 20)
%98 : Float(3!, 20)
%99 : Float(3!, 20)
%101 : Float(3!, 20)
%102 : Float(3!, 20)
%104 : Float(3!, 20)
%105 : Float(3!, 20)) {
%106 : Float(3, 20) = aten::mul(%104, %105)
%103 : Float(3, 20) = aten::mul(%101, %102)
%100 : Float(3, 20) = aten::mul(%98, %99)
%97 : Float(3, 20) = aten::mul(%95, %96)
%94 : Float(3, 20) = aten::mul(%92, %93)
%91 : Float(3, 20) = aten::mul(%89, %90)
%88 : Float(3, 20) = aten::mul(%86, %87)
%85 : Float(3, 20) = aten::mul(%83, %84)
%107 : Float(3, 80)
%112 : Float(3!, 80)
%117 : Float(3!, 80)
%122 : Float(3, 80)
%127 : Float(3!, 80)
%132 : Float(3, 80)) {
%133 : Dynamic, %134 : Dynamic, %135 : Dynamic, %136 : Dynamic = prim::FusedChunk[chunks=4, dim=1](%132)
%128 : Dynamic, %129 : Dynamic, %130 : Dynamic, %131 : Dynamic = prim::FusedChunk[chunks=4, dim=1](%127)
%123 : Dynamic, %124 : Dynamic, %125 : Dynamic, %126 : Dynamic = prim::FusedChunk[chunks=4, dim=1](%122)
%118 : Dynamic, %119 : Dynamic, %120 : Dynamic, %121 : Dynamic = prim::FusedChunk[chunks=4, dim=1](%117)
%113 : Dynamic, %114 : Dynamic, %115 : Dynamic, %116 : Dynamic = prim::FusedChunk[chunks=4, dim=1](%112)
%108 : Dynamic, %109 : Dynamic, %110 : Dynamic, %111 : Dynamic = prim::FusedChunk[chunks=4, dim=1](%107)
%106 : Float(3, 20) = aten::mul(%136, %126)
%103 : Float(3, 20) = aten::mul(%131, %111)
%100 : Float(3, 20) = aten::mul(%135, %125)
%97 : Float(3, 20) = aten::mul(%133, %123)
%94 : Float(3, 20) = aten::mul(%128, %108)
%91 : Float(3, 20) = aten::mul(%134, %124)
%88 : Float(3, 20) = aten::mul(%129, %109)
%85 : Float(3, 20) = aten::mul(%130, %110)
%81 : int = prim::Constant[value=1]()
%82 : Float(3, 20) = aten::add(%91, %88, %81)
%78 : Float(3, 20) = aten::mul(%76, %77)
%78 : Float(3, 20) = aten::mul(%119, %124)
%74 : int = prim::Constant[value=1]()
%75 : Float(3, 20) = aten::add(%97, %94, %74)
%71 : Float(3, 20) = aten::mul(%69, %70)
%68 : Float(3, 20) = aten::mul(%66, %67)
%71 : Float(3, 20) = aten::mul(%118, %123)
%68 : Float(3, 20) = aten::mul(%120, %125)
%64 : int = prim::Constant[value=1]()
%65 : Float(3, 20) = aten::add(%100, %85, %64)
%61 : Float(3, 20) = aten::mul(%59, %60)
%61 : Float(3, 20) = aten::mul(%121, %126)
%57 : int = prim::Constant[value=1]()
%58 : Float(3, 20) = aten::add(%106, %103, %57)
%53 : int = prim::Constant[value=1]()
Expand All @@ -86,13 +62,13 @@ with prim::FusionGroup_0 = graph(%13 : Float(3, 20)
%41 : int = prim::Constant[value=1]()
%42 : Float(3, 20) = aten::add(%58, %61, %41)
%37 : int = prim::Constant[value=1]()
%38 : Float(3, 20) = aten::add(%54, %36, %37)
%38 : Float(3, 20) = aten::add(%54, %113, %37)
%33 : int = prim::Constant[value=1]()
%34 : Float(3, 20) = aten::add(%50, %32, %33)
%34 : Float(3, 20) = aten::add(%50, %114, %33)
%29 : int = prim::Constant[value=1]()
%30 : Float(3, 20) = aten::add(%46, %28, %29)
%30 : Float(3, 20) = aten::add(%46, %115, %29)
%25 : int = prim::Constant[value=1]()
%26 : Float(3, 20) = aten::add(%42, %24, %25)
%26 : Float(3, 20) = aten::add(%42, %116, %25)
%ingate.2 : Float(3, 20) = aten::sigmoid(%38)
%forgetgate.2 : Float(3, 20) = aten::sigmoid(%34)
%cellgate.2 : Float(3, 20) = aten::tanh(%30)
Expand Down
Loading

0 comments on commit f1420ad

Please sign in to comment.