Skip to content

Commit

Permalink
[bin/export_gpu] fix streaming export onnx (#2654)
Browse files Browse the repository at this point in the history
* [bin/export_gpu] fix streaming export onnx

* fix
  • Loading branch information
Mddct authored Nov 8, 2024
1 parent 4c2d2f6 commit d00940f
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions wenet/bin/export_onnx_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,13 @@
from __future__ import print_function

import argparse
import logging
import os
import sys

import torch
import yaml
import logging

import torch.nn.functional as F
import yaml
from wenet.transformer.ctc import CTC
from wenet.transformer.decoder import TransformerDecoder
from wenet.transformer.encoder import BaseEncoder
Expand Down Expand Up @@ -169,15 +168,19 @@ def forward(self, chunk_xs, chunk_lens, offset, att_cache, cnn_cache,
r_att_cache = []
r_cnn_cache = []
for i, layer in enumerate(self.encoder.encoders):
xs, _, new_att_cache, new_cnn_cache = layer(
i_kv_cache = att_cache[i]
size = att_cache.size(-1) // 2
kv_cache = (i_kv_cache[:, :, :, :size], i_kv_cache[:, :, :, size:])
xs, _, new_kv_cache, new_cnn_cache = layer(
xs,
masks,
pos_emb,
att_cache=att_cache[i],
att_cache=kv_cache,
cnn_cache=cnn_cache[i],
)
# shape(new_att_cache) is (B, head, attention_key_size, d_k * 2),
# shape(new_cnn_cache) is (B, hidden-dim, cache_t2)
new_att_cache = torch.cat(new_kv_cache, dim=-1)
r_att_cache.append(
new_att_cache[:, :, next_cache_start:, :].unsqueeze(1))
if not self.transformer:
Expand Down Expand Up @@ -1241,8 +1244,8 @@ def export_rescoring_decoder(model, configs, args, logger, decoder_onnx_path,
if args.fp16:
try:
import onnxmltools
from onnxmltools.utils.float16_converter import (
convert_float_to_float16, )
from onnxmltools.utils.float16_converter import \
convert_float_to_float16
except ImportError:
print("Please install onnxmltools!")
sys.exit(1)
Expand Down

0 comments on commit d00940f

Please sign in to comment.