Skip to content

Update magics for new Neptune Analytics API #560

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion ChangeLog.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ Starting with v1.31.6, this file will contain a record of major features and upd
## Upcoming
- New Neptune Analytics notebook - Vector Similarity Algorithms ([Link to PR](https://github.com/aws/graph-notebook/pull/555))
- Path: 02-Neptune-Analytics > 02-Graph-Algorithms > 06-Vector-Similarity-Algorithms
- Deprecated Python 3.7 support ([Link to PR](https://github.com/aws/graph-notebook/pull/551))
- Updated various Neptune magics for new Analytics API ([Link to PR](https://github.com/aws/graph-notebook/pull/560))
- Added `%graph_notebook_service` line magic ([Link to PR](https://github.com/aws/graph-notebook/pull/560))
- Added unit abbreviation support to `--max-content-length` ([Link to PR](https://github.com/aws/graph-notebook/pull/553))
- Deprecated Python 3.7 support ([Link to PR](https://github.com/aws/graph-notebook/pull/551))

## Release 4.0.2 (Dec 14, 2023)
- Fixed `neptune_ml_utils` imports in `03-Neptune-ML` samples ([Link to PR](https://github.com/aws/graph-notebook/pull/546))
Expand Down
141 changes: 105 additions & 36 deletions src/graph_notebook/magics/graph_magic.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,14 @@
from graph_notebook.decorators.decorators import display_exceptions, magic_variables, neptune_db_only
from graph_notebook.magics.ml import neptune_ml_magic_handler, generate_neptune_ml_parser
from graph_notebook.magics.streams import StreamViewer
from graph_notebook.neptune.client import ClientBuilder, Client,PARALLELISM_OPTIONS, PARALLELISM_HIGH, \
from graph_notebook.neptune.client import ClientBuilder, Client, PARALLELISM_OPTIONS, PARALLELISM_HIGH, \
LOAD_JOB_MODES, MODE_AUTO, FINAL_LOAD_STATUSES, SPARQL_ACTION, FORMAT_CSV, FORMAT_OPENCYPHER, FORMAT_NTRIPLE, \
DB_LOAD_TYPES, ANALYTICS_LOAD_TYPES, VALID_BULK_FORMATS, VALID_INCREMENTAL_FORMATS, \
FORMAT_NQUADS, FORMAT_RDFXML, FORMAT_TURTLE, STREAM_RDF, STREAM_PG, STREAM_ENDPOINTS, \
NEPTUNE_CONFIG_HOST_IDENTIFIERS, is_allowed_neptune_host, \
STATISTICS_LANGUAGE_INPUTS, STATISTICS_MODES, SUMMARY_MODES, \
SPARQL_EXPLAIN_MODES, OPENCYPHER_EXPLAIN_MODES, OPENCYPHER_PLAN_CACHE_MODES, OPENCYPHER_DEFAULT_TIMEOUT
STATISTICS_LANGUAGE_INPUTS, STATISTICS_LANGUAGE_INPUTS_SPARQL, STATISTICS_MODES, SUMMARY_MODES, \
SPARQL_EXPLAIN_MODES, OPENCYPHER_EXPLAIN_MODES, OPENCYPHER_PLAN_CACHE_MODES, OPENCYPHER_DEFAULT_TIMEOUT, \
OPENCYPHER_STATUS_STATE_MODES, normalize_service_name
from graph_notebook.network import SPARQLNetwork
from graph_notebook.network.gremlin.GremlinNetwork import parse_pattern_list_str, GremlinNetwork
from graph_notebook.visualization.rows_and_columns import sparql_get_rows_and_columns, opencypher_get_rows_and_columns
Expand Down Expand Up @@ -255,22 +256,31 @@ def get_load_ids(neptune_client):
return ids, res


def process_statistics_400(is_summary: bool, response):
def process_statistics_400(response, is_summary: bool = False, is_analytics: bool = False):
bad_request_res = json.loads(response.text)
res_code = bad_request_res['code']
if res_code == 'StatisticsNotAvailableException':
print("No statistics found. Please ensure that auto-generation of DFE statistics is enabled by running "
"'%statistics' and checking if 'autoCompute' if set to True. Alternately, you can manually "
"trigger statistics generation by running: '%statistics --mode refresh'.")
print("No statistics found. ", end="")
if not is_analytics:
print("Please ensure that auto-generation of DFE statistics is enabled by running '%statistics' and "
"checking if 'autoCompute' if set to True. Alternately, you can manually trigger statistics "
"generation by running: '%statistics --mode refresh'.")
return
elif res_code == "BadRequestException":
print("Unable to query the statistics endpoint. Please check that your Neptune instance is of size r5.large or "
"greater in order to have DFE statistics enabled.")
if is_summary and "Statistics is disabled" not in bad_request_res["detailedMessage"]:
print("\nPlease also note that the Graph Summary API is only available in Neptune engine version 1.2.1.0 "
"and later.")
else:
print("Query encountered 400 error, please see below.")
if is_analytics:
if bad_request_res["message"] == 'Bad route: /summary':
logger.debug("Encountered bad route exception for Analytics, retrying with legacy statistics endpoint.")
return 1
else:
print("Unable to query the statistics endpoint. Please check that your Neptune instance is of size "
"r5.large or greater in order to have DFE statistics enabled.")
if is_summary and "Statistics is disabled" not in bad_request_res["detailedMessage"]:
print("\nPlease also note that the Graph Summary API is only available in Neptune engine version "
"1.2.1.0 and later.")
return
print("Query encountered 400 error, please see below.")
print(f"\nFull response: {bad_request_res}")
return


def mcl_to_bytes(mcl):
Expand Down Expand Up @@ -445,6 +455,7 @@ def stream_viewer(self,line):
@line_magic
@needs_local_scope
@display_exceptions
@neptune_db_only
def statistics(self, line, local_ns: dict = None):
parser = argparse.ArgumentParser()
parser.add_argument('language', nargs='?', type=str.lower, default="propertygraph",
Expand Down Expand Up @@ -476,9 +487,9 @@ def statistics(self, line, local_ns: dict = None):
statistics_res = self.client.statistics(args.language, args.summary, mode)
if statistics_res.status_code == 400:
if args.summary:
process_statistics_400(True, statistics_res)
process_statistics_400(statistics_res)
else:
process_statistics_400(False, statistics_res)
process_statistics_400(statistics_res)
return
statistics_res.raise_for_status()
statistics_res_json = statistics_res.json()
Expand Down Expand Up @@ -508,10 +519,21 @@ def summary(self, line, local_ns: dict = None):
else:
mode = "basic"

summary_res = self.client.statistics(args.language, True, mode)
language_ep = args.language
if self.client.is_analytics_domain():
is_analytics = True
if language_ep in STATISTICS_LANGUAGE_INPUTS_SPARQL:
print("SPARQL is not supported for Neptune Analytics, defaulting to PropertyGraph.")
language_ep = 'propertygraph'
else:
is_analytics = False
summary_res = self.client.statistics(language_ep, True, mode, is_analytics)
if summary_res.status_code == 400:
process_statistics_400(True, summary_res)
return
retry_legacy = process_statistics_400(summary_res, is_summary=True, is_analytics=is_analytics)
if retry_legacy == 1:
summary_res = self.client.statistics(language_ep, True, mode, False)
else:
return
summary_res.raise_for_status()
summary_res_json = summary_res.json()
if not args.silent:
Expand All @@ -530,6 +552,16 @@ def graph_notebook_host(self, line):
self._generate_client_from_config(self.graph_notebook_config)
print(f'set host to {self.graph_notebook_config.host}')

@line_magic
def graph_notebook_service(self, line):
if line == '':
print(f'current service name: {self.graph_notebook_config.neptune_service}')
return

self.graph_notebook_config.neptune_service = normalize_service_name(line)
self._generate_client_from_config(self.graph_notebook_config)
print(f'set service name to {self.graph_notebook_config.neptune_service}')

@magic_variables
@cell_magic
@needs_local_scope
Expand Down Expand Up @@ -1177,6 +1209,7 @@ def opencypher_status(self, line='', local_ns: dict = None):
@line_magic
@needs_local_scope
@display_exceptions
@neptune_db_only
def status(self, line='', local_ns: dict = None):
logger.info(f'calling for status on endpoint {self.graph_notebook_config.host}')
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -1547,6 +1580,7 @@ def load(self, line='', local_ns: dict = None):
value=str(args.concurrency),
placeholder=1,
min=1,
max=2**16,
disabled=False,
layout=widgets.Layout(display=concurrency_hbox_visibility,
width=widget_width)
Expand All @@ -1556,6 +1590,7 @@ def load(self, line='', local_ns: dict = None):
value=args.periodic_commit,
placeholder=0,
min=0,
max=1000000,
disabled=False,
layout=widgets.Layout(display=periodic_commit_hbox_visibility,
width=widget_width)
Expand Down Expand Up @@ -1770,13 +1805,12 @@ def on_button_clicked(b):
source_format_validation_label = widgets.HTML('<p style="color:red;">Format cannot be blank.</p>')
source_format_hbox.children += (source_format_validation_label,)

if not arn.value.startswith('arn:aws') and source.value.startswith(
"s3://"): # only do this validation if we are using an s3 bucket.
validated = False
arn_validation_label = widgets.HTML('<p style="color:red;">Load ARN must start with "arn:aws"</p>')
arn_hbox.children += (arn_validation_label,)

if load_type == 'bulk':
if not arn.value.startswith('arn:aws') and source.value.startswith(
"s3://"): # only do this validation if we are using an s3 bucket.
validated = False
arn_validation_label = widgets.HTML('<p style="color:red;">Load ARN must start with "arn:aws"</p>')
arn_hbox.children += (arn_validation_label,)
dependencies_list = list(filter(None, dependencies.value.split('\n')))
if not len(dependencies_list) < 64:
validated = False
Expand Down Expand Up @@ -3105,9 +3139,15 @@ def handle_opencypher_status(self, line, local_ns):
parser.add_argument('-c', '--cancelQuery', action='store_true', default=False,
help='Tells the status command to cancel a query. This parameter does not take a value.')
parser.add_argument('-w', '--includeWaiting', action='store_true', default=False,
help='When set to true and other parameters are not present, causes status information '
'for waiting queries to be returned as well as for running queries. '
'This parameter does not take a value.')
help='Neptune DB only. When set to true and other parameters are not present, causes '
'status information for waiting queries to be returned as well as for running '
'queries. This parameter does not take a value.')
parser.add_argument('--state', type=str.upper, default='ALL',
help=f'Neptune Analytics only. Specifies what subset of query states to retrieve the '
f'status of. Default is ALL. Accepted values: ${OPENCYPHER_STATUS_STATE_MODES}')
parser.add_argument('-m', '--maxResults', type=int, default=200,
help=f'Neptune Analytics only. Sets an upper limit on the set of returned queries whose '
f'status matches --state. Default is 200.')
parser.add_argument('-s', '--silent-cancel', action='store_true', default=False,
help='If silent_cancel=true then the running query is cancelled and the HTTP response '
'code is 200. If silent_cancel is not present or silent_cancel=false, '
Expand All @@ -3116,21 +3156,50 @@ def handle_opencypher_status(self, line, local_ns):
parser.add_argument('--store-to', type=str, default='', help='store query result to this variable')
args = parser.parse_args(line.split())

using_analytics = self.client.is_analytics_domain()
if not args.cancelQuery:
if args.includeWaiting and not args.queryId:
res = self.client.opencypher_status(include_waiting=args.includeWaiting)
query_id = ''
include_waiting = None
state = ''
max_results = None
if args.includeWaiting and not args.queryId and not self.client.is_analytics_domain():
include_waiting = args.includeWaiting
elif args.state and not args.queryId and self.client.is_analytics_domain():
state = args.state
max_results = args.maxResults
else:
res = self.client.opencypher_status(query_id=args.queryId)
query_id = args.queryId
res = self.client.opencypher_status(query_id=query_id,
include_waiting=include_waiting,
state=state,
max_results=max_results,
use_analytics_endpoint=using_analytics)
if using_analytics and res.status_code == 400 and 'Bad route: /queries' in res.json()["message"]:
res = self.client.opencypher_status(query_id=query_id,
include_waiting=include_waiting,
state=state,
max_results=max_results,
use_analytics_endpoint=False)
res.raise_for_status()
else:
if args.queryId == '':
if not args.silent:
print(OPENCYPHER_CANCEL_HINT_MSG)
return
else:
res = self.client.opencypher_cancel(args.queryId, args.silent_cancel)
res = self.client.opencypher_cancel(args.queryId,
silent=args.silent_cancel,
use_analytics_endpoint=using_analytics)
if using_analytics and res.status_code == 400 and 'Bad route: /queries' in res.json()["message"]:
res = self.client.opencypher_cancel(args.queryId,
silent=args.silent_cancel,
use_analytics_endpoint=False)
res.raise_for_status()
js = res.json()
store_to_ns(args.store_to, js, local_ns)
if not args.silent:
print(json.dumps(js, indent=2))
if using_analytics and args.cancelQuery:
if not args.silent:
print(f'Submitted cancellation request for query ID: {args.queryId}')
else:
js = res.json()
store_to_ns(args.store_to, js, local_ns)
if not args.silent:
print(json.dumps(js, indent=2))
42 changes: 34 additions & 8 deletions src/graph_notebook/neptune/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,15 @@

STATISTICS_MODES = ["", "status", "disableAutoCompute", "enableAutoCompute", "refresh", "delete"]
SUMMARY_MODES = ["", "basic", "detailed"]
STATISTICS_LANGUAGE_INPUTS = ["propertygraph", "pg", "gremlin", "oc", "opencypher", "sparql", "rdf"]
STATISTICS_LANGUAGE_INPUTS_PG = ["propertygraph", "pg", "gremlin", "oc", "opencypher"]
STATISTICS_LANGUAGE_INPUTS_SPARQL = ["sparql", "rdf"]
STATISTICS_LANGUAGE_INPUTS = STATISTICS_LANGUAGE_INPUTS_PG + STATISTICS_LANGUAGE_INPUTS_SPARQL

SPARQL_EXPLAIN_MODES = ['dynamic', 'static', 'details']
OPENCYPHER_EXPLAIN_MODES = ['dynamic', 'static', 'details']
OPENCYPHER_PLAN_CACHE_MODES = ['auto', 'enabled', 'disabled']
OPENCYPHER_DEFAULT_TIMEOUT = 120000
OPENCYPHER_STATUS_STATE_MODES = ['ALL', 'RUNNING', 'WAITING', 'CANCELLING']


def is_allowed_neptune_host(hostname: str, host_allowlist: list):
Expand Down Expand Up @@ -405,7 +408,7 @@ def opencypher_http(self, query: str, headers: dict = None, explain: str = None,
if plan_cache:
data['planCache'] = plan_cache
if query_timeout:
headers['query_timeout_millis'] = str(query_timeout)
data['queryTimeoutMilliseconds'] = str(query_timeout)
else:
url += 'db/neo4j/tx/commit'
headers['content-type'] = 'application/json'
Expand Down Expand Up @@ -441,16 +444,20 @@ def opencyper_bolt(self, query: str, **kwargs):
driver.close()
return data

def opencypher_status(self, query_id: str = '', include_waiting: bool = False):
def opencypher_status(self, query_id: str = '', include_waiting: bool = False, state: str = '',
max_results: int = None, use_analytics_endpoint: bool = False):
if use_analytics_endpoint:
return self._analytics_query_status(query_id=query_id, state=state, max_results=max_results)
kwargs = {}
if include_waiting:
kwargs['includeWaiting'] = True
return self._query_status('openCypher', query_id=query_id, **kwargs)

def opencypher_cancel(self, query_id, silent: bool = False):
def opencypher_cancel(self, query_id, silent: bool = False, use_analytics_endpoint: bool = False):
if type(query_id) is not str or query_id == '':
raise ValueError('query_id must be a non-empty string')

if use_analytics_endpoint:
return self._analytics_query_status(query_id=query_id, cancel_query=True)
return self._query_status('openCypher', query_id=query_id, cancelQuery=True, silent=silent)

def get_opencypher_driver(self):
Expand Down Expand Up @@ -808,7 +815,25 @@ def _query_status(self, language: str, *, query_id: str = '', **kwargs) -> reque
res = self._http_session.send(req, verify=self.ssl_verify)
return res

def statistics(self, language: str, summary: bool = False, mode: str = '') -> requests.Response:
def _analytics_query_status(self, query_id: str = '', state: str = '', max_results: int = None,
cancel_query: bool = False) -> requests.Response:
url = f'{self._http_protocol}://{self.host}:{self.port}/queries'
if query_id != '':
url += f'/{query_id}'
elif state != '':
url += f'?state={state}&maxResults={max_results}'

method = 'DELETE' if cancel_query else 'GET'

headers = {
'Content-Type': 'application/x-www-form-urlencoded'
}
req = self._prepare_request(method, url, headers=headers)
res = self._http_session.send(req, verify=self.ssl_verify)
return res

def statistics(self, language: str, summary: bool = False, mode: str = '',
use_analytics_endpoint: bool = False) -> requests.Response:
headers = {
'Accept': 'application/json'
}
Expand All @@ -817,11 +842,12 @@ def statistics(self, language: str, summary: bool = False, mode: str = '') -> re
elif language == "sparql":
language = "rdf"

url = f'{self._http_protocol}://{self.host}:{self.port}/{language}/statistics'
base_url = f'{self._http_protocol}://{self.host}:{self.port}'
url = base_url + f'/{language}/statistics'
data = {'mode': mode}

if summary:
summary_url = url + '/summary'
summary_url = (base_url if use_analytics_endpoint else url) + '/summary'
if mode:
summary_mode_param = '?mode=' + mode
summary_url += summary_mode_param
Expand Down