Skip to content

Make Neptune specific parameters optional when setting %%graph_notebook_config #181

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ChangeLog.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ Starting with v1.31.6, this file will contain a record of major features and upd
## Upcoming

- Disabled SigV4 signing for non-IAM AWS requests ([Link to PR](https://github.com/aws/graph-notebook/pull/179))
- Added new `--nopoll` option to `%load` to disable status polling ([Link to PR](https://github.com/aws/graph-notebook/pull/180))
- Made Neptune specific parameters optional for `%%graph_notebook_config` ([Link to PR](https://github.com/aws/graph-notebook/pull/181))

## Release 3.0.3 (August 11, 2021)

Expand Down
13 changes: 2 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,7 @@ In a new cell in the Jupyter notebook, change the configuration using `%%graph_n
{
"host": "localhost",
"port": 8182,
"auth_mode": "DEFAULT",
"load_from_s3_arn": "",
"ssl": false,
"aws_region": "us-east-1"
"ssl": false
}
```

Expand All @@ -150,10 +147,7 @@ Change the configuration using `%%graph_notebook_config` and modify the fields f
{
"host": "localhost",
"port": 9999,
"auth_mode": "DEFAULT",
"load_from_s3_arn": "",
"ssl": false,
"aws_region": "us-east-1"
"ssl": false
}
```

Expand All @@ -165,10 +159,7 @@ You can also make use of namespaces for Blazegraph by specifying the path `graph
{
"host": "localhost",
"port": 9999,
"auth_mode": "DEFAULT",
"load_from_s3_arn": "",
"ssl": false,
"aws_region": "us-west-2",
"sparql": {
"path": "blazegraph/namespace/foo/sparql"
}
Expand Down
45 changes: 27 additions & 18 deletions src/graph_notebook/configuration/generate_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,22 +51,34 @@ def __init__(self, host: str, port: int,
sparql_section: SparqlSection = None):
self.host = host
self.port = port
self.auth_mode = auth_mode
self.load_from_s3_arn = load_from_s3_arn
self.ssl = ssl
self.aws_region = aws_region
self.sparql = sparql_section if sparql_section is not None else SparqlSection()
if ".neptune.amazonaws.com" in self.host:
self.is_neptune_config = True
self.auth_mode = auth_mode
self.load_from_s3_arn = load_from_s3_arn
self.aws_region = aws_region
else:
self.is_neptune_config = False

def to_dict(self) -> dict:
return {
'host': self.host,
'port': self.port,
'auth_mode': self.auth_mode.value,
'load_from_s3_arn': self.load_from_s3_arn,
'ssl': self.ssl,
'aws_region': self.aws_region,
'sparql': self.sparql.to_dict()
}
if self.is_neptune_config:
return {
'host': self.host,
'port': self.port,
'auth_mode': self.auth_mode.value,
'load_from_s3_arn': self.load_from_s3_arn,
'ssl': self.ssl,
'aws_region': self.aws_region,
'sparql': self.sparql.to_dict()
}
else:
return {
'host': self.host,
'port': self.port,
'ssl': self.ssl,
'sparql': self.sparql.to_dict()
}

def write_to_file(self, file_path=DEFAULT_CONFIG_LOCATION):
data = self.to_dict()
Expand All @@ -76,7 +88,8 @@ def write_to_file(self, file_path=DEFAULT_CONFIG_LOCATION):
return


def generate_config(host, port, auth_mode, ssl, load_from_s3_arn, aws_region):
def generate_config(host, port, auth_mode: AuthModeEnum = AuthModeEnum.DEFAULT, ssl: bool = True, load_from_s3_arn='',
aws_region: str = 'us-east-1'):
use_ssl = False if ssl in [False, 'False', 'false', 'FALSE'] else True
c = Configuration(host, port, auth_mode, load_from_s3_arn, use_ssl, aws_region)
return c
Expand All @@ -93,10 +106,6 @@ def generate_default_config():
parser.add_argument("--port", help="the port to use when creating a connection", default="8182")
parser.add_argument("--auth_mode", default=AuthModeEnum.DEFAULT.value,
help="type of authentication the cluster being connected to is using. Can be DEFAULT or IAM")

# TODO: this can now be removed.
parser.add_argument("--iam_credentials_provider", default='ROLE',
help="The mode used to obtain credentials for IAM Authentication. Can be ROLE or ENV")
parser.add_argument("--ssl",
help="whether to make connections to the created endpoint with ssl or not [True|False]",
default=True)
Expand All @@ -107,7 +116,7 @@ def generate_default_config():
args = parser.parse_args()

auth_mode_arg = args.auth_mode if args.auth_mode != '' else AuthModeEnum.DEFAULT.value
config = generate_config(args.host, int(args.port), AuthModeEnum(auth_mode_arg), args.ssl ,
config = generate_config(args.host, int(args.port), AuthModeEnum(auth_mode_arg), args.ssl,
args.load_from_s3_arn, args.aws_region)
config.write_to_file(args.config_destination)

Expand Down
9 changes: 6 additions & 3 deletions src/graph_notebook/configuration/get_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@

def get_config_from_dict(data: dict) -> Configuration:
sparql_section = SparqlSection(**data['sparql']) if 'sparql' in data else SparqlSection('')
config = Configuration(host=data['host'], port=data['port'], auth_mode=AuthModeEnum(data['auth_mode']),
ssl=data['ssl'],
load_from_s3_arn=data['load_from_s3_arn'], aws_region=data['aws_region'], sparql_section=sparql_section)
if ".neptune.amazonaws.com" in data['host']:
config = Configuration(host=data['host'], port=data['port'], auth_mode=AuthModeEnum(data['auth_mode']),
ssl=data['ssl'], load_from_s3_arn=data['load_from_s3_arn'],
aws_region=data['aws_region'], sparql_section=sparql_section)
else:
config = Configuration(host=data['host'], port=data['port'], ssl=data['ssl'], sparql_section=sparql_section)
return config


Expand Down
27 changes: 17 additions & 10 deletions src/graph_notebook/magics/graph_magic.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
from requests import HTTPError

import graph_notebook
from graph_notebook.configuration.generate_config import generate_default_config, DEFAULT_CONFIG_LOCATION, AuthModeEnum, \
Configuration
from graph_notebook.configuration.generate_config import generate_default_config, DEFAULT_CONFIG_LOCATION, \
AuthModeEnum, Configuration
from graph_notebook.decorators.decorators import display_exceptions, magic_variables
from graph_notebook.magics.ml import neptune_ml_magic_handler, generate_neptune_ml_parser
from graph_notebook.neptune.client import ClientBuilder, Client, VALID_FORMATS, PARALLELISM_OPTIONS, PARALLELISM_HIGH, \
Expand Down Expand Up @@ -151,14 +151,21 @@ def _generate_client_from_config(self, config: Configuration):
if self.client:
self.client.close()

builder = ClientBuilder() \
.with_host(config.host) \
.with_port(config.port) \
.with_region(config.aws_region) \
.with_tls(config.ssl) \
.with_sparql_path(config.sparql.path)
if config.auth_mode == AuthModeEnum.IAM:
builder = builder.with_iam(get_session())
if ".neptune.amazonaws.com" in config.host:
builder = ClientBuilder() \
.with_host(config.host) \
.with_port(config.port) \
.with_region(config.aws_region) \
.with_tls(config.ssl) \
.with_sparql_path(config.sparql.path)
if config.auth_mode == AuthModeEnum.IAM:
builder = builder.with_iam(get_session())
else:
builder = ClientBuilder() \
.with_host(config.host) \
.with_port(config.port) \
.with_tls(config.ssl) \
.with_sparql_path(config.sparql.path)

self.client = builder.build()

Expand Down
56 changes: 42 additions & 14 deletions test/unit/configuration/test_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,48 +13,76 @@
class TestGenerateConfiguration(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
cls.host = 'blah'
cls.generic_host = 'blah'
cls.neptune_host = 'instance.cluster.us-west-2.neptune.amazonaws.com'
cls.port = 8182
cls.test_file_path = f'{os.path.abspath(os.path.curdir)}/test_configuration_file.json'

def tearDown(self) -> None:
if os.path.exists(self.test_file_path):
os.remove(self.test_file_path)

def test_configuration_default_auth_defaults(self):
config = Configuration(self.host, self.port)
self.assertEqual(self.host, config.host)
def test_configuration_default_auth_defaults_neptune(self):
config = Configuration(self.neptune_host, self.port)
self.assertEqual(self.neptune_host, config.host)
self.assertEqual(self.port, config.port)
self.assertEqual(DEFAULT_AUTH_MODE, config.auth_mode)
self.assertEqual(True, config.ssl)
self.assertEqual('', config.load_from_s3_arn)

def test_configuration_override_defaults(self):
def test_configuration_default_auth_defaults_generic(self):
config = Configuration(self.generic_host, self.port)
self.assertEqual(self.generic_host, config.host)
self.assertEqual(self.port, config.port)
self.assertEqual(True, config.ssl)

def test_configuration_override_defaults_neptune(self):
auth_mode = AuthModeEnum.IAM
ssl = False
loader_arn = 'foo'
config = Configuration(self.host, self.port, auth_mode, loader_arn, ssl)
config = Configuration(self.neptune_host, self.port, auth_mode=auth_mode, load_from_s3_arn=loader_arn, ssl=ssl)
self.assertEqual(auth_mode, config.auth_mode)
self.assertEqual(ssl, config.ssl)
self.assertEqual(loader_arn, config.load_from_s3_arn)

def test_generate_configuration_with_defaults(self):
config = Configuration(self.host, self.port)
c = generate_config(config.host, config.port, config.auth_mode, config.ssl,
config.load_from_s3_arn, config.aws_region)
def test_configuration_override_defaults_generic(self):
ssl = False
config = Configuration(self.generic_host, self.port, ssl=ssl)
self.assertEqual(ssl, config.ssl)

def test_generate_configuration_with_defaults_neptune(self):
config = Configuration(self.neptune_host, self.port)
c = generate_config(config.host, config.port, auth_mode=config.auth_mode, ssl=config.ssl,
load_from_s3_arn=config.load_from_s3_arn, aws_region=config.aws_region)
c.write_to_file(self.test_file_path)
config_from_file = get_config(self.test_file_path)
self.assertEqual(config.to_dict(), config_from_file.to_dict())

def test_generate_configuration_override_defaults(self):
def test_generate_configuration_with_defaults_generic(self):
config = Configuration(self.generic_host, self.port)
c = generate_config(config.host, config.port, ssl=config.ssl)
c.write_to_file(self.test_file_path)
config_from_file = get_config(self.test_file_path)
self.assertEqual(config.to_dict(), config_from_file.to_dict())

def test_generate_configuration_override_defaults_neptune(self):
auth_mode = AuthModeEnum.IAM
ssl = False
loader_arn = 'foo'
aws_region = 'us-west-2'
config = Configuration(self.host, self.port, auth_mode, loader_arn, ssl, aws_region)
config = Configuration(self.neptune_host, self.port, auth_mode=auth_mode, load_from_s3_arn=loader_arn, ssl=ssl,
aws_region=aws_region)

c = generate_config(config.host, config.port, config.auth_mode, config.ssl,
config.load_from_s3_arn, config.aws_region)
c = generate_config(config.host, config.port, auth_mode=config.auth_mode, ssl=config.ssl,
load_from_s3_arn=config.load_from_s3_arn, aws_region=config.aws_region)
c.write_to_file(self.test_file_path)
config_from_file = get_config(self.test_file_path)
self.assertEqual(config.to_dict(), config_from_file.to_dict())

def test_generate_configuration_override_defaults_generic(self):
ssl = False
config = Configuration(self.generic_host, self.port, ssl=ssl)
c = generate_config(config.host, config.port, ssl=config.ssl)
c.write_to_file(self.test_file_path)
config_from_file = get_config(self.test_file_path)
self.assertEqual(config.to_dict(), config_from_file.to_dict())
54 changes: 43 additions & 11 deletions test/unit/configuration/test_configuration_from_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,35 +13,67 @@
class TestGenerateConfigurationMain(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
cls.host = 'blah'
cls.generic_host = 'blah'
cls.neptune_host = 'instance.cluster.us-west-2.neptune.amazonaws.com'
cls.port = 8182
cls.test_file_path = f'{os.path.abspath(os.path.curdir)}/test_generate_from_main.json'
cls.python_cmd = os.environ.get('PYTHON_CMD', 'python3') # environment variable to let ToD hosts specify where the python command is that is being used for testing.
cls.python_cmd = os.environ.get('PYTHON_CMD', 'python3') # environment variable to let ToD hosts specify
# where the python command is that is being used for testing.

def tearDown(self) -> None:
if os.path.exists(self.test_file_path):
os.remove(self.test_file_path)

def test_generate_configuration_main_defaults(self):
expected_config = Configuration(self.host, self.port, AuthModeEnum.DEFAULT, '', True)
def test_generate_configuration_main_defaults_neptune(self):
expected_config = Configuration(self.neptune_host, self.port, auth_mode=AuthModeEnum.DEFAULT,
load_from_s3_arn='', ssl=True)
self.generate_config_from_main_and_test(expected_config, host_type='neptune')

def test_generate_configuration_main_defaults_generic(self):
expected_config = Configuration(self.generic_host, self.port, ssl=True)
self.generate_config_from_main_and_test(expected_config)

def test_generate_configuration_main_override_defaults(self):
expected_config = Configuration(self.host, self.port, AuthModeEnum.IAM, 'loader_arn', False)
def test_generate_configuration_main_override_defaults_neptune(self):
expected_config = Configuration(self.neptune_host, self.port, auth_mode=AuthModeEnum.IAM,
load_from_s3_arn='loader_arn', ssl=False)
self.generate_config_from_main_and_test(expected_config, host_type='neptune')

def test_generate_configuration_main_override_defaults_generic(self):
expected_config = Configuration(self.generic_host, self.port, ssl=False)
self.generate_config_from_main_and_test(expected_config)

def test_generate_configuration_main_empty_args(self):
expected_config = Configuration(self.host, self.port)
result = os.system(f'{self.python_cmd} -m graph_notebook.configuration.generate_config --host "{expected_config.host}" --port "{expected_config.port}" --auth_mode "" --ssl "" --iam_credentials_provider "" --load_from_s3_arn "" --config_destination="{self.test_file_path}" ')
def test_generate_configuration_main_empty_args_neptune(self):
expected_config = Configuration(self.neptune_host, self.port)
result = os.system(f'{self.python_cmd} -m graph_notebook.configuration.generate_config '
f'--host "{expected_config.host}" --port "{expected_config.port}" --auth_mode "" --ssl "" '
f'--load_from_s3_arn "" --config_destination="{self.test_file_path}" ')
self.assertEqual(0, result)
config = get_config(self.test_file_path)
self.assertEqual(expected_config.to_dict(), config.to_dict())

def test_generate_configuration_main_empty_args_generic(self):
expected_config = Configuration(self.generic_host, self.port)
result = os.system(f'{self.python_cmd} -m graph_notebook.configuration.generate_config '
f'--host "{expected_config.host}" --port "{expected_config.port}" --ssl "" '
f'--config_destination="{self.test_file_path}" ')
self.assertEqual(0, result)
config = get_config(self.test_file_path)
self.assertEqual(expected_config.to_dict(), config.to_dict())

def generate_config_from_main_and_test(self, source_config: Configuration):
def generate_config_from_main_and_test(self, source_config: Configuration, host_type=None):
# This will run the main method that our install script runs on a Sagemaker notebook.
# The return code should be 0, but more importantly, we need to assert that the
# Configuration object we get from the resulting file is what we expect.
result = os.system(f'{self.python_cmd} -m graph_notebook.configuration.generate_config --host "{source_config.host}" --port "{source_config.port}" --auth_mode "{source_config.auth_mode.value}" --ssl "{source_config.ssl}" --load_from_s3_arn "{source_config.load_from_s3_arn}" --config_destination="{self.test_file_path}" ')
if host_type == 'neptune':
result = os.system(f'{self.python_cmd} -m graph_notebook.configuration.generate_config '
f'--host "{source_config.host}" --port "{source_config.port}" '
f'--auth_mode "{source_config.auth_mode.value}" --ssl "{source_config.ssl}" '
f'--load_from_s3_arn "{source_config.load_from_s3_arn}" '
f'--config_destination="{self.test_file_path}" ')
else:
result = os.system(f'{self.python_cmd} -m graph_notebook.configuration.generate_config '
f'--host "{source_config.host}" --port "{source_config.port}" '
f'--ssl "{source_config.ssl}" --config_destination="{self.test_file_path}" ')
self.assertEqual(result, 0)
config = get_config(self.test_file_path)
self.assertEqual(source_config.to_dict(), config.to_dict())