Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
onnx bi-transformer (#385)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch/fairseq#385

Pull Request resolved: #6

Pull Request resolved: pytorch/pytorch#14292

Differential Revision: D10517864

fbshipit-source-id: d491b91703461baae69c8c9a1d52d9bcfda75528
  • Loading branch information
Haoran Li authored and facebook-github-bot committed Nov 27, 2018
1 parent f10fa86 commit 34bfef2
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 11 deletions.
9 changes: 9 additions & 0 deletions pytext/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,12 @@ def get_model_params_for_optimizer(
dense_grads_params[name] = param

return sparse_grads_params, dense_grads_params

def prepare_for_onnx_export_(self, **kwargs):
"""Make model exportable via ONNX trace."""

def apply_prepare_for_onnx_export_(module):
if module != self and hasattr(module, "prepare_for_onnx_export_"):
module.prepare_for_onnx_export_(**kwargs)

self.apply(apply_prepare_for_onnx_export_)
28 changes: 17 additions & 11 deletions pytext/utils/onnx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,22 +54,28 @@ def create_vocab_index(vocab_list, net, net_workspace, index_name):
return vocab_index


def create_vocab_indices_map(c2_prepared, init_net, vocab_map):
vocab_indices = {}
for feat_name, vocab in vocab_map.items():
assert len(vocab) > 1
vocab_indices[feat_name] = create_vocab_index(
# Skip index 0 as it is reserved for unkwon tokens
# in Caffe2's index implementation
np.array(vocab[1:], dtype=str),
init_net,
c2_prepared.workspace,
feat_name + "_index",
)
return vocab_indices


def add_feats_numericalize_ops(c2_prepared, vocab_map, input_names):
predict_net = c2_prepared.predict_net # Protobuf of the predict_net
init_net = core.Net(c2_prepared.init_net)
final_input_names = input_names.copy()
with c2_prepared.workspace._ctx:
vocab_indices = {}
for feat_name, vocab in vocab_map.items():
assert len(vocab) > 1
vocab_indices[feat_name] = create_vocab_index(
# Skip index 0 as it is reserved for unkwon tokens
# in Caffe2's index implementation
np.array(vocab[1:], dtype=str),
init_net,
c2_prepared.workspace,
feat_name + "_index",
)
vocab_indices = create_vocab_indices_map(c2_prepared, init_net, vocab_map)

# Add operators to convert string features to ids based on the vocab
final_predict_net = core.Net(c2_prepared.predict_net.name + "_processed")
final_inputs = set(
Expand Down

0 comments on commit 34bfef2

Please sign in to comment.