Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support LSTM #144

Closed
allenling opened this issue Nov 7, 2019 · 24 comments
Closed

support LSTM #144

allenling opened this issue Nov 7, 2019 · 24 comments

Comments

@allenling
Copy link

allenling commented Nov 7, 2019

i use pytorch to build my lstm network, it looks like

    class TestNet(torch.nn.Module):
        def __init__(self):
            super(TestNet, self).__init__()
            self.lstm = nn.LSTM(256,
                                128, 2,
                                batch_first=True, bidirectional=True)
            return
        def forward(self, x):
            self.lstm.flatten_parameters()
            res = self.lstm(x)
            return res

input_tensor = torch.randn(30, 61, 256)

and lstm converter

op = trt.RNNOperation.LSTM
ctx.network.add_rnn_v2(input_tensor._trt, layer_count, hidden_size, max_seq_length, op)

max_seq_length should be input_tensor.shape[0]?

got a error

[TensorRT] ERROR: Parameter check failed at: ../builder/Network.cpp::addRNNCommon::397, condition: input.getDimensions().d[di.seqLen()] == maxSeqLen

and, how could i set a reverse weights?

please any help...

@allenling
Copy link
Author

it seems that max_seq_length should be batch(input_tensor.shape[1]).

since in torch, input shape would be:

input of shape (seq_len, batch, input_size)

@bfortuner
Copy link

@allenling Were you able to get this working? Would love to hear your approach

@allenling
Copy link
Author

allenling commented Nov 13, 2019

import tensorrt as trt
import torch
from torch import nn 
from torch2trt.torch2trt import *


@tensorrt_converter('torch.nn.LSTM.forward')
def convert_lstm(ctx):
    module = ctx.method_args[0]
    input_tensor = ctx.method_args[1]
    output = ctx.method_return[0]
    layer_count = module.num_layers
    hidden_size = module.hidden_size
    max_seq_length = input_tensor.shape[1] if module.batch_first else input_tensor.shape[0]
    op = trt.RNNOperation.LSTM
    layer = ctx.network.add_rnn_v2(input_tensor._trt, layer_count, hidden_size, max_seq_length, op)
    if module.bidirectional is True:
        layer.direction = trt.RNNDirection.BIDIRECTION
    for i in range(layer_count):
        iw = getattr(module, "weight_ih_l%s" % i).detach().cpu().numpy()
        hw = getattr(module, "weight_hh_l%s" % i).detach().cpu().numpy()
        
        rela_index = 2*i if module.bidirectional is True else i

        layer.set_weights_for_gate(rela_index, trt.RNNGateType.INPUT, True, iw[:hidden_size,:].copy())
        layer.set_weights_for_gate(rela_index, trt.RNNGateType.FORGET, True, iw[hidden_size:hidden_size * 2,:].copy())
        layer.set_weights_for_gate(rela_index, trt.RNNGateType.CELL, True, iw[hidden_size * 2: hidden_size * 3,:].copy())
        layer.set_weights_for_gate(rela_index, trt.RNNGateType.OUTPUT, True, iw[hidden_size * 3:hidden_size * 4,:].copy())

        layer.set_weights_for_gate(rela_index, trt.RNNGateType.INPUT, False, hw[:hidden_size,:].copy())
        layer.set_weights_for_gate(rela_index, trt.RNNGateType.FORGET, False, hw[hidden_size:hidden_size * 2,:].copy())
        layer.set_weights_for_gate(rela_index, trt.RNNGateType.CELL, False, hw[hidden_size * 2: hidden_size * 3,:].copy())
        layer.set_weights_for_gate(rela_index, trt.RNNGateType.OUTPUT, False, hw[hidden_size * 3:hidden_size * 4,:].copy())

        ib = getattr(module, "bias_ih_l%s" % i).detach().cpu().numpy()
        hb = getattr(module, "bias_hh_l%s" % i).detach().cpu().numpy()
        layer.set_bias_for_gate(rela_index, trt.RNNGateType.INPUT, True, ib[:hidden_size].copy())
        layer.set_bias_for_gate(rela_index, trt.RNNGateType.FORGET, True, ib[hidden_size:hidden_size * 2].copy())
        layer.set_bias_for_gate(rela_index, trt.RNNGateType.CELL, True, ib[hidden_size * 2: hidden_size * 3].copy())
        layer.set_bias_for_gate(rela_index, trt.RNNGateType.OUTPUT, True, ib[hidden_size * 3:hidden_size * 4].copy())

        layer.set_bias_for_gate(rela_index, trt.RNNGateType.INPUT, False, hb[:hidden_size].copy())
        layer.set_bias_for_gate(rela_index, trt.RNNGateType.FORGET, False, hb[hidden_size:hidden_size * 2].copy())
        layer.set_bias_for_gate(rela_index, trt.RNNGateType.CELL, False, hb[hidden_size * 2: hidden_size * 3].copy())
        layer.set_bias_for_gate(rela_index, trt.RNNGateType.OUTPUT, False, hb[hidden_size * 3:hidden_size * 4].copy())

        if module.bidirectional is True:
            # ================reverse=====================
            iw_r = getattr(module, "weight_ih_l%s_reverse" % i).detach().cpu().numpy()
            hw_r = getattr(module, "weight_hh_l%s_reverse" % i).detach().cpu().numpy()
    
            layer.set_weights_for_gate(2*i+1, trt.RNNGateType.INPUT, True, iw_r[:hidden_size,:].copy())
            layer.set_weights_for_gate(2*i+1, trt.RNNGateType.FORGET, True, iw_r[hidden_size:hidden_size * 2,:].copy())
            layer.set_weights_for_gate(2*i+1, trt.RNNGateType.CELL, True, iw_r[hidden_size * 2: hidden_size * 3,:].copy())
            layer.set_weights_for_gate(2*i+1, trt.RNNGateType.OUTPUT, True, iw_r[hidden_size * 3:hidden_size * 4,:].copy())
    
            layer.set_weights_for_gate(2*i+1, trt.RNNGateType.INPUT, False, hw_r[:hidden_size,:].copy())
            layer.set_weights_for_gate(2*i+1, trt.RNNGateType.FORGET, False, hw_r[hidden_size:hidden_size * 2,:].copy())
            layer.set_weights_for_gate(2*i+1, trt.RNNGateType.CELL, False, hw_r[hidden_size * 2: hidden_size * 3,:].copy())
            layer.set_weights_for_gate(2*i+1, trt.RNNGateType.OUTPUT, False, hw_r[hidden_size * 3:hidden_size * 4,:].copy())
    
            ib_r = getattr(module, "bias_ih_l%s_reverse" % i).detach().cpu().numpy()
            hb_r = getattr(module, "bias_hh_l%s_reverse" % i).detach().cpu().numpy()
    
            layer.set_bias_for_gate(2*i+1, trt.RNNGateType.INPUT, True, ib_r[:hidden_size].copy())
            layer.set_bias_for_gate(2*i+1, trt.RNNGateType.FORGET, True, ib_r[hidden_size:hidden_size * 2].copy())
            layer.set_bias_for_gate(2*i+1, trt.RNNGateType.CELL, True, ib_r[hidden_size * 2: hidden_size * 3].copy())
            layer.set_bias_for_gate(2*i+1, trt.RNNGateType.OUTPUT, True, ib_r[hidden_size * 3:hidden_size * 4].copy())
    
            layer.set_bias_for_gate(2*i+1, trt.RNNGateType.INPUT, False, hb_r[:hidden_size].copy())
            layer.set_bias_for_gate(2*i+1, trt.RNNGateType.FORGET, False, hb_r[hidden_size:hidden_size * 2].copy())
            layer.set_bias_for_gate(2*i+1, trt.RNNGateType.CELL, False, hb_r[hidden_size * 2: hidden_size * 3].copy())
            layer.set_bias_for_gate(2*i+1, trt.RNNGateType.OUTPUT, False, hb_r[hidden_size * 3:hidden_size * 4].copy())
    lstm_output = layer.get_output(0)
    output._trt = lstm_output
    return


def main():
    class TestNet(torch.nn.Module):
        def __init__(self):
            super(TestNet, self).__init__()
            self.lstm = nn.LSTM(256,
                                256 // 2, 2,
                                batch_first=True, bidirectional=True)
            return
        def forward(self, x):
            self.lstm.flatten_parameters()
            res = self.lstm(x)
            return res
    fp16 = True
    print("fp16" if fp16 else "fp32")
    net = TestNet()
    if fp16:
        net.half()
    net.cuda()
    x = torch.randn(30, 61, 256)
    if fp16:
        x = x.to(torch.float16)
    x = x.to("cuda:0")
    inputs = [x.clone()]
    trt_net = torch2trt(net, inputs, fp16_mode=fp16, max_workspace_size = 1<<25)
    res = net(x)
    return
    
if __name__ == "__main__":
    main()
    
    

not the best solution, but works for me

only care about first output

hope this could help you
@bfortuner

@bfortuner
Copy link

Thanks! Will give it a try

@bfortuner
Copy link

Hey! Which version of PyTorch and Trt are you using? I'm getting some warnings and an error which I think means it failed to convert something

Warning: Encountered known unsupported method torch.Tensor.__hash__
Warning: Encountered known unsupported method torch.Tensor.__hash__
Warning: Encountered known unsupported method torch.Tensor.type             
Warning: Encountered known unsupported method torch.Tensor.data_ptr
Warning: Encountered known unsupported method torch.Tensor.data_ptr
Warning: Encountered known unsupported method torch.Tensor.data_ptr
Warning: Encountered known unsupported method torch.Tensor.data_ptr
Warning: Encountered known unsupported method torch.Tensor.data_ptr
Warning: Encountered known unsupported method torch.Tensor.data_ptr
Warning: Encountered known unsupported method torch.Tensor.data_ptr
Warning: Encountered known unsupported method torch.Tensor.data_ptr
Warning: Encountered known unsupported method torch.Tensor.data_ptr
Warning: Encountered known unsupported method torch.Tensor.data_ptr
Warning: Encountered known unsupported method torch.Tensor.data_ptr
Warning: Encountered known unsupported method torch.Tensor.data_ptr             
Warning: Encountered known unsupported method torch.Tensor.data_ptr                                                                     
Warning: Encountered known unsupported method torch.Tensor.data_ptr
Warning: Encountered known unsupported method torch.Tensor.data_ptr                                                                        
Warning: Encountered known unsupported method torch.Tensor.data_ptr
Warning: Encountered known unsupported method torch.Tensor.get_device
Warning: Encountered known unsupported method torch.is_grad_enabled
Warning: Encountered known unsupported method torch.is_grad_enabled

  File "/home/brendan.fortuner/workplace/fusion_seq/.venvpy3/lib/python3.6/site-packages/torch2trt/torch2trt.py", line 281, in mark_outputs
    trt_tensor = torch_output._trt
AttributeError: 'Tensor' object has no attribute '_trt'

I'm on Python3.6, Pytorch 1.3, CUDA 10, Ubuntu 18, TrT 6.0.1, and the latest commit from this repo.

I'm able to convert all the image models in the repo (alexnet, resnet, etc)

@bfortuner
Copy link

A small update to forward fixed it for me:

From this

      def forward(self, x):
            self.lstm.flatten_parameters()
            res = self.lstm(x)
            return res

To this:

        def forward(self, x):
            self.lstm.flatten_parameters()
            out, (h0, c0) = self.lstm(x)
            return out

@xieydd
Copy link

xieydd commented Nov 23, 2019

When i use ths lstm op above, i meet a problem:

TensorRT] ERROR: (Unnamed Layer* 17) [Shuffle]: uninferred dimensions are not an exact divisor of in put dimensions, so inferred dimension cannot be calculated python: ./RNNUtils.h:122: nvinfer1::RNNStats::DimIndices::DimIndices(nvinfer1::RNNDimOrder, int): Assertion `nbInputDims >= 2' failed.

This is my code:

class CRNN(nn.Module):                                                                                                                                                                              [59/421]

    def __init__(self, imgH, nc, nclass, nh, n_rnn=2, leakyRelu=False):
        super(CRNN, self).__init__()
        assert imgH % 16 == 0, 'imgH has to be a multiple of 16'

        #ks = [3, 3, 3, 3, 3, 3, 2]
        #ps = [1, 1, 1, 1, 1, 1, 0]
        #ss = [1, 1, 1, 1, 1, 1, 1]
        ks = [3, 3, 3, 3, 3]
        ps = [1, 1, 1, 1, 0]
        ss = [1, 1, 1, 1, 1]

        #nm = [64, 128, 256, 256, 512, 512, 512]
        nm = [64, 128, 128, 256, 256]

        cnn = nn.Sequential()

        def convRelu(i, batchNormalization=False):
            nIn = nc if i == 0 else nm[i - 1]
            nOut = nm[i]
            cnn.add_module('conv{0}'.format(i),
                           nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))
            if batchNormalization:
                cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
            if leakyRelu:
                cnn.add_module('relu{0}'.format(i),
                               nn.LeakyReLU(0.2, inplace=True))
            else:
                cnn.add_module('relu{0}'.format(i), nn.ReLU(True))

        convRelu(0)
        cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2))  # 64x16x64
        # convRelu(1)
        # cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2))  # 128x8x32
        convRelu(1, True)
        convRelu(2)
        cnn.add_module('pooling{0}'.format(2),
                       nn.MaxPool2d((3, 2), (2, 1), (0, 1)))  # 256x4x16
        convRelu(3, True)
        # convRelu(5)
        cnn.add_module('pooling{0}'.format(3),
                       nn.MaxPool2d((3, 2), (2, 1), (0, 1)))  # 512x2x16
        convRelu(4, True)  # 512x1x16

        self.cnn = cnn
        #self.rnn = BidirectionalLSTM(256, nh, nclass)
        self.rnn = nn.LSTM(256, nh,2,bidirectional=True, batch_first=True)
    def forward(self, input):                                                                                                                                                                        [5/421]
        # conv features
        conv = self.cnn(input)
        b, c, h, w = conv.size()
        #assert h == 1, "the height of conv must be 1"
        conv = conv.squeeze(2)
        print(conv.size())
        # conv = conv.permute(2, 0, 1)  # [w, b, c]
        conv = conv.permute(0, 2, 1)  # [w, b, c]

        print(conv.size())
        # rnn features
        output = self.rnn(conv)

        return output

def main():
    fp16 = False
    print("fp16" if fp16 else "fp32")
    net = CRNN(32, 1, 37, 128)
    x = torch.randn(1, 1, 32, 100)
    print(net)
    if fp16:
        net.half()
    net.cuda()
    if fp16:
        x = x.to(torch.float16)
    x = x.to("cuda:0")
    inputs = [x.clone()]
    trt_net = torch2trt(net, inputs, fp16_mode=fp16,
                        max_workspace_size=1 << 25)
    res = net(x)
    print(res)
    return
if __name__ == '__main__':
    main()

Can you help me @allenling @bfortuner

@allenling
Copy link
Author

lost dim after permute

    def forward(self, input):
        # conv features
        conv = self.cnn(input)
        b, c, h, w = conv.size()
        #assert h == 1, "the height of conv must be 1"
        conv = conv.squeeze(2)
        print(conv.size())
        # conv = conv.permute(2, 0, 1)  # [w, b, c]
        # ============>here, we lost dim!
        conv = conv.permute(0, 2, 1)  # [w, b, c]
        # ==========>check dims here, conv._trt.shape is (0)
        print(conv.size(), conv._trt.shape)
        # rnn features
        output = self.rnn(conv)

        return output

seems cause by unsupported method squeeze
i got a warning

Warning: Encountered known unsupported method torch.Tensor.squeeze

@xieydd

@xieydd
Copy link

xieydd commented Nov 25, 2019

yes, this warn is also appear in my test.
@allenling Your meaning is i need implement the op squeeze register to trt ?

when i test cnn or rnn alone are successful, but when i run cnn, and use cnn output as lstm input, will error

@xieydd
Copy link

xieydd commented Nov 25, 2019

@allenling Does you know a way to replace the squeeze to a op trt supported?

@xieydd
Copy link

xieydd commented Nov 25, 2019

@allenling I notice tensorrt onnx parser support squeeze op. support op list

@xieydd
Copy link

xieydd commented Nov 25, 2019

@allenling I change squeeze to view, it works;

def forward(self, input):
        # conv features
        cnn = self.cnn(input)
        print(cnn.size(),cnn._trt.shape)
        b, c, h, w = cnn.size()
        conv = cnn.view(b,c,w)
        print(conv.size(), conv.dtype)
        conv = conv.permute(0, 2, 1) # [w, b, c]
        self.rnn.flatten_parameters()
        rnn, (h0, c0) = self.rnn(conv)
        print(rnn.size(),rnn._trt.shape) # passed, torch.Size([1, 50, 256]) (50, 256)

        b, T, h = rnn.size()
        t_rec = rnn.view(T * b, h) # error: [TensorRT] ERROR: (Unnamed Layer* 19) [Shuffle]: uninferred dimensions are not an exact divisor of input dimensions, so inferred dimension cannot be calculated.   torch.Size([50, 256]) torch.float32 (0)
        print(t_rec.size(),t_rec.dtype,t_rec._trt.shape)

        e = self.embedding(t_rec)  # [T * b, nOut]
        output = e.view(T, b, -1)

        return output

@xieydd
Copy link

xieydd commented Nov 26, 2019

@allenling But i still have a problem, you can see see the error above.

@allenling
Copy link
Author

if i wanna change the shape of rnn, which is (1, 50, 256), to (1, 256, 50), you should:

t_rec = rnn.view(b, h, T) # which is rnn.view(1, 256, 50)

if you miss b, and convert would cut h, and output._trt shape would be (T,), not (h, T)

def convert_view(ctx):
    input = ctx.method_args[0]
    input_trt = trt_(ctx.network, input)
    output = ctx.method_return
    layer = ctx.network.add_shuffle(input_trt)
    # ======>here, if call rnn.view(h, T), then output.shape would be (256, 50)
    # ======>and  layer.reshape_dims would be output.shape[1:], that is (50, )
    # ======>so we lost 256
    # ======> we have to include batch size to make sure convert would not cut other shape
    layer.reshape_dims = tuple(output.shape[1:])
    output._trt = layer.get_output(0)

so, be careful about the batch size, convert could cut the batch size, which is shape[0]

@xieydd

@xieydd
Copy link

xieydd commented Nov 26, 2019

@allenling Very tks.
i change t_rec = rnn.view(T * b, h) to t_rec = rnn.view(1, T * b, h), it works.
I got it. sadly the result of linear layer error again

rec = rnn.view(1, b*T,h) # rec._trt.shape is [50,256]
e = self.embedding(rec)  # error: e.shape: torch.Size([1, 50, 37])  and e._trt shape is 0, very confused it, the Linear op is supported; 
output = e.view(1, T, b,-1) # i also notice the batch_size 

Update:
i find linear input should be 2-dims

@xieydd
Copy link

xieydd commented Nov 26, 2019

@allenling if i only want to get a 2-dim from view, how can i do?

@allenling
Copy link
Author

@allenling if i only want to get a 2-dim from view, how can i do?

maybe make sure convert can reshape correctly by hacking converter?

@xieydd
Copy link

xieydd commented Nov 26, 2019

@allenling can i add your wechat, something wired?

@xieydd
Copy link

xieydd commented Dec 3, 2019

@allenling Is there a way to change the input size of lstm from 1, 50, 256 to 50,1,256 and not use batch_first=False

@PKQ1688
Copy link

PKQ1688 commented Dec 4, 2019

I encountered a similar problem, is there any better solution now?

@bfortuner
Copy link

bfortuner commented Dec 4, 2019

If this is your first time converting an LSTM, I found it helpful to use the raw TensorRT API instead to make sure you understand what's happening underneath the surface. The torch2trt repo is good for basic stuff, but for anything non-trivial, you'll have to use the TensorRT API directly. I would start there.

https://docs.nvidia.com/deeplearning/sdk/tensorrt-api/python_api/index.html

It will take more time initially, but it will save you tons of time in the future.

@iAlexKai
Copy link

Guys, what if I want to input both input_tensor and the initial hidden state h_0 into a GRU or LSTM?
I found that the ctx.network.add_rnn_v2 only supports the input of input_tensor rather than h_0.
How could I solve this problem?

@iAlexKai
Copy link

I find the answer!
After defining the layer using:

layer = ctx.network.add_rnn_v2(input_tensor._trt, layer_count, hidden_size, max_seq_length, op)

set the hidden_state directly:

layer.hidden_state = init_state_tensor._trt

reference: https://www.programmersought.com/article/22907561559/

@jaybdub jaybdub closed this as completed Jul 18, 2022
@TAOSHss
Copy link

TAOSHss commented Nov 30, 2022

A small update to forward fixed it for me:

From this

      def forward(self, x):
            self.lstm.flatten_parameters()
            res = self.lstm(x)
            return res

To this:

        def forward(self, x):
            self.lstm.flatten_parameters()
            out, (h0, c0) = self.lstm(x)
            return out

Hey! Which version of PyTorch and Trt are you using? I'm getting some warnings and an error which I think means it failed to convert something

Warning: Encountered known unsupported method torch.Tensor.__hash__
Warning: Encountered known unsupported method torch.Tensor.__hash__
Warning: Encountered known unsupported method torch.Tensor.type             
Warning: Encountered known unsupported method torch.Tensor.data_ptr
Warning: Encountered known unsupported method torch.Tensor.data_ptr
Warning: Encountered known unsupported method torch.Tensor.data_ptr
Warning: Encountered known unsupported method torch.Tensor.data_ptr
Warning: Encountered known unsupported method torch.Tensor.data_ptr
Warning: Encountered known unsupported method torch.Tensor.data_ptr
Warning: Encountered known unsupported method torch.Tensor.data_ptr
Warning: Encountered known unsupported method torch.Tensor.data_ptr
Warning: Encountered known unsupported method torch.Tensor.data_ptr
Warning: Encountered known unsupported method torch.Tensor.data_ptr
Warning: Encountered known unsupported method torch.Tensor.data_ptr
Warning: Encountered known unsupported method torch.Tensor.data_ptr             
Warning: Encountered known unsupported method torch.Tensor.data_ptr                                                                     
Warning: Encountered known unsupported method torch.Tensor.data_ptr
Warning: Encountered known unsupported method torch.Tensor.data_ptr                                                                        
Warning: Encountered known unsupported method torch.Tensor.data_ptr
Warning: Encountered known unsupported method torch.Tensor.get_device
Warning: Encountered known unsupported method torch.is_grad_enabled
Warning: Encountered known unsupported method torch.is_grad_enabled

  File "/home/brendan.fortuner/workplace/fusion_seq/.venvpy3/lib/python3.6/site-packages/torch2trt/torch2trt.py", line 281, in mark_outputs
    trt_tensor = torch_output._trt
AttributeError: 'Tensor' object has no attribute '_trt'

I'm on Python3.6, Pytorch 1.3, CUDA 10, Ubuntu 18, TrT 6.0.1, and the latest commit from this repo.

I'm able to convert all the image models in the repo (alexnet, resnet, etc)

Warning: Encountered known unsupported method torch.Tensor.data_ptr
Warning: Encountered known unsupported method torch.Tensor.data_ptr
Warning: Encountered known unsupported method torch.Tensor.get_device
Warning: Encountered known unsupported method torch.is_grad_enabled
Warning: Encountered known unsupported method torch.is_grad_enabled
Warning: Encountered known unsupported method torch.is_grad_enabled

Have you solved these warnings?

I also encountered these warnings. Although the model was exported, the reasoning failed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants