diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh index bdc8a3838b..34dbdf44dc 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh @@ -58,7 +58,9 @@ log "Decode with ONNX models" --jit-filename $repo/exp/cpu_jit.pt \ --onnx-encoder-filename $repo/exp/encoder.onnx \ --onnx-decoder-filename $repo/exp/decoder.onnx \ - --onnx-joiner-filename $repo/exp/joiner.onnx + --onnx-joiner-filename $repo/exp/joiner.onnx \ + --onnx-joiner-encoder-proj-filename $repo/exp/joiner_encoder_proj.onnx \ + --onnx-joiner-decoder-proj-filename $repo/exp/joiner_decoder_proj.onnx ./pruned_transducer_stateless3/onnx_check_all_in_one.py \ --jit-filename $repo/exp/cpu_jit.pt \ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py index a4687f35d4..11f24244ed 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py @@ -62,12 +62,15 @@ --avg 10 \ --onnx 1 -It will generate the following three files in the given `exp_dir`. +It will generate the following six files in the given `exp_dir`. Check `onnx_check.py` for how to use them. - encoder.onnx - decoder.onnx - joiner.onnx + - joiner_encoder_proj.onnx + - joiner_decoder_proj.onnx + - all_in_one.onnx (4) Export `model.state_dict()` @@ -115,6 +118,7 @@ import logging from pathlib import Path +import onnx_graphsurgeon as gs import onnx import sentencepiece as spm import torch @@ -218,6 +222,9 @@ def get_parser(): - encoder.onnx - decoder.onnx - joiner.onnx + - joiner_encoder_proj.onnx + - joiner_decoder_proj.onnx + - all_in_one.onnx Check ./onnx_check.py and ./onnx_pretrained.py for how to use them. """, @@ -485,14 +492,11 @@ def export_joiner_model_onnx( - joiner_out: a tensor of shape (N, vocab_size) - Note: The argument project_input is fixed to True. A user should not - project the encoder_out/decoder_out by himself/herself. The exported joiner - will do that for the user. """ encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] - encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) - decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + encoder_out = torch.rand(1, 1, 1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, 1, 1, decoder_out_dim, dtype=torch.float32) project_input = True # Note: It uses torch.jit.trace() internally @@ -510,10 +514,63 @@ def export_joiner_model_onnx( "logit": {0: "N"}, }, ) + torch.onnx.export( + joiner_model.encoder_proj, + (encoder_out.squeeze(0).squeeze(0)), + str(joiner_filename).replace(".onnx", "_encoder_proj.onnx"), + verbose=False, + opset_version=opset_version, + input_names=["encoder_out"], + output_names=["encoder_proj"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "encoder_proj": {0: "N"}, + }, + ) + torch.onnx.export( + joiner_model.decoder_proj, + (decoder_out.squeeze(0).squeeze(0)), + str(joiner_filename).replace(".onnx", "_decoder_proj.onnx"), + verbose=False, + opset_version=opset_version, + input_names=["decoder_out"], + output_names=["decoder_proj"], + dynamic_axes={ + "decoder_out": {0: "N"}, + "decoder_proj": {0: "N"}, + }, + ) logging.info(f"Saved to {joiner_filename}") +def add_variables( + model: nn.Module, combined_model: onnx.ModelProto +) -> onnx.ModelProto: + graph = gs.import_onnx(combined_model) + + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + + node = gs.Node( + op="Identity", + name="constants_lm", + attrs={ + "blank_id": blank_id, + "unk_id": unk_id, + "context_size": context_size, + }, + inputs=[], + outputs=[], + ) + graph.nodes.append(node) + + graph = gs.export_onnx(graph) + return graph + + def export_all_in_one_onnx( + model: nn.Module, encoder_filename: str, decoder_filename: str, joiner_filename: str, @@ -522,10 +579,22 @@ def export_all_in_one_onnx( encoder_onnx = onnx.load(encoder_filename) decoder_onnx = onnx.load(decoder_filename) joiner_onnx = onnx.load(joiner_filename) + joiner_encoder_proj_onnx = onnx.load( + str(joiner_filename).replace(".onnx", "_encoder_proj.onnx") + ) + joiner_decoder_proj_onnx = onnx.load( + str(joiner_filename).replace(".onnx", "_decoder_proj.onnx") + ) encoder_onnx = onnx.compose.add_prefix(encoder_onnx, prefix="encoder/") decoder_onnx = onnx.compose.add_prefix(decoder_onnx, prefix="decoder/") joiner_onnx = onnx.compose.add_prefix(joiner_onnx, prefix="joiner/") + joiner_encoder_proj_onnx = onnx.compose.add_prefix( + joiner_encoder_proj_onnx, prefix="joiner_encoder_proj/" + ) + joiner_decoder_proj_onnx = onnx.compose.add_prefix( + joiner_decoder_proj_onnx, prefix="joiner_decoder_proj/" + ) combined_model = onnx.compose.merge_models( encoder_onnx, decoder_onnx, io_map={} @@ -533,6 +602,13 @@ def export_all_in_one_onnx( combined_model = onnx.compose.merge_models( combined_model, joiner_onnx, io_map={} ) + combined_model = onnx.compose.merge_models( + combined_model, joiner_encoder_proj_onnx, io_map={} + ) + combined_model = onnx.compose.merge_models( + combined_model, joiner_decoder_proj_onnx, io_map={} + ) + combined_model = add_variables(model, combined_model) onnx.save(combined_model, all_in_one_filename) logging.info(f"Saved to {all_in_one_filename}") @@ -631,6 +707,7 @@ def main(): all_in_one_filename = params.exp_dir / "all_in_one.onnx" export_all_in_one_onnx( + model, encoder_filename, decoder_filename, joiner_filename, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py index 3da31b7ceb..a04b408d81 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py @@ -63,6 +63,20 @@ def get_parser(): help="Path to the onnx joiner model", ) + parser.add_argument( + "--onnx-joiner-encoder-proj-filename", + required=True, + type=str, + help="Path to the onnx joiner encoder projection model", + ) + + parser.add_argument( + "--onnx-joiner-decoder-proj-filename", + required=True, + type=str, + help="Path to the onnx joiner decoder projection model", + ) + return parser @@ -126,17 +140,27 @@ def test_decoder( def test_joiner( model: torch.jit.ScriptModule, joiner_session: ort.InferenceSession, + joiner_encoder_proj_session: ort.InferenceSession, + joiner_decoder_proj_session: ort.InferenceSession, ): joiner_inputs = joiner_session.get_inputs() assert joiner_inputs[0].name == "encoder_out" - assert joiner_inputs[0].shape == ["N", 512] + assert joiner_inputs[0].shape == ["N", 1, 1, 512] assert joiner_inputs[1].name == "decoder_out" - assert joiner_inputs[1].shape == ["N", 512] + assert joiner_inputs[1].shape == ["N", 1, 1, 512] + + joiner_encoder_proj_inputs = joiner_encoder_proj_session.get_inputs() + assert joiner_encoder_proj_inputs[0].name == "encoder_out" + assert joiner_encoder_proj_inputs[0].shape == ["N", 512] + + joiner_decoder_proj_inputs = joiner_decoder_proj_session.get_inputs() + assert joiner_decoder_proj_inputs[0].name == "decoder_out" + assert joiner_decoder_proj_inputs[0].shape == ["N", 512] for N in [1, 5, 10]: - encoder_out = torch.rand(N, 512) - decoder_out = torch.rand(N, 512) + encoder_out = torch.rand(N, 1, 1, 512) + decoder_out = torch.rand(N, 1, 1, 512) joiner_inputs = { "encoder_out": encoder_out.numpy(), @@ -154,6 +178,44 @@ def test_joiner( (joiner_out - torch_joiner_out).abs().max() ) + joiner_encoder_proj_inputs = { + "encoder_out": encoder_out.squeeze(1).squeeze(1).numpy() + } + joiner_encoder_proj_out = joiner_encoder_proj_session.run( + ["encoder_proj"], joiner_encoder_proj_inputs + )[0] + joiner_encoder_proj_out = torch.from_numpy(joiner_encoder_proj_out) + + torch_joiner_encoder_proj_out = model.joiner.encoder_proj( + encoder_out.squeeze(1).squeeze(1) + ) + assert torch.allclose( + joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5 + ), ( + (joiner_encoder_proj_out - torch_joiner_encoder_proj_out) + .abs() + .max() + ) + + joiner_decoder_proj_inputs = { + "decoder_out": decoder_out.squeeze(1).squeeze(1).numpy() + } + joiner_decoder_proj_out = joiner_decoder_proj_session.run( + ["decoder_proj"], joiner_decoder_proj_inputs + )[0] + joiner_decoder_proj_out = torch.from_numpy(joiner_decoder_proj_out) + + torch_joiner_decoder_proj_out = model.joiner.decoder_proj( + decoder_out.squeeze(1).squeeze(1) + ) + assert torch.allclose( + joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5 + ), ( + (joiner_decoder_proj_out - torch_joiner_decoder_proj_out) + .abs() + .max() + ) + @torch.no_grad() def main(): @@ -185,7 +247,20 @@ def main(): args.onnx_joiner_filename, sess_options=options, ) - test_joiner(model, joiner_session) + joiner_encoder_proj_session = ort.InferenceSession( + args.onnx_joiner_encoder_proj_filename, + sess_options=options, + ) + joiner_decoder_proj_session = ort.InferenceSession( + args.onnx_joiner_decoder_proj_filename, + sess_options=options, + ) + test_joiner( + model, + joiner_session, + joiner_encoder_proj_session, + joiner_decoder_proj_session, + ) logging.info("Finished checking ONNX models") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py index ebfae9d5f0..3e4a323aa4 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py @@ -27,7 +27,7 @@ Usage of this script: -./pruned_transducer_stateless3/jit_trace_pretrained.py \ +./pruned_transducer_stateless3/onnx_pretrained.py \ --encoder-model-filename ./pruned_transducer_stateless3/exp/encoder.onnx \ --decoder-model-filename ./pruned_transducer_stateless3/exp/decoder.onnx \ --joiner-model-filename ./pruned_transducer_stateless3/exp/joiner.onnx \ @@ -194,6 +194,7 @@ def greedy_search( decoder_input_nodes[0].name: decoder_input.numpy(), }, )[0].squeeze(1) + decoder_out = torch.from_numpy(decoder_out) offset = 0 for batch_size in batch_size_list: @@ -209,11 +210,17 @@ def greedy_search( logits = joiner.run( [joiner_output_nodes[0].name], { - joiner_input_nodes[0].name: current_encoder_out.numpy(), - joiner_input_nodes[1].name: decoder_out, + joiner_input_nodes[0] + .name: current_encoder_out.unsqueeze(1) + .unsqueeze(1) + .numpy(), + joiner_input_nodes[1] + .name: decoder_out.unsqueeze(1) + .unsqueeze(1) + .numpy(), }, )[0] - logits = torch.from_numpy(logits) + logits = torch.from_numpy(logits).squeeze(1).squeeze(1) # logits'shape (batch_size, vocab_size) assert logits.ndim == 2, logits.shape @@ -236,6 +243,7 @@ def greedy_search( decoder_input_nodes[0].name: decoder_input.numpy(), }, )[0].squeeze(1) + decoder_out = torch.from_numpy(decoder_out) sorted_ans = [h[context_size:] for h in hyps] ans = []