Skip to content

4. API Guide

Jing Jin edited this page Sep 23, 2024 · 14 revisions

The model_explorer package provides the following APIs to let you visualize models, and create and visualize custom node data using Model Explorer from python code. Make sure to install it first by following the installation guide.


Table of Contents


Visualize models

model_explorer provides convenient APIs to quickly visualize models from files or from a PyTorch module, and a lower level API to visualize models from multiple sources.

Visualize models from files

Usage:

import model_explorer

model_explorer.visualize('/path/to/model/file')

API reference:

visualize(
  model_paths=[],
  host='localhost',
  port=8080,
  extensions: list[str] = [],
  node_data: Union[NodeDataInfo, list[NodeDataInfo]] = [],
  colab_height=850,
  reuse_server: bool = False,
  reuse_server_host: str = DEFAULT_HOST,
  reuse_server_port: Union[int, None] = None)

Starts the Model Explorer local server and visualizes the models by the given paths.

When you've passed multiple models to model_paths, the visualization page will initially show the largest subgraph from the first model. You can easily switch between models and their subgraphs using the model graph selector in the top-right corner.

Args:

  • model_paths: str|list[str]: a model path or a list of model paths to visualize.
  • host: str: The host of the server. Default to localhost.
  • port: int: The port of the server. Default to 8080.
  • extensions: list[str]: List of extension names to be run with model explorer.
  • node_data: list[NodeDataInfo]|NodeDataInfo: The node data or a list of node data to display. Example: node_data={'name': 'my node data', 'node_data': node_data_json_str}.
  • colab_height: int: The height of the embedded iFrame when running in colab. Default to 850.
  • reuse_server: bool: Whether to reuse the current server/browser tab(s) to visualize.
  • reuse_server_host: str: The host of the server to reuse. Default to localhost.
  • reuse_server_port: int: The port of the server to reuse. If unspecified, it will try to find a running server from port 8080 to 8099.

Visualize PyTorch models

Visualizing PyTorch models requires a slightly different approach due to their lack of a standard serialization format. Model Explorer offers a specialized API to visualize PyTorch models directly, using the ExportedProgram from torch.export.export.

Usage:

import model_explorer
import torch
import torchvision

# Prepare a PyTorch model and its inputs.
model = torchvision.models.mobilenet_v2().eval()
inputs = (torch.rand([1, 3, 224, 224]),)
ep = torch.export.export(model, inputs)

# Visualize.
model_explorer.visualize_pytorch('mobilenet', exported_program=ep)

API reference:

visualize_pytorch(
    name,
    exported_program,
    settings={'const_element_count_limit': 16},
    host='localhost', port=8080, colab_height=850)

Starts the Model Explorer local server and visualizes the given PyTorch ExportedProgram.

Args:

  • name: str: The name of the model for display purpose.
  • exported_program: torch.export.ExportedProgram: The ExportedProgram from torch.export.export.
  • settings: Dict: Key-value pairs of settings. For now it only supports one setting with the key const_element_count_limit that controls how many values should be returned for a const from the adapter.
  • see the visualize function above for host, port, extensions, node_data, and colab_height

Visualize models from multiple sources

Sometimes you want to load models from files as well as a PyTorch model at the same time into Model Explorer. To accomplish this, you will need to use the following lower level APIs. The basic steps are:

  1. Create a config object and add models to it.
  2. Pass it to visualize_from_config API.

Usage:

import model_explorer
import torch
import torchvision

# Prepare a PyTorch model and its inputs.
model = torchvision.models.mobilenet_v2().eval()
inputs = (torch.rand([1, 3, 224, 224]),)

# Create Model Explorer config.
config = model_explorer.config()

# Add model file path and PyTorch model to the config.
config.add_model_from_path('/path/to/model').add_model_from_pytorch('mobilenet', model, inputs)

# Visualize with config.
model_explorer.visualize_from_config(config=config)

API reference:

  • config() -> ModelExplorerConfig

    Creates a new Model Explorer config object.

    • ModelExplorerConfig.add_model_from_path(self, path) -> ModelExplorerConfig

      Adds a model path to the config.

      Args:

      • path:str: the model file path to add.
    • ModelExplorerConfig.add_model_from_pytorch(
        self,
        name,
        exported_program,
        settings={'const_element_count_limit': 16}) -> ModelExplorerConfig

      Adds a PyTorch model with inputs to the config. After calling this method, Model Explorer will invoke the internal adapter to convert the given PyTorch model into Model Explorer graphs. This process might take some time depending on the complexity of the model.

      Args:

      • name: str: The name of the model for display purpose.
      • See visualize_pytorch above for the exported_program and the settings parameter.

    • ModelExplorerConfig.set_reuse_server(
        self,
        server_host: str = 'localhost',
        server_port: Union[int, None] = None) -> ModelExplorerConfig

      Makes it to reuse the existing server instead of starting a new one.

      Args:

      • server_host: str: The host of the server to reuse.
      • server_port: int|None: The port of the server to reuse. If unspecified, it will try to find a running server from port 8080 to 8099.
  • visualize_from_config(config=None, host='localhost', port=8080, colab_height=850)

    Starts the visualization from the given config.

    Args:

    • config: ModelExplorerConfig|None: the object that stores the models to be visualized.
    • see the visualize function above for host, port, and colab_height

Custom node data

model_explorer provides APIs to create custom node data and visualize it in a model graph. For more info about how custom node data works, see the user guide.

Create custom node data

We provide a set of data classes to help you build custom node data. See the comments in node_data_builder.py as the official documentation.

From a high level, the custom node data has the following structure:

  • ModelNodeData: The top-level container storing all the data for a model. It consists one or more GraphNodeData objects indexed by graph ids.
  • GraphNodeData: Holds the data for a specific graph within the model. It includes:
    • results: Stores the custom node values, indexed by either node ids or output tensor names.
    • thresholds or gradient: color configurations that associate each node value with a corresponding node background color or label color, enabling visual representation of the data.

Usage:

from model_explorer import node_data_builder as ndb

# Populate values for the main graph in a model.
main_graph_results: dict[str, ndb.NodeDataResult] = {}
main_graph_results['node_id1'] = ndb.NodeDataResult(value=100)
main_graph_results['node_id2'] = ndb.NodeDataResult(value=200)
main_graph_results['any/output/tensor/name/'] = ndb.NodeDataResult(value=300)

# Create a gradient color mapping.
#
# The minimum value in `main_graph_results` maps to the color with stop=0.
# The maximum value in `main_graph_results` maps to the color with stop=1.
# Other values maps to a interpolated color in-between.
gradient: list[ndb.GradientItem] = [
    ndb.GradientItem(stop=0, bgColor='yellow'),
    ndb.GradientItem(stop=1, bgColor='red'),
]

# Construct the data for the main graph.
main_graph_data = ndb.GraphNodeData(
    results=main_graph_results, gradient=gradient)

# Construct the data for the model.
model_data = ndb.ModelNodeData(graphsData={'main': main_graph_data})

# You can save the data to a json file.
model_data.save_to_file('path/to/file.json')

Visualize custom node data

You can visualize the custom node data from data classes (see previous section above), from a json string, or from a json file (see save_to_file above). The basic steps are:

  1. Create a config object and add various custom node data sources to it.
  2. Pass it to visualize_from_config API.

Usage:

import model_explorer
from model_explorer import node_data_builder as ndb

# Create a `ModelNodeData` as shown in previous section.
model_node_data = ...

# Create a config.
config = model_explorer.config()

# Add model and custom node data to it.
(config
 .add_model_from_path('/path/to/a/model')
 # Add node data from a json file.
 # A node data json file can be generated by calling `ModelNodeData.save_to_file`
 .add_node_data_from_path('/path/to/node_data.json')
 # Add node data from data class object
 .add_node_data('my data', model_node_data))
 # Add node data from a json string (the content of `ModelNodeData.save_to_file`)
 .add_node_data('my data 2', model_node_data_json_str))

# Visualize
model_explorer.visualize_from_config(config)