Skip to content

Make Gremlin traversal source configurable #221

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Nov 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ChangeLog.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Starting with v1.31.6, this file will contain a record of major features and updates made in each release of graph-notebook.

## Upcoming
- Added support for specifying the Gremlin traversal source ([Link to PR](https://github.com/aws/graph-notebook/pull/221))
- Added edge tooltips, and options for specifying edge label length ([Link to PR](https://github.com/aws/graph-notebook/pull/218))

## Release 3.0.7 (October 25, 2021)
Expand Down
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,17 @@ python -m graph_notebook.start_notebook --notebooks-dir ~/notebook/destination/d

### Gremlin Server

In a new cell in the Jupyter notebook, change the configuration using `%%graph_notebook_config` and modify the fields for `host`, `port`, and `ssl`. For a local Gremlin server (HTTP or WebSockets), you can use the following command:
In a new cell in the Jupyter notebook, change the configuration using `%%graph_notebook_config` and modify the fields for `host`, `port`, and `ssl`. Optionally, modify `traversal_source` if your graph traversal source name differs from the default value. For a local Gremlin server (HTTP or WebSockets), you can use the following command:

```
%%graph_notebook_config
{
"host": "localhost",
"port": 8182,
"ssl": false
"ssl": false,
"gremlin": {
"traversal_source": "g"
}
}
```

Expand Down
30 changes: 27 additions & 3 deletions src/graph_notebook/configuration/generate_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,31 @@ def to_dict(self):
return self.__dict__


class GremlinSection(object):
"""
Used for gremlin-specific settings in a notebook's configuration
"""

def __init__(self, traversal_source: str = ''):
"""
:param traversal_source: used to specify the traversal source for a Gremlin traversal, in the case that we are
connected to an endpoint that can access multiple graphs.
"""

if traversal_source == '':
traversal_source = 'g'

self.traversal_source = traversal_source

def to_dict(self):
return self.__dict__


class Configuration(object):
def __init__(self, host: str, port: int,
auth_mode: AuthModeEnum = AuthModeEnum.DEFAULT,
load_from_s3_arn='', ssl: bool = True, aws_region: str = 'us-east-1',
sparql_section: SparqlSection = None):
sparql_section: SparqlSection = None, gremlin_section: GremlinSection = None):
self.host = host
self.port = port
self.ssl = ssl
Expand All @@ -56,8 +76,10 @@ def __init__(self, host: str, port: int,
self.auth_mode = auth_mode
self.load_from_s3_arn = load_from_s3_arn
self.aws_region = aws_region
self.gremlin = GremlinSection()
else:
self.is_neptune_config = False
self.gremlin = gremlin_section if gremlin_section is not None else GremlinSection()

def to_dict(self) -> dict:
if self.is_neptune_config:
Expand All @@ -68,14 +90,16 @@ def to_dict(self) -> dict:
'load_from_s3_arn': self.load_from_s3_arn,
'ssl': self.ssl,
'aws_region': self.aws_region,
'sparql': self.sparql.to_dict()
'sparql': self.sparql.to_dict(),
'gremlin': self.gremlin.to_dict()
}
else:
return {
'host': self.host,
'port': self.port,
'ssl': self.ssl,
'sparql': self.sparql.to_dict()
'sparql': self.sparql.to_dict(),
'gremlin': self.gremlin.to_dict()
}

def write_to_file(self, file_path=DEFAULT_CONFIG_LOCATION):
Expand Down
11 changes: 8 additions & 3 deletions src/graph_notebook/configuration/get_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,22 @@
import json

from graph_notebook.configuration.generate_config import DEFAULT_CONFIG_LOCATION, Configuration, AuthModeEnum, \
SparqlSection
SparqlSection, GremlinSection


def get_config_from_dict(data: dict) -> Configuration:
sparql_section = SparqlSection(**data['sparql']) if 'sparql' in data else SparqlSection('')
gremlin_section = GremlinSection(**data['gremlin']) if 'gremlin' in data else GremlinSection('')
if "amazonaws.com" in data['host']:
if gremlin_section.to_dict()['traversal_source'] != 'g':
print('Ignoring custom traversal source, Amazon Neptune does not support this functionality.\n')
config = Configuration(host=data['host'], port=data['port'], auth_mode=AuthModeEnum(data['auth_mode']),
ssl=data['ssl'], load_from_s3_arn=data['load_from_s3_arn'],
aws_region=data['aws_region'], sparql_section=sparql_section)
aws_region=data['aws_region'], sparql_section=sparql_section,
gremlin_section=gremlin_section)
else:
config = Configuration(host=data['host'], port=data['port'], ssl=data['ssl'], sparql_section=sparql_section)
config = Configuration(host=data['host'], port=data['port'], ssl=data['ssl'], sparql_section=sparql_section,
gremlin_section=gremlin_section)
return config


Expand Down
3 changes: 2 additions & 1 deletion src/graph_notebook/magics/graph_magic.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,8 @@ def _generate_client_from_config(self, config: Configuration):
.with_host(config.host) \
.with_port(config.port) \
.with_tls(config.ssl) \
.with_sparql_path(config.sparql.path)
.with_sparql_path(config.sparql.path) \
.with_gremlin_traversal_source(config.gremlin.traversal_source)

self.client = builder.build()

Expand Down
17 changes: 15 additions & 2 deletions src/graph_notebook/neptune/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from gremlin_python.driver import client
from gremlin_python.driver.protocol import GremlinServerError
from neo4j import GraphDatabase
import nest_asyncio

Expand Down Expand Up @@ -79,11 +80,12 @@

class Client(object):
def __init__(self, host: str, port: int = DEFAULT_PORT, ssl: bool = True, region: str = DEFAULT_REGION,
sparql_path: str = '/sparql', auth=None, session: Session = None):
sparql_path: str = '/sparql', gremlin_traversal_source: str = 'g', auth=None, session: Session = None):
self.host = host
self.port = port
self.ssl = ssl
self.sparql_path = sparql_path
self.gremlin_traversal_source = gremlin_traversal_source
self.region = region
self._auth = auth
self._session = session
Expand Down Expand Up @@ -174,7 +176,9 @@ def get_gremlin_connection(self) -> client.Client:
request = self._prepare_request('GET', uri)

ws_url = f'{self._ws_protocol}://{self.host}:{self.port}/gremlin'
return client.Client(ws_url, 'g', headers=dict(request.headers))

traversal_source = 'g' if "neptune.amazonaws.com" in self.host else self.gremlin_traversal_source
return client.Client(ws_url, traversal_source, headers=dict(request.headers))

def gremlin_query(self, query, bindings=None):
c = self.get_gremlin_connection()
Expand All @@ -185,6 +189,11 @@ def gremlin_query(self, query, bindings=None):
c.close()
return results
except Exception as e:
if isinstance(e, GremlinServerError):
if e.status_code == 499:
print("Error returned by the Gremlin Server for the traversal_source specified in notebook "
"configuration. Please ensure that your graph database endpoint supports re-naming of "
"GraphTraversalSource from the default of 'g' in Gremlin Server.")
c.close()
raise e

Expand Down Expand Up @@ -667,6 +676,10 @@ def with_sparql_path(self, path: str):
self.args['sparql_path'] = path
return ClientBuilder(self.args)

def with_gremlin_traversal_source(self, traversal_source: str):
self.args['gremlin_traversal_source'] = traversal_source
return ClientBuilder(self.args)

def with_tls(self, tls: bool):
self.args['ssl'] = tls
return ClientBuilder(self.args)
Expand Down
3 changes: 2 additions & 1 deletion test/integration/IntegrationTest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def setup_client_builder(config: Configuration) -> ClientBuilder:
.with_port(config.port) \
.with_region(config.aws_region) \
.with_tls(config.ssl) \
.with_sparql_path(config.sparql.path)
.with_sparql_path(config.sparql.path) \
.with_gremlin_traversal_source(config.gremlin.traversal_source)

if config.auth_mode == AuthModeEnum.IAM:
builder = builder.with_iam(get_session())
Expand Down
2 changes: 2 additions & 0 deletions test/integration/iam/ml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@ def setup_iam_client(config: Configuration) -> Client:
.with_region(config.aws_region) \
.with_tls(config.ssl) \
.with_sparql_path(config.sparql.path) \
.with_gremlin_traversal_source(config.gremlin.traversal_source) \
.with_iam(get_session()) \
.build()

assert client.host == config.host
assert client.port == config.port
assert client.region == config.aws_region
assert client.sparql_path == config.sparql.path
assert client.gremlin_traversal_source == config.gremlin.traversal_source
assert client.ssl is config.ssl
return client