diff --git a/ChangeLog.md b/ChangeLog.md index 150f2957..0b2a47df 100644 --- a/ChangeLog.md +++ b/ChangeLog.md @@ -4,6 +4,7 @@ Starting with v1.31.6, this file will contain a record of major features and upd ## Upcoming - Added `%statistics` magic for Neptune DFE engine ([Link to PR](https://github.com/aws/graph-notebook/pull/377)) +- Added custom Gremlin authentication and serializer support ([Link to PR](https://github.com/aws/graph-notebook/pull/356)) ## Release 3.6.2 (October 18, 2022) - New Sample Applications - Security Graphs notebooks ([Link to PR](https://github.com/aws/graph-notebook/pull/373)) diff --git a/README.md b/README.md index ddb07bac..4af1b3a4 100644 --- a/README.md +++ b/README.md @@ -165,7 +165,7 @@ Alternatively, the magic extensions can be manually reloaded for a single notebo ### Gremlin Server -In a new cell in the Jupyter notebook, change the configuration using `%%graph_notebook_config` and modify the fields for `host`, `port`, and `ssl`. Optionally, modify `traversal_source` if your graph traversal source name differs from the default value. For a local Gremlin server (HTTP or WebSockets), you can use the following command: +In a new cell in the Jupyter notebook, change the configuration using `%%graph_notebook_config` and modify the fields for `host`, `port`, and `ssl`. Optionally, modify `traversal_source` if your graph traversal source name differs from the default value, `username` and `password` if required by the graph store, or `message_serializer` for a specific data transfer format. For a local Gremlin server (HTTP or WebSockets), you can use the following command: ``` %%graph_notebook_config @@ -174,7 +174,10 @@ In a new cell in the Jupyter notebook, change the configuration using `%%graph_n "port": 8182, "ssl": false, "gremlin": { - "traversal_source": "g" + "traversal_source": "g", + "username": "", + "password": "", + "message_serializer": "graphsonv3" } } ``` diff --git a/additional-databases/gremlin-server/README.md b/additional-databases/gremlin-server/README.md index 2ead52ed..607790e8 100644 --- a/additional-databases/gremlin-server/README.md +++ b/additional-databases/gremlin-server/README.md @@ -36,7 +36,13 @@ Several of the steps below are optional but please read each step carefully and { "host": "localhost", "port": 8182, - "ssl": false + "ssl": false, + "gremlin": { + "traversal_source": "g", + "username": "", + "password": "", + "message_serializer": "graphsonv3" + } } ``` If the Gremlin Server you wish to connect to is remote, replacing `localhost` with the IP address or DNS of the remote server should work. This assumes you have access to that server from your local machine. diff --git a/src/graph_notebook/configuration/generate_config.py b/src/graph_notebook/configuration/generate_config.py index a660d51c..1e3eb940 100644 --- a/src/graph_notebook/configuration/generate_config.py +++ b/src/graph_notebook/configuration/generate_config.py @@ -8,8 +8,9 @@ import os from enum import Enum -from graph_notebook.neptune.client import SPARQL_ACTION, DEFAULT_PORT, DEFAULT_REGION, \ - NEPTUNE_CONFIG_HOST_IDENTIFIERS, is_allowed_neptune_host +from graph_notebook.neptune.client import SPARQL_ACTION, DEFAULT_PORT, DEFAULT_REGION, DEFAULT_GREMLIN_SERIALIZER, \ + DEFAULT_GREMLIN_TRAVERSAL_SOURCE, NEPTUNE_CONFIG_HOST_IDENTIFIERS, is_allowed_neptune_host, GRAPHSONV3_VARIANTS, \ + GRAPHSONV2_VARIANTS, GRAPHBINARYV1_VARIANTS DEFAULT_CONFIG_LOCATION = os.path.expanduser('~/graph_notebook_config.json') @@ -49,16 +50,38 @@ class GremlinSection(object): Used for gremlin-specific settings in a notebook's configuration """ - def __init__(self, traversal_source: str = ''): + def __init__(self, traversal_source: str = '', username: str = '', password: str = '', + message_serializer: str = ''): """ :param traversal_source: used to specify the traversal source for a Gremlin traversal, in the case that we are connected to an endpoint that can access multiple graphs. + :param username: used to specify a username for authenticating to Gremlin Server, if the endpoint supports it. + :param password: used to specify a password for authenticating to Gremlin Server, if the endpoint supports it. + :param message_serializer: used to specify a serializer for encoding the data to and from Gremlin Server. """ if traversal_source == '': - traversal_source = 'g' + traversal_source = DEFAULT_GREMLIN_TRAVERSAL_SOURCE + + serializer_lower = message_serializer.lower() + if serializer_lower == '': + message_serializer = DEFAULT_GREMLIN_SERIALIZER + elif serializer_lower in GRAPHSONV3_VARIANTS: + message_serializer = 'graphsonv3' + elif serializer_lower in GRAPHSONV2_VARIANTS: + message_serializer = 'graphsonv2' + elif serializer_lower in GRAPHBINARYV1_VARIANTS: + message_serializer = 'graphbinaryv1' + else: + print(f'Invalid Gremlin serializer specified, defaulting to graphsonv3. ' + f'Valid serializers: [graphsonv3, graphsonv2, graphbinaryv1].') + message_serializer = DEFAULT_GREMLIN_SERIALIZER + self.traversal_source = traversal_source + self.username = username + self.password = password + self.message_serializer = message_serializer def to_dict(self): return self.__dict__ @@ -141,10 +164,11 @@ def write_to_file(self, file_path=DEFAULT_CONFIG_LOCATION): def generate_config(host, port, auth_mode: AuthModeEnum = AuthModeEnum.DEFAULT, ssl: bool = True, load_from_s3_arn='', aws_region: str = DEFAULT_REGION, proxy_host: str = '', proxy_port: int = DEFAULT_PORT, + sparql_section: SparqlSection = SparqlSection(), gremlin_section: GremlinSection = GremlinSection(), neptune_hosts: list = NEPTUNE_CONFIG_HOST_IDENTIFIERS): use_ssl = False if ssl in [False, 'False', 'false', 'FALSE'] else True c = Configuration(host, port, auth_mode, load_from_s3_arn, use_ssl, aws_region, proxy_host, proxy_port, - neptune_hosts=neptune_hosts) + sparql_section, gremlin_section, neptune_hosts) return c @@ -171,6 +195,15 @@ def generate_default_config(): parser.add_argument("--aws_region", help="aws region your ml cluster is in.", default=DEFAULT_REGION) parser.add_argument("--proxy_host", help="the proxy host url to route a connection through", default='') parser.add_argument("--proxy_port", help="the proxy port to use when creating proxy connection", default=8182) + parser.add_argument("--sparql_path", help="the namespace path to append to the SPARQL endpoint", + default=SPARQL_ACTION) + parser.add_argument("--gremlin_traversal_source", help="the traversal source to use for Gremlin queries", + default=DEFAULT_GREMLIN_TRAVERSAL_SOURCE) + parser.add_argument("--gremlin_username", help="the username to use when creating Gremlin connections", default='') + parser.add_argument("--gremlin_password", help="the password to use when creating Gremlin connections", default='') + parser.add_argument("--gremlin_serializer", + help="the serializer to use as the encoding format when creating Gremlin connections", + default=DEFAULT_GREMLIN_SERIALIZER) parser.add_argument("--neptune_hosts", help="list of host snippets to use for identifying neptune endpoints", default=DEFAULT_CONFIG_LOCATION) args = parser.parse_args() @@ -178,7 +211,10 @@ def generate_default_config(): auth_mode_arg = args.auth_mode if args.auth_mode != '' else AuthModeEnum.DEFAULT.value config = generate_config(args.host, int(args.port), AuthModeEnum(auth_mode_arg), args.ssl, args.load_from_s3_arn, args.aws_region, args.proxy_host, int(args.proxy_port), - neptune_hosts=args.neptune_hosts) + SparqlSection(args.sparql_path, ''), + GremlinSection(args.gremlin_traversal_source, args.gremlin_username, + args.gremlin_serializer), + args.neptune_hosts) config.write_to_file(args.config_destination) exit(0) diff --git a/src/graph_notebook/configuration/get_config.py b/src/graph_notebook/configuration/get_config.py index dfffb6f9..618bbde4 100644 --- a/src/graph_notebook/configuration/get_config.py +++ b/src/graph_notebook/configuration/get_config.py @@ -13,7 +13,7 @@ def get_config_from_dict(data: dict, neptune_hosts: list = NEPTUNE_CONFIG_HOST_IDENTIFIERS) -> Configuration: sparql_section = SparqlSection(**data['sparql']) if 'sparql' in data else SparqlSection('') - gremlin_section = GremlinSection(**data['gremlin']) if 'gremlin' in data else GremlinSection('') + gremlin_section = GremlinSection(**data['gremlin']) if 'gremlin' in data else GremlinSection() proxy_host = str(data['proxy_host']) if 'proxy_host' in data else '' proxy_port = int(data['proxy_port']) if 'proxy_port' in data else 8182 diff --git a/src/graph_notebook/magics/graph_magic.py b/src/graph_notebook/magics/graph_magic.py index 17de9330..fb105d90 100644 --- a/src/graph_notebook/magics/graph_magic.py +++ b/src/graph_notebook/magics/graph_magic.py @@ -277,7 +277,8 @@ def _generate_client_from_config(self, config: Configuration): .with_tls(config.ssl) \ .with_proxy_host(config.proxy_host) \ .with_proxy_port(config.proxy_port) \ - .with_sparql_path(config.sparql.path) + .with_sparql_path(config.sparql.path) \ + .with_gremlin_serializer(config.gremlin.message_serializer) if config.auth_mode == AuthModeEnum.IAM: builder = builder.with_iam(get_session()) if self.neptune_cfg_allowlist != NEPTUNE_CONFIG_HOST_IDENTIFIERS: @@ -288,7 +289,9 @@ def _generate_client_from_config(self, config: Configuration): .with_port(config.port) \ .with_tls(config.ssl) \ .with_sparql_path(config.sparql.path) \ - .with_gremlin_traversal_source(config.gremlin.traversal_source) + .with_gremlin_traversal_source(config.gremlin.traversal_source) \ + .with_gremlin_login(config.gremlin.username, config.gremlin.password) \ + .with_gremlin_serializer(config.gremlin.message_serializer) self.client = builder.build() diff --git a/src/graph_notebook/neptune/client.py b/src/graph_notebook/neptune/client.py index 29ff0310..7db5fa36 100644 --- a/src/graph_notebook/neptune/client.py +++ b/src/graph_notebook/neptune/client.py @@ -14,16 +14,17 @@ from botocore.session import Session as botocoreSession from botocore.auth import SigV4Auth from botocore.awsrequest import AWSRequest -from gremlin_python.driver import client +from gremlin_python.driver import client, serializer from gremlin_python.driver.protocol import GremlinServerError from neo4j import GraphDatabase import nest_asyncio -# from graph_notebook.magics.graph_magic import NEPTUNE_CONFIG_HOST_IDENTIFIERS # This patch is no longer needed when graph_notebook is using the a Gremlin Python # client >= 3.5.0 as the HashableDict is now part of that client driver. # import graph_notebook.neptune.gremlin.graphsonV3d0_MapType_objectify_patch # noqa F401 +DEFAULT_GREMLIN_SERIALIZER = 'graphsonv3' +DEFAULT_GREMLIN_TRAVERSAL_SOURCE = 'g' DEFAULT_SPARQL_CONTENT_TYPE = 'application/x-www-form-urlencoded' DEFAULT_PORT = 8182 DEFAULT_REGION = 'us-east-1' @@ -94,9 +95,14 @@ NEPTUNE_CONFIG_HOST_IDENTIFIERS = ["amazonaws.com"] +GRAPHSONV3_VARIANTS = ['graphsonv3', 'graphsonv3d0', 'graphsonserializersv3d0'] +GRAPHSONV2_VARIANTS = ['graphsonv2', 'graphsonv2d0', 'graphsonserializersv2d0'] +GRAPHBINARYV1_VARIANTS = ['graphbinaryv1', 'graphbinary', 'graphbinaryserializersv1'] + STATISTICS_MODES = ["status", "disableAutoCompute", "enableAutoCompute", "refresh", "delete"] STATISTICS_LANGUAGE_INPUTS = ["propertygraph", "pg", "gremlin", "sparql", "rdf"] + def is_allowed_neptune_host(hostname: str, host_allowlist: list): for host_snippet in host_allowlist: if host_snippet in hostname: @@ -104,9 +110,22 @@ def is_allowed_neptune_host(hostname: str, host_allowlist: list): return False +def get_gremlin_serializer(serializer_str: str): + serializer_lower = serializer_str.lower() + if serializer_lower == 'graphbinaryv1': + return serializer.GraphBinarySerializersV1() + elif serializer_lower == 'graphsonv2': + return serializer.GraphSONSerializersV2d0() + else: + return serializer.GraphSONSerializersV3d0() + + class Client(object): def __init__(self, host: str, port: int = DEFAULT_PORT, ssl: bool = True, region: str = DEFAULT_REGION, - sparql_path: str = '/sparql', gremlin_traversal_source: str = 'g', auth=None, session: Session = None, + sparql_path: str = '/sparql', gremlin_traversal_source: str = DEFAULT_GREMLIN_TRAVERSAL_SOURCE, + gremlin_username: str = '', gremlin_password: str = '', + gremlin_serializer: str = DEFAULT_GREMLIN_SERIALIZER, + auth=None, session: Session = None, proxy_host: str = '', proxy_port: int = DEFAULT_PORT, neptune_hosts: list = None): self.target_host = host @@ -114,6 +133,9 @@ def __init__(self, host: str, port: int = DEFAULT_PORT, ssl: bool = True, region self.ssl = ssl self.sparql_path = sparql_path self.gremlin_traversal_source = gremlin_traversal_source + self.gremlin_username = gremlin_username + self.gremlin_password = gremlin_password + self.gremlin_serializer = get_gremlin_serializer(gremlin_serializer) self.region = region self._auth = auth self._session = session @@ -225,7 +247,9 @@ def get_gremlin_connection(self, transport_kwargs) -> client.Client: ws_url = f'{self.get_uri_with_port(use_websocket=True)}/gremlin' request = self._prepare_request('GET', ws_url) traversal_source = 'g' if self.is_neptune_domain() else self.gremlin_traversal_source - return client.Client(ws_url, traversal_source, headers=dict(request.headers), **transport_kwargs) + return client.Client(ws_url, traversal_source, username=self.gremlin_username, + password=self.gremlin_password, message_serializer=self.gremlin_serializer, + headers=dict(request.headers), **transport_kwargs) def gremlin_query(self, query, transport_args=None, bindings=None): if transport_args is None: @@ -756,6 +780,15 @@ def with_gremlin_traversal_source(self, traversal_source: str): self.args['gremlin_traversal_source'] = traversal_source return ClientBuilder(self.args) + def with_gremlin_login(self, username: str, password: str): + self.args['gremlin_username'] = username + self.args['gremlin_password'] = password + return ClientBuilder(self.args) + + def with_gremlin_serializer(self, message_serializer: str): + self.args['gremlin_serializer'] = message_serializer + return ClientBuilder(self.args) + def with_tls(self, tls: bool): self.args['ssl'] = tls return ClientBuilder(self.args) diff --git a/test/integration/IntegrationTest.py b/test/integration/IntegrationTest.py index e6ffe490..5134a023 100644 --- a/test/integration/IntegrationTest.py +++ b/test/integration/IntegrationTest.py @@ -23,7 +23,8 @@ def setup_client_builder(config: Configuration) -> ClientBuilder: .with_proxy_host(config.proxy_host) \ .with_proxy_port(config.proxy_port) \ .with_sparql_path(config.sparql.path) \ - .with_gremlin_traversal_source(config.gremlin.traversal_source) + .with_gremlin_traversal_source(config.gremlin.traversal_source) \ + .with_gremlin_serializer(config.gremlin.message_serializer) if config.auth_mode == AuthModeEnum.IAM: builder = builder.with_iam(get_session()) else: @@ -34,7 +35,9 @@ def setup_client_builder(config: Configuration) -> ClientBuilder: .with_proxy_host(config.proxy_host) \ .with_proxy_port(config.proxy_port) \ .with_sparql_path(config.sparql.path) \ - .with_gremlin_traversal_source(config.gremlin.traversal_source) + .with_gremlin_traversal_source(config.gremlin.traversal_source) \ + .with_gremlin_login(config.gremlin.username, config.gremlin.password) \ + .with_gremlin_serializer(config.gremlin.message_serializer) return builder diff --git a/test/integration/iam/ml/__init__.py b/test/integration/iam/ml/__init__.py index 06dc9cd8..d842b83e 100644 --- a/test/integration/iam/ml/__init__.py +++ b/test/integration/iam/ml/__init__.py @@ -19,6 +19,8 @@ def setup_iam_client(config: Configuration) -> Client: .with_proxy_port(config.proxy_port) \ .with_sparql_path(config.sparql.path) \ .with_gremlin_traversal_source(config.gremlin.traversal_source) \ + .with_gremlin_login(config.gremlin.username, config.gremlin.password) \ + .with_gremlin_serializer(config.gremlin.message_serializer) \ .with_iam(get_session()) \ .build() @@ -29,5 +31,8 @@ def setup_iam_client(config: Configuration) -> Client: assert client.proxy_port == config.proxy_port assert client.sparql_path == config.sparql.path assert client.gremlin_traversal_source == config.gremlin.traversal_source + assert client.gremlin_username == config.gremlin.username + assert client.gremlin_password == config.gremlin.password + assert client.gremlin_serializer == config.gremlin.message_serializer assert client.ssl is config.ssl return client