Skip to content

Commit 39590fb

Browse files
authored
Add neptune_config_allowlist magic (#327)
* Add neptune_config_allowlist magic * Refactor duplicate code Co-authored-by: Michael Chin <chnmch@amazon.com>
1 parent 677de34 commit 39590fb

File tree

6 files changed

+158
-20
lines changed

6 files changed

+158
-20
lines changed

src/graph_notebook/configuration/generate_config.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
import os
99
from enum import Enum
1010

11-
from graph_notebook.neptune.client import SPARQL_ACTION, DEFAULT_PORT, DEFAULT_REGION
11+
from graph_notebook.neptune.client import SPARQL_ACTION, DEFAULT_PORT, DEFAULT_REGION, \
12+
NEPTUNE_CONFIG_HOST_IDENTIFIERS, is_allowed_neptune_host
13+
1214
DEFAULT_CONFIG_LOCATION = os.path.expanduser('~/graph_notebook_config.json')
1315

1416

@@ -67,14 +69,18 @@ def __init__(self, host: str, port: int,
6769
auth_mode: AuthModeEnum = DEFAULT_AUTH_MODE,
6870
load_from_s3_arn='', ssl: bool = True, aws_region: str = DEFAULT_REGION,
6971
proxy_host: str = '', proxy_port: int = DEFAULT_PORT,
70-
sparql_section: SparqlSection = None, gremlin_section: GremlinSection = None):
72+
sparql_section: SparqlSection = None, gremlin_section: GremlinSection = None,
73+
neptune_hosts: list = NEPTUNE_CONFIG_HOST_IDENTIFIERS):
7174
self.host = host
7275
self.port = port
7376
self.ssl = ssl
7477
self.proxy_host = proxy_host
7578
self.proxy_port = proxy_port
7679
self.sparql = sparql_section if sparql_section is not None else SparqlSection()
77-
if "amazonaws.com" in self.host or "amazonaws.com" in self.proxy_host:
80+
81+
is_neptune_host = is_allowed_neptune_host(hostname=self.host, host_allowlist=neptune_hosts) \
82+
or is_allowed_neptune_host(hostname=self.proxy_host, host_allowlist=neptune_hosts)
83+
if is_neptune_host:
7884
self.is_neptune_config = True
7985
self.auth_mode = auth_mode
8086
self.load_from_s3_arn = load_from_s3_arn
@@ -118,9 +124,11 @@ def write_to_file(self, file_path=DEFAULT_CONFIG_LOCATION):
118124

119125

120126
def generate_config(host, port, auth_mode: AuthModeEnum = AuthModeEnum.DEFAULT, ssl: bool = True, load_from_s3_arn='',
121-
aws_region: str = DEFAULT_REGION, proxy_host: str = '', proxy_port: int = DEFAULT_PORT):
127+
aws_region: str = DEFAULT_REGION, proxy_host: str = '', proxy_port: int = DEFAULT_PORT,
128+
neptune_hosts: list = NEPTUNE_CONFIG_HOST_IDENTIFIERS):
122129
use_ssl = False if ssl in [False, 'False', 'false', 'FALSE'] else True
123-
c = Configuration(host, port, auth_mode, load_from_s3_arn, use_ssl, aws_region, proxy_host, proxy_port)
130+
c = Configuration(host, port, auth_mode, load_from_s3_arn, use_ssl, aws_region, proxy_host, proxy_port,
131+
neptune_hosts=neptune_hosts)
124132
return c
125133

126134

@@ -147,11 +155,14 @@ def generate_default_config():
147155
parser.add_argument("--aws_region", help="aws region your ml cluster is in.", default=DEFAULT_REGION)
148156
parser.add_argument("--proxy_host", help="the proxy host url to route a connection through", default='')
149157
parser.add_argument("--proxy_port", help="the proxy port to use when creating proxy connection", default=8182)
158+
parser.add_argument("--neptune_hosts", help="list of host snippets to use for identifying neptune endpoints",
159+
default=DEFAULT_CONFIG_LOCATION)
150160
args = parser.parse_args()
151161

152162
auth_mode_arg = args.auth_mode if args.auth_mode != '' else AuthModeEnum.DEFAULT.value
153163
config = generate_config(args.host, int(args.port), AuthModeEnum(auth_mode_arg), args.ssl,
154-
args.load_from_s3_arn, args.aws_region, args.proxy_host, int(args.proxy_port))
164+
args.load_from_s3_arn, args.aws_region, args.proxy_host, int(args.proxy_port),
165+
neptune_hosts=args.neptune_hosts)
155166
config.write_to_file(args.config_destination)
156167

157168
exit(0)

src/graph_notebook/configuration/get_config.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,29 +7,35 @@
77

88
from graph_notebook.configuration.generate_config import DEFAULT_CONFIG_LOCATION, Configuration, AuthModeEnum, \
99
SparqlSection, GremlinSection
10+
from graph_notebook.neptune.client import NEPTUNE_CONFIG_HOST_IDENTIFIERS, is_allowed_neptune_host
1011

1112

12-
def get_config_from_dict(data: dict) -> Configuration:
13+
def get_config_from_dict(data: dict, neptune_hosts: list = NEPTUNE_CONFIG_HOST_IDENTIFIERS) -> Configuration:
14+
1315
sparql_section = SparqlSection(**data['sparql']) if 'sparql' in data else SparqlSection('')
1416
gremlin_section = GremlinSection(**data['gremlin']) if 'gremlin' in data else GremlinSection('')
1517
proxy_host = str(data['proxy_host']) if 'proxy_host' in data else ''
1618
proxy_port = int(data['proxy_port']) if 'proxy_port' in data else 8182
17-
if "amazonaws.com" in data['host']:
19+
20+
is_neptune_host = is_allowed_neptune_host(hostname=data["host"], host_allowlist=neptune_hosts)
21+
22+
if is_neptune_host:
1823
if gremlin_section.to_dict()['traversal_source'] != 'g':
1924
print('Ignoring custom traversal source, Amazon Neptune does not support this functionality.\n')
2025
config = Configuration(host=data['host'], port=data['port'], auth_mode=AuthModeEnum(data['auth_mode']),
2126
ssl=data['ssl'], load_from_s3_arn=data['load_from_s3_arn'],
2227
aws_region=data['aws_region'], sparql_section=sparql_section,
2328
gremlin_section=gremlin_section, proxy_host=proxy_host,
24-
proxy_port=proxy_port)
29+
proxy_port=proxy_port, neptune_hosts=neptune_hosts)
2530
else:
2631
config = Configuration(host=data['host'], port=data['port'], ssl=data['ssl'], sparql_section=sparql_section,
2732
gremlin_section=gremlin_section, proxy_host=proxy_host,
2833
proxy_port=proxy_port)
2934
return config
3035

3136

32-
def get_config(path: str = DEFAULT_CONFIG_LOCATION) -> Configuration:
37+
def get_config(path: str = DEFAULT_CONFIG_LOCATION,
38+
neptune_hosts: list = NEPTUNE_CONFIG_HOST_IDENTIFIERS) -> Configuration:
3339
with open(path) as config_file:
3440
data = json.load(config_file)
35-
return get_config_from_dict(data)
41+
return get_config_from_dict(data=data, neptune_hosts=neptune_hosts)

src/graph_notebook/magics/graph_magic.py

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
import datetime
1414
import os
1515
import uuid
16+
import ast
1617
from enum import Enum
18+
from copy import copy
1719
from json import JSONDecodeError
1820
from graph_notebook.network.opencypher.OCNetwork import OCNetwork
1921

@@ -37,7 +39,8 @@
3739
from graph_notebook.magics.streams import StreamViewer
3840
from graph_notebook.neptune.client import ClientBuilder, Client, VALID_FORMATS, PARALLELISM_OPTIONS, PARALLELISM_HIGH, \
3941
LOAD_JOB_MODES, MODE_AUTO, FINAL_LOAD_STATUSES, SPARQL_ACTION, FORMAT_CSV, FORMAT_OPENCYPHER, FORMAT_NTRIPLE, \
40-
FORMAT_NQUADS, FORMAT_RDFXML, FORMAT_TURTLE, STREAM_RDF, STREAM_PG, STREAM_ENDPOINTS
42+
FORMAT_NQUADS, FORMAT_RDFXML, FORMAT_TURTLE, STREAM_RDF, STREAM_PG, STREAM_ENDPOINTS, \
43+
NEPTUNE_CONFIG_HOST_IDENTIFIERS, is_allowed_neptune_host
4144
from graph_notebook.network import SPARQLNetwork
4245
from graph_notebook.network.gremlin.GremlinNetwork import parse_pattern_list_str, GremlinNetwork
4346
from graph_notebook.visualization.rows_and_columns import sparql_get_rows_and_columns, opencypher_get_rows_and_columns
@@ -182,11 +185,12 @@ def __init__(self, shell):
182185
# You must call the parent constructor
183186
super(Graph, self).__init__(shell)
184187

188+
self.neptune_cfg_allowlist = copy(NEPTUNE_CONFIG_HOST_IDENTIFIERS)
185189
self.graph_notebook_config = generate_default_config()
186190
try:
187191
self.config_location = os.getenv('GRAPH_NOTEBOOK_CONFIG', DEFAULT_CONFIG_LOCATION)
188192
self.client: Client = None
189-
self.graph_notebook_config = get_config(self.config_location)
193+
self.graph_notebook_config = get_config(self.config_location, neptune_hosts=self.neptune_cfg_allowlist)
190194
except FileNotFoundError:
191195
print('Could not find a valid configuration. '
192196
'Do not forget to validate your settings using %graph_notebook_config.')
@@ -201,7 +205,9 @@ def _generate_client_from_config(self, config: Configuration):
201205
if self.client:
202206
self.client.close()
203207

204-
if "amazonaws.com" in config.host:
208+
is_neptune_host = is_allowed_neptune_host(hostname=config.host, host_allowlist=self.neptune_cfg_allowlist)
209+
210+
if is_neptune_host:
205211
builder = ClientBuilder() \
206212
.with_host(config.host) \
207213
.with_port(config.port) \
@@ -212,6 +218,8 @@ def _generate_client_from_config(self, config: Configuration):
212218
.with_sparql_path(config.sparql.path)
213219
if config.auth_mode == AuthModeEnum.IAM:
214220
builder = builder.with_iam(get_session())
221+
if self.neptune_cfg_allowlist != NEPTUNE_CONFIG_HOST_IDENTIFIERS:
222+
builder = builder.with_custom_neptune_hosts(self.neptune_cfg_allowlist)
215223
else:
216224
builder = ClientBuilder() \
217225
.with_host(config.host) \
@@ -229,13 +237,13 @@ def _generate_client_from_config(self, config: Configuration):
229237
def graph_notebook_config(self, line='', cell='', local_ns: dict = None):
230238
if cell != '':
231239
data = json.loads(cell)
232-
config = get_config_from_dict(data)
240+
config = get_config_from_dict(data, neptune_hosts=self.neptune_cfg_allowlist)
233241
self.graph_notebook_config = config
234242
self._generate_client_from_config(config)
235243
print('set notebook config to:')
236244
print(json.dumps(self.graph_notebook_config.to_dict(), indent=2))
237245
elif line == 'reset':
238-
self.graph_notebook_config = get_config(self.config_location)
246+
self.graph_notebook_config = get_config(self.config_location, neptune_hosts=self.neptune_cfg_allowlist)
239247
print('reset notebook config to:')
240248
print(json.dumps(self.graph_notebook_config.to_dict(), indent=2))
241249
elif line == 'silent':
@@ -251,6 +259,46 @@ def graph_notebook_config(self, line='', cell='', local_ns: dict = None):
251259

252260
return self.graph_notebook_config
253261

262+
@line_cell_magic
263+
def neptune_config_allowlist(self, line='', cell=''):
264+
parser = argparse.ArgumentParser()
265+
parser.add_argument('mode', nargs='?', default='add',
266+
help='mode (default=add) [add|remove|overwrite|reset]')
267+
args = parser.parse_args(line.split())
268+
269+
try:
270+
cell_new = ast.literal_eval(cell)
271+
input_type = 'list'
272+
except:
273+
cell_new = cell
274+
input_type = 'string'
275+
276+
allowlist_modified = True
277+
if args.mode == 'reset':
278+
self.neptune_cfg_allowlist = copy(NEPTUNE_CONFIG_HOST_IDENTIFIERS)
279+
elif cell != '':
280+
if args.mode == 'add':
281+
if input_type == 'string':
282+
self.neptune_cfg_allowlist.append(cell_new.strip())
283+
else:
284+
self.neptune_cfg_allowlist = list(set(self.neptune_cfg_allowlist) | set(cell_new))
285+
elif args.mode == 'remove':
286+
if input_type == 'string':
287+
self.neptune_cfg_allowlist.remove(cell_new.strip())
288+
else:
289+
self.neptune_cfg_allowlist = list(set(self.neptune_cfg_allowlist) - set(cell_new))
290+
elif args.mode == 'overwrite':
291+
if input_type == 'string':
292+
self.neptune_cfg_allowlist = [cell_new.strip()]
293+
else:
294+
self.neptune_cfg_allowlist = cell_new
295+
else:
296+
allowlist_modified = False
297+
298+
if allowlist_modified:
299+
print(f'Set Neptune config allow list to: {self.neptune_cfg_allowlist}')
300+
else:
301+
print(f'Current Neptune config allow list: {self.neptune_cfg_allowlist}')
254302

255303
@line_magic
256304
def stream_viewer(self,line):

src/graph_notebook/neptune/client.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from gremlin_python.driver.protocol import GremlinServerError
1919
from neo4j import GraphDatabase
2020
import nest_asyncio
21+
# from graph_notebook.magics.graph_magic import NEPTUNE_CONFIG_HOST_IDENTIFIERS
2122

2223
# This patch is no longer needed when graph_notebook is using the a Gremlin Python
2324
# client >= 3.5.0 as the HashableDict is now part of that client driver.
@@ -91,11 +92,21 @@
9192
STREAM_RDF = 'RDF'
9293
STREAM_ENDPOINTS = {STREAM_PG: 'gremlin', STREAM_RDF: 'sparql'}
9394

95+
NEPTUNE_CONFIG_HOST_IDENTIFIERS = ["amazonaws.com"]
96+
97+
98+
def is_allowed_neptune_host(hostname: str, host_allowlist: list):
99+
for host_snippet in host_allowlist:
100+
if host_snippet in hostname:
101+
return True
102+
return False
103+
94104

95105
class Client(object):
96106
def __init__(self, host: str, port: int = DEFAULT_PORT, ssl: bool = True, region: str = DEFAULT_REGION,
97107
sparql_path: str = '/sparql', gremlin_traversal_source: str = 'g', auth=None, session: Session = None,
98-
proxy_host: str = '', proxy_port: int = DEFAULT_PORT):
108+
proxy_host: str = '', proxy_port: int = DEFAULT_PORT,
109+
neptune_hosts: list = None):
99110
self.target_host = host
100111
self.target_port = port
101112
self.ssl = ssl
@@ -106,6 +117,7 @@ def __init__(self, host: str, port: int = DEFAULT_PORT, ssl: bool = True, region
106117
self._session = session
107118
self.proxy_host = proxy_host
108119
self.proxy_port = proxy_port
120+
self.neptune_hosts = NEPTUNE_CONFIG_HOST_IDENTIFIERS if neptune_hosts is None else neptune_hosts
109121

110122
self._http_protocol = 'https' if self.ssl else 'http'
111123
self._ws_protocol = 'wss' if self.ssl else 'ws'
@@ -125,7 +137,7 @@ def port(self):
125137
return self.target_port
126138

127139
def is_neptune_domain(self):
128-
return "neptune.amazonaws.com" in self.target_host
140+
return is_allowed_neptune_host(hostname=self.target_host, host_allowlist=self.neptune_hosts)
129141

130142
def get_uri_with_port(self, use_websocket=False, use_proxy=False):
131143
protocol = self._http_protocol
@@ -743,5 +755,9 @@ def with_proxy_port(self, proxy_port: int):
743755
self.args['proxy_port'] = proxy_port
744756
return ClientBuilder(self.args)
745757

758+
def with_custom_neptune_hosts(self, neptune_hosts: list):
759+
self.args['neptune_hosts'] = neptune_hosts
760+
return ClientBuilder(self.args)
761+
746762
def build(self) -> Client:
747763
return Client(**self.args)

test/unit/configuration/test_configuration.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ def setUpClass(cls) -> None:
1616
cls.generic_host = 'blah'
1717
cls.neptune_host_reg = 'instance.cluster.us-west-2.neptune.amazonaws.com'
1818
cls.neptune_host_cn = 'instance.cluster.neptune.cn-north-1.amazonaws.com.cn'
19+
cls.neptune_host_custom = 'localhost'
1920
cls.port = 8182
21+
cls.custom_hosts_list = ['localhost']
2022
cls.test_file_path = f'{os.path.abspath(os.path.curdir)}/test_configuration_file.json'
2123

2224
def tearDown(self) -> None:
@@ -39,6 +41,13 @@ def test_configuration_default_auth_defaults_neptune_cn(self):
3941
self.assertEqual(True, config.ssl)
4042
self.assertEqual('', config.load_from_s3_arn)
4143

44+
def test_configuration_default_auth_defaults_neptune_custom(self):
45+
config = Configuration(self.neptune_host_custom, self.port, neptune_hosts=self.custom_hosts_list)
46+
self.assertEqual(self.neptune_host_custom, config.host)
47+
self.assertEqual(self.port, config.port)
48+
self.assertEqual(DEFAULT_AUTH_MODE, config.auth_mode)
49+
self.assertEqual(True, config.ssl)
50+
4251
def test_configuration_default_auth_defaults_generic(self):
4352
config = Configuration(self.generic_host, self.port)
4453
self.assertEqual(self.generic_host, config.host)
@@ -49,7 +58,8 @@ def test_configuration_override_defaults_neptune_reg(self):
4958
auth_mode = AuthModeEnum.IAM
5059
ssl = False
5160
loader_arn = 'foo'
52-
config = Configuration(self.neptune_host_reg, self.port, auth_mode=auth_mode, load_from_s3_arn=loader_arn, ssl=ssl)
61+
config = Configuration(self.neptune_host_reg, self.port, auth_mode=auth_mode, load_from_s3_arn=loader_arn,
62+
ssl=ssl)
5363
self.assertEqual(auth_mode, config.auth_mode)
5464
self.assertEqual(ssl, config.ssl)
5565
self.assertEqual(loader_arn, config.load_from_s3_arn)
@@ -58,7 +68,18 @@ def test_configuration_override_defaults_neptune_cn(self):
5868
auth_mode = AuthModeEnum.IAM
5969
ssl = False
6070
loader_arn = 'foo'
61-
config = Configuration(self.neptune_host_cn, self.port, auth_mode=auth_mode, load_from_s3_arn=loader_arn, ssl=ssl)
71+
config = Configuration(self.neptune_host_cn, self.port, auth_mode=auth_mode, load_from_s3_arn=loader_arn,
72+
ssl=ssl)
73+
self.assertEqual(auth_mode, config.auth_mode)
74+
self.assertEqual(ssl, config.ssl)
75+
self.assertEqual(loader_arn, config.load_from_s3_arn)
76+
77+
def test_configuration_override_defaults_neptune_custom(self):
78+
auth_mode = AuthModeEnum.IAM
79+
ssl = False
80+
loader_arn = 'foo'
81+
config = Configuration(self.neptune_host_custom, self.port, auth_mode=auth_mode, load_from_s3_arn=loader_arn,
82+
ssl=ssl, neptune_hosts=self.custom_hosts_list)
6283
self.assertEqual(auth_mode, config.auth_mode)
6384
self.assertEqual(ssl, config.ssl)
6485
self.assertEqual(loader_arn, config.load_from_s3_arn)
@@ -84,6 +105,15 @@ def test_generate_configuration_with_defaults_neptune_cn(self):
84105
config_from_file = get_config(self.test_file_path)
85106
self.assertEqual(config.to_dict(), config_from_file.to_dict())
86107

108+
def test_generate_configuration_with_defaults_neptune_custom(self):
109+
config = Configuration(self.neptune_host_custom, self.port, neptune_hosts=self.custom_hosts_list)
110+
c = generate_config(config.host, config.port, auth_mode=config.auth_mode, ssl=config.ssl,
111+
load_from_s3_arn=config.load_from_s3_arn, aws_region=config.aws_region,
112+
neptune_hosts=self.custom_hosts_list)
113+
c.write_to_file(self.test_file_path)
114+
config_from_file = get_config(self.test_file_path, neptune_hosts=self.custom_hosts_list)
115+
self.assertEqual(config.to_dict(), config_from_file.to_dict())
116+
87117
def test_generate_configuration_with_defaults_generic(self):
88118
config = Configuration(self.generic_host, self.port)
89119
c = generate_config(config.host, config.port, ssl=config.ssl)
@@ -119,6 +149,21 @@ def test_generate_configuration_override_defaults_neptune_cn(self):
119149
config_from_file = get_config(self.test_file_path)
120150
self.assertEqual(config.to_dict(), config_from_file.to_dict())
121151

152+
def test_generate_configuration_override_defaults_neptune_custom(self):
153+
auth_mode = AuthModeEnum.IAM
154+
ssl = False
155+
loader_arn = 'foo'
156+
aws_region = 'cn-north-1'
157+
config = Configuration(self.neptune_host_custom, self.port, auth_mode=auth_mode, load_from_s3_arn=loader_arn,
158+
ssl=ssl, aws_region=aws_region, neptune_hosts=self.custom_hosts_list)
159+
160+
c = generate_config(config.host, config.port, auth_mode=config.auth_mode, ssl=config.ssl,
161+
load_from_s3_arn=config.load_from_s3_arn, aws_region=config.aws_region,
162+
neptune_hosts=self.custom_hosts_list)
163+
c.write_to_file(self.test_file_path)
164+
config_from_file = get_config(self.test_file_path, neptune_hosts=self.custom_hosts_list)
165+
self.assertEqual(config.to_dict(), config_from_file.to_dict())
166+
122167
def test_generate_configuration_override_defaults_generic(self):
123168
ssl = False
124169
config = Configuration(self.generic_host, self.port, ssl=ssl)

0 commit comments

Comments
 (0)