Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for training apis to support custom ops #16601

Merged
merged 19 commits into from
Jul 14, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
Binary file not shown.
198 changes: 141 additions & 57 deletions orttraining/orttraining/python/orttraining_pybind_state.cc

Large diffs are not rendered by default.

25 changes: 8 additions & 17 deletions orttraining/orttraining/python/training/api/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from onnxruntime.capi import _pybind_state as C
from onnxruntime.capi.onnxruntime_inference_collection import OrtValue, get_ort_device_type
from onnxruntime.capi.onnxruntime_pybind11_state import OrtValueVector
from onnxruntime.capi.onnxruntime_pybind11_state import OrtValueVector, SessionOptions
from onnxruntime.training.api.checkpoint_state import CheckpointState


Expand All @@ -33,6 +33,7 @@ class Module:
state: The checkpoint state object.
eval_model_uri: The path to the evaluation model.
device: The device to run the model on. Default is "cpu".
session_options: The session options to use for the model.
"""

training: bool
Expand All @@ -43,11 +44,13 @@ def __init__(
state: CheckpointState,
eval_model_uri: os.PathLike | None = None,
device: str = "cpu",
session_options: SessionOptions | None = None,
) -> None:
self.training = True
options = device.split(":")
self._device_type = options[0]
device_id = 0 if len(options) < 2 else int(options[1])
self._session_options = session_options if session_options is not None else SessionOptions()

self._device = C.OrtDevice(
get_ort_device_type(self._device_type, device_id),
Expand All @@ -59,6 +62,7 @@ def __init__(
state._state,
os.fspath(eval_model_uri) if eval_model_uri is not None else None,
self._device,
self._session_options,
)
self._state = state

Expand All @@ -70,17 +74,7 @@ def __call__(self, *user_inputs) -> tuple[np.ndarray] | np.ndarray:
Returns:
The outputs of the model.
"""
is_np_input = False
forward_inputs = OrtValueVector()
baijumeswani marked this conversation as resolved.
Show resolved Hide resolved
forward_inputs.reserve(len(user_inputs))
for tensor in user_inputs:
if isinstance(tensor, np.ndarray):
is_np_input = True
forward_inputs.push_back(OrtValue.ortvalue_from_numpy(tensor)._ortvalue)
elif isinstance(tensor, OrtValue):
forward_inputs.push_back(tensor._ortvalue)
else:
raise ValueError(f"Expected input of type: numpy array or OrtValue, actual: {type(tensor)}")
forward_inputs = [user_input for user_input in user_inputs]
baijumeswani marked this conversation as resolved.
Show resolved Hide resolved
fetches = OrtValueVector()

if self.training:
Expand All @@ -89,12 +83,9 @@ def __call__(self, *user_inputs) -> tuple[np.ndarray] | np.ndarray:
self._model.eval_step(forward_inputs, fetches)

if len(fetches) == 1:
if is_np_input:
return fetches[0].numpy()
return fetches[0].numpy()

return fetches[0]

return tuple(val.numpy() for val in fetches) if is_np_input else tuple(fetches)
return tuple(val.numpy() for val in fetches)

def train(self, mode: bool = True) -> Module:
"""Sets the Module in training mode.
Expand Down
4 changes: 3 additions & 1 deletion orttraining/orttraining/python/training/api/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ class Optimizer:
"""

def __init__(self, optimizer_uri: str | os.PathLike, module: Module):
self._optimizer = C.Optimizer(os.fspath(optimizer_uri), module._state._state, module._device)
self._optimizer = C.Optimizer(
os.fspath(optimizer_uri), module._state._state, module._device, module._session_options
)

def step(self) -> None:
"""Updates the model parameters based on the computed gradients.
Expand Down
23 changes: 21 additions & 2 deletions orttraining/orttraining/python/training/artifacts.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import contextlib
import logging
import os
import pathlib
Expand Down Expand Up @@ -61,6 +62,8 @@ def generate_artifacts(
If None, the current working directory is used.
prefix (str): The prefix to be used for the generated artifacts. If not specified, no prefix is used.
ort_format (bool): Whether to save the generated artifacts in ORT format or not. Default is False.
custom_op_library (str | os.PathLike): The path to the custom op library.
If not specified, no custom op library is used.

Raises:
RuntimeError: If the loss provided is neither one of the supported losses nor an instance of `onnxblock.Block`
Expand Down Expand Up @@ -121,14 +124,30 @@ def build(self, *inputs_to_loss):
training_model = None
eval_model = None
model_params = None
with onnxblock.base(model):

custom_op_library = extra_options.get("custom_op_library", None)
if custom_op_library is not None:
logging.info("Custom op library provided: %s", custom_op_library)
custom_op_library = pathlib.Path(custom_op_library)

with onnxblock.base(model), onnxblock.custom_op_library(
custom_op_library
) if custom_op_library is not None else contextlib.nullcontext():
_ = training_block(*[output.name for output in model.graph.output])
training_model, eval_model = training_block.to_model_proto()
model_params = training_block.parameters()

def _export_to_ort_format(model_path, output_dir, extra_options):
if extra_options.get("ort_format", False):
convert_onnx_models_to_ort(model_path, output_dir=output_dir, optimization_styles=[OptimizationStyle.Fixed])
custom_op_library = extra_options.get("custom_op_library", None)
if custom_op_library is not None:
custom_op_library = pathlib.Path(custom_op_library)
convert_onnx_models_to_ort(
model_path,
output_dir=output_dir,
custom_op_library_path=custom_op_library,
optimization_styles=[OptimizationStyle.Fixed],
)

if artifact_directory is None:
artifact_directory = pathlib.Path.cwd()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import onnxruntime.training.onnxblock.optim as optim
from onnxruntime.training.onnxblock.blocks import Block
from onnxruntime.training.onnxblock.checkpoint_utils import load_checkpoint_to_model, save_checkpoint
from onnxruntime.training.onnxblock.model_accessor import base, empty_base
from onnxruntime.training.onnxblock.model_accessor import base, custom_op_library, empty_base
from onnxruntime.training.onnxblock.onnxblock import ForwardBlock, TrainingBlock

__all__ = [
Expand All @@ -21,5 +21,6 @@
"load_checkpoint_to_model",
"save_checkpoint",
"base",
"custom_op_library",
"empty_base",
]
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
# Licensed under the MIT License.

import copy
import os
from typing import List, Optional, Set, Tuple, Union

import onnx

from onnxruntime import SessionOptions
from onnxruntime.capi._pybind_state import GradientGraphBuilder, get_optimized_model


Expand Down Expand Up @@ -66,17 +68,25 @@ def _move_initializers_to_inputs(model: onnx.ModelProto, initializer_names: Opti


def _gradient_model_for(
model: onnx.ModelProto, requires_grad: Set[str], output_names: List[str], loss_name: str
model: onnx.ModelProto,
requires_grad: Set[str],
output_names: List[str],
loss_name: str,
options: Optional[SessionOptions] = None,
) -> onnx.ModelProto:
"""Builds the gradient graph on top of the given input forward only graph."""

builder = GradientGraphBuilder(model.SerializeToString(), set(output_names), requires_grad, loss_name)
builder = GradientGraphBuilder(model.SerializeToString(), set(output_names), requires_grad, loss_name, options)
builder.build()
return onnx.load_from_string(builder.get_model())


def build_gradient_graph(
model: onnx.ModelProto, requires_grad: Set[str], frozen_params: Set[str], output_names: Union[List[str], str]
model: onnx.ModelProto,
requires_grad: Set[str],
frozen_params: Set[str],
output_names: Union[List[str], str],
custom_op_library: Optional[str] = None,
) -> Tuple[onnx.ModelProto, onnx.ModelProto]:
"""Prepare the training model and the eval model.

Expand Down Expand Up @@ -106,10 +116,14 @@ def build_gradient_graph(
eval_model = copy.deepcopy(model)
_disable_training_mode(eval_model)

optimized_model = onnx.load_from_string(get_optimized_model(model.SerializeToString(), requires_grad))
options = SessionOptions()
if custom_op_library is not None:
baijumeswani marked this conversation as resolved.
Show resolved Hide resolved
options.register_custom_ops_library(os.fspath(custom_op_library))

optimized_model = onnx.load_from_string(get_optimized_model(model.SerializeToString(), requires_grad, options))

# Assumption is that the first graph output is the loss output
gradient_model = _gradient_model_for(optimized_model, requires_grad, output_names, output_names[0])
gradient_model = _gradient_model_for(optimized_model, requires_grad, output_names, output_names[0], options)

_reorder_outputs(gradient_model, output_names, requires_grad)

Expand Down
43 changes: 43 additions & 0 deletions orttraining/orttraining/python/training/onnxblock/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from abc import ABC, abstractmethod
from typing import Any, List, Optional

import numpy as np
import onnx

import onnxruntime.training.onnxblock._graph_utils as _graph_utils
Expand Down Expand Up @@ -427,3 +428,45 @@ def build(self, cast_input_name: str):
self.base.graph.node.append(cast_node)

return cast_output_name


class Gemm(Block):
baijumeswani marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
super().__init__()

self._in_features = in_features
self._out_features = out_features
self._alpha = alpha
self._beta = beta

def build(self, gemm_input_name: str):
# Weight initializer
gemm_node_weight_name = _graph_utils.generate_graph_name("gemm.weight")

self.base.graph.initializer.append(
onnx.numpy_helper.from_array(
np.random.randn(self._in_features, self._out_features).astype(np.float32), gemm_node_weight_name
)
)

# Bias initializer
gemm_node_bias_name = _graph_utils.generate_graph_name("gemm.bias")
baijumeswani marked this conversation as resolved.
Show resolved Hide resolved
self.base.graph.initializer.append(
onnx.numpy_helper.from_array(np.random.randn(self._out_features).astype(np.float32), gemm_node_bias_name)
)

gemm_node_input_names = [gemm_input_name, gemm_node_weight_name, gemm_node_bias_name]
gemm_node_output_name = _graph_utils.generate_graph_name("gemm.output")
gemm_node_output_names = [gemm_node_output_name]
gemm_node = onnx.helper.make_node(
"Gemm",
gemm_node_input_names,
gemm_node_output_names,
_graph_utils.generate_graph_name("Gemm"),
alpha=self._alpha,
beta=self._beta,
)

self.base.graph.node.append(gemm_node)

return gemm_node_output_name
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

from __future__ import annotations

import copy
import os
from contextlib import contextmanager
from typing import Optional

import onnx

Expand All @@ -29,10 +31,11 @@ def model(self) -> onnx.ModelProto:
return self._model


# This variable resides in the global namespace.
# These variable resides in the global namespace.
# Different methods can access this global model and manipulate it.
# Its construction and destruction is managed by the base and empty_base contextmanagers
_GLOBAL_ACCESSOR = None
_GLOBAL_CUSTOM_OP_LIBRARY = None


@contextmanager
Expand Down Expand Up @@ -74,7 +77,7 @@ def base(model: onnx.ModelProto):


@contextmanager
def empty_base(opset_version: Optional[int] = None):
def empty_base(opset_version: int | None = None):
"""Registers an empty base model to be manipulated by the onnx blocks.

Example:
Expand All @@ -89,8 +92,7 @@ def empty_base(opset_version: Optional[int] = None):
model_handle.

Args:
opset_version (int, optional): The opset version to use for the model.
Defaults to onnx.defs.onnx_opset_version()
opset_version: The opset version to use for the model. Defaults to onnx.defs.onnx_opset_version()

Returns:
ModelAccessor: The model accessor that contains the modified model.
Expand All @@ -115,3 +117,35 @@ def empty_base(opset_version: Optional[int] = None):
yield _GLOBAL_ACCESSOR
finally:
_GLOBAL_ACCESSOR = None


@contextmanager
def custom_op_library(custom_op_library_path: os.PathLike):
"""Registers the custom op library to be used by the onnx blocks.

Example:
>>> with onnxblock.custom_op_library(custom_op_library_path):
>>> # manipulate the model using blocks
>>> ...

In this example, custom_op_library will register the given input custom op library path to be used
during the model manipulation (gradient graph building and optimization).

Args:
custom_op_library_path: The path to the custom op library.

Returns:
ModelAccessor: The model accessor that contains the modified model.
"""
global _GLOBAL_CUSTOM_OP_LIBRARY # pylint: disable=global-statement # noqa: PLW0603
if _GLOBAL_CUSTOM_OP_LIBRARY is not None:
raise RuntimeError("CustomOp library already set. Cannot set multiple custom op libraries.")

if not os.path.exists(custom_op_library_path):
raise RuntimeError(f"Custom op library path {custom_op_library_path} does not exist.")

_GLOBAL_CUSTOM_OP_LIBRARY = copy.copy(custom_op_library_path) # noqa: PLW0603
baijumeswani marked this conversation as resolved.
Show resolved Hide resolved
try:
yield _GLOBAL_CUSTOM_OP_LIBRARY
finally:
_GLOBAL_CUSTOM_OP_LIBRARY = None

Check notice

Code scanning / CodeQL

Unused global variable

The global variable '_GLOBAL_CUSTOM_OP_LIBRARY' is not used.
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,7 @@ def __call__(self, *args, **kwargs):
# The order of model inputs after gradient graph building is: user inputs, model parameters as inputs
# The order of the model outputs is: user outputs, model parameter gradients (in the order of parameter inputs)
self._training_model, self._eval_model = _training_graph_utils.build_gradient_graph(
model,
self._requires_grad,
self._frozen_params,
output,
model, self._requires_grad, self._frozen_params, output, accessor._GLOBAL_CUSTOM_OP_LIBRARY
)

_training_graph_utils.build_gradient_accumulation_graph(self._training_model, self._requires_grad)
Expand Down
Loading