Skip to content

Commit

Permalink
support original model evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
guotuofeng committed Jan 9, 2024
1 parent a9f4cf7 commit 36e2c51
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
2 changes: 1 addition & 1 deletion examples/phi2/phi2_optimize.json
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
],
"engine": {
"search_strategy": false,
"evaluate_input_model": false,
"evaluate_input_model": true,
"evaluator": "common_evaluator",
"target": "local_system",
"cache_dir": "cache",
Expand Down
6 changes: 5 additions & 1 deletion examples/phi2/user_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import torch
from transformers import AutoConfig, AutoModelForCausalLM

from olive.constants import Framework

if TYPE_CHECKING:
from transformers import PhiConfig

Expand Down Expand Up @@ -127,8 +129,10 @@ def get_io_config(model):
def create_dataloader(data_dir, batch_size, *args, **kwargs):
sequence_length, past_sequence_length = 8, 0
max_sequence_length = 512
model_framework = kwargs.get("model_framework", Framework.PYTORCH)
engine = "ort" if model_framework == Framework.ONNX else "pt"

return RandomDataLoader(batch_size, sequence_length, past_sequence_length, max_sequence_length, engine="ort")
return RandomDataLoader(batch_size, sequence_length, past_sequence_length, max_sequence_length, engine=engine)


class RandomDataLoader:
Expand Down

0 comments on commit 36e2c51

Please sign in to comment.