Skip to content

Commit

Permalink
Fix perftuning to meet DLIS integration requirements (microsoft#820)
Browse files Browse the repository at this point in the history
## Describe your changes

* Change the execution provider priority to use the parameters instead
of that one in inference_settings.
Bug description: In DLIS scenarios, the inferenece_settings is something
like:
{'execution_provider': [('MIGraphXExecutionProvider', {})]} and the
provider_lists specified in perf_tuning is: ['ROCMExecutionProvider',
'MIGraphXExecutionProvider'].

Originally, the inference_settings priority is higher than the
execution_providers specified in arguments. As the result, when calling
OnnxEvaluator.evaluate with the provider_list in baseline evaluation,
the execution provider will be used as ["MIGraphXExecutionProvider"].
Later, the OnnxEvaluator.disable_ort_fallback(session,
execution_providers) will check whether the underlying session's
providers(MIGraphXExecutionProvider, CPUExecutionProvider) contains the
execution_providers(ROCMExecutionProvider, MIGraphXExecutionProvider).
In this case, the ROCMExecutionProvider is excluded when creating
inference session. and the OliveEvaluationError exception will be
raised.

To fix this issue, we need move the check_ort_fallback into
prepare_session in onnx model.
perf tuning will always pass None as the execution provider to evaluate
to make the logic clear

* Change perf tuning to always has execution provider info regardless of
whether the best result is baseline or from perf tuning.
* Change perf tuning to accept both provider list and provider list with
provider options.
* Add force_evaluate_other_eps to force run all other EPs if the current
EP is not same with the additional EPs
* Change OnnxModel.prepare_session to accept both execution provider
list and (execution provider, provider options) list
In this way, the EP/provider option handling logic is normalized.
* Add MIGraphXExecutionProvider to the accelerator map
* Fix the critical bug when creating inference session that doesn't
respect provider options in ort_inference.py
* Add io_bind inference for accuracy since illegal memory access in cuda
will be thrown if enable_cuda_graph.
If io_bind is not enabled and enable_cuda_graph is True, the following
errors will be raised:

onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1
: FAIL : CUDA failure 700: an illegal memory access was encountered ;
GPU=0 ; hostname=c93f1847c000000 ;
file=/onnxruntime_src/onnxruntime/core/providers/cuda/cuda_graph.cc ;
line=49 ; expr=cudaGraphLaunch(graph_exec_, stream_);
[E:onnxruntime:Default, cuda_call.cc:116 CudaCall] CUDA failure 700: an
illegal memory access was encountered ; GPU=0 ; hostname=c93f1847c000000
;
file=/onnxruntime_src/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
; line=286 ; expr=cudaStreamDestroy(stream_);
[E:onnxruntime:Default, cuda_call.cc:116 CudaCall] CUDNN failure 4:
CUDNN_STATUS_INTERNAL_ERROR ; GPU=0 ; hostname=c93f1847c000000 ;
file=/onnxruntime_src/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
; line=181 ; expr=cudnnDestroy(cudnn_handle_);


## Checklist before requesting a review
- [x] Add unit tests for this change.
- [x] Make sure all tests can pass.
- [ ] Update documents if necessary.
- [x] Lint and apply fixes to your code by running `lintrunner -a`
- [x] Is this a user-facing change? If yes, give a description of this
change to be included in the release notes.

## (Optional) Issue link
  • Loading branch information
guotuofeng authored Dec 21, 2023
1 parent 69c2aa7 commit 855be5d
Show file tree
Hide file tree
Showing 11 changed files with 763 additions and 192 deletions.
15 changes: 14 additions & 1 deletion docs/source/features/passes/onnx.md
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,20 @@ improve performance.
"config": {
"user_script": "user_script.py",
"dataloader_func": "create_dataloader",
"batch_size": 1
"batch_size": 1,
"providers_list" : [
[
"CUDAExecutionProvider",
{
"device_id": 0,
"arena_extend_strategy": "kNextPowerOfTwo",
"gpu_mem_limit": 2147483648, // 2 * 1024 * 1024 * 1024,
"cudnn_conv_algo_search": "EXHAUSTIVE",
"do_copy_in_default_stream": true,
},
],
"CPUExecutionProvider",
]
}
}
```
Expand Down
8 changes: 4 additions & 4 deletions olive/common/ort_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ def get_ort_inference_session(
for key, value in extra_session_config.items():
sess_options.add_session_config_entry(key, value)

if isinstance(execution_provider, list):
# execution_provider may be a list of tuples/lists where the first item in each tuple is the EP name
execution_provider = [i[0] if isinstance(i, (tuple, list)) else i for i in execution_provider]
elif isinstance(execution_provider, str):
if isinstance(execution_provider, str):
execution_provider = [execution_provider]
else:
# execution providers should be list[str]
assert isinstance(execution_provider, list) and all(isinstance(ep, str) for ep in execution_provider)

for idx, ep in enumerate(execution_provider):
if ep == "QNNExecutionProvider":
Expand Down
8 changes: 1 addition & 7 deletions olive/evaluator/metric_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,12 @@
"input_shapes": ConfigParam(type_=List),
"input_types": ConfigParam(type_=List),
"shared_kv_buffer": ConfigParam(type_=bool, default_value=False),
"io_bind": ConfigParam(type_=bool, default_value=False),
}

_common_user_config_validators = {}

_type_to_user_config = {
"latency": {
# TODO(anyone): extract io_bind to a _common_user_config
"io_bind": ConfigParam(type_=bool, default_value=False),
},
"throughput": {
"io_bind": ConfigParam(type_=bool, default_value=False),
},
"accuracy": {
"post_processing_func": ConfigParam(type_=Union[Callable, str], category=ParamCategory.OBJECT),
},
Expand Down
141 changes: 61 additions & 80 deletions olive/evaluator/olive_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
joint_metric_key,
)
from olive.evaluator.metric_backend import MetricBackend
from olive.exception import OliveEvaluationError
from olive.hardware import Device
from olive.model import (
DistributedOnnxModelHandler,
Expand All @@ -45,6 +44,7 @@
SNPEModelHandler,
)
from olive.model.config.io_config import is_io_config_static
from olive.model.utils.onnx_utils import bind_input_data, bind_output_data, prepare_io_bindings
from olive.snpe.data_loader import SNPECommonDataLoader, SNPEDataLoader

logger = logging.getLogger(__name__)
Expand All @@ -70,16 +70,27 @@ def __init_subclass__(cls, framework: Framework, **kwargs) -> None:
def __init__(self):
pass

def get_inference_settings(self, metric: Metric) -> Dict[str, Any]:
@classmethod
def get_inference_settings(cls, metric: Metric) -> Dict[str, Any]:
# user.config.inference_settings > model.inference_settings > default inference_settings
# when user.config.inference_settings is None, the model.inference_settings
# will be used in model.prepare_session(..)
return (
metric.user_config.inference_settings.get(self.framework.lower())
metric.user_config.inference_settings.get(cls.framework.lower())
if metric.user_config.inference_settings
else None
)

@classmethod
def io_bind_enabled(cls, metric: Metric, inference_settings: Dict) -> bool:
if metric.user_config.io_bind:
return True

if inference_settings and inference_settings.get("io_bind"):
return True

return False

@abstractmethod
def _inference(
self,
Expand Down Expand Up @@ -404,55 +415,6 @@ def format_input(input_data, io_config):
if k in input_names
}

@staticmethod
def prepare_io_bindings(
session, input_data, device, device_id: int = 0, shared_kv_buffer: bool = False, kv_cache_ortvalues: dict = None
):
"""Convert input from numpy array to OrtValue.
session: ONNXRuntime session
input_data: dict of input data, value is numpy array
device: olive device
device_id: 0 by default. TODO(trajep): support user to specified device id
shared_kv_buffer: whether to share the key/value buffer across multiple runs, it is False by default,
and only used when we observe kv cache and fp16 is used.
TODO(trajep): how shared_kv_buffer works with generation task
"""
from onnxruntime import OrtValue

use_fp16 = any(v.dtype == np.float16 for v in input_data.values())
io_bind_op = session.io_binding()
io_bind_device = "cuda" if device == "gpu" else "cpu"

if shared_kv_buffer:
kv_cache_ortvalues = kv_cache_ortvalues or {}

for k, v in input_data.items():
# "cache": from microsoft llama model" https://github.com/microsoft/Llama-2-Onnx#before-you-start
# "past_key_values": from huggingface llama2 https://huggingface.co/meta-llama/Llama-2-13b-hf
if shared_kv_buffer and use_fp16 and ("cache" in k or "past_key_values" in k):
if k not in kv_cache_ortvalues:
kv_cache_ortvalues[k] = OrtValue.ortvalue_from_numpy(v, io_bind_device, device_id)
else:
kv_cache_ortvalues[k].update_inplace(v)
ort_v = kv_cache_ortvalues[k]
else:
ort_v = OrtValue.ortvalue_from_numpy(v, io_bind_device, device_id)
io_bind_op.bind_ortvalue_input(k, ort_v)

for item in session.get_outputs():
name = item.name
# "out": from microsoft llama model" https://github.com/microsoft/Llama-2-Onnx#before-you-start
# "present": from huggingface llama2 https://huggingface.co/meta-llama/Llama-2-13b-hf
if shared_kv_buffer and use_fp16 and ("out" in name or "present" in name):
# Bind present KV cache outputs to past KV cache inputs in order to use shared buffer
input_name = name.replace("out", "cache").replace("present", "past_key_values")
io_bind_op.bind_ortvalue_output(name, kv_cache_ortvalues[input_name])
else:
io_bind_op.bind_output(name, io_bind_device)

return io_bind_op

def _inference(
self,
model: ONNXModelHandler,
Expand All @@ -468,19 +430,53 @@ def _inference(
execution_providers=execution_providers,
)

OnnxEvaluator.disable_ort_fallback(session, execution_providers)

io_config = model.get_io_config()

preds = []
targets = []
logits = []
logits_dict = collections.defaultdict(list)
output_names = io_config["output_names"]
io_bind = self.io_bind_enabled(metric, model.inference_settings)
device = "cuda" if device == "gpu" else "cpu"
if io_bind and device == "cuda":
io_binding = session.io_binding()
input_data, _ = next(iter(dataloader))
input_dict = OnnxEvaluator.format_input(input_data, io_config)
use_fp16 = any(v.dtype == np.float16 for v in input_data.values())
# no deepcopy for kv_cache_ortvalues, will update the value inplace and keep it shared across runs
if metric.user_config.shared_kv_buffer:
kv_cache_ortvalues = {}
else:
kv_cache_ortvalues = None
bind_output_data(
io_binding,
session.get_outputs(),
use_fp16,
device,
shared_kv_buffer=metric.user_config.shared_kv_buffer,
kv_cache_ortvalues=kv_cache_ortvalues,
)

is_single_tensor_output = len(output_names) == 1
for input_data, labels in dataloader:
input_dict = OnnxEvaluator.format_input(input_data, io_config)
res = session.run(input_feed=input_dict, output_names=None)
if io_bind and device == "cuda":
bind_input_data(
io_binding,
input_dict,
use_fp16,
device,
shared_kv_buffer=metric.user_config.shared_kv_buffer,
kv_cache_ortvalues=kv_cache_ortvalues,
)
io_binding.synchronize_inputs()
session.run_with_iobinding(io_binding)
io_binding.synchronize_outputs()
res = [i.numpy() for i in io_binding.get_outputs()]
io_binding.clear_binding_inputs()
else:
res = session.run(input_feed=input_dict, output_names=None)
if is_single_tensor_output:
result = torch.Tensor(res[0])
else:
Expand Down Expand Up @@ -517,7 +513,7 @@ def _evaluate_onnx_accuracy(

def _evaluate_onnx_latency(
self,
model: OliveModelHandler,
model: ONNXModelHandler,
metric: Metric,
dataloader: Dataset,
post_func=None,
Expand All @@ -531,17 +527,16 @@ def _evaluate_onnx_latency(
device=device,
execution_providers=execution_providers,
)
OnnxEvaluator.disable_ort_fallback(session, execution_providers)

io_config = model.get_io_config()

input_data, _ = next(iter(dataloader))
input_dict = OnnxEvaluator.format_input(input_data, io_config)
# no deepcopy for kv_cache_ortvalues, will update the value inplace and keep it shared across runs
kv_cache_ortvalues = {} if metric.user_config.shared_kv_buffer else None

if metric.user_config.io_bind:
io_bind_op = OnnxEvaluator.prepare_io_bindings(
io_bind = self.io_bind_enabled(metric, model.inference_settings)
if io_bind:
io_bind_op = prepare_io_bindings(
session,
input_dict,
device,
Expand All @@ -550,7 +545,7 @@ def _evaluate_onnx_latency(
)

for _ in range(warmup_num):
if metric.user_config.io_bind:
if io_bind:
io_bind_op.synchronize_inputs()
session.run_with_iobinding(io_bind_op)
io_bind_op.synchronize_outputs()
Expand All @@ -559,7 +554,7 @@ def _evaluate_onnx_latency(

latencies = []
for _ in range(repeat_test_num):
if metric.user_config.io_bind:
if io_bind:
io_bind_op.synchronize_inputs()
# count the time after data are all in gpu after data copy
t = time.perf_counter()
Expand Down Expand Up @@ -694,8 +689,9 @@ def _evaluate_distributed_latency_worker(data_root, config) -> List[float]:
input_feed = OnnxEvaluator.format_input(input_feed, io_config)
kv_cache_ortvalues = {} if metric.user_config.shared_kv_buffer else None

if metric.user_config.io_bind:
io_bind_op = OnnxEvaluator.prepare_io_bindings(
io_bind = OnnxEvaluator.io_bind_enabled(metric, model.inference_settings)
if io_bind:
io_bind_op = prepare_io_bindings(
session,
input_feed,
Device.GPU,
Expand All @@ -706,7 +702,7 @@ def _evaluate_distributed_latency_worker(data_root, config) -> List[float]:
for i in range(warmup_num + repeat_test_num):
MPI.COMM_WORLD.barrier() # Synchronize before starting each run
start_time = time.perf_counter()
if metric.user_config.io_bind:
if io_bind:
session.run_with_iobinding(io_bind_op)
else:
session.run(input_feed=input_feed, output_names=None)
Expand Down Expand Up @@ -788,21 +784,6 @@ def _evaluate_raw_latency(
else:
raise TypeError(f"Cannot evaluate latency for model of type: {type(model)}")

@staticmethod
def disable_ort_fallback(session, execution_providers):
# pylint: disable=protected-access
if execution_providers:
assert isinstance(execution_providers, (str, list))
execution_providers = [execution_providers] if isinstance(execution_providers, str) else execution_providers
session_providers = session.get_providers()
for ep in execution_providers:
if ep not in session_providers:
raise OliveEvaluationError(
f"The onnxruntime fallback happens. {ep} is not in the session providers {session_providers}."
f" session._enable_fallback = {session._enable_fallback}"
)
session.disable_fallback()


class PyTorchEvaluator(OliveEvaluator, framework=Framework.PYTORCH):
def __init__(self):
Expand Down
1 change: 1 addition & 0 deletions olive/hardware/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class AcceleratorLookup:
"DmlExecutionProvider",
"CUDAExecutionProvider",
"ROCMExecutionProvider",
"MIGraphXExecutionProvider",
"TensorrtExecutionProvider",
"CPUExecutionProvider",
"OpenVINOExecutionProvider",
Expand Down
36 changes: 24 additions & 12 deletions olive/model/handler/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from olive.model.config.registry import model_handler_registry
from olive.model.handler.base import OliveModelHandler
from olive.model.handler.mixin import OnnxEpValidateMixin, OnnxGraphMixin
from olive.model.utils.onnx_utils import get_onnx_file_path
from olive.model.utils.onnx_utils import check_and_normalize_provider_args, check_ort_fallback, get_onnx_file_path
from olive.resource_path import OLIVE_RESOURCE_ANNOTATIONS

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -74,29 +74,41 @@ def prepare_session(
execution_providers: Union[str, List[str]] = None,
rank: Optional[int] = None,
):
import onnxruntime as ort

# user provided inference_settings > model's inference_settings > default settings
inference_settings = inference_settings or self.inference_settings or {}
# deep copy to avoid modifying the original settings
inference_settings = deepcopy(inference_settings)

# if user doesn't not provide ep list, use default value([ep]). Otherwise, use the user's ep list
# user provided ep list > eps given by arguments > default eps
execution_providers = inference_settings.get("execution_provider") or execution_providers
if inference_settings.get("execution_provider") is not None:
execution_providers = inference_settings.get("execution_provider")
provider_options = inference_settings.get("provider_options")
else:
provider_options = None

if not execution_providers:
execution_providers = self.get_default_execution_providers(device)
elif isinstance(execution_providers, str):
provider_options = None
elif isinstance(execution_providers, (str, tuple)):
execution_providers = [execution_providers]
else:
# the execution_providers is a list
pass
inference_settings["execution_provider"] = execution_providers

if (device == Device.GPU) and (rank is not None) and not inference_settings.get("provider_options"):
inference_settings["provider_options"] = [
{"device_id": str(rank)} if ep == "CUDAExecutionProvider" else {} for ep in execution_providers
]
# split the execution_providers and provider_options
execution_providers, provider_options = check_and_normalize_provider_args(
execution_providers, provider_options, ort.get_available_providers()
)

return get_ort_inference_session(self.model_path, inference_settings, self.use_ort_extensions)
if (device == Device.GPU) and (rank is not None):
for i, ep in enumerate(execution_providers):
if ep == "CUDAExecutionProvider" and not provider_options[i]:
provider_options[i] = {"device_id": str(rank)}
inference_settings["execution_provider"] = execution_providers
inference_settings["provider_options"] = provider_options
session = get_ort_inference_session(self.model_path, inference_settings, self.use_ort_extensions)
check_ort_fallback(session, execution_providers)
return session

def get_default_execution_providers(self, device: Device):
# return firstly available ep as ort default ep
Expand Down
Loading

0 comments on commit 855be5d

Please sign in to comment.