Skip to content

Commit

Permalink
Fix What-If Tool PD plots in py3 (#2669)
Browse files Browse the repository at this point in the history
  • Loading branch information
jameswex authored Sep 20, 2019
1 parent 09298ed commit ded6760
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 11 deletions.
2 changes: 2 additions & 0 deletions tensorboard/plugins/interactive_inference/utils/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ py_library(
":platform_utils",
"//tensorboard:expect_absl_logging_installed",
"//tensorboard:expect_tensorflow_installed",
"@org_pythonhosted_six",
"@org_tensorflow_serving_api",
],
)
Expand All @@ -55,6 +56,7 @@ py_test(
"//tensorboard:expect_numpy_installed",
"//tensorboard:expect_tensorflow_installed",
"@org_pythonhosted_mock",
"@org_pythonhosted_six",
"@org_tensorflow_serving_api",
],
)
27 changes: 18 additions & 9 deletions tensorboard/plugins/interactive_inference/utils/inference_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
import numpy as np
import tensorflow as tf
from google.protobuf import json_format
from six import binary_type, string_types, integer_types
from six import iteritems
from six import string_types, integer_types
from six.moves import zip # pylint: disable=redefined-builtin

from tensorboard.plugins.interactive_inference.utils import common_utils
Expand Down Expand Up @@ -125,7 +125,8 @@ class OriginalFeatureList(object):
def __init__(self, feature_name, original_value, feature_type):
"""Inits OriginalFeatureList."""
self.feature_name = feature_name
self.original_value = original_value
self.original_value = [
ensure_not_binary(value) for value in original_value]
self.feature_type = feature_type

# Derived attributes.
Expand Down Expand Up @@ -164,7 +165,8 @@ def __init__(self, original_feature, index, mutant_value):
'index should be None or int, but had unexpected type: {}'.format(
type(index)))
self.index = index
self.mutant_value = mutant_value
self.mutant_value = (mutant_value.encode()
if isinstance(mutant_value, string_types) else mutant_value)


class ServingBundle(object):
Expand Down Expand Up @@ -226,6 +228,11 @@ def __init__(self, inference_address, model_name, model_type, model_version,
self.custom_predict_fn = custom_predict_fn


def ensure_not_binary(value):
"""Return non-binary version of value."""
return value.decode() if isinstance(value, binary_type) else value


def proto_value_for_feature(example, feature_name):
"""Get the value of a feature from Example regardless of feature type."""
feature = get_example_features(example)[feature_name]
Expand Down Expand Up @@ -563,9 +570,10 @@ def make_json_formatted_for_single_chart(mutant_features,
key += ' (index %d)' % index_to_mutate
if not key in series:
series[key] = {}
if not mutant_feature.mutant_value in series[key]:
series[key][mutant_feature.mutant_value] = []
series[key][mutant_feature.mutant_value].append(
mutant_val = ensure_not_binary(mutant_feature.mutant_value)
if not mutant_val in series[key]:
series[key][mutant_val] = []
series[key][mutant_val].append(
classification_class.score)

# Post-process points to have separate list for each class
Expand All @@ -589,9 +597,10 @@ def make_json_formatted_for_single_chart(mutant_features,
# results. So, modding by len(mutant_features) allows us to correctly
# lookup the mutant value for each inference.
mutant_feature = mutant_features[idx % len(mutant_features)]
if not mutant_feature.mutant_value in points:
points[mutant_feature.mutant_value] = []
points[mutant_feature.mutant_value].append(regression.value)
mutant_val = ensure_not_binary(mutant_feature.mutant_value)
if not mutant_val in points:
points[mutant_val] = []
points[mutant_val].append(regression.value)
key = 'value'
if (index_to_mutate != 0):
key += ' (index %d)' % index_to_mutate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def test_get_categorical_features_to_sampling(self):
examples[0: 3], top_k=1)
self.assertDictEqual({
'non_numeric': {
'samples': [b'cat']
'samples': ['cat']
}
}, data)

Expand All @@ -186,7 +186,7 @@ def test_get_categorical_features_to_sampling(self):
examples[0: 20], top_k=2)
self.assertDictEqual({
'non_numeric': {
'samples': [b'pony', b'cow']
'samples': ['pony', 'cow']
}
}, data)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
'google-api-python-client>=1.7.8',
'ipywidgets>=7.0.0',
'jupyter>=1.0,<2',
'six>=1.12.0',
] + _TF_REQ

def get_readme():
Expand Down

0 comments on commit ded6760

Please sign in to comment.