Skip to content

Commit

Permalink
feat: Add XAI SDK integration to TensorFlow models with LIT integrati…
Browse files Browse the repository at this point in the history
…on (#917)

Add automatic addition of feature attribution for TensorFlow 2 models in the LIT integration on Vertex Notebooks. Detects for Vertex Notebooks by looking for the same environment variable to check for Vertex Notebooks as the LIT library does.

Fixes b/210943910 🦕

go/local-explanations-lit-xai-notebook
  • Loading branch information
taiseiak authored Jan 19, 2022
1 parent 235fbf9 commit ea2b5cf
Show file tree
Hide file tree
Showing 3 changed files with 227 additions and 52 deletions.
167 changes: 127 additions & 40 deletions google/cloud/aiplatform/explain/lit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Tuple, Union
import logging
import os

from typing import Dict, List, Optional, Tuple, Union

try:
from lit_nlp.api import dataset as lit_dataset
from lit_nlp.api import dtypes as lit_dtypes
from lit_nlp.api import model as lit_model
from lit_nlp.api import types as lit_types
from lit_nlp import notebook
Expand Down Expand Up @@ -82,6 +86,7 @@ def __init__(
model: str,
input_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
output_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
attribution_method: str = "sampled_shapley",
):
"""Construct a VertexLitModel.
Args:
Expand All @@ -94,39 +99,33 @@ def __init__(
output_types:
Required. An OrderedDict of string names matching the labels of the model
as the key, and the associated LitType of the label.
attribution_method:
Optional. A string to choose what attribution configuration to
set up the explainer with. Valid options are 'sampled_shapley'
or 'integrated_gradients'.
"""
self._loaded_model = tf.saved_model.load(model)
serving_default = self._loaded_model.signatures[
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
]
_, self._kwargs_signature = serving_default.structured_input_signature
self._output_signature = serving_default.structured_outputs

if len(self._kwargs_signature) != 1:
raise ValueError("Please use a model with only one input tensor.")

if len(self._output_signature) != 1:
raise ValueError("Please use a model with only one output tensor.")

self._load_model(model)
self._input_types = input_types
self._output_types = output_types
self._input_tensor_name = next(iter(self._kwargs_signature))
self._attribution_explainer = None
if os.environ.get("LIT_PROXY_URL"):
self._set_up_attribution_explainer(model, attribution_method)

@property
def attribution_explainer(self,) -> Optional["AttributionExplainer"]: # noqa: F821
"""Gets the attribution explainer property if set."""
return self._attribution_explainer

def predict_minibatch(
self, inputs: List[lit_types.JsonDict]
) -> List[lit_types.JsonDict]:
"""Returns predictions for a single batch of examples.
Args:
inputs:
sequence of inputs, following model.input_spec()
Returns:
list of outputs, following model.output_spec()
"""
instances = []
for input in inputs:
instance = [input[feature] for feature in self._input_types]
instances.append(instance)
prediction_input_dict = {
next(iter(self._kwargs_signature)): tf.convert_to_tensor(instances)
self._input_tensor_name: tf.convert_to_tensor(instances)
}
prediction_dict = self._loaded_model.signatures[
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
Expand All @@ -140,6 +139,15 @@ def predict_minibatch(
for label, value in zip(self._output_types.keys(), prediction)
}
)
# Get feature attributions
if self.attribution_explainer:
attributions = self.attribution_explainer.explain(
[{self._input_tensor_name: i} for i in instances]
)
for i, attribution in enumerate(attributions):
outputs[i]["feature_attribution"] = lit_dtypes.FeatureSalience(
attribution.feature_importance()
)
return outputs

def input_spec(self) -> lit_types.Spec:
Expand All @@ -148,7 +156,70 @@ def input_spec(self) -> lit_types.Spec:

def output_spec(self) -> lit_types.Spec:
"""Return a spec describing model outputs."""
return self._output_types
output_spec_dict = dict(self._output_types)
if self.attribution_explainer:
output_spec_dict["feature_attribution"] = lit_types.FeatureSalience(
signed=True
)
return output_spec_dict

def _load_model(self, model: str):
"""Loads a TensorFlow saved model and populates the input and output signature attributes of the class.
Args:
model: Required. A string reference to a TensorFlow saved model directory.
Raises:
ValueError if the model has more than one input tensor or more than one output tensor.
"""
self._loaded_model = tf.saved_model.load(model)
serving_default = self._loaded_model.signatures[
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
]
_, self._kwargs_signature = serving_default.structured_input_signature
self._output_signature = serving_default.structured_outputs

if len(self._kwargs_signature) != 1:
raise ValueError("Please use a model with only one input tensor.")

if len(self._output_signature) != 1:
raise ValueError("Please use a model with only one output tensor.")

def _set_up_attribution_explainer(
self, model: str, attribution_method: str = "integrated_gradients"
):
"""Populates the attribution explainer attribute of the class.
Args:
model: Required. A string reference to a TensorFlow saved model directory.
attribution_method:
Optional. A string to choose what attribution configuration to
set up the explainer with. Valid options are 'sampled_shapley'
or 'integrated_gradients'.
"""
try:
import explainable_ai_sdk
from explainable_ai_sdk.metadata.tf.v2 import SavedModelMetadataBuilder
except ImportError:
logging.info(
"Skipping explanations because the Explainable AI SDK is not installed."
'Please install the SDK using "pip install explainable-ai-sdk"'
)
return

builder = SavedModelMetadataBuilder(model)
builder.get_metadata()
builder.set_numeric_metadata(
self._input_tensor_name,
index_feature_mapping=list(self._input_types.keys()),
)
builder.save_metadata(model)
if attribution_method == "integrated_gradients":
explainer_config = explainable_ai_sdk.IntegratedGradientsConfig()
else:
explainer_config = explainable_ai_sdk.SampledShapleyConfig()

self._attribution_explainer = explainable_ai_sdk.load_model_from_local_path(
model, explainer_config
)
self._load_model(model)


def create_lit_dataset(
Expand All @@ -172,22 +243,27 @@ def create_lit_model(
model: str,
input_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
output_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
attribution_method: str = "sampled_shapley",
) -> lit_model.Model:
"""Creates a LIT Model object.
Args:
model:
Required. A string reference to a local TensorFlow saved model directory.
The model must have at most one input and one output tensor.
Required. A string reference to a local TensorFlow saved model directory.
The model must have at most one input and one output tensor.
input_types:
Required. An OrderedDict of string names matching the features of the model
as the key, and the associated LitType of the feature.
Required. An OrderedDict of string names matching the features of the model
as the key, and the associated LitType of the feature.
output_types:
Required. An OrderedDict of string names matching the labels of the model
as the key, and the associated LitType of the label.
Required. An OrderedDict of string names matching the labels of the model
as the key, and the associated LitType of the label.
attribution_method:
Optional. A string to choose what attribution configuration to
set up the explainer with. Valid options are 'sampled_shapley'
or 'integrated_gradients'.
Returns:
A LIT Model object that has the same functionality as the model provided.
"""
return _VertexLitModel(model, input_types, output_types)
return _VertexLitModel(model, input_types, output_types, attribution_method)


def open_lit(
Expand All @@ -198,11 +274,11 @@ def open_lit(
"""Open LIT from the provided models and datasets.
Args:
models:
Required. A list of LIT models to open LIT with.
Required. A list of LIT models to open LIT with.
input_types:
Required. A lit of LIT datasets to open LIT with.
Required. A lit of LIT datasets to open LIT with.
open_in_new_tab:
Optional. A boolean to choose if LIT open in a new tab or not.
Optional. A boolean to choose if LIT open in a new tab or not.
Raises:
ImportError if LIT is not installed.
"""
Expand All @@ -216,24 +292,31 @@ def set_up_and_open_lit(
model: Union[str, lit_model.Model],
input_types: Union[List[str], Dict[str, lit_types.LitType]],
output_types: Union[str, List[str], Dict[str, lit_types.LitType]],
attribution_method: str = "sampled_shapley",
open_in_new_tab: bool = True,
) -> Tuple[lit_dataset.Dataset, lit_model.Model]:
"""Creates a LIT dataset and model and opens LIT.
Args:
dataset:
dataset:
Required. A Pandas DataFrame that includes feature column names and data.
column_types:
column_types:
Required. An OrderedDict of string names matching the columns of the dataset
as the key, and the associated LitType of the column.
model:
model:
Required. A string reference to a TensorFlow saved model directory.
The model must have at most one input and one output tensor.
input_types:
input_types:
Required. An OrderedDict of string names matching the features of the model
as the key, and the associated LitType of the feature.
output_types:
output_types:
Required. An OrderedDict of string names matching the labels of the model
as the key, and the associated LitType of the label.
attribution_method:
Optional. A string to choose what attribution configuration to
set up the explainer with. Valid options are 'sampled_shapley'
or 'integrated_gradients'.
open_in_new_tab:
Optional. A boolean to choose if LIT open in a new tab or not.
Returns:
A Tuple of the LIT dataset and model created.
Raises:
Expand All @@ -244,8 +327,12 @@ def set_up_and_open_lit(
dataset = create_lit_dataset(dataset, column_types)

if not isinstance(model, lit_model.Model):
model = create_lit_model(model, input_types, output_types)
model = create_lit_model(
model, input_types, output_types, attribution_method=attribution_method
)

open_lit({"model": model}, {"dataset": dataset}, open_in_new_tab=open_in_new_tab)
open_lit(
{"model": model}, {"dataset": dataset}, open_in_new_tab=open_in_new_tab,
)

return dataset, model
7 changes: 6 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,12 @@
tensorboard_extra_require = ["tensorflow >=2.3.0, <=2.7.0"]
metadata_extra_require = ["pandas >= 1.0.0"]
xai_extra_require = ["tensorflow >=2.3.0, <=2.5.0"]
lit_extra_require = ["tensorflow >= 2.3.0", "pandas >= 1.0.0", "lit-nlp >= 0.4.0"]
lit_extra_require = [
"tensorflow >= 2.3.0",
"pandas >= 1.0.0",
"lit-nlp >= 0.4.0",
"explainable-ai-sdk >= 1.0.0",
]
profiler_extra_require = [
"tensorboard-plugin-profile >= 2.4.0",
"werkzeug >= 2.0.0",
Expand Down
Loading

0 comments on commit ea2b5cf

Please sign in to comment.