diff --git a/backend/src/apiserver/visualization/exporter.py b/backend/src/apiserver/visualization/exporter.py index abb81921c8a..238a5c2ba65 100644 --- a/backend/src/apiserver/visualization/exporter.py +++ b/backend/src/apiserver/visualization/exporter.py @@ -17,7 +17,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import argparse from enum import Enum import json from pathlib import Path @@ -90,27 +89,17 @@ def __init__( ) @staticmethod - def create_cell_from_args(args: argparse.Namespace) -> NotebookNode: - """Creates a NotebookNode object with provided arguments as variables. + def create_cell_from_args(variables: dict) -> NotebookNode: + """Creates NotebookNode object containing dict of provided variables. Args: - args: Arguments that need to be injected into a NotebookNode. + variables: Arguments that need to be injected into a NotebookNode. Returns: NotebookNode with provided arguments as variables. """ - variables = "" - args = json.loads(args) - for key in sorted(args.keys()): - # Check type of variable to maintain type when converting from JSON - # to notebook cell - if args[key] is None or isinstance(args[key], bool): - variables += "{} = {}\n".format(key, args[key]) - else: - variables += '{} = "{}"\n'.format(key, args[key]) - - return new_code_cell(variables) + return new_code_cell("variables = {}".format(repr(variables))) @staticmethod def create_cell_from_file(filepath: Text) -> NotebookNode: diff --git a/backend/src/apiserver/visualization/roc_curve.py b/backend/src/apiserver/visualization/roc_curve.py index 581914d9f84..be38bc0a27c 100644 --- a/backend/src/apiserver/visualization/roc_curve.py +++ b/backend/src/apiserver/visualization/roc_curve.py @@ -34,13 +34,13 @@ # trueclass # true_score_column -if is_generated is False: +if "is_generated" is not in variables or variables["is_generated"] is False: # Create data from specified csv file(s). # The schema file provides column names for the csv file that will be used # to generate the roc curve. - schema_file = Path(source) / 'schema.json' + schema_file = Path(source) / "schema.json" schema = json.loads(file_io.read_file_to_string(schema_file)) - names = [x['name'] for x in schema] + names = [x["name"] for x in schema] dfs = [] files = file_io.get_matching_files(source) @@ -48,25 +48,25 @@ dfs.append(pd.read_csv(f, names=names)) df = pd.concat(dfs) - if target_lambda: - df['target'] = df.apply(eval(target_lambda), axis=1) + if variables["target_lambda"]: + df["target"] = df.apply(eval(variables["target_lambda"]), axis=1) else: - df['target'] = df['target'].apply(lambda x: 1 if x == trueclass else 0) - fpr, tpr, thresholds = roc_curve(df['target'], df[true_score_column]) - source = pd.DataFrame({'fpr': fpr, 'tpr': tpr, 'thresholds': thresholds}) + df["target"] = df["target"].apply(lambda x: 1 if x == variables["trueclass"] else 0) + fpr, tpr, thresholds = roc_curve(df["target"], df[variables["true_score_column"]]) + df = pd.DataFrame({"fpr": fpr, "tpr": tpr, "thresholds": thresholds}) else: # Load data from generated csv file. - source = pd.read_csv( + df = pd.read_csv( source, header=None, - names=['fpr', 'tpr', 'thresholds'] + names=["fpr", "tpr", "thresholds"] ) # Create visualization. output_notebook() p = figure(tools="pan,wheel_zoom,box_zoom,reset,hover,previewsave") -p.line('fpr', 'tpr', line_width=2, source=source) +p.line("fpr", "tpr", line_width=2, source=df) hover = p.select(dict(type=HoverTool)) hover.tooltips = [("Threshold", "@thresholds")] diff --git a/backend/src/apiserver/visualization/server.py b/backend/src/apiserver/visualization/server.py index f894326d001..fe02c6820d6 100644 --- a/backend/src/apiserver/visualization/server.py +++ b/backend/src/apiserver/visualization/server.py @@ -13,7 +13,9 @@ # limitations under the License. import argparse +from argparse import Namespace import importlib +import json from pathlib import Path from typing import Text import shlex @@ -73,48 +75,47 @@ def initialize(self): help="JSON string of arguments to be provided to visualizations." ) - def get_arguments_from_body(self) -> argparse.Namespace: - """Converts arguments from post request to argparser.Namespace format. + def get_arguments_from_body(self) -> Namespace: + """Converts arguments from post request to Namespace format. This is done because arguments, by default are provided in the x-www-form-urlencoded format. This format is difficult to parse compared - to argparser.Namespace, which is a dict. + to Namespace, which is a dict. Returns: - Arguments provided from post request as arparser.Namespace object. + Arguments provided from post request as a Namespace object. """ split_arguments = shlex.split(self.get_body_argument("arguments")) return self.requestParser.parse_args(split_arguments) - def is_valid_request_arguments(self, arguments: argparse.Namespace) -> bool: - """Validates arguments from post request and sends error if invalid. + def is_valid_request_arguments(self, arguments: Namespace): + """Validates arguments from post request and raises error if invalid. Args: - arguments: x-www-form-urlencoded formatted arguments - - Returns: - Boolean value representing if provided arguments are valid. + arguments: Namespace formatted arguments """ if arguments.type is None: - self.send_error(400, reason="No type specified.") - return False + raise Exception("No type specified.") if arguments.source is None: - self.send_error(400, reason="No source specified.") - return False + raise Exception("No source specified.") + try: + json.loads(arguments.arguments) + except json.JSONDecodeError: + raise Exception("Invalid JSON provided as arguments.") return True def generate_notebook_from_arguments( self, - arguments: argparse.Namespace, + arguments: dict, source: Text, visualization_type: Text ) -> NotebookNode: """Generates a NotebookNode from provided arguments. Args: - arguments: x-www-form-urlencoded formatted arguments. - input_path: Path or path pattern to be used as data reference for + arguments: JSON object containing provided arguments. + source: Path or path pattern to be used as data reference for visualization. visualization_type: Name of visualization to be generated. @@ -139,16 +140,20 @@ def post(self): # Parse arguments from request. request_arguments = self.get_arguments_from_body() # Validate arguments from request. - if self.is_valid_request_arguments(request_arguments): - # Create notebook with arguments from request. - nb = self.generate_notebook_from_arguments( - request_arguments.arguments, - request_arguments.source, - request_arguments.type - ) - # Generate visualization (output for notebook). - html = _exporter.generate_html_from_notebook(nb) - self.write(html) + try: + self.is_valid_request_arguments(request_arguments) + except Exception as e: + return self.send_error(400, reason=str(e)) + + # Create notebook with arguments from request. + nb = self.generate_notebook_from_arguments( + json.loads(request_arguments.arguments), + request_arguments.source, + request_arguments.type + ) + # Generate visualization (output for notebook). + html = _exporter.generate_html_from_notebook(nb) + self.write(html) if __name__ == "__main__": diff --git a/backend/src/apiserver/visualization/snapshots/snap_test_exporter.py b/backend/src/apiserver/visualization/snapshots/snap_test_exporter.py index 8a43b6a1b11..b4955ec2df3 100644 --- a/backend/src/apiserver/visualization/snapshots/snap_test_exporter.py +++ b/backend/src/apiserver/visualization/snapshots/snap_test_exporter.py @@ -7,13 +7,73 @@ snapshots = Snapshot() -snapshots['TestExporterMethods::test_create_cell_from_args_with_multiple_args 1'] = '''source = "gs://ml-pipeline/data.csv" -target_lambda = "lambda x: (x[\'target\'] > x[\'fare\'] * 0.2)" +snapshots['TestExporterMethods::test_create_cell_from_args_with_multiple_args 1'] = ''' +
+
+ + +
+ +
+ + +
+
['gs://ml-pipeline/data.csv', "lambda x: (x['target'] > x['fare'] * 0.2)"]
+
+
+
+ +
+
+ + + +''' + +snapshots['TestExporterMethods::test_create_cell_from_args_with_no_args 1'] = ''' +
+
+ + +
+ +
+ + +
+
{}
+
+
+
+ +
+
+ + + ''' -snapshots['TestExporterMethods::test_create_cell_from_args_with_no_args 1'] = '' +snapshots['TestExporterMethods::test_create_cell_from_args_with_one_arg 1'] = ''' +
+
+ + +
+ +
+ + +
+
['gs://ml-pipeline/data.csv']
+
+
+
+ +
+
+ + -snapshots['TestExporterMethods::test_create_cell_from_args_with_one_arg 1'] = '''source = "gs://ml-pipeline/data.csv" ''' snapshots['TestExporterMethods::test_create_cell_from_file 1'] = '''# Copyright 2019 Google LLC diff --git a/backend/src/apiserver/visualization/test_exporter.py b/backend/src/apiserver/visualization/test_exporter.py index 52e4dced1fe..c490b26d44a 100644 --- a/backend/src/apiserver/visualization/test_exporter.py +++ b/backend/src/apiserver/visualization/test_exporter.py @@ -28,24 +28,33 @@ def setUp(self): def test_create_cell_from_args_with_no_args(self): self.maxDiff = None - args = "{}" - cell = self.exporter.create_cell_from_args(args) - self.assertMatchSnapshot(cell.source) + nb = new_notebook() + args = {} + nb.cells.append(self.exporter.create_cell_from_args(args)) + nb.cells.append(new_code_cell("print(variables)")) + html = self.exporter.generate_html_from_notebook(nb) + self.assertMatchSnapshot(html) def test_create_cell_from_args_with_one_arg(self): self.maxDiff = None - args = '{"source": "gs://ml-pipeline/data.csv"}' - cell = self.exporter.create_cell_from_args(args) - self.assertMatchSnapshot(cell.source) + nb = new_notebook() + args = {"source": "gs://ml-pipeline/data.csv"} + nb.cells.append(self.exporter.create_cell_from_args(args)) + nb.cells.append(new_code_cell("print([variables[key] for key in sorted(variables.keys())])")) + html = self.exporter.generate_html_from_notebook(nb) + self.assertMatchSnapshot(html) def test_create_cell_from_args_with_multiple_args(self): self.maxDiff = None - args = ( - '{"source": "gs://ml-pipeline/data.csv", ' - "\"target_lambda\": \"lambda x: (x['target'] > x['fare'] * 0.2)\"}" - ) - cell = self.exporter.create_cell_from_args(args) - self.assertMatchSnapshot(cell.source) + nb = new_notebook() + args = { + "source": "gs://ml-pipeline/data.csv", + "target_lambda": "lambda x: (x['target'] > x['fare'] * 0.2)" + } + nb.cells.append(self.exporter.create_cell_from_args(args)) + nb.cells.append(new_code_cell("print([variables[key] for key in sorted(variables.keys())])")) + html = self.exporter.generate_html_from_notebook(nb) + self.assertMatchSnapshot(html) def test_create_cell_from_file(self): self.maxDiff = None @@ -55,9 +64,9 @@ def test_create_cell_from_file(self): def test_generate_html_from_notebook(self): self.maxDiff = None nb = new_notebook() - args = '{"x": 2}' + args = {"x": 2} nb.cells.append(self.exporter.create_cell_from_args(args)) - nb.cells.append(new_code_cell("print(x)")) + nb.cells.append(new_code_cell("print(variables['x'])")) html = self.exporter.generate_html_from_notebook(nb) self.assertMatchSnapshot(html) diff --git a/backend/src/apiserver/visualization/test_server.py b/backend/src/apiserver/visualization/test_server.py index a70e2a8e126..222090ddc88 100644 --- a/backend/src/apiserver/visualization/test_server.py +++ b/backend/src/apiserver/visualization/test_server.py @@ -70,6 +70,17 @@ def test_create_visualization_fails_when_missing_input_path(self): response.body ) + def test_create_visualization_fails_when_invalid_json_is_provided(self): + response = self.fetch( + "/", + method="POST", + body='arguments=--type test --source gs://ml-pipeline/data.csv --arguments "{"') + self.assertEqual(400, response.code) + self.assertEqual( + wrap_error_in_html("400: Invalid JSON provided as arguments."), + response.body + ) + def test_create_visualization(self): response = self.fetch( "/",