-
Notifications
You must be signed in to change notification settings - Fork 89
4. API Guide
The model_explorer
package provides the following APIs to let you visualize models, and create and visualize custom node data using Model Explorer from python code. Make sure to install it first by following the installation guide.
Table of Contents
model_explorer
provides convenient APIs to quickly visualize models from files or from a PyTorch module, and a lower level API to visualize models from multiple sources.
Usage:
import model_explorer
model_explorer.visualize('/path/to/model/file')
API reference:
visualize( model_paths=[], host='localhost', port=8080, extensions: list[str] = [], node_data: Union[NodeDataInfo, list[NodeDataInfo]] = [], colab_height=850, reuse_server: bool = False, reuse_server_host: str = DEFAULT_HOST, reuse_server_port: Union[int, None] = None)
Starts the Model Explorer local server and visualizes the models by the given paths.
When you've passed multiple models to model_paths
, the visualization page will initially show the largest subgraph from the first model. You can easily switch between models and their subgraphs using the model graph selector in the top-right corner.
Args:
-
model_paths: str|list[str]
: a model path or a list of model paths to visualize. -
host: str
: The host of the server. Default to localhost. -
port: int
: The port of the server. Default to 8080. -
extensions: list[str]
: List of extension names to be run with model explorer. -
node_data: list[NodeDataInfo]|NodeDataInfo
: The node data or a list of node data to display. Example:node_data={'name': 'my node data', 'node_data': node_data_json_str}
. -
colab_height: int
: The height of the embedded iFrame when running in colab. Default to 850. -
reuse_server: bool
: Whether to reuse the current server/browser tab(s) to visualize. -
reuse_server_host: str
: The host of the server to reuse. Default to localhost. -
reuse_server_port: int
: The port of the server to reuse. If unspecified, it will try to find a running server from port 8080 to 8099.
Visualizing PyTorch models requires a slightly different approach due to their lack of a standard serialization format. Model Explorer offers a specialized API to visualize PyTorch models directly, using the ExportedProgram
from torch.export.export
.
Usage:
import model_explorer
import torch
import torchvision
# Prepare a PyTorch model and its inputs.
model = torchvision.models.mobilenet_v2().eval()
inputs = (torch.rand([1, 3, 224, 224]),)
ep = torch.export.export(model, inputs)
# Visualize.
model_explorer.visualize_pytorch('mobilenet', exported_program=ep)
API reference:
visualize_pytorch( name, exported_program, settings={'const_element_count_limit': 16}, host='localhost', port=8080, colab_height=850)
Starts the Model Explorer local server and visualizes the given PyTorch ExportedProgram
.
Args:
-
name: str
: The name of the model for display purpose. -
exported_program: torch.export.ExportedProgram
: TheExportedProgram
fromtorch.export.export
. -
settings: Dict
: Key-value pairs of settings. For now it only supports one setting with the keyconst_element_count_limit
that controls how many values should be returned for a const from the adapter. -
see the
visualize
function above forhost
,port
,extensions
,node_data
, andcolab_height
Sometimes you want to load models from files as well as a PyTorch model at the same time into Model Explorer. To accomplish this, you will need to use the following lower level APIs. The basic steps are:
- Create a
config
object and add models to it. - Pass it to
visualize_from_config
API.
Usage:
import model_explorer
import torch
import torchvision
# Prepare a PyTorch model and its inputs.
model = torchvision.models.mobilenet_v2().eval()
inputs = (torch.rand([1, 3, 224, 224]),)
# Create Model Explorer config.
config = model_explorer.config()
# Add model file path and PyTorch model to the config.
config.add_model_from_path('/path/to/model').add_model_from_pytorch('mobilenet', model, inputs)
# Visualize with config.
model_explorer.visualize_from_config(config=config)
API reference:
-
config() -> ModelExplorerConfig
Creates a new Model Explorer config object.
-
ModelExplorerConfig.add_model_from_path(self, path) -> ModelExplorerConfig
Adds a model path to the config.
Args:
-
path:str
: the model file path to add.
-
-
ModelExplorerConfig.add_model_from_pytorch( self, name, exported_program, settings={'const_element_count_limit': 16}) -> ModelExplorerConfig
Adds a PyTorch model with inputs to the config. After calling this method, Model Explorer will invoke the internal adapter to convert the given PyTorch model into Model Explorer graphs. This process might take some time depending on the complexity of the model.
Args:
-
name: str
: The name of the model for display purpose. -
See
visualize_pytorch
above for theexported_program
and thesettings
parameter.
-
-
ModelExplorerConfig.set_reuse_server( self, server_host: str = 'localhost', server_port: Union[int, None] = None) -> ModelExplorerConfig
Makes it to reuse the existing server instead of starting a new one.
Args:
-
server_host: str
: The host of the server to reuse. -
server_port: int|None
: The port of the server to reuse. If unspecified, it will try to find a running server from port 8080 to 8099.
-
-
-
visualize_from_config(config=None, host='localhost', port=8080, colab_height=850)
Starts the visualization from the given config.
Args:
-
config: ModelExplorerConfig|None
: the object that stores the models to be visualized. -
see the
visualize
function above forhost
,port
, andcolab_height
-
model_explorer
provides APIs to create custom node data and visualize it in a model graph. For more info about how custom node data works, see the user guide.
We provide a set of data classes to help you build custom node data. See the comments in node_data_builder.py as the official documentation.
From a high level, the custom node data has the following structure:
-
ModelNodeData
: The top-level container storing all the data for a model. It consists one or moreGraphNodeData
objects indexed by graph ids. -
GraphNodeData
: Holds the data for a specific graph within the model. It includes:-
results
: Stores the custom node values, indexed by either node ids or output tensor names. -
thresholds
orgradient
: color configurations that associate each node value with a corresponding node background color or label color, enabling visual representation of the data.
-
Usage:
from model_explorer import node_data_builder as ndb
# Populate values for the main graph in a model.
main_graph_results: dict[str, ndb.NodeDataResult] = {}
main_graph_results['node_id1'] = ndb.NodeDataResult(value=100)
main_graph_results['node_id2'] = ndb.NodeDataResult(value=200)
main_graph_results['any/output/tensor/name/'] = ndb.NodeDataResult(value=300)
# Create a gradient color mapping.
#
# The minimum value in `main_graph_results` maps to the color with stop=0.
# The maximum value in `main_graph_results` maps to the color with stop=1.
# Other values maps to a interpolated color in-between.
gradient: list[ndb.GradientItem] = [
ndb.GradientItem(stop=0, bgColor='yellow'),
ndb.GradientItem(stop=1, bgColor='red'),
]
# Construct the data for the main graph.
main_graph_data = ndb.GraphNodeData(
results=main_graph_results, gradient=gradient)
# Construct the data for the model.
model_data = ndb.ModelNodeData(graphsData={'main': main_graph_data})
# You can save the data to a json file.
model_data.save_to_file('path/to/file.json')
You can visualize the custom node data from data classes (see previous section above), from a json string, or from a json file (see save_to_file
above). The basic steps are:
- Create a
config
object and add various custom node data sources to it. - Pass it to
visualize_from_config
API.
Usage:
import model_explorer
from model_explorer import node_data_builder as ndb
# Create a `ModelNodeData` as shown in previous section.
model_node_data = ...
# Create a config.
config = model_explorer.config()
# Add model and custom node data to it.
(config
.add_model_from_path('/path/to/a/model')
# Add node data from a json file.
# A node data json file can be generated by calling `ModelNodeData.save_to_file`
.add_node_data_from_path('/path/to/node_data.json')
# Add node data from data class object
.add_node_data('my data', model_node_data))
# Add node data from a json string (the content of `ModelNodeData.save_to_file`)
.add_node_data('my data 2', model_node_data_json_str))
# Visualize
model_explorer.visualize_from_config(config)