Skip to content

Commit

Permalink
Complete io_config, Pass model to model funcs (microsoft#750)
Browse files Browse the repository at this point in the history
## Describe your changes
`io_config` attribute can take both io_config dict and function names.
These cases are not cleanly handled yet.
There is a special case where it is a string and is called with
`hf_config.model_name`. But this is arbitrary and was initially
implemented to make the whisper example work.

In this PR:
- Handle all possible values for `io_config` attribute. 
- `PyTorchModel.get_io_config` method is added to get the io config from
`io_config` and `hf_config` in decreasing order of priority.
- `io_config`, `component_func` and `dummy_input_funcs` are called with
`self` by the caller model. This makes them all standardized and not be
called with some arbitrary arguments.
- `HFConfig` updated to handle io_config and dummy_input creation from
onnx config itself.

## Checklist before requesting a review
- [x] Add unit tests for this change.
- [x] Make sure all tests can pass.
- [x] Update documents if necessary.
- [x] Lint and apply fixes to your code by running `lintrunner -a`
- [x] Is this a user-facing change? If yes, give a description of this
change to be included in the release notes.

`component_func` and `io_config` functions in `PyTorchModel` are now
called with the calling `OliveModel` itself.
## (Optional) Issue link
  • Loading branch information
jambayk authored Nov 28, 2023
1 parent 3a74d46 commit 2461ffe
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 85 deletions.
6 changes: 3 additions & 3 deletions docs/source/features/huggingface_model_optimization.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,15 @@ You can use your own custom components functions for your model. You will need t
#### Script example
```
# my_script.py
def get_dec_io_config(model_name: str):
def get_dec_io_config(model: OliveModel):
# return your io dict
...
def get_decoder(model_name: str):
def get_decoder(model: OliveModel):
# your component implementation
...
def dummy_inputs_func():
def dummy_inputs_func(model: OliveModel):
# return the dummy input for your component
...
```
Expand Down
4 changes: 2 additions & 2 deletions docs/source/overview/options.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ find more details in [Olive Models](https://microsoft.github.io/Olive/api/models

- `script_dir: [str]` The directory that contains dependencies for the model script.

- `io_config: [Dict[str, Any], IOConfig, str]`: The inputs and outputs information of the model. It can be a dictionary, an IOConfig object or a function string under `model_script`. Basically, it contains following items:
- `io_config: [Dict[str, Any] | IOConfig | str | Callable]`: The inputs and outputs information of the model. It can be a dictionary, an IOConfig object or a function string under `model_script`. Basically, it contains following items:
- `input_names: [List[str]]` The input names of the model.
- `input_types: [List[str]]` The input types of the model.
- `input_shapes: [List[List[int]]]` The input shapes of the model.
Expand All @@ -118,7 +118,7 @@ find more details in [Olive Models](https://microsoft.github.io/Olive/api/models
- `components: [List[HFComponent]]`: HFComponent list:
- `HFComponent`:
- `name: [str]`: Component name. Olive will generate a model class with this str as attribute name.
- `io_config: [str | Dict]`: The io_config of this component. If `str`, Olive will load `io_config` from `model_script`.
- `io_config: [Dict[str, Any] | IOConfig | str | Callable]`: The io_config of this component. If `str`, Olive will load `io_config` from `model_script`.
- `component_func: [str]`: The component function name will be loaded from `model_script`.
- `dummy_inputs_func: [str]`: The dummy input function name will be loaded from `model_script`.

Expand Down
4 changes: 2 additions & 2 deletions examples/llama2/user_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,8 @@ def get_merged_model_dynamic_axes(input_names: List[str], output_names: List[str
return dynamic_axes


def get_merged_decoder_with_past_io_config(model_name):
config = LlamaConfig.from_pretrained(model_name)
def get_merged_decoder_with_past_io_config(model: PyTorchModel):
config = model.get_hf_model_config()

input_names = [
"input_ids",
Expand Down
38 changes: 20 additions & 18 deletions examples/whisper/code/user_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from past_helper import PastKeyValuesHelper
from transformers import WhisperForConditionalGeneration
from whisper_dataset import WhisperDataset
from whisper_decoder import WhisperDecoder, WhisperDecoderInputs
from whisper_encoder_decoder_init import WhisperEncoderDecoderInit, WhisperEncoderDecoderInitInputs

from olive.model import PyTorchModel

def get_encoder_decoder_init(model_name):
model = WhisperForConditionalGeneration.from_pretrained(model_name)

def get_encoder_decoder_init(olive_model: PyTorchModel):
# model is WhisperForConditionalGeneration
model = olive_model.load_model()
return WhisperEncoderDecoderInit(
model,
model,
Expand All @@ -19,13 +21,15 @@ def get_encoder_decoder_init(model_name):
)


def get_decoder(model_name):
model = WhisperForConditionalGeneration.from_pretrained(model_name)
def get_decoder(olive_model: PyTorchModel):
# model is WhisperForConditionalGeneration
model = olive_model.load_model()
return WhisperDecoder(model, model.config)


def get_encdec_io_config(model_name):
model = get_encoder_decoder_init(model_name)
def get_encdec_io_config(olive_model: PyTorchModel):
# model is WhisperEncoderDecoderInit
model = olive_model.load_model()
use_decoder_input_ids = True

inputs = WhisperEncoderDecoderInitInputs.create_dummy(
Expand Down Expand Up @@ -96,13 +100,13 @@ def get_encdec_io_config(model_name):
}


def get_dec_io_config(model_name):
def get_dec_io_config(olive_model: PyTorchModel):
# Fix past disappearing bug - duplicate first past entry
# input_list.insert(2, input_list[2])
model = get_decoder(model_name)
past_names = PastKeyValuesHelper.get_past_names(model.config.decoder_layers, present=False)
present_names = PastKeyValuesHelper.get_past_names(model.config.decoder_layers, present=True)
present_self_names = present_names[: 2 * model.config.decoder_layers]
config = olive_model.get_hf_model_config()
past_names = PastKeyValuesHelper.get_past_names(config.decoder_layers, present=False)
present_names = PastKeyValuesHelper.get_past_names(config.decoder_layers, present=True)
present_self_names = present_names[: 2 * config.decoder_layers]

input_past_names = past_names
output_present_names = present_self_names
Expand Down Expand Up @@ -138,10 +142,9 @@ def get_dec_io_config(model_name):
}


def encoder_decoder_init_dummy_inputs(model):
model = model.load_model()
def encoder_decoder_init_dummy_inputs(olive_model: PyTorchModel):
inputs = WhisperEncoderDecoderInitInputs.create_dummy(
model.config,
olive_model.get_hf_model_config(),
batch_size=2,
encode_sequence_length=3000,
use_decoder_input_ids=True,
Expand All @@ -151,10 +154,9 @@ def encoder_decoder_init_dummy_inputs(model):
return tuple(inputs.to_list())


def decoder_dummy_inputs(model):
model = model.load_model()
def decoder_dummy_inputs(olive_model: PyTorchModel):
inputs = WhisperDecoderInputs.create_dummy(
model.config,
olive_model.get_hf_model_config(),
batch_size=2,
encode_sequence_length=3000,
past_decode_sequence_length=5,
Expand Down
84 changes: 58 additions & 26 deletions olive/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from olive.common.user_module_loader import UserModuleLoader
from olive.constants import Framework, ModelFileFormat
from olive.hardware import AcceleratorLookup, Device
from olive.model.hf_utils import HFConfig, get_hf_model_dummy_input, huggingface_model_loader
from olive.model.hf_utils import HFConfig, huggingface_model_loader
from olive.model.model_config import IOConfig
from olive.resource_path import (
OLIVE_RESOURCE_ANNOTATIONS,
Expand Down Expand Up @@ -478,7 +478,7 @@ def __init__(
model_loader: Union[str, Callable] = None,
model_script: Union[str, Path] = None,
script_dir: Union[str, Path] = None,
io_config: Union[Dict[str, Any], IOConfig, str] = None,
io_config: Union[Dict[str, Any], IOConfig, str, Callable] = None,
dummy_inputs_func: Union[str, Callable] = None,
hf_config: Union[Dict[str, Any], HFConfig] = None,
adapter_path: OLIVE_RESOURCE_ANNOTATIONS = None,
Expand Down Expand Up @@ -527,11 +527,10 @@ def __init__(
), "model_script must be a local file or a string name."

# io config for conversion to onnx
# TODO(trajep): support callable io_config
if isinstance(io_config, str):
user_module_loader = UserModuleLoader(self.model_script, self.script_dir)
io_config = user_module_loader.call_object(io_config, self.hf_config.model_name)
self.io_config = validate_config(io_config, IOConfig).dict() if io_config else None
self.io_config = (
validate_config(io_config, IOConfig).dict() if isinstance(io_config, (IOConfig, dict)) else io_config
)

self.dummy_inputs_func = dummy_inputs_func

self.dummy_inputs = None
Expand Down Expand Up @@ -637,29 +636,71 @@ def prepare_session(
):
return self.load_model().eval()

def _resolve_io_config(self, io_config: Union[Dict[str, Any], IOConfig, str, Callable]) -> Dict[str, Any]:
"""Resolve io_config to a dictionary.
If io_config is a string name or a callable, it will be called to get io_config.
"""
if isinstance(io_config, dict):
# io_config is provided
return io_config

if isinstance(io_config, IOConfig):
# io_config is an IOConfig
return io_config.dict()

if isinstance(io_config, (str, Callable)):
# io_config is a string name or a callable
logger.debug(f"Calling {io_config} to get io_config")
user_module_loader = UserModuleLoader(self.model_script, self.script_dir)
io_config = user_module_loader.call_object(io_config, self)
return validate_config(io_config, IOConfig).dict()

return None

def get_io_config(self) -> Dict[str, Any]:
"""Return io config of the model.
Priority: io_config > hf_config (using onnx_config)
"""
io_config = None
if self.io_config:
# io_config is provided
io_config = self._resolve_io_config(self.io_config)
elif self.hf_config and self.hf_config.task and not self.hf_config.components:
# hf_config is provided
logger.debug("Using hf onnx_config to get io_config")
io_config = self.hf_config.get_io_config(self.model_path)

return io_config

def get_dummy_inputs(self):
"""Return a dummy input for the model."""
if self.dummy_inputs is not None:
return self.dummy_inputs

# Priority: dummy_inputs_func > io_config.input_shapes > hf_config.dataset > onnx_config
dummy_inputs = None
# resolved self.io_config
# won't use self.get_io_config() since we don't want hf_config to be used
resolved_io_config = self._resolve_io_config(self.io_config) or {}
if self.dummy_inputs_func is not None:
logger.debug("Using dummy_inputs_func to get dummy inputs")
user_module_loader = UserModuleLoader(self.model_script, self.script_dir)
dummy_inputs = user_module_loader.call_object(self.dummy_inputs_func, self)
elif self.io_config and self.io_config["input_shapes"]:
elif resolved_io_config.get("input_shapes"):
logger.debug("Using io_config.input_shapes to get dummy inputs")
dummy_inputs, _ = (
# input_types is optional
data_config_template.dummy_data_config_template(
input_shapes=self.io_config["input_shapes"],
input_types=self.io_config.get("input_types"),
input_shapes=resolved_io_config["input_shapes"],
input_types=resolved_io_config.get("input_types"),
)
.to_data_container()
.get_first_batch(data_root_path=None)
)
elif self.hf_config and self.hf_config.model_name and self.hf_config.task:
# need both model_name and task to get dummy inputs
if self.hf_config.dataset:
logger.debug("Using hf_config.dataset to get dummy inputs")
dummy_inputs, _ = (
Expand All @@ -673,12 +714,7 @@ def get_dummy_inputs(self):
)
elif not self.hf_config.components:
logger.debug("Using hf onnx_config to get dummy inputs")
kwargs = {}
if self.hf_config.model_loading_args:
kwargs["trust_remote_code"] = self.hf_config.model_loading_args.trust_remote_code
dummy_inputs = get_hf_model_dummy_input(
self.hf_config.model_name, self.hf_config.task, self.hf_config.feature, **kwargs
)
dummy_inputs = self.hf_config.get_dummy_inputs(self.model_path)

if dummy_inputs is None:
raise ValueError(
Expand Down Expand Up @@ -719,13 +755,7 @@ def get_component(self, component_name: str) -> "PyTorchModel":
model_component = self.hf_config.load_model(self.model_path)
else:
user_module_loader = UserModuleLoader(self.model_script, self.script_dir)
model_component = user_module_loader.call_object(hf_component.component_func, self.hf_config.model_name)

io_config = hf_component.io_config
if isinstance(io_config, str):
user_module_loader = UserModuleLoader(self.model_script, self.script_dir)
io_config = user_module_loader.call_object(hf_component.io_config, self.hf_config.model_name)
io_config = validate_config(io_config, IOConfig)
model_component = user_module_loader.call_object(hf_component.component_func, self)

def model_loader(_):
return model_component
Expand All @@ -735,7 +765,7 @@ def model_loader(_):

return PyTorchModel(
model_loader=model_loader,
io_config=io_config,
io_config=hf_component.io_config,
dummy_inputs_func=hf_component.dummy_inputs_func,
model_script=self.model_script,
script_dir=self.script_dir,
Expand Down Expand Up @@ -1043,7 +1073,7 @@ def __init__(
model_loader: Union[str, Callable] = None,
model_script: Union[str, Path] = None,
script_dir: Union[str, Path] = None,
io_config: Union[Dict[str, Any], IOConfig, str] = None,
io_config: Union[Dict[str, Any], IOConfig, str, Callable] = None,
dummy_inputs_func: Union[str, Callable] = None,
hf_config: Union[Dict[str, Any], HFConfig] = None,
adapter_path: OLIVE_RESOURCE_ANNOTATIONS = None,
Expand All @@ -1062,7 +1092,9 @@ def __init__(
self.model_name_pattern = model_name_pattern
self.num_ranks = num_ranks
self.model_loader = model_loader
self.io_config = io_config
self.io_config = (
validate_config(io_config, IOConfig).dict() if isinstance(io_config, (IOConfig, dict)) else io_config
)
self.dummy_inputs_func = dummy_inputs_func
self.hf_config = hf_config

Expand Down
43 changes: 25 additions & 18 deletions olive/model/hf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@

class HFComponent(ConfigBase):
name: str
# TODO(trajep): support callable io_config
io_config: Union[IOConfig, str, Dict[str, Any]]
io_config: Union[IOConfig, Dict[str, Any], str, Callable]
component_func: Union[str, Callable] = None
dummy_inputs_func: Union[str, Callable]

Expand Down Expand Up @@ -219,11 +218,14 @@ def task_or_model_class_required(cls, v, values):
raise ValueError("Either task or model_class must be specified")
return v

def _get_loading_args(self):
return self.model_loading_args.get_loading_args() if self.model_loading_args else {}

def load_model(self, model_path: str = None):
"""Load model from model_path or model_name."""
model_name_or_path = model_path or self.model_name
loading_args = self._get_loading_args()
logger.info(f"Loading Huggingface model from {model_name_or_path}")
loading_args = self.model_loading_args.get_loading_args() if self.model_loading_args else {}
if self.task:
model = load_huggingface_model_from_task(self.task, model_name_or_path, **loading_args)
elif self.model_class:
Expand All @@ -235,9 +237,19 @@ def load_model(self, model_path: str = None):

def load_model_config(self, model_path: str = None):
"""Load model config from model_path or model_name."""
model_name_or_path = model_path or self.model_name
loading_args = self.model_loading_args.get_loading_args() if self.model_loading_args else {}
return get_hf_model_config(model_name_or_path, **loading_args)
return get_hf_model_config(model_path or self.model_name, **self._get_loading_args())

def get_io_config(self, model_path: str = None):
"""Get IO config for the model."""
return get_hf_model_io_config(
model_path or self.model_name, self.task, self.feature, **self._get_loading_args()
)

def get_dummy_inputs(self, model_path: str = None):
"""Get dummy inputs for the model."""
return get_hf_model_dummy_input(
model_path or self.model_name, self.task, self.feature, **self._get_loading_args()
)


def load_huggingface_model_from_task(task: str, name: str, **kwargs):
Expand Down Expand Up @@ -326,7 +338,7 @@ def patched_supported_features_mapping(*supported_features: str, onnx_config_cls
return mapping


def get_onnx_config(model_name: str, task: str, feature: Optional[str] = None):
def get_onnx_config(model_name: str, task: str, feature: Optional[str] = None, **kwargs):
# pylint: disable=protected-access
from transformers.onnx import FeaturesManager

Expand All @@ -347,7 +359,7 @@ def get_onnx_config(model_name: str, task: str, feature: Optional[str] = None):

# don't want to load the model here since all we need is the config
# model loading is expensive computationally and memory-wise for large models
config = get_hf_model_config(model_name)
config = get_hf_model_config(model_name, **kwargs)
# recreate the logic for FeaturesManager.check_supported_model_or_raise to get the model_onnx_config
# https://github.com/huggingface/transformers/blob/main/src/transformers/onnx/features.py#L712
model_type = config.model_type.replace("_", "-")
Expand All @@ -359,8 +371,8 @@ def get_onnx_config(model_name: str, task: str, feature: Optional[str] = None):
return FeaturesManager.get_config(model_type, feature)(config)


def get_hf_model_io_config(model_name: str, task: str, feature: Optional[str] = None):
model_config = get_onnx_config(model_name, task, feature)
def get_hf_model_io_config(model_name: str, task: str, feature: Optional[str] = None, **kwargs):
model_config = get_onnx_config(model_name, task, feature, **kwargs)
inputs = model_config.inputs
outputs = model_config.outputs
io_config = {}
Expand All @@ -370,14 +382,9 @@ def get_hf_model_io_config(model_name: str, task: str, feature: Optional[str] =
return io_config


def get_hf_model_dummy_input(
model_name: str,
task: str,
feature: Optional[str] = None,
trust_remote_code: Optional[bool] = None,
):
model_config = get_onnx_config(model_name, task, feature)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=trust_remote_code)
def get_hf_model_dummy_input(model_name: str, task: str, feature: Optional[str] = None, **kwargs):
model_config = get_onnx_config(model_name, task, feature, **kwargs)
tokenizer = AutoTokenizer.from_pretrained(model_name, **kwargs)
return model_config.generate_dummy_inputs(tokenizer, framework="pt")


Expand Down
Loading

0 comments on commit 2461ffe

Please sign in to comment.