Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support MaskedLM from HuggingFace #6509

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 181 additions & 36 deletions extension/export_util/export_hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,45 +10,33 @@
import torch
import torch.export._trace
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig, to_edge
from executorch.exir import (
EdgeCompileConfig,
ExecutorchBackendConfig,
to_edge,
to_edge_transform_and_lower,
)
from torch.nn.attention import SDPBackend
from transformers import AutoModelForCausalLM
from transformers import (
AutoConfig,
AutoImageProcessor,
AutoModelForCausalLM,
AutoModelForDepthEstimation,
AutoModelForMaskedLM,
AutoModelForSemanticSegmentation,
AutoTokenizer,
)
from transformers.generation.configuration_utils import GenerationConfig
from transformers.integrations.executorch import convert_and_export_with_cache
from transformers.modeling_utils import PreTrainedModel

from .task_registry import register_task, task_registry

def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"-hfm",
"--hf_model_repo",
required=True,
default=None,
help="a valid huggingface model repo name",
)
parser.add_argument(
"-d",
"--dtype",
type=str,
choices=["float32", "float16", "bfloat16"],
default="float32",
help="specify the dtype for loading the model",
)
parser.add_argument(
"-o",
"--output_name",
required=False,
default=None,
help="output name of the exported model",
)

args = parser.parse_args()

# Configs to HF model
@register_task("causal_lm")
def export_causal_lm(args):
device = "cpu"
# TODO: remove getattr once https://github.com/huggingface/transformers/pull/33741 is merged
dtype = getattr(torch, args.dtype)
dtype = args.dtype
batch_size = 1
max_length = 123
cache_implementation = "static"
Expand Down Expand Up @@ -106,11 +94,168 @@ def _get_constant_methods(model: PreTrainedModel):
.to_backend(XnnpackPartitioner())
.to_executorch(ExecutorchBackendConfig(extract_delegate_segments=True))
)
out_name = args.output_name if args.output_name else model.config.model_type
filename = os.path.join("./", f"{out_name}.pte")
with open(filename, "wb") as f:
prog.write_to_file(f)
print(f"Saved exported program to {filename}")

return model, prog


@register_task("masked_lm")
def export_masked_lm(args):
device = "cpu"
max_length = 64
attn_implementation = "sdpa"

config = AutoConfig.from_pretrained(args.hf_model_repo)
kwargs = {}
if hasattr(config, "use_cache"):
kwargs["use_cache"] = True

print(f"DEBUG: attn_implementation: {attn_implementation}")
tokenizer = AutoTokenizer.from_pretrained(args.hf_model_repo)
mask_token = tokenizer.mask_token
print(f"Mask token: {mask_token}")
inputs = tokenizer(
f"The goal of life is {mask_token}.",
return_tensors="pt",
padding="max_length",
max_length=max_length,
)

model = AutoModelForMaskedLM.from_pretrained(
args.hf_model_repo,
device_map=device,
attn_implementation=attn_implementation,
**kwargs,
)
print(f"{model.config}")
print(f"{model.generation_config}")

# pre-autograd export. eventually this will become torch.export
exported_program = torch.export.export_for_training(
model,
args=(inputs["input_ids"],),
kwargs={"attention_mask": inputs["attention_mask"]},
strict=True,
)

return model, to_edge_transform_and_lower(
exported_program,
partitioner=[XnnpackPartitioner()],
compile_config=EdgeCompileConfig(
_skip_dim_order=True,
),
).to_executorch(config=ExecutorchBackendConfig(extract_delegate_segments=False))


@register_task("semantic_segmentation")
def export_semantic_segmentation(args):
import requests
from PIL import Image

device = "cpu"
model = AutoModelForSemanticSegmentation.from_pretrained(
args.hf_model_repo,
device_map=device,
)
image_processor = AutoImageProcessor.from_pretrained(
args.hf_model_repo,
device_map=device,
)
image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/coco_sample.png"
image = Image.open(requests.get(image_url, stream=True).raw)
inputs = image_processor(images=image, return_tensors="pt")

exported_program = torch.export.export_for_training(
model,
args=(inputs["pixel_values"],),
strict=True,
)

return model, to_edge_transform_and_lower(
exported_program,
partitioner=[XnnpackPartitioner()],
compile_config=EdgeCompileConfig(
_skip_dim_order=True,
),
).to_executorch(config=ExecutorchBackendConfig(extract_delegate_segments=False))


@register_task("depth_estimation")
def export_depth_estimation(args):
import requests
from PIL import Image

device = "cpu"
model = AutoModelForDepthEstimation.from_pretrained(
args.hf_model_repo,
device_map=device,
)
image_processor = AutoImageProcessor.from_pretrained(
args.hf_model_repo,
device_map=device,
)
image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/coco_sample.png"
image = Image.open(requests.get(image_url, stream=True).raw)
inputs = image_processor(images=image, return_tensors="pt")

exported_program = torch.export.export_for_training(
model,
args=(inputs["pixel_values"],),
strict=True,
)

return model, to_edge_transform_and_lower(
exported_program,
partitioner=[XnnpackPartitioner()],
compile_config=EdgeCompileConfig(
_skip_dim_order=True,
),
).to_executorch(config=ExecutorchBackendConfig(extract_delegate_segments=False))


def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"-hfm",
"--hf_model_repo",
required=True,
default=None,
help="a valid huggingface model repo name",
)
parser.add_argument(
"-d",
"--dtype",
type=str,
choices=["float32", "float16", "bfloat16"],
default="float32",
help="specify the dtype for loading the model",
)
parser.add_argument(
"-o",
"--output_name",
required=False,
default=None,
help="output name of the exported model",
)
parser.add_argument(
"-t",
"--task",
type=str,
choices=list(task_registry.keys()),
default="causal_lm",
help=f"type of task of the model to load from huggingface. supported tasks: {task_registry.keys()}",
)

args = parser.parse_args()
try:
model, prog = task_registry[args.task](args)
except AttributeError:
raise ValueError(f"Unsupported task type {args.task}")

out_name = args.output_name if args.output_name else model.config.model_type
filename = os.path.join("./", f"{out_name}.pte")
with open(filename, "wb") as f:
prog.write_to_file(f)
print(f"Saved exported program to {filename}")


if __name__ == "__main__":
Expand Down
16 changes: 16 additions & 0 deletions extension/export_util/task_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


task_registry = {}


def register_task(task_name):
def decorator(func):
task_registry[task_name] = func
return func

return decorator
Loading