diff --git a/examples/phi2/phi2_optimize.json b/examples/phi2/phi2_optimize.json index ef9860f0a..9ec3e048a 100644 --- a/examples/phi2/phi2_optimize.json +++ b/examples/phi2/phi2_optimize.json @@ -90,7 +90,7 @@ ], "engine": { "search_strategy": false, - "evaluate_input_model": false, + "evaluate_input_model": true, "evaluator": "common_evaluator", "target": "local_system", "cache_dir": "cache", diff --git a/examples/phi2/user_script.py b/examples/phi2/user_script.py index 8b0c44893..e5387dea9 100644 --- a/examples/phi2/user_script.py +++ b/examples/phi2/user_script.py @@ -10,6 +10,8 @@ import torch from transformers import AutoConfig, AutoModelForCausalLM +from olive.constants import Framework + if TYPE_CHECKING: from transformers import PhiConfig @@ -127,8 +129,10 @@ def get_io_config(model): def create_dataloader(data_dir, batch_size, *args, **kwargs): sequence_length, past_sequence_length = 8, 0 max_sequence_length = 512 + model_framework = kwargs.get("model_framework", Framework.PYTORCH) + engine = "ort" if model_framework == Framework.ONNX else "pt" - return RandomDataLoader(batch_size, sequence_length, past_sequence_length, max_sequence_length, engine="ort") + return RandomDataLoader(batch_size, sequence_length, past_sequence_length, max_sequence_length, engine=engine) class RandomDataLoader: