Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change how Variables are Provided to Visualizations #1754

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 4 additions & 15 deletions backend/src/apiserver/visualization/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
from enum import Enum
import json
from pathlib import Path
Expand Down Expand Up @@ -90,27 +89,17 @@ def __init__(
)

@staticmethod
def create_cell_from_args(args: argparse.Namespace) -> NotebookNode:
"""Creates a NotebookNode object with provided arguments as variables.
def create_cell_from_args(variables: dict) -> NotebookNode:
"""Creates NotebookNode object containing dict of provided variables.

Args:
args: Arguments that need to be injected into a NotebookNode.
variables: Arguments that need to be injected into a NotebookNode.

Returns:
NotebookNode with provided arguments as variables.

"""
variables = ""
args = json.loads(args)
for key in sorted(args.keys()):
# Check type of variable to maintain type when converting from JSON
# to notebook cell
if args[key] is None or isinstance(args[key], bool):
variables += "{} = {}\n".format(key, args[key])
else:
variables += '{} = "{}"\n'.format(key, args[key])

return new_code_cell(variables)
return new_code_cell("variables = {}".format(repr(variables)))

@staticmethod
def create_cell_from_file(filepath: Text) -> NotebookNode:
Expand Down
22 changes: 11 additions & 11 deletions backend/src/apiserver/visualization/roc_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,39 +34,39 @@
# trueclass
# true_score_column

if is_generated is False:
if "is_generated" is not in variables or variables["is_generated"] is False:
# Create data from specified csv file(s).
# The schema file provides column names for the csv file that will be used
# to generate the roc curve.
schema_file = Path(source) / 'schema.json'
schema_file = Path(source) / "schema.json"
schema = json.loads(file_io.read_file_to_string(schema_file))
names = [x['name'] for x in schema]
names = [x["name"] for x in schema]

dfs = []
files = file_io.get_matching_files(source)
for f in files:
dfs.append(pd.read_csv(f, names=names))

df = pd.concat(dfs)
if target_lambda:
df['target'] = df.apply(eval(target_lambda), axis=1)
if variables["target_lambda"]:
df["target"] = df.apply(eval(variables["target_lambda"]), axis=1)
else:
df['target'] = df['target'].apply(lambda x: 1 if x == trueclass else 0)
fpr, tpr, thresholds = roc_curve(df['target'], df[true_score_column])
source = pd.DataFrame({'fpr': fpr, 'tpr': tpr, 'thresholds': thresholds})
df["target"] = df["target"].apply(lambda x: 1 if x == variables["trueclass"] else 0)
fpr, tpr, thresholds = roc_curve(df["target"], df[variables["true_score_column"]])
df = pd.DataFrame({"fpr": fpr, "tpr": tpr, "thresholds": thresholds})
else:
# Load data from generated csv file.
source = pd.read_csv(
df = pd.read_csv(
source,
header=None,
names=['fpr', 'tpr', 'thresholds']
names=["fpr", "tpr", "thresholds"]
)

# Create visualization.
output_notebook()

p = figure(tools="pan,wheel_zoom,box_zoom,reset,hover,previewsave")
p.line('fpr', 'tpr', line_width=2, source=source)
p.line("fpr", "tpr", line_width=2, source=df)

hover = p.select(dict(type=HoverTool))
hover.tooltips = [("Threshold", "@thresholds")]
Expand Down
59 changes: 32 additions & 27 deletions backend/src/apiserver/visualization/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
# limitations under the License.

import argparse
from argparse import Namespace
import importlib
import json
from pathlib import Path
from typing import Text
import shlex
Expand Down Expand Up @@ -73,48 +75,47 @@ def initialize(self):
help="JSON string of arguments to be provided to visualizations."
)

def get_arguments_from_body(self) -> argparse.Namespace:
"""Converts arguments from post request to argparser.Namespace format.
def get_arguments_from_body(self) -> Namespace:
"""Converts arguments from post request to Namespace format.

This is done because arguments, by default are provided in the
x-www-form-urlencoded format. This format is difficult to parse compared
to argparser.Namespace, which is a dict.
to Namespace, which is a dict.

Returns:
Arguments provided from post request as arparser.Namespace object.
Arguments provided from post request as a Namespace object.
"""
split_arguments = shlex.split(self.get_body_argument("arguments"))
return self.requestParser.parse_args(split_arguments)

def is_valid_request_arguments(self, arguments: argparse.Namespace) -> bool:
"""Validates arguments from post request and sends error if invalid.
def is_valid_request_arguments(self, arguments: Namespace):
"""Validates arguments from post request and raises error if invalid.

Args:
arguments: x-www-form-urlencoded formatted arguments

Returns:
Boolean value representing if provided arguments are valid.
arguments: Namespace formatted arguments
"""
if arguments.type is None:
self.send_error(400, reason="No type specified.")
return False
raise Exception("No type specified.")
if arguments.source is None:
self.send_error(400, reason="No source specified.")
return False
raise Exception("No source specified.")
try:
json.loads(arguments.arguments)
except json.JSONDecodeError:
raise Exception("Invalid JSON provided as arguments.")

return True

def generate_notebook_from_arguments(
self,
arguments: argparse.Namespace,
arguments: dict,
source: Text,
visualization_type: Text
) -> NotebookNode:
"""Generates a NotebookNode from provided arguments.

Args:
arguments: x-www-form-urlencoded formatted arguments.
input_path: Path or path pattern to be used as data reference for
arguments: JSON object containing provided arguments.
source: Path or path pattern to be used as data reference for
visualization.
visualization_type: Name of visualization to be generated.

Expand All @@ -139,16 +140,20 @@ def post(self):
# Parse arguments from request.
request_arguments = self.get_arguments_from_body()
# Validate arguments from request.
if self.is_valid_request_arguments(request_arguments):
# Create notebook with arguments from request.
nb = self.generate_notebook_from_arguments(
request_arguments.arguments,
request_arguments.source,
request_arguments.type
)
# Generate visualization (output for notebook).
html = _exporter.generate_html_from_notebook(nb)
self.write(html)
try:
self.is_valid_request_arguments(request_arguments)
except Exception as e:
return self.send_error(400, reason=str(e))

# Create notebook with arguments from request.
nb = self.generate_notebook_from_arguments(
json.loads(request_arguments.arguments),
request_arguments.source,
request_arguments.type
)
# Generate visualization (output for notebook).
html = _exporter.generate_html_from_notebook(nb)
self.write(html)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,73 @@

snapshots = Snapshot()

snapshots['TestExporterMethods::test_create_cell_from_args_with_multiple_args 1'] = '''source = "gs://ml-pipeline/data.csv"
target_lambda = "lambda x: (x[\'target\'] > x[\'fare\'] * 0.2)"
snapshots['TestExporterMethods::test_create_cell_from_args_with_multiple_args 1'] = '''
<div class="output_wrapper">
<div class="output">


<div class="output_area">

<div class="prompt"></div>


<div class="output_subarea output_stream output_stdout output_text">
<pre>[&#39;gs://ml-pipeline/data.csv&#39;, &#34;lambda x: (x[&#39;target&#39;] &gt; x[&#39;fare&#39;] * 0.2)&#34;]
</pre>
</div>
</div>

</div>
</div>



'''

snapshots['TestExporterMethods::test_create_cell_from_args_with_no_args 1'] = '''
<div class="output_wrapper">
<div class="output">


<div class="output_area">

<div class="prompt"></div>


<div class="output_subarea output_stream output_stdout output_text">
<pre>{}
</pre>
</div>
</div>

</div>
</div>



'''

snapshots['TestExporterMethods::test_create_cell_from_args_with_no_args 1'] = ''
snapshots['TestExporterMethods::test_create_cell_from_args_with_one_arg 1'] = '''
<div class="output_wrapper">
<div class="output">


<div class="output_area">

<div class="prompt"></div>


<div class="output_subarea output_stream output_stdout output_text">
<pre>[&#39;gs://ml-pipeline/data.csv&#39;]
</pre>
</div>
</div>

</div>
</div>



snapshots['TestExporterMethods::test_create_cell_from_args_with_one_arg 1'] = '''source = "gs://ml-pipeline/data.csv"
'''

snapshots['TestExporterMethods::test_create_cell_from_file 1'] = '''# Copyright 2019 Google LLC
Expand Down
37 changes: 23 additions & 14 deletions backend/src/apiserver/visualization/test_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,33 @@ def setUp(self):

def test_create_cell_from_args_with_no_args(self):
self.maxDiff = None
args = "{}"
cell = self.exporter.create_cell_from_args(args)
self.assertMatchSnapshot(cell.source)
nb = new_notebook()
args = {}
nb.cells.append(self.exporter.create_cell_from_args(args))
nb.cells.append(new_code_cell("print(variables)"))
html = self.exporter.generate_html_from_notebook(nb)
self.assertMatchSnapshot(html)

def test_create_cell_from_args_with_one_arg(self):
self.maxDiff = None
args = '{"source": "gs://ml-pipeline/data.csv"}'
cell = self.exporter.create_cell_from_args(args)
self.assertMatchSnapshot(cell.source)
nb = new_notebook()
args = {"source": "gs://ml-pipeline/data.csv"}
nb.cells.append(self.exporter.create_cell_from_args(args))
nb.cells.append(new_code_cell("print([variables[key] for key in sorted(variables.keys())])"))
html = self.exporter.generate_html_from_notebook(nb)
self.assertMatchSnapshot(html)

def test_create_cell_from_args_with_multiple_args(self):
self.maxDiff = None
args = (
'{"source": "gs://ml-pipeline/data.csv", '
"\"target_lambda\": \"lambda x: (x['target'] > x['fare'] * 0.2)\"}"
)
cell = self.exporter.create_cell_from_args(args)
self.assertMatchSnapshot(cell.source)
nb = new_notebook()
args = {
"source": "gs://ml-pipeline/data.csv",
"target_lambda": "lambda x: (x['target'] > x['fare'] * 0.2)"
}
nb.cells.append(self.exporter.create_cell_from_args(args))
nb.cells.append(new_code_cell("print([variables[key] for key in sorted(variables.keys())])"))
html = self.exporter.generate_html_from_notebook(nb)
self.assertMatchSnapshot(html)

def test_create_cell_from_file(self):
self.maxDiff = None
Expand All @@ -55,9 +64,9 @@ def test_create_cell_from_file(self):
def test_generate_html_from_notebook(self):
self.maxDiff = None
nb = new_notebook()
args = '{"x": 2}'
args = {"x": 2}
nb.cells.append(self.exporter.create_cell_from_args(args))
nb.cells.append(new_code_cell("print(x)"))
nb.cells.append(new_code_cell("print(variables['x'])"))
html = self.exporter.generate_html_from_notebook(nb)
self.assertMatchSnapshot(html)

Expand Down
11 changes: 11 additions & 0 deletions backend/src/apiserver/visualization/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,17 @@ def test_create_visualization_fails_when_missing_input_path(self):
response.body
)

def test_create_visualization_fails_when_invalid_json_is_provided(self):
response = self.fetch(
"/",
method="POST",
body='arguments=--type test --source gs://ml-pipeline/data.csv --arguments "{"')
self.assertEqual(400, response.code)
self.assertEqual(
wrap_error_in_html("400: Invalid JSON provided as arguments."),
response.body
)

def test_create_visualization(self):
response = self.fetch(
"/",
Expand Down