Skip to content

Commit

Permalink
Support MaskedLM from HuggingFace
Browse files Browse the repository at this point in the history
  • Loading branch information
Guang Yang committed Oct 26, 2024
1 parent e93ad5f commit 1b48e83
Showing 1 changed file with 95 additions and 30 deletions.
125 changes: 95 additions & 30 deletions extension/export_util/export_hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,42 +10,62 @@
import torch
import torch.export._trace
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig, to_edge
from executorch.exir import (
EdgeCompileConfig,
ExecutorchBackendConfig,
to_edge,
to_edge_transform_and_lower,
)
from torch.nn.attention import SDPBackend
from transformers import AutoModelForCausalLM
from transformers import AutoModelForCausalLM, AutoModelForMaskedLM, AutoTokenizer
from transformers.generation.configuration_utils import GenerationConfig
from transformers.integrations.executorch import convert_and_export_with_cache
from transformers.modeling_utils import PreTrainedModel


def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"-hfm",
"--hf_model_repo",
required=True,
default=None,
help="a valid huggingface model repo name",
def _export_masked_lm(args):

device = "cpu"
attn_implementation = "sdpa"
max_length = 64

tokenizer = AutoTokenizer.from_pretrained(args.hf_model_repo)
mask_token = tokenizer.mask_token
print(f"Mask token: {mask_token}")
inputs = tokenizer(
f"The goal of life is {mask_token}.",
return_tensors="pt",
padding="max_length",
max_length=max_length,
)
parser.add_argument(
"-d",
"--dtype",
type=str,
choices=["float32", "float16", "bfloat16"],
default="float32",
help="specify the dtype for loading the model",

model = AutoModelForMaskedLM.from_pretrained(
args.hf_model_repo,
device_map=device,
attn_implementation=attn_implementation,
use_cache=True,
)
parser.add_argument(
"-o",
"--output_name",
required=False,
default=None,
help="output name of the exported model",
print(f"{model.config}")
print(f"{model.generation_config}")

# pre-autograd export. eventually this will become torch.export
exported_program = torch.export.export_for_training(
model,
args=(inputs["input_ids"],),
kwargs={"attention_mask": inputs["attention_mask"]},
strict=True,
)

args = parser.parse_args()
return model, to_edge_transform_and_lower(
exported_program,
partitioner=[XnnpackPartitioner()],
compile_config=EdgeCompileConfig(
_skip_dim_order=True,
),
).to_executorch(config=ExecutorchBackendConfig(extract_delegate_segments=False))


# Configs to HF model
def _export_causal_lm(args):
device = "cpu"
# TODO: remove getattr once https://github.com/huggingface/transformers/pull/33741 is merged
dtype = getattr(torch, args.dtype)
Expand Down Expand Up @@ -106,11 +126,56 @@ def _get_constant_methods(model: PreTrainedModel):
.to_backend(XnnpackPartitioner())
.to_executorch(ExecutorchBackendConfig(extract_delegate_segments=True))
)
out_name = args.output_name if args.output_name else model.config.model_type
filename = os.path.join("./", f"{out_name}.pte")
with open(filename, "wb") as f:
prog.write_to_file(f)
print(f"Saved exported program to {filename}")

return model, prog


def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"-hfm",
"--hf_model_repo",
required=True,
default=None,
help="a valid huggingface model repo name",
)
parser.add_argument(
"-d",
"--dtype",
type=str,
choices=["float32", "float16", "bfloat16"],
default="float32",
help="specify the dtype for loading the model",
)
parser.add_argument(
"-o",
"--output_name",
required=False,
default=None,
help="output name of the exported model",
)
parser.add_argument(
"-lm",
required=True,
type=str,
choices=["masked_lm", "causal_lm"],
help="type of lm to load from huggingface",
)

args = parser.parse_args()

if args.lm == "masked_lm":
model, prog = _export_masked_lm(args)
elif args.lm == "causal_lm":
model, prog = _export_causal_lm(args)
else:
raise ValueError(f"Unsupported LM type {args.lm}")

out_name = args.output_name if args.output_name else model.config.model_type
filename = os.path.join("./", f"{out_name}.pte")
with open(filename, "wb") as f:
prog.write_to_file(f)
print(f"Saved exported program to {filename}")


if __name__ == "__main__":
Expand Down

0 comments on commit 1b48e83

Please sign in to comment.