diff --git a/ChangeLog.md b/ChangeLog.md index 93c5a5ff..435f6a9b 100644 --- a/ChangeLog.md +++ b/ChangeLog.md @@ -5,6 +5,8 @@ Starting with v1.31.6, this file will contain a record of major features and upd ## Upcoming - Disabled SigV4 signing for non-IAM AWS requests ([Link to PR](https://github.com/aws/graph-notebook/pull/179)) +- Added new `--nopoll` option to `%load` to disable status polling ([Link to PR](https://github.com/aws/graph-notebook/pull/180)) +- Made Neptune specific parameters optional for `%%graph_notebook_config` ([Link to PR](https://github.com/aws/graph-notebook/pull/181)) ## Release 3.0.3 (August 11, 2021) diff --git a/README.md b/README.md index 68cf5e3c..ac52136c 100644 --- a/README.md +++ b/README.md @@ -132,10 +132,7 @@ In a new cell in the Jupyter notebook, change the configuration using `%%graph_n { "host": "localhost", "port": 8182, - "auth_mode": "DEFAULT", - "load_from_s3_arn": "", - "ssl": false, - "aws_region": "us-east-1" + "ssl": false } ``` @@ -150,10 +147,7 @@ Change the configuration using `%%graph_notebook_config` and modify the fields f { "host": "localhost", "port": 9999, - "auth_mode": "DEFAULT", - "load_from_s3_arn": "", - "ssl": false, - "aws_region": "us-east-1" + "ssl": false } ``` @@ -165,10 +159,7 @@ You can also make use of namespaces for Blazegraph by specifying the path `graph { "host": "localhost", "port": 9999, - "auth_mode": "DEFAULT", - "load_from_s3_arn": "", "ssl": false, - "aws_region": "us-west-2", "sparql": { "path": "blazegraph/namespace/foo/sparql" } diff --git a/src/graph_notebook/configuration/generate_config.py b/src/graph_notebook/configuration/generate_config.py index f57aaa11..d4277b34 100644 --- a/src/graph_notebook/configuration/generate_config.py +++ b/src/graph_notebook/configuration/generate_config.py @@ -51,22 +51,34 @@ def __init__(self, host: str, port: int, sparql_section: SparqlSection = None): self.host = host self.port = port - self.auth_mode = auth_mode - self.load_from_s3_arn = load_from_s3_arn self.ssl = ssl - self.aws_region = aws_region self.sparql = sparql_section if sparql_section is not None else SparqlSection() + if ".neptune.amazonaws.com" in self.host: + self.is_neptune_config = True + self.auth_mode = auth_mode + self.load_from_s3_arn = load_from_s3_arn + self.aws_region = aws_region + else: + self.is_neptune_config = False def to_dict(self) -> dict: - return { - 'host': self.host, - 'port': self.port, - 'auth_mode': self.auth_mode.value, - 'load_from_s3_arn': self.load_from_s3_arn, - 'ssl': self.ssl, - 'aws_region': self.aws_region, - 'sparql': self.sparql.to_dict() - } + if self.is_neptune_config: + return { + 'host': self.host, + 'port': self.port, + 'auth_mode': self.auth_mode.value, + 'load_from_s3_arn': self.load_from_s3_arn, + 'ssl': self.ssl, + 'aws_region': self.aws_region, + 'sparql': self.sparql.to_dict() + } + else: + return { + 'host': self.host, + 'port': self.port, + 'ssl': self.ssl, + 'sparql': self.sparql.to_dict() + } def write_to_file(self, file_path=DEFAULT_CONFIG_LOCATION): data = self.to_dict() @@ -76,7 +88,8 @@ def write_to_file(self, file_path=DEFAULT_CONFIG_LOCATION): return -def generate_config(host, port, auth_mode, ssl, load_from_s3_arn, aws_region): +def generate_config(host, port, auth_mode: AuthModeEnum = AuthModeEnum.DEFAULT, ssl: bool = True, load_from_s3_arn='', + aws_region: str = 'us-east-1'): use_ssl = False if ssl in [False, 'False', 'false', 'FALSE'] else True c = Configuration(host, port, auth_mode, load_from_s3_arn, use_ssl, aws_region) return c @@ -93,10 +106,6 @@ def generate_default_config(): parser.add_argument("--port", help="the port to use when creating a connection", default="8182") parser.add_argument("--auth_mode", default=AuthModeEnum.DEFAULT.value, help="type of authentication the cluster being connected to is using. Can be DEFAULT or IAM") - - # TODO: this can now be removed. - parser.add_argument("--iam_credentials_provider", default='ROLE', - help="The mode used to obtain credentials for IAM Authentication. Can be ROLE or ENV") parser.add_argument("--ssl", help="whether to make connections to the created endpoint with ssl or not [True|False]", default=True) @@ -107,7 +116,7 @@ def generate_default_config(): args = parser.parse_args() auth_mode_arg = args.auth_mode if args.auth_mode != '' else AuthModeEnum.DEFAULT.value - config = generate_config(args.host, int(args.port), AuthModeEnum(auth_mode_arg), args.ssl , + config = generate_config(args.host, int(args.port), AuthModeEnum(auth_mode_arg), args.ssl, args.load_from_s3_arn, args.aws_region) config.write_to_file(args.config_destination) diff --git a/src/graph_notebook/configuration/get_config.py b/src/graph_notebook/configuration/get_config.py index 72ab829a..d559deee 100644 --- a/src/graph_notebook/configuration/get_config.py +++ b/src/graph_notebook/configuration/get_config.py @@ -11,9 +11,12 @@ def get_config_from_dict(data: dict) -> Configuration: sparql_section = SparqlSection(**data['sparql']) if 'sparql' in data else SparqlSection('') - config = Configuration(host=data['host'], port=data['port'], auth_mode=AuthModeEnum(data['auth_mode']), - ssl=data['ssl'], - load_from_s3_arn=data['load_from_s3_arn'], aws_region=data['aws_region'], sparql_section=sparql_section) + if ".neptune.amazonaws.com" in data['host']: + config = Configuration(host=data['host'], port=data['port'], auth_mode=AuthModeEnum(data['auth_mode']), + ssl=data['ssl'], load_from_s3_arn=data['load_from_s3_arn'], + aws_region=data['aws_region'], sparql_section=sparql_section) + else: + config = Configuration(host=data['host'], port=data['port'], ssl=data['ssl'], sparql_section=sparql_section) return config diff --git a/src/graph_notebook/magics/graph_magic.py b/src/graph_notebook/magics/graph_magic.py index eea9debf..93b16b41 100644 --- a/src/graph_notebook/magics/graph_magic.py +++ b/src/graph_notebook/magics/graph_magic.py @@ -26,8 +26,8 @@ from requests import HTTPError import graph_notebook -from graph_notebook.configuration.generate_config import generate_default_config, DEFAULT_CONFIG_LOCATION, AuthModeEnum, \ - Configuration +from graph_notebook.configuration.generate_config import generate_default_config, DEFAULT_CONFIG_LOCATION, \ + AuthModeEnum, Configuration from graph_notebook.decorators.decorators import display_exceptions, magic_variables from graph_notebook.magics.ml import neptune_ml_magic_handler, generate_neptune_ml_parser from graph_notebook.neptune.client import ClientBuilder, Client, VALID_FORMATS, PARALLELISM_OPTIONS, PARALLELISM_HIGH, \ @@ -151,14 +151,21 @@ def _generate_client_from_config(self, config: Configuration): if self.client: self.client.close() - builder = ClientBuilder() \ - .with_host(config.host) \ - .with_port(config.port) \ - .with_region(config.aws_region) \ - .with_tls(config.ssl) \ - .with_sparql_path(config.sparql.path) - if config.auth_mode == AuthModeEnum.IAM: - builder = builder.with_iam(get_session()) + if ".neptune.amazonaws.com" in config.host: + builder = ClientBuilder() \ + .with_host(config.host) \ + .with_port(config.port) \ + .with_region(config.aws_region) \ + .with_tls(config.ssl) \ + .with_sparql_path(config.sparql.path) + if config.auth_mode == AuthModeEnum.IAM: + builder = builder.with_iam(get_session()) + else: + builder = ClientBuilder() \ + .with_host(config.host) \ + .with_port(config.port) \ + .with_tls(config.ssl) \ + .with_sparql_path(config.sparql.path) self.client = builder.build() diff --git a/test/unit/configuration/test_configuration.py b/test/unit/configuration/test_configuration.py index 0bbb4d50..d6d8724a 100644 --- a/test/unit/configuration/test_configuration.py +++ b/test/unit/configuration/test_configuration.py @@ -13,7 +13,8 @@ class TestGenerateConfiguration(unittest.TestCase): @classmethod def setUpClass(cls) -> None: - cls.host = 'blah' + cls.generic_host = 'blah' + cls.neptune_host = 'instance.cluster.us-west-2.neptune.amazonaws.com' cls.port = 8182 cls.test_file_path = f'{os.path.abspath(os.path.curdir)}/test_configuration_file.json' @@ -21,40 +22,67 @@ def tearDown(self) -> None: if os.path.exists(self.test_file_path): os.remove(self.test_file_path) - def test_configuration_default_auth_defaults(self): - config = Configuration(self.host, self.port) - self.assertEqual(self.host, config.host) + def test_configuration_default_auth_defaults_neptune(self): + config = Configuration(self.neptune_host, self.port) + self.assertEqual(self.neptune_host, config.host) self.assertEqual(self.port, config.port) self.assertEqual(DEFAULT_AUTH_MODE, config.auth_mode) self.assertEqual(True, config.ssl) self.assertEqual('', config.load_from_s3_arn) - def test_configuration_override_defaults(self): + def test_configuration_default_auth_defaults_generic(self): + config = Configuration(self.generic_host, self.port) + self.assertEqual(self.generic_host, config.host) + self.assertEqual(self.port, config.port) + self.assertEqual(True, config.ssl) + + def test_configuration_override_defaults_neptune(self): auth_mode = AuthModeEnum.IAM ssl = False loader_arn = 'foo' - config = Configuration(self.host, self.port, auth_mode, loader_arn, ssl) + config = Configuration(self.neptune_host, self.port, auth_mode=auth_mode, load_from_s3_arn=loader_arn, ssl=ssl) self.assertEqual(auth_mode, config.auth_mode) self.assertEqual(ssl, config.ssl) self.assertEqual(loader_arn, config.load_from_s3_arn) - def test_generate_configuration_with_defaults(self): - config = Configuration(self.host, self.port) - c = generate_config(config.host, config.port, config.auth_mode, config.ssl, - config.load_from_s3_arn, config.aws_region) + def test_configuration_override_defaults_generic(self): + ssl = False + config = Configuration(self.generic_host, self.port, ssl=ssl) + self.assertEqual(ssl, config.ssl) + + def test_generate_configuration_with_defaults_neptune(self): + config = Configuration(self.neptune_host, self.port) + c = generate_config(config.host, config.port, auth_mode=config.auth_mode, ssl=config.ssl, + load_from_s3_arn=config.load_from_s3_arn, aws_region=config.aws_region) c.write_to_file(self.test_file_path) config_from_file = get_config(self.test_file_path) self.assertEqual(config.to_dict(), config_from_file.to_dict()) - def test_generate_configuration_override_defaults(self): + def test_generate_configuration_with_defaults_generic(self): + config = Configuration(self.generic_host, self.port) + c = generate_config(config.host, config.port, ssl=config.ssl) + c.write_to_file(self.test_file_path) + config_from_file = get_config(self.test_file_path) + self.assertEqual(config.to_dict(), config_from_file.to_dict()) + + def test_generate_configuration_override_defaults_neptune(self): auth_mode = AuthModeEnum.IAM ssl = False loader_arn = 'foo' aws_region = 'us-west-2' - config = Configuration(self.host, self.port, auth_mode, loader_arn, ssl, aws_region) + config = Configuration(self.neptune_host, self.port, auth_mode=auth_mode, load_from_s3_arn=loader_arn, ssl=ssl, + aws_region=aws_region) - c = generate_config(config.host, config.port, config.auth_mode, config.ssl, - config.load_from_s3_arn, config.aws_region) + c = generate_config(config.host, config.port, auth_mode=config.auth_mode, ssl=config.ssl, + load_from_s3_arn=config.load_from_s3_arn, aws_region=config.aws_region) + c.write_to_file(self.test_file_path) + config_from_file = get_config(self.test_file_path) + self.assertEqual(config.to_dict(), config_from_file.to_dict()) + + def test_generate_configuration_override_defaults_generic(self): + ssl = False + config = Configuration(self.generic_host, self.port, ssl=ssl) + c = generate_config(config.host, config.port, ssl=config.ssl) c.write_to_file(self.test_file_path) config_from_file = get_config(self.test_file_path) self.assertEqual(config.to_dict(), config_from_file.to_dict()) diff --git a/test/unit/configuration/test_configuration_from_main.py b/test/unit/configuration/test_configuration_from_main.py index ebfd604c..41406dd0 100644 --- a/test/unit/configuration/test_configuration_from_main.py +++ b/test/unit/configuration/test_configuration_from_main.py @@ -13,35 +13,67 @@ class TestGenerateConfigurationMain(unittest.TestCase): @classmethod def setUpClass(cls) -> None: - cls.host = 'blah' + cls.generic_host = 'blah' + cls.neptune_host = 'instance.cluster.us-west-2.neptune.amazonaws.com' cls.port = 8182 cls.test_file_path = f'{os.path.abspath(os.path.curdir)}/test_generate_from_main.json' - cls.python_cmd = os.environ.get('PYTHON_CMD', 'python3') # environment variable to let ToD hosts specify where the python command is that is being used for testing. + cls.python_cmd = os.environ.get('PYTHON_CMD', 'python3') # environment variable to let ToD hosts specify + # where the python command is that is being used for testing. def tearDown(self) -> None: if os.path.exists(self.test_file_path): os.remove(self.test_file_path) - def test_generate_configuration_main_defaults(self): - expected_config = Configuration(self.host, self.port, AuthModeEnum.DEFAULT, '', True) + def test_generate_configuration_main_defaults_neptune(self): + expected_config = Configuration(self.neptune_host, self.port, auth_mode=AuthModeEnum.DEFAULT, + load_from_s3_arn='', ssl=True) + self.generate_config_from_main_and_test(expected_config, host_type='neptune') + + def test_generate_configuration_main_defaults_generic(self): + expected_config = Configuration(self.generic_host, self.port, ssl=True) self.generate_config_from_main_and_test(expected_config) - def test_generate_configuration_main_override_defaults(self): - expected_config = Configuration(self.host, self.port, AuthModeEnum.IAM, 'loader_arn', False) + def test_generate_configuration_main_override_defaults_neptune(self): + expected_config = Configuration(self.neptune_host, self.port, auth_mode=AuthModeEnum.IAM, + load_from_s3_arn='loader_arn', ssl=False) + self.generate_config_from_main_and_test(expected_config, host_type='neptune') + + def test_generate_configuration_main_override_defaults_generic(self): + expected_config = Configuration(self.generic_host, self.port, ssl=False) self.generate_config_from_main_and_test(expected_config) - def test_generate_configuration_main_empty_args(self): - expected_config = Configuration(self.host, self.port) - result = os.system(f'{self.python_cmd} -m graph_notebook.configuration.generate_config --host "{expected_config.host}" --port "{expected_config.port}" --auth_mode "" --ssl "" --iam_credentials_provider "" --load_from_s3_arn "" --config_destination="{self.test_file_path}" ') + def test_generate_configuration_main_empty_args_neptune(self): + expected_config = Configuration(self.neptune_host, self.port) + result = os.system(f'{self.python_cmd} -m graph_notebook.configuration.generate_config ' + f'--host "{expected_config.host}" --port "{expected_config.port}" --auth_mode "" --ssl "" ' + f'--load_from_s3_arn "" --config_destination="{self.test_file_path}" ') + self.assertEqual(0, result) + config = get_config(self.test_file_path) + self.assertEqual(expected_config.to_dict(), config.to_dict()) + + def test_generate_configuration_main_empty_args_generic(self): + expected_config = Configuration(self.generic_host, self.port) + result = os.system(f'{self.python_cmd} -m graph_notebook.configuration.generate_config ' + f'--host "{expected_config.host}" --port "{expected_config.port}" --ssl "" ' + f'--config_destination="{self.test_file_path}" ') self.assertEqual(0, result) config = get_config(self.test_file_path) self.assertEqual(expected_config.to_dict(), config.to_dict()) - def generate_config_from_main_and_test(self, source_config: Configuration): + def generate_config_from_main_and_test(self, source_config: Configuration, host_type=None): # This will run the main method that our install script runs on a Sagemaker notebook. # The return code should be 0, but more importantly, we need to assert that the # Configuration object we get from the resulting file is what we expect. - result = os.system(f'{self.python_cmd} -m graph_notebook.configuration.generate_config --host "{source_config.host}" --port "{source_config.port}" --auth_mode "{source_config.auth_mode.value}" --ssl "{source_config.ssl}" --load_from_s3_arn "{source_config.load_from_s3_arn}" --config_destination="{self.test_file_path}" ') + if host_type == 'neptune': + result = os.system(f'{self.python_cmd} -m graph_notebook.configuration.generate_config ' + f'--host "{source_config.host}" --port "{source_config.port}" ' + f'--auth_mode "{source_config.auth_mode.value}" --ssl "{source_config.ssl}" ' + f'--load_from_s3_arn "{source_config.load_from_s3_arn}" ' + f'--config_destination="{self.test_file_path}" ') + else: + result = os.system(f'{self.python_cmd} -m graph_notebook.configuration.generate_config ' + f'--host "{source_config.host}" --port "{source_config.port}" ' + f'--ssl "{source_config.ssl}" --config_destination="{self.test_file_path}" ') self.assertEqual(result, 0) config = get_config(self.test_file_path) self.assertEqual(source_config.to_dict(), config.to_dict())