diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index 036b8f50d6..787487fc37 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -2,6 +2,7 @@ #include "torch/csrc/jit/passes/fuse_linear.h" #include "torch/csrc/jit/passes/freeze_module.h" #include "torch/csrc/jit/passes/lower_graph.h" +#include "torch/csrc/jit/passes/lower_tuples.h" #include "torch/csrc/jit/passes/quantization.h" #include "torch/csrc/jit/passes/guard_elimination.h" @@ -23,6 +24,7 @@ void LowerGraph(std::shared_ptr& g) { torch::jit::EliminateRedundantGuards(g); passes::EliminateExceptionOrPassPattern(g); torch::jit::FuseLinear(g); + torch::jit::LowerAllTuples(g); passes::RemoveDropout(g); passes::FuseFlattenLinear(g); passes::Conv2DToConvolution(g);