Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

exporting projection layers of joiner separately for onnx #584

Merged
merged 8 commits into from
Oct 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ log "Decode with ONNX models"
--jit-filename $repo/exp/cpu_jit.pt \
--onnx-encoder-filename $repo/exp/encoder.onnx \
--onnx-decoder-filename $repo/exp/decoder.onnx \
--onnx-joiner-filename $repo/exp/joiner.onnx
--onnx-joiner-filename $repo/exp/joiner.onnx \
--onnx-joiner-encoder-proj-filename $repo/exp/joiner_encoder_proj.onnx \
--onnx-joiner-decoder-proj-filename $repo/exp/joiner_decoder_proj.onnx

./pruned_transducer_stateless3/onnx_check_all_in_one.py \
--jit-filename $repo/exp/cpu_jit.pt \
Expand Down
89 changes: 83 additions & 6 deletions egs/librispeech/ASR/pruned_transducer_stateless3/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,15 @@
--avg 10 \
--onnx 1

It will generate the following three files in the given `exp_dir`.
It will generate the following six files in the given `exp_dir`.
Check `onnx_check.py` for how to use them.

- encoder.onnx
- decoder.onnx
- joiner.onnx
- joiner_encoder_proj.onnx
- joiner_decoder_proj.onnx
- all_in_one.onnx
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please change three on line 65 to four or just remove it.



(4) Export `model.state_dict()`
Expand Down Expand Up @@ -115,6 +118,7 @@
import logging
from pathlib import Path

import onnx_graphsurgeon as gs
import onnx
import sentencepiece as spm
import torch
Expand Down Expand Up @@ -218,6 +222,9 @@ def get_parser():
- encoder.onnx
- decoder.onnx
- joiner.onnx
- joiner_encoder_proj.onnx
- joiner_decoder_proj.onnx
- all_in_one.onnx

Check ./onnx_check.py and ./onnx_pretrained.py for how to use them.
""",
Expand Down Expand Up @@ -485,14 +492,11 @@ def export_joiner_model_onnx(

- joiner_out: a tensor of shape (N, vocab_size)

Note: The argument project_input is fixed to True. A user should not
project the encoder_out/decoder_out by himself/herself. The exported joiner
will do that for the user.
"""
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
encoder_out = torch.rand(1, 1, 1, encoder_out_dim, dtype=torch.float32)
decoder_out = torch.rand(1, 1, 1, decoder_out_dim, dtype=torch.float32)

project_input = True
# Note: It uses torch.jit.trace() internally
Expand All @@ -510,10 +514,63 @@ def export_joiner_model_onnx(
"logit": {0: "N"},
},
)
torch.onnx.export(
joiner_model.encoder_proj,
(encoder_out.squeeze(0).squeeze(0)),
str(joiner_filename).replace(".onnx", "_encoder_proj.onnx"),
verbose=False,
opset_version=opset_version,
input_names=["encoder_out"],
output_names=["encoder_proj"],
dynamic_axes={
"encoder_out": {0: "N"},
"encoder_proj": {0: "N"},
},
)
torch.onnx.export(
joiner_model.decoder_proj,
(decoder_out.squeeze(0).squeeze(0)),
str(joiner_filename).replace(".onnx", "_decoder_proj.onnx"),
verbose=False,
opset_version=opset_version,
input_names=["decoder_out"],
output_names=["decoder_proj"],
dynamic_axes={
"decoder_out": {0: "N"},
"decoder_proj": {0: "N"},
},
)
logging.info(f"Saved to {joiner_filename}")


def add_variables(
model: nn.Module, combined_model: onnx.ModelProto
) -> onnx.ModelProto:
graph = gs.import_onnx(combined_model)

blank_id = model.decoder.blank_id
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size

node = gs.Node(
op="Identity",
name="constants_lm",
attrs={
"blank_id": blank_id,
"unk_id": unk_id,
"context_size": context_size,
},
inputs=[],
outputs=[],
)
graph.nodes.append(node)

graph = gs.export_onnx(graph)
return graph


def export_all_in_one_onnx(
model: nn.Module,
encoder_filename: str,
decoder_filename: str,
joiner_filename: str,
Expand All @@ -522,17 +579,36 @@ def export_all_in_one_onnx(
encoder_onnx = onnx.load(encoder_filename)
decoder_onnx = onnx.load(decoder_filename)
joiner_onnx = onnx.load(joiner_filename)
joiner_encoder_proj_onnx = onnx.load(
str(joiner_filename).replace(".onnx", "_encoder_proj.onnx")
)
joiner_decoder_proj_onnx = onnx.load(
str(joiner_filename).replace(".onnx", "_decoder_proj.onnx")
)

encoder_onnx = onnx.compose.add_prefix(encoder_onnx, prefix="encoder/")
decoder_onnx = onnx.compose.add_prefix(decoder_onnx, prefix="decoder/")
joiner_onnx = onnx.compose.add_prefix(joiner_onnx, prefix="joiner/")
joiner_encoder_proj_onnx = onnx.compose.add_prefix(
joiner_encoder_proj_onnx, prefix="joiner_encoder_proj/"
)
joiner_decoder_proj_onnx = onnx.compose.add_prefix(
joiner_decoder_proj_onnx, prefix="joiner_decoder_proj/"
)

combined_model = onnx.compose.merge_models(
encoder_onnx, decoder_onnx, io_map={}
)
combined_model = onnx.compose.merge_models(
combined_model, joiner_onnx, io_map={}
)
combined_model = onnx.compose.merge_models(
combined_model, joiner_encoder_proj_onnx, io_map={}
)
combined_model = onnx.compose.merge_models(
combined_model, joiner_decoder_proj_onnx, io_map={}
)
combined_model = add_variables(model, combined_model)
onnx.save(combined_model, all_in_one_filename)
logging.info(f"Saved to {all_in_one_filename}")

Expand Down Expand Up @@ -631,6 +707,7 @@ def main():

all_in_one_filename = params.exp_dir / "all_in_one.onnx"
export_all_in_one_onnx(
model,
encoder_filename,
decoder_filename,
joiner_filename,
Expand Down
85 changes: 80 additions & 5 deletions egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,20 @@ def get_parser():
help="Path to the onnx joiner model",
)

parser.add_argument(
"--onnx-joiner-encoder-proj-filename",
required=True,
type=str,
help="Path to the onnx joiner encoder projection model",
)

parser.add_argument(
"--onnx-joiner-decoder-proj-filename",
required=True,
type=str,
help="Path to the onnx joiner decoder projection model",
)

return parser


Expand Down Expand Up @@ -126,17 +140,27 @@ def test_decoder(
def test_joiner(
model: torch.jit.ScriptModule,
joiner_session: ort.InferenceSession,
joiner_encoder_proj_session: ort.InferenceSession,
joiner_decoder_proj_session: ort.InferenceSession,
):
joiner_inputs = joiner_session.get_inputs()
assert joiner_inputs[0].name == "encoder_out"
assert joiner_inputs[0].shape == ["N", 512]
assert joiner_inputs[0].shape == ["N", 1, 1, 512]

assert joiner_inputs[1].name == "decoder_out"
assert joiner_inputs[1].shape == ["N", 512]
assert joiner_inputs[1].shape == ["N", 1, 1, 512]

joiner_encoder_proj_inputs = joiner_encoder_proj_session.get_inputs()
assert joiner_encoder_proj_inputs[0].name == "encoder_out"
assert joiner_encoder_proj_inputs[0].shape == ["N", 512]

joiner_decoder_proj_inputs = joiner_decoder_proj_session.get_inputs()
assert joiner_decoder_proj_inputs[0].name == "decoder_out"
assert joiner_decoder_proj_inputs[0].shape == ["N", 512]

for N in [1, 5, 10]:
encoder_out = torch.rand(N, 512)
decoder_out = torch.rand(N, 512)
encoder_out = torch.rand(N, 1, 1, 512)
decoder_out = torch.rand(N, 1, 1, 512)

joiner_inputs = {
"encoder_out": encoder_out.numpy(),
Expand All @@ -154,6 +178,44 @@ def test_joiner(
(joiner_out - torch_joiner_out).abs().max()
)

joiner_encoder_proj_inputs = {
"encoder_out": encoder_out.squeeze(1).squeeze(1).numpy()
}
joiner_encoder_proj_out = joiner_encoder_proj_session.run(
["encoder_proj"], joiner_encoder_proj_inputs
)[0]
joiner_encoder_proj_out = torch.from_numpy(joiner_encoder_proj_out)

torch_joiner_encoder_proj_out = model.joiner.encoder_proj(
encoder_out.squeeze(1).squeeze(1)
)
assert torch.allclose(
joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5
), (
(joiner_encoder_proj_out - torch_joiner_encoder_proj_out)
.abs()
.max()
)

joiner_decoder_proj_inputs = {
"decoder_out": decoder_out.squeeze(1).squeeze(1).numpy()
}
joiner_decoder_proj_out = joiner_decoder_proj_session.run(
["decoder_proj"], joiner_decoder_proj_inputs
)[0]
joiner_decoder_proj_out = torch.from_numpy(joiner_decoder_proj_out)

torch_joiner_decoder_proj_out = model.joiner.decoder_proj(
decoder_out.squeeze(1).squeeze(1)
)
assert torch.allclose(
joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5
), (
(joiner_decoder_proj_out - torch_joiner_decoder_proj_out)
.abs()
.max()
)


@torch.no_grad()
def main():
Expand Down Expand Up @@ -185,7 +247,20 @@ def main():
args.onnx_joiner_filename,
sess_options=options,
)
test_joiner(model, joiner_session)
joiner_encoder_proj_session = ort.InferenceSession(
args.onnx_joiner_encoder_proj_filename,
sess_options=options,
)
joiner_decoder_proj_session = ort.InferenceSession(
args.onnx_joiner_decoder_proj_filename,
sess_options=options,
)
test_joiner(
model,
joiner_session,
joiner_encoder_proj_session,
joiner_decoder_proj_session,
)
logging.info("Finished checking ONNX models")


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

Usage of this script:

./pruned_transducer_stateless3/jit_trace_pretrained.py \
./pruned_transducer_stateless3/onnx_pretrained.py \
--encoder-model-filename ./pruned_transducer_stateless3/exp/encoder.onnx \
--decoder-model-filename ./pruned_transducer_stateless3/exp/decoder.onnx \
--joiner-model-filename ./pruned_transducer_stateless3/exp/joiner.onnx \
Expand Down Expand Up @@ -194,6 +194,7 @@ def greedy_search(
decoder_input_nodes[0].name: decoder_input.numpy(),
},
)[0].squeeze(1)
decoder_out = torch.from_numpy(decoder_out)

offset = 0
for batch_size in batch_size_list:
Expand All @@ -209,11 +210,17 @@ def greedy_search(
logits = joiner.run(
[joiner_output_nodes[0].name],
{
joiner_input_nodes[0].name: current_encoder_out.numpy(),
joiner_input_nodes[1].name: decoder_out,
joiner_input_nodes[0]
.name: current_encoder_out.unsqueeze(1)
.unsqueeze(1)
.numpy(),
joiner_input_nodes[1]
.name: decoder_out.unsqueeze(1)
.unsqueeze(1)
.numpy(),
},
)[0]
logits = torch.from_numpy(logits)
logits = torch.from_numpy(logits).squeeze(1).squeeze(1)
# logits'shape (batch_size, vocab_size)

assert logits.ndim == 2, logits.shape
Expand All @@ -236,6 +243,7 @@ def greedy_search(
decoder_input_nodes[0].name: decoder_input.numpy(),
},
)[0].squeeze(1)
decoder_out = torch.from_numpy(decoder_out)

sorted_ans = [h[context_size:] for h in hyps]
ans = []
Expand Down