From 572ef9456b9498a6b581195bc359c85c4bd46bb9 Mon Sep 17 00:00:00 2001 From: Michael Chin Date: Tue, 25 Jun 2024 08:27:20 -0700 Subject: [PATCH] Fix `%%graph_notebook_config` error when excluding optional Gremlin section (#633) * Fix %%graph_notebook_config exception when excluding optional Gremlin section * Linter fix * update changelog --- ChangeLog.md | 1 + .../configuration/get_config.py | 7 +- test/unit/configuration/test_configuration.py | 225 +++++++++++++++++- 3 files changed, 228 insertions(+), 5 deletions(-) diff --git a/ChangeLog.md b/ChangeLog.md index 88d31037..1a66fbc6 100644 --- a/ChangeLog.md +++ b/ChangeLog.md @@ -5,6 +5,7 @@ Starting with v1.31.6, this file will contain a record of major features and upd ## Upcoming - Fixed broken `--help` option for `%%gremlin` ([Link to PR](https://github.com/aws/graph-notebook/pull/630)) - Fixed openCypher query bug regression in the [`01-About-the-Neptune-Notebook`](https://github.com/aws/graph-notebook/blob/main/src/graph_notebook/notebooks/01-Getting-Started/01-About-the-Neptune-Notebook.ipynb) sample ([Link to PR](https://github.com/aws/graph-notebook/pull/631)) +- Fixed `%%graph_notebook_config` error when excluding optional Gremlin section ([Link to PR](https://github.com/aws/graph-notebook/pull/633)) ## Release 4.4.2 (June 18, 2024) - Set Gremlin `connection_protocol` defaults based on Neptune service when generating configuration via arguments ([Link to PR](https://github.com/aws/graph-notebook/pull/626)) diff --git a/src/graph_notebook/configuration/get_config.py b/src/graph_notebook/configuration/get_config.py index c5efd1d2..e45de9e3 100644 --- a/src/graph_notebook/configuration/get_config.py +++ b/src/graph_notebook/configuration/get_config.py @@ -59,9 +59,10 @@ def get_config_from_dict(data: dict, neptune_hosts: list = NEPTUNE_CONFIG_HOST_I for p in neptune_params: if p in data: excluded_params.append(p) - for gp in neptune_gremlin_params: - if gp in data['gremlin']: - excluded_params.append(gp) + if 'gremlin' in data: + for gp in neptune_gremlin_params: + if gp in data['gremlin']: + excluded_params.append(gp) if excluded_params: print(f"The provided configuration contains the following parameters that are incompatible with the " f"specified host: {str(excluded_params)}. These parameters have not been saved.\n") diff --git a/test/unit/configuration/test_configuration.py b/test/unit/configuration/test_configuration.py index e76647b8..b45265eb 100644 --- a/test/unit/configuration/test_configuration.py +++ b/test/unit/configuration/test_configuration.py @@ -6,11 +6,11 @@ import os import unittest -from graph_notebook.configuration.get_config import get_config +from graph_notebook.configuration.get_config import get_config, get_config_from_dict from graph_notebook.configuration.generate_config import Configuration, DEFAULT_AUTH_MODE, AuthModeEnum, \ generate_config, generate_default_config, GremlinSection from graph_notebook.neptune.client import NEPTUNE_DB_SERVICE_NAME, NEPTUNE_ANALYTICS_SERVICE_NAME, \ - DEFAULT_GREMLIN_PROTOCOL, DEFAULT_HTTP_PROTOCOL + DEFAULT_GREMLIN_PROTOCOL, DEFAULT_HTTP_PROTOCOL, NEPTUNE_CONFIG_HOST_IDENTIFIERS class TestGenerateConfiguration(unittest.TestCase): @@ -121,6 +121,227 @@ def test_configuration_override_defaults_generic(self): config = Configuration(self.generic_host, self.port, ssl=ssl) self.assertEqual(ssl, config.ssl) + def test_get_configuration_empty_input(self): + input_config = {} + with self.assertRaises(KeyError): + get_config_from_dict(input_config, neptune_hosts=NEPTUNE_CONFIG_HOST_IDENTIFIERS) + + def test_get_configuration_no_host(self): + input_config = { + "port": 8182, + "ssl": True + } + with self.assertRaises(KeyError): + get_config_from_dict(input_config, neptune_hosts=NEPTUNE_CONFIG_HOST_IDENTIFIERS) + + def test_get_configuration_generic_no_port(self): + input_config = { + "host": "localhost", + "ssl": True + } + with self.assertRaises(KeyError): + get_config_from_dict(input_config, neptune_hosts=NEPTUNE_CONFIG_HOST_IDENTIFIERS) + + def test_get_configuration_generic_no_ssl(self): + input_config = { + "host": "localhost", + "port": 8182 + } + with self.assertRaises(KeyError): + get_config_from_dict(input_config, neptune_hosts=NEPTUNE_CONFIG_HOST_IDENTIFIERS) + + def test_get_configuration_generic_required_input(self): + input_config = { + "host": "localhost", + "port": 8182, + "ssl": True + } + expected_config = { + 'host': 'localhost', + 'port': 8182, + 'proxy_host': '', + 'proxy_port': 8182, + 'ssl': True, + 'ssl_verify': True, + 'sparql': { + 'path': '' + }, + 'gremlin': { + 'traversal_source': 'g', + 'username': '', + 'password': '', + 'message_serializer': 'graphsonv3' + }, + 'neo4j': { + 'username': 'neo4j', + 'password': 'password', + 'auth': True, + 'database': None + } + } + config = get_config_from_dict(input_config, neptune_hosts=NEPTUNE_CONFIG_HOST_IDENTIFIERS) + self.assertEqual(config.to_dict(), expected_config) + + def test_get_configuration_generic_all_input(self): + input_and_expected_config = { + 'host': 'a_host', + 'port': 9999, + 'proxy_host': 'a_proxy_host', + 'proxy_port': 9999, + 'ssl': False, + 'ssl_verify': False, + 'sparql': { + 'path': 'a_path' + }, + 'gremlin': { + 'traversal_source': 'a', + 'username': 'user', + 'password': 'pass', + 'message_serializer': 'graphbinaryv1' + }, + 'neo4j': { + 'username': 'neo_user', + 'password': 'neo_pass', + 'auth': False, + 'database': 'neo_db' + } + } + config = get_config_from_dict(input_and_expected_config, neptune_hosts=NEPTUNE_CONFIG_HOST_IDENTIFIERS) + self.assertEqual(config.to_dict(), input_and_expected_config) + + def test_get_configuration_neptune_no_auth_mode(self): + input_config = { + "host": "db.cluster-xxxxxxxxx.us-west-2.neptune.amazonaws.com", + "port": 8182, + "ssl": True, + "load_from_s3_arn": "", + "aws_region": "us-west-2" + } + with self.assertRaises(KeyError): + get_config_from_dict(input_config, neptune_hosts=NEPTUNE_CONFIG_HOST_IDENTIFIERS) + + def test_get_configuration_neptune_no_load_arn(self): + input_config = { + "host": "db.cluster-xxxxxxxxx.us-west-2.neptune.amazonaws.com", + "port": 8182, + "ssl": True, + "aws_region": "us-west-2" + } + with self.assertRaises(KeyError): + get_config_from_dict(input_config, neptune_hosts=NEPTUNE_CONFIG_HOST_IDENTIFIERS) + + def test_get_configuration_neptune_no_region(self): + input_config = { + "host": "db.cluster-xxxxxxxxx.us-west-2.neptune.amazonaws.com", + "port": 8182, + "ssl": True, + "load_from_s3_arn": "" + } + with self.assertRaises(KeyError): + get_config_from_dict(input_config, neptune_hosts=NEPTUNE_CONFIG_HOST_IDENTIFIERS) + + def test_get_configuration_neptune_required_input(self): + input_config = { + "host": "db.cluster-xxxxxxxxx.us-west-2.neptune.amazonaws.com", + "port": 8182, + "auth_mode": "IAM", + "load_from_s3_arn": "", + "ssl": True, + "aws_region": "us-west-2" + } + expected_config = { + 'host': 'db.cluster-xxxxxxxxx.us-west-2.neptune.amazonaws.com', + 'neptune_service': 'neptune-db', + 'port': 8182, + 'proxy_host': '', + 'proxy_port': 8182, + 'auth_mode': 'IAM', + 'load_from_s3_arn': '', + 'ssl': True, + 'ssl_verify': True, + 'aws_region': 'us-west-2', + 'sparql': { + 'path': '' + }, + 'gremlin': { + 'traversal_source': 'g', + 'username': '', + 'password': '', + 'message_serializer': 'graphsonv3', + 'connection_protocol': 'websockets' + }, + 'neo4j': { + 'username': 'neo4j', + 'password': 'password', + 'auth': True, + 'database': None + } + } + + config = get_config_from_dict(input_config, neptune_hosts=NEPTUNE_CONFIG_HOST_IDENTIFIERS) + self.assertEqual(config.to_dict(), expected_config) + + def test_get_configuration_neptune_all_input(self): + input_config = { + 'host': 'db.cluster-xxxxxxxxx.us-west-2.neptune.amazonaws.com', + 'neptune_service': 'neptune-graph', + 'port': 9999, + 'proxy_host': 'a_proxy+port', + 'proxy_port': 9999, + 'auth_mode': 'DEFAULT', + 'load_from_s3_arn': 'a_role', + 'ssl': False, + 'ssl_verify': False, + 'aws_region': 'us-west-2', + 'sparql': { + 'path': 'a_path' + }, + 'gremlin': { + 'traversal_source': 'a', + 'username': 'a_user', + 'password': 'a_pass', + 'message_serializer': 'graphbinaryv1', + 'connection_protocol': 'http' + }, + 'neo4j': { + 'username': 'a_user', + 'password': 'a_pass', + 'auth': False, + 'database': 'a_db' + } + } + expected_config = { + 'host': 'db.cluster-xxxxxxxxx.us-west-2.neptune.amazonaws.com', + 'neptune_service': 'neptune-graph', + 'port': 9999, + 'proxy_host': 'a_proxy+port', + 'proxy_port': 9999, + 'auth_mode': 'DEFAULT', + 'load_from_s3_arn': 'a_role', + 'ssl': False, + 'ssl_verify': False, + 'aws_region': 'us-west-2', + 'sparql': { + 'path': 'a_path' + }, + 'gremlin': { + 'traversal_source': 'g', + 'username': '', + 'password': '', + 'message_serializer': 'graphbinaryv1', + 'connection_protocol': 'http' + }, + 'neo4j': { + 'username': 'neo4j', + 'password': 'password', + 'auth': True, + 'database': None + } + } + + config = get_config_from_dict(input_config, neptune_hosts=NEPTUNE_CONFIG_HOST_IDENTIFIERS) + self.assertEqual(config.to_dict(), expected_config) + def test_generate_configuration_with_defaults_neptune_reg(self): config = Configuration(self.neptune_host_reg, self.port) c = generate_config(config.host, config.port, auth_mode=config.auth_mode, ssl=config.ssl,