Skip to content

Add neptune_config_allowlist magic #327

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions src/graph_notebook/configuration/generate_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
import os
from enum import Enum

from graph_notebook.neptune.client import SPARQL_ACTION, DEFAULT_PORT, DEFAULT_REGION
from graph_notebook.neptune.client import SPARQL_ACTION, DEFAULT_PORT, DEFAULT_REGION, \
NEPTUNE_CONFIG_HOST_IDENTIFIERS, is_allowed_neptune_host

DEFAULT_CONFIG_LOCATION = os.path.expanduser('~/graph_notebook_config.json')


Expand Down Expand Up @@ -67,14 +69,18 @@ def __init__(self, host: str, port: int,
auth_mode: AuthModeEnum = DEFAULT_AUTH_MODE,
load_from_s3_arn='', ssl: bool = True, aws_region: str = DEFAULT_REGION,
proxy_host: str = '', proxy_port: int = DEFAULT_PORT,
sparql_section: SparqlSection = None, gremlin_section: GremlinSection = None):
sparql_section: SparqlSection = None, gremlin_section: GremlinSection = None,
neptune_hosts: list = NEPTUNE_CONFIG_HOST_IDENTIFIERS):
self.host = host
self.port = port
self.ssl = ssl
self.proxy_host = proxy_host
self.proxy_port = proxy_port
self.sparql = sparql_section if sparql_section is not None else SparqlSection()
if "amazonaws.com" in self.host or "amazonaws.com" in self.proxy_host:

is_neptune_host = is_allowed_neptune_host(hostname=self.host, host_allowlist=neptune_hosts) \
or is_allowed_neptune_host(hostname=self.proxy_host, host_allowlist=neptune_hosts)
if is_neptune_host:
self.is_neptune_config = True
self.auth_mode = auth_mode
self.load_from_s3_arn = load_from_s3_arn
Expand Down Expand Up @@ -118,9 +124,11 @@ def write_to_file(self, file_path=DEFAULT_CONFIG_LOCATION):


def generate_config(host, port, auth_mode: AuthModeEnum = AuthModeEnum.DEFAULT, ssl: bool = True, load_from_s3_arn='',
aws_region: str = DEFAULT_REGION, proxy_host: str = '', proxy_port: int = DEFAULT_PORT):
aws_region: str = DEFAULT_REGION, proxy_host: str = '', proxy_port: int = DEFAULT_PORT,
neptune_hosts: list = NEPTUNE_CONFIG_HOST_IDENTIFIERS):
use_ssl = False if ssl in [False, 'False', 'false', 'FALSE'] else True
c = Configuration(host, port, auth_mode, load_from_s3_arn, use_ssl, aws_region, proxy_host, proxy_port)
c = Configuration(host, port, auth_mode, load_from_s3_arn, use_ssl, aws_region, proxy_host, proxy_port,
neptune_hosts=neptune_hosts)
return c


Expand All @@ -147,11 +155,14 @@ def generate_default_config():
parser.add_argument("--aws_region", help="aws region your ml cluster is in.", default=DEFAULT_REGION)
parser.add_argument("--proxy_host", help="the proxy host url to route a connection through", default='')
parser.add_argument("--proxy_port", help="the proxy port to use when creating proxy connection", default=8182)
parser.add_argument("--neptune_hosts", help="list of host snippets to use for identifying neptune endpoints",
default=DEFAULT_CONFIG_LOCATION)
args = parser.parse_args()

auth_mode_arg = args.auth_mode if args.auth_mode != '' else AuthModeEnum.DEFAULT.value
config = generate_config(args.host, int(args.port), AuthModeEnum(auth_mode_arg), args.ssl,
args.load_from_s3_arn, args.aws_region, args.proxy_host, int(args.proxy_port))
args.load_from_s3_arn, args.aws_region, args.proxy_host, int(args.proxy_port),
neptune_hosts=args.neptune_hosts)
config.write_to_file(args.config_destination)

exit(0)
16 changes: 11 additions & 5 deletions src/graph_notebook/configuration/get_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,35 @@

from graph_notebook.configuration.generate_config import DEFAULT_CONFIG_LOCATION, Configuration, AuthModeEnum, \
SparqlSection, GremlinSection
from graph_notebook.neptune.client import NEPTUNE_CONFIG_HOST_IDENTIFIERS, is_allowed_neptune_host


def get_config_from_dict(data: dict) -> Configuration:
def get_config_from_dict(data: dict, neptune_hosts: list = NEPTUNE_CONFIG_HOST_IDENTIFIERS) -> Configuration:

sparql_section = SparqlSection(**data['sparql']) if 'sparql' in data else SparqlSection('')
gremlin_section = GremlinSection(**data['gremlin']) if 'gremlin' in data else GremlinSection('')
proxy_host = str(data['proxy_host']) if 'proxy_host' in data else ''
proxy_port = int(data['proxy_port']) if 'proxy_port' in data else 8182
if "amazonaws.com" in data['host']:

is_neptune_host = is_allowed_neptune_host(hostname=data["host"], host_allowlist=neptune_hosts)

if is_neptune_host:
if gremlin_section.to_dict()['traversal_source'] != 'g':
print('Ignoring custom traversal source, Amazon Neptune does not support this functionality.\n')
config = Configuration(host=data['host'], port=data['port'], auth_mode=AuthModeEnum(data['auth_mode']),
ssl=data['ssl'], load_from_s3_arn=data['load_from_s3_arn'],
aws_region=data['aws_region'], sparql_section=sparql_section,
gremlin_section=gremlin_section, proxy_host=proxy_host,
proxy_port=proxy_port)
proxy_port=proxy_port, neptune_hosts=neptune_hosts)
else:
config = Configuration(host=data['host'], port=data['port'], ssl=data['ssl'], sparql_section=sparql_section,
gremlin_section=gremlin_section, proxy_host=proxy_host,
proxy_port=proxy_port)
return config


def get_config(path: str = DEFAULT_CONFIG_LOCATION) -> Configuration:
def get_config(path: str = DEFAULT_CONFIG_LOCATION,
neptune_hosts: list = NEPTUNE_CONFIG_HOST_IDENTIFIERS) -> Configuration:
with open(path) as config_file:
data = json.load(config_file)
return get_config_from_dict(data)
return get_config_from_dict(data=data, neptune_hosts=neptune_hosts)
58 changes: 53 additions & 5 deletions src/graph_notebook/magics/graph_magic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
import datetime
import os
import uuid
import ast
from enum import Enum
from copy import copy
from json import JSONDecodeError
from graph_notebook.network.opencypher.OCNetwork import OCNetwork

Expand All @@ -37,7 +39,8 @@
from graph_notebook.magics.streams import StreamViewer
from graph_notebook.neptune.client import ClientBuilder, Client, VALID_FORMATS, PARALLELISM_OPTIONS, PARALLELISM_HIGH, \
LOAD_JOB_MODES, MODE_AUTO, FINAL_LOAD_STATUSES, SPARQL_ACTION, FORMAT_CSV, FORMAT_OPENCYPHER, FORMAT_NTRIPLE, \
FORMAT_NQUADS, FORMAT_RDFXML, FORMAT_TURTLE, STREAM_RDF, STREAM_PG, STREAM_ENDPOINTS
FORMAT_NQUADS, FORMAT_RDFXML, FORMAT_TURTLE, STREAM_RDF, STREAM_PG, STREAM_ENDPOINTS, \
NEPTUNE_CONFIG_HOST_IDENTIFIERS, is_allowed_neptune_host
from graph_notebook.network import SPARQLNetwork
from graph_notebook.network.gremlin.GremlinNetwork import parse_pattern_list_str, GremlinNetwork
from graph_notebook.visualization.rows_and_columns import sparql_get_rows_and_columns, opencypher_get_rows_and_columns
Expand Down Expand Up @@ -182,11 +185,12 @@ def __init__(self, shell):
# You must call the parent constructor
super(Graph, self).__init__(shell)

self.neptune_cfg_allowlist = copy(NEPTUNE_CONFIG_HOST_IDENTIFIERS)
self.graph_notebook_config = generate_default_config()
try:
self.config_location = os.getenv('GRAPH_NOTEBOOK_CONFIG', DEFAULT_CONFIG_LOCATION)
self.client: Client = None
self.graph_notebook_config = get_config(self.config_location)
self.graph_notebook_config = get_config(self.config_location, neptune_hosts=self.neptune_cfg_allowlist)
except FileNotFoundError:
print('Could not find a valid configuration. '
'Do not forget to validate your settings using %graph_notebook_config.')
Expand All @@ -201,7 +205,9 @@ def _generate_client_from_config(self, config: Configuration):
if self.client:
self.client.close()

if "amazonaws.com" in config.host:
is_neptune_host = is_allowed_neptune_host(hostname=config.host, host_allowlist=self.neptune_cfg_allowlist)

if is_neptune_host:
builder = ClientBuilder() \
.with_host(config.host) \
.with_port(config.port) \
Expand All @@ -212,6 +218,8 @@ def _generate_client_from_config(self, config: Configuration):
.with_sparql_path(config.sparql.path)
if config.auth_mode == AuthModeEnum.IAM:
builder = builder.with_iam(get_session())
if self.neptune_cfg_allowlist != NEPTUNE_CONFIG_HOST_IDENTIFIERS:
builder = builder.with_custom_neptune_hosts(self.neptune_cfg_allowlist)
else:
builder = ClientBuilder() \
.with_host(config.host) \
Expand All @@ -229,13 +237,13 @@ def _generate_client_from_config(self, config: Configuration):
def graph_notebook_config(self, line='', cell='', local_ns: dict = None):
if cell != '':
data = json.loads(cell)
config = get_config_from_dict(data)
config = get_config_from_dict(data, neptune_hosts=self.neptune_cfg_allowlist)
self.graph_notebook_config = config
self._generate_client_from_config(config)
print('set notebook config to:')
print(json.dumps(self.graph_notebook_config.to_dict(), indent=2))
elif line == 'reset':
self.graph_notebook_config = get_config(self.config_location)
self.graph_notebook_config = get_config(self.config_location, neptune_hosts=self.neptune_cfg_allowlist)
print('reset notebook config to:')
print(json.dumps(self.graph_notebook_config.to_dict(), indent=2))
elif line == 'silent':
Expand All @@ -251,6 +259,46 @@ def graph_notebook_config(self, line='', cell='', local_ns: dict = None):

return self.graph_notebook_config

@line_cell_magic
def neptune_config_allowlist(self, line='', cell=''):
parser = argparse.ArgumentParser()
parser.add_argument('mode', nargs='?', default='add',
help='mode (default=add) [add|remove|overwrite|reset]')
args = parser.parse_args(line.split())

try:
cell_new = ast.literal_eval(cell)
input_type = 'list'
except:
cell_new = cell
input_type = 'string'

allowlist_modified = True
if args.mode == 'reset':
self.neptune_cfg_allowlist = copy(NEPTUNE_CONFIG_HOST_IDENTIFIERS)
elif cell != '':
if args.mode == 'add':
if input_type == 'string':
self.neptune_cfg_allowlist.append(cell_new.strip())
else:
self.neptune_cfg_allowlist = list(set(self.neptune_cfg_allowlist) | set(cell_new))
elif args.mode == 'remove':
if input_type == 'string':
self.neptune_cfg_allowlist.remove(cell_new.strip())
else:
self.neptune_cfg_allowlist = list(set(self.neptune_cfg_allowlist) - set(cell_new))
elif args.mode == 'overwrite':
if input_type == 'string':
self.neptune_cfg_allowlist = [cell_new.strip()]
else:
self.neptune_cfg_allowlist = cell_new
else:
allowlist_modified = False

if allowlist_modified:
print(f'Set Neptune config allow list to: {self.neptune_cfg_allowlist}')
else:
print(f'Current Neptune config allow list: {self.neptune_cfg_allowlist}')

@line_magic
def stream_viewer(self,line):
Expand Down
20 changes: 18 additions & 2 deletions src/graph_notebook/neptune/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from gremlin_python.driver.protocol import GremlinServerError
from neo4j import GraphDatabase
import nest_asyncio
# from graph_notebook.magics.graph_magic import NEPTUNE_CONFIG_HOST_IDENTIFIERS

# This patch is no longer needed when graph_notebook is using the a Gremlin Python
# client >= 3.5.0 as the HashableDict is now part of that client driver.
Expand Down Expand Up @@ -91,11 +92,21 @@
STREAM_RDF = 'RDF'
STREAM_ENDPOINTS = {STREAM_PG: 'gremlin', STREAM_RDF: 'sparql'}

NEPTUNE_CONFIG_HOST_IDENTIFIERS = ["amazonaws.com"]


def is_allowed_neptune_host(hostname: str, host_allowlist: list):
for host_snippet in host_allowlist:
if host_snippet in hostname:
return True
return False


class Client(object):
def __init__(self, host: str, port: int = DEFAULT_PORT, ssl: bool = True, region: str = DEFAULT_REGION,
sparql_path: str = '/sparql', gremlin_traversal_source: str = 'g', auth=None, session: Session = None,
proxy_host: str = '', proxy_port: int = DEFAULT_PORT):
proxy_host: str = '', proxy_port: int = DEFAULT_PORT,
neptune_hosts: list = None):
self.target_host = host
self.target_port = port
self.ssl = ssl
Expand All @@ -106,6 +117,7 @@ def __init__(self, host: str, port: int = DEFAULT_PORT, ssl: bool = True, region
self._session = session
self.proxy_host = proxy_host
self.proxy_port = proxy_port
self.neptune_hosts = NEPTUNE_CONFIG_HOST_IDENTIFIERS if neptune_hosts is None else neptune_hosts

self._http_protocol = 'https' if self.ssl else 'http'
self._ws_protocol = 'wss' if self.ssl else 'ws'
Expand All @@ -125,7 +137,7 @@ def port(self):
return self.target_port

def is_neptune_domain(self):
return "neptune.amazonaws.com" in self.target_host
return is_allowed_neptune_host(hostname=self.target_host, host_allowlist=self.neptune_hosts)

def get_uri_with_port(self, use_websocket=False, use_proxy=False):
protocol = self._http_protocol
Expand Down Expand Up @@ -743,5 +755,9 @@ def with_proxy_port(self, proxy_port: int):
self.args['proxy_port'] = proxy_port
return ClientBuilder(self.args)

def with_custom_neptune_hosts(self, neptune_hosts: list):
self.args['neptune_hosts'] = neptune_hosts
return ClientBuilder(self.args)

def build(self) -> Client:
return Client(**self.args)
49 changes: 47 additions & 2 deletions test/unit/configuration/test_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ def setUpClass(cls) -> None:
cls.generic_host = 'blah'
cls.neptune_host_reg = 'instance.cluster.us-west-2.neptune.amazonaws.com'
cls.neptune_host_cn = 'instance.cluster.neptune.cn-north-1.amazonaws.com.cn'
cls.neptune_host_custom = 'localhost'
cls.port = 8182
cls.custom_hosts_list = ['localhost']
cls.test_file_path = f'{os.path.abspath(os.path.curdir)}/test_configuration_file.json'

def tearDown(self) -> None:
Expand All @@ -39,6 +41,13 @@ def test_configuration_default_auth_defaults_neptune_cn(self):
self.assertEqual(True, config.ssl)
self.assertEqual('', config.load_from_s3_arn)

def test_configuration_default_auth_defaults_neptune_custom(self):
config = Configuration(self.neptune_host_custom, self.port, neptune_hosts=self.custom_hosts_list)
self.assertEqual(self.neptune_host_custom, config.host)
self.assertEqual(self.port, config.port)
self.assertEqual(DEFAULT_AUTH_MODE, config.auth_mode)
self.assertEqual(True, config.ssl)

def test_configuration_default_auth_defaults_generic(self):
config = Configuration(self.generic_host, self.port)
self.assertEqual(self.generic_host, config.host)
Expand All @@ -49,7 +58,8 @@ def test_configuration_override_defaults_neptune_reg(self):
auth_mode = AuthModeEnum.IAM
ssl = False
loader_arn = 'foo'
config = Configuration(self.neptune_host_reg, self.port, auth_mode=auth_mode, load_from_s3_arn=loader_arn, ssl=ssl)
config = Configuration(self.neptune_host_reg, self.port, auth_mode=auth_mode, load_from_s3_arn=loader_arn,
ssl=ssl)
self.assertEqual(auth_mode, config.auth_mode)
self.assertEqual(ssl, config.ssl)
self.assertEqual(loader_arn, config.load_from_s3_arn)
Expand All @@ -58,7 +68,18 @@ def test_configuration_override_defaults_neptune_cn(self):
auth_mode = AuthModeEnum.IAM
ssl = False
loader_arn = 'foo'
config = Configuration(self.neptune_host_cn, self.port, auth_mode=auth_mode, load_from_s3_arn=loader_arn, ssl=ssl)
config = Configuration(self.neptune_host_cn, self.port, auth_mode=auth_mode, load_from_s3_arn=loader_arn,
ssl=ssl)
self.assertEqual(auth_mode, config.auth_mode)
self.assertEqual(ssl, config.ssl)
self.assertEqual(loader_arn, config.load_from_s3_arn)

def test_configuration_override_defaults_neptune_custom(self):
auth_mode = AuthModeEnum.IAM
ssl = False
loader_arn = 'foo'
config = Configuration(self.neptune_host_custom, self.port, auth_mode=auth_mode, load_from_s3_arn=loader_arn,
ssl=ssl, neptune_hosts=self.custom_hosts_list)
self.assertEqual(auth_mode, config.auth_mode)
self.assertEqual(ssl, config.ssl)
self.assertEqual(loader_arn, config.load_from_s3_arn)
Expand All @@ -84,6 +105,15 @@ def test_generate_configuration_with_defaults_neptune_cn(self):
config_from_file = get_config(self.test_file_path)
self.assertEqual(config.to_dict(), config_from_file.to_dict())

def test_generate_configuration_with_defaults_neptune_custom(self):
config = Configuration(self.neptune_host_custom, self.port, neptune_hosts=self.custom_hosts_list)
c = generate_config(config.host, config.port, auth_mode=config.auth_mode, ssl=config.ssl,
load_from_s3_arn=config.load_from_s3_arn, aws_region=config.aws_region,
neptune_hosts=self.custom_hosts_list)
c.write_to_file(self.test_file_path)
config_from_file = get_config(self.test_file_path, neptune_hosts=self.custom_hosts_list)
self.assertEqual(config.to_dict(), config_from_file.to_dict())

def test_generate_configuration_with_defaults_generic(self):
config = Configuration(self.generic_host, self.port)
c = generate_config(config.host, config.port, ssl=config.ssl)
Expand Down Expand Up @@ -119,6 +149,21 @@ def test_generate_configuration_override_defaults_neptune_cn(self):
config_from_file = get_config(self.test_file_path)
self.assertEqual(config.to_dict(), config_from_file.to_dict())

def test_generate_configuration_override_defaults_neptune_custom(self):
auth_mode = AuthModeEnum.IAM
ssl = False
loader_arn = 'foo'
aws_region = 'cn-north-1'
config = Configuration(self.neptune_host_custom, self.port, auth_mode=auth_mode, load_from_s3_arn=loader_arn,
ssl=ssl, aws_region=aws_region, neptune_hosts=self.custom_hosts_list)

c = generate_config(config.host, config.port, auth_mode=config.auth_mode, ssl=config.ssl,
load_from_s3_arn=config.load_from_s3_arn, aws_region=config.aws_region,
neptune_hosts=self.custom_hosts_list)
c.write_to_file(self.test_file_path)
config_from_file = get_config(self.test_file_path, neptune_hosts=self.custom_hosts_list)
self.assertEqual(config.to_dict(), config_from_file.to_dict())

def test_generate_configuration_override_defaults_generic(self):
ssl = False
config = Configuration(self.generic_host, self.port, ssl=ssl)
Expand Down
Loading