From f4aeba8ee26be7dd3f58b77b7908341fe5e6d1cb Mon Sep 17 00:00:00 2001 From: Kline Date: Thu, 18 Mar 2021 10:52:53 -0700 Subject: [PATCH 01/16] - Refactor all modules calling into various api endpoints to coalesce into one `client` objet. - Add a builder object to facilitate creating the client with various options - Remove specification of `iam_credentials_provider_type` and instead make use of the default boto3 session for obtaining aws credentials (as we do for Sagemaker integration) - Organize all tests using pytest to more easily filter on what tests should be run or not. The neptune client can be build either directly with its constructor: ```python from graph_notebook.neptune.client import Client c = Client(host=foo) c.status() ``` It can also be created using our builder class: ```python from botocore.session import get_session from graph_notebook.neptune.client import ClientBuilder builder = ClientBuilder() \ .with_host(config.host) \ .with_port(config.port) \ .with_region(config.aws_region) \ .with_tls(config.ssl) \ .with_iam(get_session()) c = builder.build() c.status() ``` The `Client` object has some components which are Neptune-specific, and some which are not: - `sparql` - takes any SPARQL query and interprets whether it should be issued as type `query` or type `update` - `sparql_query` - sends a query request to the configured SPARQL endpoint with the payload `{'query': 'YOUR QUERY'}` - `sparql_update` - sends an update request to the configured SPARQL endpoint with the payload `{'update': 'YOUR QUERY'}` - `do_sparql_request` - submits the given payload to the configured SPARQL endpoint - `get_gremlin_connection` - returns a websocket connection to the configured gremlin endpoint. - `gremlin_query` - obtains a new gremlin connection and submits the given query. The opened connection will be closed after obtaining query results - `gremlin_http_query` - executes the given gremlin query via http(s) instead of websocket. - `gremlin_status` - returns the status of running gremlin queries on the configured Neptune endpoint. Takes an optional `query_id` input to obtain the status of a specific query - `sparql_explain` - obtains an explain query plan for the given SPARQL query (can be of type update or query) - `sparql_status` - returns the status of running SPARQL queries on the configured Neptune endpoint. Takes an optional `query_id` input to obtain the status of a specific query - `sparql_cancel` - cancels the running SPARQL query with the provided query_id - `gremlin_cancel` - cancels the running Gremlin query with the provided `query_id` - `gremlin_explain` - obtains an explain query plan for a given Gremlin query - `gremlin_profile` - obtains a profile query plan for a given Gremlin query - `status` - retrieves the status of the configured Neptune endpoint - `load` - submits a new bulk load job with the provided parameters. - `load_status` - obtains the status of the bulk loader. Takes an optional `query_id` to obtain the status of a specific loader job - `cancel_load` - cancels the provided bulk loader job id - `initiate_reset` - obtains a token needed to execute a fast reset of your configured Neptune endpoint - `perform_reset` - takes a token obtained from `initiate_reset` and performs the reset - `dataprocessing_start` - starts a NeptuneML dataprocessing job with the provided parameters - `dataprocessing_job_status` - obtains the status of a given dataprocessing job id - `dataprocessing_status` - obtains the status of the configured Neptune dataprocessing endpoint - `dataprocessing_stop` - stops the given dataprocessing job id - `modeltraining_start` - starts a NeptuneML modeltraining job with the provided parameters - `modeltraining_job_status` - obtains the status of a given modeltraining job id - `modeltraining_status` - obtains the status of the configured Neptune modeltraining endpoint - `modeltraining_stop` - stops the given modeltraining job id - `endpoints_create` - creates a NeptuneML endpoint with the provided parameters - `endpoints_status` - obtain the status of a given endpoint job - `endpoints_delete` - delete a given endpoint id - `endpoints` - obtain the status of all endpoints to the configured Neptune database - `export` - helper function to call the Neptune exporter for NeptuneML. Note that this is not a Neptune endpoint. - `export_status` - obtain the status of the configured exporter endpoint. --- .github/workflows/integration.yml | 29 +- .github/workflows/unit.yml | 2 +- pytest.ini | 11 + requirements.txt | 3 +- setup.py | 9 +- .../iam_credentials_provider/__init__.py | 4 - .../credentials_factory.py | 24 - .../credentials_provider.py | 20 - .../ec2_metadata_credentials_provider.py | 26 - .../env_credentials_provider.py | 35 -- .../authentication/iam_headers.py | 217 -------- .../configuration/generate_config.py | 36 +- .../configuration/get_config.py | 3 - .../gremlin/client_provider/default_client.py | 21 - .../gremlin/client_provider/factory.py | 21 - .../gremlin/client_provider/iam_client.py | 52 -- src/graph_notebook/gremlin/query.py | 57 -- src/graph_notebook/gremlin/status.py | 44 -- src/graph_notebook/loader/load.py | 90 ---- src/graph_notebook/magics/graph_magic.py | 296 ++++++----- src/graph_notebook/magics/ml.py | 138 ++--- src/graph_notebook/ml/sagemaker.py | 86 --- .../nbextensions/neptune_menu/static/main.js | 6 +- .../{authentication => neptune}/__init__.py | 0 src/graph_notebook/neptune/client.py | 498 ++++++++++++++++++ .../{ml => neptune/gremlin}/__init__.py | 0 .../graphsonV3d0_MapType_objectify_patch.py | 0 .../00-Sample-Applications-Overview.ipynb | 2 +- .../04-Machine-Learning/neptune_ml_utils.py | 98 ++-- .../call_and_get_response.py | 32 -- .../default_request_generator.py | 18 - .../request_param_generator/factory.py | 23 - .../iam_request_generator.py | 21 - .../sparql_request_generator.py | 22 - src/graph_notebook/sparql/query.py | 82 --- src/graph_notebook/sparql/status.py | 49 -- src/graph_notebook/status/get_status.py | 17 - src/graph_notebook/system/database_reset.py | 21 - .../sparql_rows_and_columns.py} | 4 +- test/integration/DataDrivenGremlinTest.py | 8 +- test/integration/DataDrivenSparqlTest.py | 13 +- test/integration/IntegrationTest.py | 31 +- .../NeptuneIntegrationWorkflowSteps.py | 27 +- test/integration/__init__.py | 2 +- test/integration/gremlin/__init__.py | 4 - .../gremlin/client_provider/__init__.py | 4 - .../client_provider_factory.py | 22 - .../gremlin/gremlin_query_with_iam.py | 37 -- .../gremlin/gremlin_query_without_iam.py | 35 -- .../gremlin/gremlin_status_with_iam.py | 134 ----- .../gremlin/gremlin_status_without_iam.py | 129 ----- .../{network/gremlin => iam}/__init__.py | 1 + .../integration/iam}/gremlin/__init__.py | 0 .../gremlin/test_gremlin_status_with_iam.py | 122 +++++ .../iam/gremlin/test_gremlin_with_iam.py | 55 ++ .../{system => iam/load}/__init__.py | 0 .../iam/load/test_load_with_iam.py | 58 ++ test/integration/iam/ml/__init__.py | 27 + .../iam/ml/test_neptune_client_with_iam.py | 25 + .../iam/ml/test_neptune_ml_with_iam.py | 179 +++++++ test/integration/iam/sparql/__init__.py | 0 .../iam/sparql/test_sparql_query_with_iam.py | 54 ++ .../iam/sparql/test_sparql_status_with_iam.py | 128 +++++ test/integration/iam/status/__init__.py | 0 .../iam/status/test_status_with_iam.py | 29 + test/integration/iam/system/__init__.py | 0 .../iam/system/test_system_with_iam.py | 60 +++ test/integration/network/__init__.py | 4 - test/integration/notebook/__init__.py | 4 - test/integration/sparql/__init__.py | 4 - .../sparql/sparql_query_with_iam.py | 30 -- .../sparql/sparql_status_with_iam.py | 134 ----- .../sparql/sparql_status_without_iam.py | 137 ----- test/integration/status/__init__.py | 4 - test/integration/status/status_with_iam.py | 24 - test/integration/system/system_with_iam.py | 25 - test/integration/system/system_without_iam.py | 15 - test/integration/without_iam/__init__.py | 0 .../without_iam/gremlin/__init__.py | 0 .../gremlin/test_gremlin_patching.py} | 37 +- .../without_iam/gremlin/test_gremlin_query.py | 37 ++ .../test_gremlin_status_without_iam.py | 113 ++++ .../without_iam/network}/__init__.py | 0 .../without_iam/network/gremlin}/__init__.py | 0 .../test_gremlin_network_from_queries.py} | 32 +- .../test_gremlin_network_with_pattern.py} | 9 +- .../notebook/GraphNotebookIntegrationTest.py | 0 .../without_iam/notebook}/__init__.py | 0 .../notebook/test_gremlin_graph_notebook.py | 5 +- .../notebook/test_sparql_graph_notebook.py | 17 +- .../notebook/test_status_graph_notebook.py | 9 +- .../without_iam}/sparql/__init__.py | 0 .../sparql/test_sparql_query_without_iam.py} | 17 +- .../sparql/test_sparql_status_without_iam.py | 118 +++++ .../without_iam}/status/__init__.py | 0 .../status/test_status_without_iam.py} | 6 +- .../without_iam/system/__init__.py | 0 .../system/test_system_without_iam.py | 18 + test/unit/configuration/test_configuration.py | 14 +- .../test_configuration_from_main.py | 7 +- test/unit/gremlin/__init__.py | 4 - test/unit/request_param_generator/__init__.py | 4 - .../test_default_request_generator.py | 32 -- .../test_factory_generator.py | 33 -- .../test_sparql_request_generator.py | 58 -- test/unit/sparql/test_sparql.py | 2 +- 106 files changed, 1971 insertions(+), 2283 deletions(-) create mode 100644 pytest.ini delete mode 100644 src/graph_notebook/authentication/iam_credentials_provider/__init__.py delete mode 100644 src/graph_notebook/authentication/iam_credentials_provider/credentials_factory.py delete mode 100644 src/graph_notebook/authentication/iam_credentials_provider/credentials_provider.py delete mode 100644 src/graph_notebook/authentication/iam_credentials_provider/ec2_metadata_credentials_provider.py delete mode 100644 src/graph_notebook/authentication/iam_credentials_provider/env_credentials_provider.py delete mode 100644 src/graph_notebook/authentication/iam_headers.py delete mode 100644 src/graph_notebook/gremlin/client_provider/default_client.py delete mode 100644 src/graph_notebook/gremlin/client_provider/factory.py delete mode 100644 src/graph_notebook/gremlin/client_provider/iam_client.py delete mode 100644 src/graph_notebook/gremlin/query.py delete mode 100644 src/graph_notebook/gremlin/status.py delete mode 100644 src/graph_notebook/loader/load.py delete mode 100644 src/graph_notebook/ml/sagemaker.py rename src/graph_notebook/{authentication => neptune}/__init__.py (100%) create mode 100644 src/graph_notebook/neptune/client.py rename src/graph_notebook/{ml => neptune/gremlin}/__init__.py (100%) rename src/graph_notebook/{gremlin/client_provider => neptune/gremlin}/graphsonV3d0_MapType_objectify_patch.py (100%) delete mode 100644 src/graph_notebook/request_param_generator/call_and_get_response.py delete mode 100644 src/graph_notebook/request_param_generator/default_request_generator.py delete mode 100644 src/graph_notebook/request_param_generator/factory.py delete mode 100644 src/graph_notebook/request_param_generator/iam_request_generator.py delete mode 100644 src/graph_notebook/request_param_generator/sparql_request_generator.py delete mode 100644 src/graph_notebook/sparql/query.py delete mode 100644 src/graph_notebook/sparql/status.py delete mode 100644 src/graph_notebook/status/get_status.py rename src/graph_notebook/{sparql/table.py => visualization/sparql_rows_and_columns.py} (89%) delete mode 100644 test/integration/gremlin/__init__.py delete mode 100644 test/integration/gremlin/client_provider/__init__.py delete mode 100644 test/integration/gremlin/client_provider/client_provider_factory.py delete mode 100644 test/integration/gremlin/gremlin_query_with_iam.py delete mode 100644 test/integration/gremlin/gremlin_query_without_iam.py delete mode 100644 test/integration/gremlin/gremlin_status_with_iam.py delete mode 100644 test/integration/gremlin/gremlin_status_without_iam.py rename test/integration/{network/gremlin => iam}/__init__.py (99%) rename {src/graph_notebook => test/integration/iam}/gremlin/__init__.py (100%) create mode 100644 test/integration/iam/gremlin/test_gremlin_status_with_iam.py create mode 100644 test/integration/iam/gremlin/test_gremlin_with_iam.py rename test/integration/{system => iam/load}/__init__.py (100%) create mode 100644 test/integration/iam/load/test_load_with_iam.py create mode 100644 test/integration/iam/ml/__init__.py create mode 100644 test/integration/iam/ml/test_neptune_client_with_iam.py create mode 100644 test/integration/iam/ml/test_neptune_ml_with_iam.py create mode 100644 test/integration/iam/sparql/__init__.py create mode 100644 test/integration/iam/sparql/test_sparql_query_with_iam.py create mode 100644 test/integration/iam/sparql/test_sparql_status_with_iam.py create mode 100644 test/integration/iam/status/__init__.py create mode 100644 test/integration/iam/status/test_status_with_iam.py create mode 100644 test/integration/iam/system/__init__.py create mode 100644 test/integration/iam/system/test_system_with_iam.py delete mode 100644 test/integration/network/__init__.py delete mode 100644 test/integration/notebook/__init__.py delete mode 100644 test/integration/sparql/__init__.py delete mode 100644 test/integration/sparql/sparql_query_with_iam.py delete mode 100644 test/integration/sparql/sparql_status_with_iam.py delete mode 100644 test/integration/sparql/sparql_status_without_iam.py delete mode 100644 test/integration/status/__init__.py delete mode 100644 test/integration/status/status_with_iam.py delete mode 100644 test/integration/system/system_with_iam.py delete mode 100644 test/integration/system/system_without_iam.py create mode 100644 test/integration/without_iam/__init__.py create mode 100644 test/integration/without_iam/gremlin/__init__.py rename test/integration/{gremlin/bug_fix_tests_without_iam.py => without_iam/gremlin/test_gremlin_patching.py} (71%) create mode 100644 test/integration/without_iam/gremlin/test_gremlin_query.py create mode 100644 test/integration/without_iam/gremlin/test_gremlin_status_without_iam.py rename {src/graph_notebook/gremlin/client_provider => test/integration/without_iam/network}/__init__.py (100%) rename {src/graph_notebook/loader => test/integration/without_iam/network/gremlin}/__init__.py (100%) rename test/integration/{network/gremlin/gremlin_network_from_queries.py => without_iam/network/gremlin/test_gremlin_network_from_queries.py} (81%) rename test/integration/{network/gremlin/gremlin_network_with_pattern.py => without_iam/network/gremlin/test_gremlin_network_with_pattern.py} (83%) rename test/integration/{ => without_iam}/notebook/GraphNotebookIntegrationTest.py (100%) rename {src/graph_notebook/request_param_generator => test/integration/without_iam/notebook}/__init__.py (100%) rename test/integration/{ => without_iam}/notebook/test_gremlin_graph_notebook.py (85%) rename test/integration/{ => without_iam}/notebook/test_sparql_graph_notebook.py (64%) rename test/integration/{ => without_iam}/notebook/test_status_graph_notebook.py (51%) rename {src/graph_notebook => test/integration/without_iam}/sparql/__init__.py (100%) rename test/integration/{sparql/sparql_query_without_iam.py => without_iam/sparql/test_sparql_query_without_iam.py} (59%) create mode 100644 test/integration/without_iam/sparql/test_sparql_status_without_iam.py rename {src/graph_notebook => test/integration/without_iam}/status/__init__.py (100%) rename test/integration/{status/status_without_iam.py => without_iam/status/test_status_without_iam.py} (77%) create mode 100644 test/integration/without_iam/system/__init__.py create mode 100644 test/integration/without_iam/system/test_system_without_iam.py delete mode 100644 test/unit/gremlin/__init__.py delete mode 100644 test/unit/request_param_generator/__init__.py delete mode 100644 test/unit/request_param_generator/test_default_request_generator.py delete mode 100644 test/unit/request_param_generator/test_factory_generator.py delete mode 100644 test/unit/request_param_generator/test_sparql_request_generator.py diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index de3fb15e..4827bd4a 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -82,36 +82,15 @@ jobs: run: | sudo chmod 777 /etc/hosts - name: Run Basic Tests + env: + GRAPH_NOTEBOK_CONFIG: /tmp/graph_notebook_config_integration_test.json run: | - python test/integration/NeptuneIntegrationWorkflowSteps.py \ - run-tests \ - --pattern "*without_iam.py" \ - --cfn-stack-name ${{ needs.generate-stack-name.outputs.stack-name }} \ - --aws-region ${{ secrets.AWS_REGION }} - - name: Run Networkx Tests - run: | - python test/integration/NeptuneIntegrationWorkflowSteps.py \ - run-tests \ - --pattern "*network*.py" \ - --cfn-stack-name ${{ needs.generate-stack-name.outputs.stack-name }} \ - --aws-region ${{ secrets.AWS_REGION }} - - name: Run Notebook Tests - run: | - python test/integration/NeptuneIntegrationWorkflowSteps.py \ - run-tests \ - --pattern "*graph_notebook.py" \ - --cfn-stack-name ${{ needs.generate-stack-name.outputs.stack-name }} \ - --aws-region ${{ secrets.AWS_REGION }} + pytest test/integration/without_iam - name: Run IAM Tests env: GRAPH_NOTEBOK_CONFIG: /tmp/graph_notebook_config_integration_test.json run: | - python test/integration/NeptuneIntegrationWorkflowSteps.py \ - run-tests \ - --pattern "*with_iam.py" \ - --iam \ - --cfn-stack-name ${{ needs.generate-stack-name.outputs.stack-name }} \ - --aws-region ${{ secrets.AWS_REGION }} + pytest test/integration/iam - name: Cleanup run: | python test/integration/NeptuneIntegrationWorkflowSteps.py \ diff --git a/.github/workflows/unit.yml b/.github/workflows/unit.yml index 8f7b5e7b..5f044413 100644 --- a/.github/workflows/unit.yml +++ b/.github/workflows/unit.yml @@ -35,4 +35,4 @@ jobs: python -m graph_notebook.notebooks.install - name: Test with pytest run: | - pytest \ No newline at end of file + pytest test/unit \ No newline at end of file diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 00000000..4fee5c5f --- /dev/null +++ b/pytest.ini @@ -0,0 +1,11 @@ +[pytest] +markers = + neptune: tests which have to run against neptune + iam: tests which require iam authentication + gremlin: tests which run against a gremlin endpoint + sparql: tests which run against SPARQL1.1 endpoint + neptuneml: tests which run Neptune ML workloads + jupyter: tests which run against ipython/jupyter frameworks + reset: test which performs a fast reset against Neptune, running this will wipe your database! + + diff --git a/requirements.txt b/requirements.txt index 04e0b7e6..1852a511 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,8 +9,9 @@ jupyter-contrib-nbextensions widgetsnbextension gremlinpython requests==2.24.0 +ipython==7.16.1 # requirements for testing boto3==1.15.15 botocore==1.18.18 -ipython==7.16.1 \ No newline at end of file +pytest==6.2.2 \ No newline at end of file diff --git a/setup.py b/setup.py index f7de0895..ee43ee23 100644 --- a/setup.py +++ b/setup.py @@ -55,8 +55,8 @@ def get_version(): setup( name='graph-notebook', - author='amazon-neptune', - author_email='amazon-neptune-pypi@amazon.com', + author='amazon-ml', + author_email='amazon-ml-pypi@amazon.com', description='jupyter notebook extension to connect to graph databases', long_description=long_description, long_description_content_type='text/markdown', @@ -92,5 +92,8 @@ def get_version(): 'Programming Language :: Python :: 3.6', 'License :: OSI Approved :: Apache Software License' ], - keywords='jupyter neptune gremlin sparql', + keywords='jupyter ml gremlin sparql', + tests_require=[ + 'pytest' + ] ) diff --git a/src/graph_notebook/authentication/iam_credentials_provider/__init__.py b/src/graph_notebook/authentication/iam_credentials_provider/__init__.py deleted file mode 100644 index fa84f3bc..00000000 --- a/src/graph_notebook/authentication/iam_credentials_provider/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" \ No newline at end of file diff --git a/src/graph_notebook/authentication/iam_credentials_provider/credentials_factory.py b/src/graph_notebook/authentication/iam_credentials_provider/credentials_factory.py deleted file mode 100644 index c849b595..00000000 --- a/src/graph_notebook/authentication/iam_credentials_provider/credentials_factory.py +++ /dev/null @@ -1,24 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -from enum import Enum - -from graph_notebook.authentication.iam_credentials_provider.credentials_provider import CredentialsProviderBase -from graph_notebook.authentication.iam_credentials_provider.env_credentials_provider import EnvCredentialsProvider -from graph_notebook.authentication.iam_credentials_provider.ec2_metadata_credentials_provider import MetadataCredentialsProvider - - -class IAMAuthCredentialsProvider(Enum): - ROLE = "ROLE" - ENV = "ENV" - - -def credentials_provider_factory(mode: IAMAuthCredentialsProvider) -> CredentialsProviderBase: - if mode == IAMAuthCredentialsProvider.ENV: - return EnvCredentialsProvider() - elif mode == IAMAuthCredentialsProvider.ROLE: - return MetadataCredentialsProvider() - else: - raise NotImplementedError(f'the provided mode of {mode} has not been implemented by credentials_provider_factory') diff --git a/src/graph_notebook/authentication/iam_credentials_provider/credentials_provider.py b/src/graph_notebook/authentication/iam_credentials_provider/credentials_provider.py deleted file mode 100644 index 8e6f34cf..00000000 --- a/src/graph_notebook/authentication/iam_credentials_provider/credentials_provider.py +++ /dev/null @@ -1,20 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -from abc import ABC, abstractmethod - - -class Credentials(object): - def __init__(self, key, secret, region, token=''): - self.key = key - self.secret = secret - self.token = token - self.region = region - - -class CredentialsProviderBase(ABC): - @abstractmethod - def get_iam_credentials(self) -> Credentials: - pass diff --git a/src/graph_notebook/authentication/iam_credentials_provider/ec2_metadata_credentials_provider.py b/src/graph_notebook/authentication/iam_credentials_provider/ec2_metadata_credentials_provider.py deleted file mode 100644 index 8dd9bf58..00000000 --- a/src/graph_notebook/authentication/iam_credentials_provider/ec2_metadata_credentials_provider.py +++ /dev/null @@ -1,26 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import botocore.session -import requests - - -from graph_notebook.authentication.iam_credentials_provider.credentials_provider import CredentialsProviderBase, \ - Credentials - -region_url = 'http://169.254.169.254/latest/meta-data/placement/availability-zone' - - -class MetadataCredentialsProvider(CredentialsProviderBase): - def __init__(self): - res = requests.get(region_url) - zone = res.content.decode('utf-8') - region = zone[0:len(zone) - 1] - self.region = region - - def get_iam_credentials(self) -> Credentials: - session = botocore.session.get_session() - creds = session.get_credentials() - return Credentials(key=creds.access_key, secret=creds.secret_key, token=creds.token, region=self.region) diff --git a/src/graph_notebook/authentication/iam_credentials_provider/env_credentials_provider.py b/src/graph_notebook/authentication/iam_credentials_provider/env_credentials_provider.py deleted file mode 100644 index 0ba39e9d..00000000 --- a/src/graph_notebook/authentication/iam_credentials_provider/env_credentials_provider.py +++ /dev/null @@ -1,35 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import os - -from graph_notebook.authentication.iam_credentials_provider.credentials_provider import CredentialsProviderBase, \ - Credentials - -ACCESS_ENV_KEY = 'AWS_ACCESS_KEY_ID' -SECRET_ENV_KEY = 'AWS_SECRET_ACCESS_KEY' -REGION_ENV_KEY = 'AWS_REGION' -AWS_TOKEN_ENV_KEY = 'AWS_SESSION_TOKEN' - - -class EnvCredentialsProvider(CredentialsProviderBase): - def __init__(self): - self.creds = Credentials(key='', secret='', region='', token='') - self.loaded = False - - def load_iam_credentials(self): - access_key = os.environ.get(ACCESS_ENV_KEY, '') - secret_key = os.environ.get(SECRET_ENV_KEY, '') - region = os.environ.get(REGION_ENV_KEY, '') - token = os.environ.get(AWS_TOKEN_ENV_KEY, '') - self.creds = Credentials(access_key, secret_key, region, token) - self.loaded = True - return - - def get_iam_credentials(self, service=None) -> Credentials: - if not self.loaded: - self.load_iam_credentials() - - return self.creds diff --git a/src/graph_notebook/authentication/iam_headers.py b/src/graph_notebook/authentication/iam_headers.py deleted file mode 100644 index 6c39038a..00000000 --- a/src/graph_notebook/authentication/iam_headers.py +++ /dev/null @@ -1,217 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import datetime -import hashlib -import hmac -import json -import logging -import urllib - -logging.basicConfig() -logger = logging.getLogger("graph_magic") - - -# Key derivation functions. See: -# https://docs.aws.amazon.com/general/latest/gr/signature-v4-examples.html#signature-v4-examples-python -def sign(key, msg): - return hmac.new(key, msg.encode('utf-8'), hashlib.sha256).digest() - - -def get_signature_key(key, dateStamp, regionName, serviceName): - k_date = sign(('AWS4' + key).encode('utf-8'), dateStamp) - k_region = sign(k_date, regionName) - k_service = sign(k_region, serviceName) - k_signing = sign(k_service, 'aws4_request') - return k_signing - - -def get_canonical_uri_and_payload(query_type, query): - # Set the stack and payload depending on query_type. - if query_type == 'sparql': - canonical_uri = '/sparql/' - payload = query - - elif query_type == 'sparqlupdate': - canonical_uri = '/sparql/' - payload = query - - elif query_type == 'sparql/status': - canonical_uri = '/sparql/status/' - payload = query - - elif query_type == 'gremlin': - canonical_uri = '/gremlin' - payload = {} - - elif query_type == 'gremlin/status': - canonical_uri = '/gremlin/status/' - payload = query - - elif query_type == "loader": - canonical_uri = "/loader/" - payload = query - - elif query_type == "status": - canonical_uri = "/status/" - payload = {} - - elif query_type == "gremlin/explain": - canonical_uri = "/gremlin/explain/" - payload = query - - elif query_type == "gremlin/profile": - canonical_uri = "/gremlin/profile/" - payload = query - - elif query_type == "system": - canonical_uri = "/system/" - payload = query - - elif query_type.startswith("ml"): - canonical_uri = f'/{query_type}' - payload = query - - elif query_type.startswith("ml/dataprocessing"): - canonical_uri = f'/{query_type}' - payload = query - - elif query_type.startswith("ml/endpoints"): - canonical_uri = f'/{query_type}' - payload = query - - else: - raise ValueError('query_type %s is not valid' % query_type) - - return canonical_uri, payload - - -def normalize_query_string(query): - kv = (list(map(str.strip, s.split("="))) - for s in query.split('&') - if len(s) > 0) - - normalized = '&'.join('%s=%s' % (p[0], p[1] if len(p) > 1 else '') - for p in sorted(kv)) - return normalized - - -def make_signed_request(method, query_type, query, host, port, signing_access_key, signing_secret, signing_region, - use_ssl=False, signing_token='', additional_headers=None): - if additional_headers is None: - additional_headers = [] - - signing_region = signing_region.lower() - service = 'neptune-db' - - if use_ssl: - protocol = 'https' - else: - protocol = 'http' - - # this is always http right now - endpoint = f'{protocol}://{host}:{port}' - - # get canonical_uri and payload - canonical_uri, payload = get_canonical_uri_and_payload(query_type, query) - - if 'content-type' in additional_headers and additional_headers['content-type'] == 'application/json': - request_parameters = payload if type(payload) is str else json.dumps(payload) - else: - request_parameters = urllib.parse.urlencode(payload, quote_via=urllib.parse.quote) - request_parameters = request_parameters.replace('%27', '%22') - t = datetime.datetime.utcnow() - amz_date = t.strftime('%Y%m%dT%H%M%SZ') - date_stamp = t.strftime('%Y%m%d') # Date w/o time, used in credential scope - - method = method.upper() - if method == 'GET' or method == 'DELETE': - canonical_querystring = normalize_query_string(request_parameters) - elif method == 'POST': - canonical_querystring = '' - else: - raise ValueError('method %s is not valid when creating canonical request' % method) - - # Step 4: Create the canonical headers and signed headers. Header names - # must be trimmed and lowercase, and sorted in code point order from - # low to high. Note that there is a trailing \n. - canonical_headers = f'host:{host}:{port}\nx-amz-date:{amz_date}\n' - - # Step 5: Create the list of signed headers. This lists the headers - # in the canonical_headers list, delimited with ";" and in alpha order. - # Note: The request can include any headers; canonical_headers and - # signed_headers lists those that you want to be included in the - # hash of the request. "Host" and "x-amz-date" are always required. - signed_headers = 'host;x-amz-date' - - # Step 6: Create payload hash (hash of the request body content). For GET and DELETE - # requests, the payload is an empty string (""). - if method == 'GET' or method == 'DELETE': - post_payload = '' - elif method == 'POST': - post_payload = request_parameters - else: - raise ValueError('method %s is not supported' % method) - - payload_hash = hashlib.sha256(post_payload.encode('utf-8')).hexdigest() - - # Step 7: Combine elements to create canonical request. - canonical_request = method + '\n' + canonical_uri + '\n' + canonical_querystring + '\n' + canonical_headers + '\n' + signed_headers + '\n' + payload_hash - - # ************* TASK 2: CREATE THE STRING TO SIGN************* - # Match the algorithm to the hashing algorithm you use, either SHA-1 or - # SHA-256 (recommended) - algorithm = 'AWS4-HMAC-SHA256' - credential_scope = date_stamp + '/' + signing_region + '/' + service + '/' + 'aws4_request' - string_to_sign = algorithm + '\n' + amz_date + '\n' + credential_scope + '\n' + hashlib.sha256( - canonical_request.encode('utf-8')).hexdigest() - - # ************* TASK 3: CALCULATE THE SIGNATURE ************* - # Create the signing key using the function defined above. - signing_key = get_signature_key(signing_secret, date_stamp, signing_region, service) - - # Sign the string_to_sign using the signing_key - signature = hmac.new(signing_key, string_to_sign.encode('utf-8'), hashlib.sha256).hexdigest() - - # ************* TASK 4: ADD SIGNING INFORMATION TO THE REQUEST ************* - # The signing information can be either in a query string value or in - # a header named Authorization. This code shows how to use a header. - # Create authorization header and add to request headers - authorization_header = algorithm + ' ' + 'Credential=' + signing_access_key + '/' + credential_scope + ', ' + 'SignedHeaders=' + signed_headers + ', ' + 'Signature=' + signature - - # The request can include any headers, but MUST include "host", "x-amz-date", - # and (for this scenario) "Authorization". "host" and "x-amz-date" must - # be included in the canonical_headers and signed_headers, as noted - # earlier. Order here is not significant. - # Python note: The 'host' header is added automatically by the Python 'requests' library. - if method == 'GET' or method == 'DELETE': - headers = { - 'x-amz-date': amz_date, - 'Authorization': authorization_header - } - elif method == 'POST': - headers = { - 'content-type': 'application/x-www-form-urlencoded', - 'x-amz-date': amz_date, - 'Authorization': authorization_header, - } - else: - raise ValueError('method %s is not valid while creating request headers' % method) - - if additional_headers is not None: - for key in additional_headers: - headers[key] = additional_headers[key] - - if signing_token != '': - headers['X-Amz-Security-Token'] = signing_token - - # ************* SEND THE REQUEST ************* - request_url = endpoint + canonical_uri - - return { - 'url': request_url, - 'headers': headers, - 'params': request_parameters - } \ No newline at end of file diff --git a/src/graph_notebook/configuration/generate_config.py b/src/graph_notebook/configuration/generate_config.py index 496455f0..f57aaa11 100644 --- a/src/graph_notebook/configuration/generate_config.py +++ b/src/graph_notebook/configuration/generate_config.py @@ -8,10 +8,7 @@ import os from enum import Enum -from graph_notebook.authentication.iam_credentials_provider.credentials_factory import IAMAuthCredentialsProvider -from graph_notebook.sparql.query import SPARQL_ACTION - -DEFAULT_IAM_CREDENTIALS_PROVIDER = IAMAuthCredentialsProvider.ROLE +from graph_notebook.neptune.client import SPARQL_ACTION DEFAULT_CONFIG_LOCATION = os.path.expanduser('~/graph_notebook_config.json') @@ -38,6 +35,9 @@ def __init__(self, path: str = SPARQL_ACTION, endpoint_prefix: str = ''): print('endpoint_prefix has been deprecated and will be removed in version 2.0.20 or greater.') if path == '': path = f'{endpoint_prefix}/sparql' + elif path == '': + path = SPARQL_ACTION + self.path = path def to_dict(self): @@ -47,13 +47,11 @@ def to_dict(self): class Configuration(object): def __init__(self, host: str, port: int, auth_mode: AuthModeEnum = AuthModeEnum.DEFAULT, - iam_credentials_provider_type: IAMAuthCredentialsProvider = DEFAULT_IAM_CREDENTIALS_PROVIDER, load_from_s3_arn='', ssl: bool = True, aws_region: str = 'us-east-1', sparql_section: SparqlSection = None): self.host = host self.port = port self.auth_mode = auth_mode - self.iam_credentials_provider_type = iam_credentials_provider_type self.load_from_s3_arn = load_from_s3_arn self.ssl = ssl self.aws_region = aws_region @@ -64,7 +62,6 @@ def to_dict(self) -> dict: 'host': self.host, 'port': self.port, 'auth_mode': self.auth_mode.value, - 'iam_credentials_provider_type': self.iam_credentials_provider_type.value, 'load_from_s3_arn': self.load_from_s3_arn, 'ssl': self.ssl, 'aws_region': self.aws_region, @@ -79,21 +76,15 @@ def write_to_file(self, file_path=DEFAULT_CONFIG_LOCATION): return -def generate_config(host, port, auth_mode, ssl, iam_credentials_provider_type, load_from_s3_arn, aws_region): +def generate_config(host, port, auth_mode, ssl, load_from_s3_arn, aws_region): use_ssl = False if ssl in [False, 'False', 'false', 'FALSE'] else True - - if iam_credentials_provider_type not in [IAMAuthCredentialsProvider.ENV, - IAMAuthCredentialsProvider.ROLE]: - iam_credentials_provider_type = DEFAULT_IAM_CREDENTIALS_PROVIDER - - config = Configuration(host, port, auth_mode, iam_credentials_provider_type, load_from_s3_arn, use_ssl, aws_region) - return config + c = Configuration(host, port, auth_mode, load_from_s3_arn, use_ssl, aws_region) + return c def generate_default_config(): - config = generate_config('change-me', 8182, AuthModeEnum.DEFAULT, True, DEFAULT_IAM_CREDENTIALS_PROVIDER, '', - 'us-east-1') - return config + c = generate_config('change-me', 8182, AuthModeEnum.DEFAULT, True, '', 'us-east-1') + return c if __name__ == "__main__": @@ -102,6 +93,8 @@ def generate_default_config(): parser.add_argument("--port", help="the port to use when creating a connection", default="8182") parser.add_argument("--auth_mode", default=AuthModeEnum.DEFAULT.value, help="type of authentication the cluster being connected to is using. Can be DEFAULT or IAM") + + # TODO: this can now be removed. parser.add_argument("--iam_credentials_provider", default='ROLE', help="The mode used to obtain credentials for IAM Authentication. Can be ROLE or ENV") parser.add_argument("--ssl", @@ -110,14 +103,11 @@ def generate_default_config(): parser.add_argument("--config_destination", help="location to put generated config", default=DEFAULT_CONFIG_LOCATION) parser.add_argument("--load_from_s3_arn", help="arn of role to use for bulk loader", default='') - parser.add_argument("--aws_region", help="aws region your neptune cluster is in.", default='us-east-1') + parser.add_argument("--aws_region", help="aws region your ml cluster is in.", default='us-east-1') args = parser.parse_args() auth_mode_arg = args.auth_mode if args.auth_mode != '' else AuthModeEnum.DEFAULT.value - iam_credentials_provider_arg = args.iam_credentials_provider if args.iam_credentials_provider != '' else IAMAuthCredentialsProvider.ROLE.value - - config = generate_config(args.host, int(args.port), AuthModeEnum(auth_mode_arg), args.ssl, - IAMAuthCredentialsProvider(iam_credentials_provider_arg), + config = generate_config(args.host, int(args.port), AuthModeEnum(auth_mode_arg), args.ssl , args.load_from_s3_arn, args.aws_region) config.write_to_file(args.config_destination) diff --git a/src/graph_notebook/configuration/get_config.py b/src/graph_notebook/configuration/get_config.py index 35698391..72ab829a 100644 --- a/src/graph_notebook/configuration/get_config.py +++ b/src/graph_notebook/configuration/get_config.py @@ -5,7 +5,6 @@ import json -from graph_notebook.authentication.iam_credentials_provider.credentials_factory import IAMAuthCredentialsProvider from graph_notebook.configuration.generate_config import DEFAULT_CONFIG_LOCATION, Configuration, AuthModeEnum, \ SparqlSection @@ -13,8 +12,6 @@ def get_config_from_dict(data: dict) -> Configuration: sparql_section = SparqlSection(**data['sparql']) if 'sparql' in data else SparqlSection('') config = Configuration(host=data['host'], port=data['port'], auth_mode=AuthModeEnum(data['auth_mode']), - iam_credentials_provider_type=IAMAuthCredentialsProvider( - data['iam_credentials_provider_type']), ssl=data['ssl'], load_from_s3_arn=data['load_from_s3_arn'], aws_region=data['aws_region'], sparql_section=sparql_section) return config diff --git a/src/graph_notebook/gremlin/client_provider/default_client.py b/src/graph_notebook/gremlin/client_provider/default_client.py deleted file mode 100644 index 0d6bd4b8..00000000 --- a/src/graph_notebook/gremlin/client_provider/default_client.py +++ /dev/null @@ -1,21 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import graph_notebook.gremlin.client_provider.graphsonV3d0_MapType_objectify_patch # noqa F401 -import logging - -from gremlin_python.driver import client - -logging.basicConfig() -logger = logging.getLogger("default_client") - - -class ClientProvider(object): - @staticmethod - def get_client(host, port, use_ssl): - protocol = 'wss' if use_ssl else 'ws' - url = f'{protocol}://{host}:{port}/gremlin' - c = client.Client(url, 'g') - return c diff --git a/src/graph_notebook/gremlin/client_provider/factory.py b/src/graph_notebook/gremlin/client_provider/factory.py deleted file mode 100644 index 0f1a459e..00000000 --- a/src/graph_notebook/gremlin/client_provider/factory.py +++ /dev/null @@ -1,21 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -from graph_notebook.configuration.generate_config import AuthModeEnum -from graph_notebook.gremlin.client_provider.default_client import ClientProvider -from graph_notebook.gremlin.client_provider.iam_client import IamClientProvider -from graph_notebook.authentication.iam_credentials_provider.credentials_factory import credentials_provider_factory, \ - IAMAuthCredentialsProvider - - -def create_client_provider(mode: AuthModeEnum, - credentials_provider_mode: IAMAuthCredentialsProvider = IAMAuthCredentialsProvider.ROLE): - if mode == AuthModeEnum.DEFAULT: - return ClientProvider() - elif mode == AuthModeEnum.IAM: - credentials_provider = credentials_provider_factory(credentials_provider_mode) - return IamClientProvider(credentials_provider) - else: - raise NotImplementedError(f"invalid client mode {mode} provided") diff --git a/src/graph_notebook/gremlin/client_provider/iam_client.py b/src/graph_notebook/gremlin/client_provider/iam_client.py deleted file mode 100644 index aca981ab..00000000 --- a/src/graph_notebook/gremlin/client_provider/iam_client.py +++ /dev/null @@ -1,52 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import graph_notebook.gremlin.client_provider.graphsonV3d0_MapType_objectify_patch # noqa F401 -import hashlib -import hmac -import logging - -from gremlin_python.driver import client -from gremlin_python.driver.client import Client -from tornado import httpclient - -from graph_notebook.authentication.iam_credentials_provider.credentials_provider import CredentialsProviderBase -from graph_notebook.authentication.iam_headers import make_signed_request - -logging.basicConfig() -logger = logging.getLogger("iam_client") - - -def sign(key, msg): - return hmac.new(key, msg.encode('utf-8'), hashlib.sha256).digest() - - -def get_signature_key(key, date_stamp, region_name, service_name): - k_date = sign(('AWS4' + key).encode('utf-8'), date_stamp) - k_region = sign(k_date, region_name) - k_service = sign(k_region, service_name) - k_signing = sign(k_service, 'aws4_request') - return k_signing - - -class IamClientProvider(object): - def __init__(self, credentials_provider: CredentialsProviderBase): - self.credentials_provider = credentials_provider - - def get_client(self, host, port, use_ssl) -> Client: - credentials = self.credentials_provider.get_iam_credentials() - request_params = make_signed_request('get', 'gremlin', '', host, port, credentials.key, - credentials.secret, credentials.region, use_ssl, - credentials.token) - ws_url = request_params['url'].strip('/').replace('http', 'ws') - signed_ws_request = httpclient.HTTPRequest(ws_url, headers=request_params['headers']) - - try: - c = client.Client(signed_ws_request, 'g') - return c - # TODO: handle exception explicitly - except Exception as e: - logger.error(f'error while creating client {e}') - raise e diff --git a/src/graph_notebook/gremlin/query.py b/src/graph_notebook/gremlin/query.py deleted file mode 100644 index 00e60010..00000000 --- a/src/graph_notebook/gremlin/query.py +++ /dev/null @@ -1,57 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import logging - -from graph_notebook.gremlin.client_provider.default_client import ClientProvider -from graph_notebook.request_param_generator.default_request_generator import DefaultRequestGenerator -from graph_notebook.request_param_generator.call_and_get_response import call_and_get_response - -logging.basicConfig() -logger = logging.getLogger("gremlin") - - -def do_gremlin_query(query_str, host, port, use_ssl, client_provider=ClientProvider()): - c = client_provider.get_client(host, port, use_ssl) - - try: - result = c.submit(query_str) - future_results = result.all() - results = future_results.result() - except Exception as e: - raise e # let the upstream decide what to do with this error. - finally: - c.close() # no matter the outcome we need to close the websocket connection - - return results - - -def do_gremlin_explain(query_str, host, port, use_ssl, request_param_generator=DefaultRequestGenerator()): - data = { - 'gremlin': query_str - } - - action = 'gremlin/explain' - res = call_and_get_response('get', action, host, port, request_param_generator, use_ssl, data) - content = res.content.decode('utf-8') - result = { - 'explain': content - } - return result - - -def do_gremlin_profile(query_str, host, port, use_ssl, request_param_generator=DefaultRequestGenerator()): - data = { - 'gremlin': query_str - } - - action = 'gremlin/profile' - res = call_and_get_response('get', action, host, port, request_param_generator, use_ssl, data) - content = res.content.decode('utf-8').strip(' ') - result = { - 'profile': content - } - - return result diff --git a/src/graph_notebook/gremlin/status.py b/src/graph_notebook/gremlin/status.py deleted file mode 100644 index 4817206d..00000000 --- a/src/graph_notebook/gremlin/status.py +++ /dev/null @@ -1,44 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -from graph_notebook.configuration.generate_config import AuthModeEnum -from graph_notebook.request_param_generator.call_and_get_response import call_and_get_response - -GREMLIN_STATUS_ACTION = 'gremlin/status' - - -def do_gremlin_status(host, port, use_ssl, mode, request_param_generator, query_id: str, include_waiting: bool): - data = {'includeWaiting': include_waiting} - if query_id != '': - data['queryId'] = query_id - - headers = {} - if mode == AuthModeEnum.DEFAULT: - """Add correct content-type header for the request. - This is needed because call_and_get_response requires custom headers to be set. - """ - headers['Content-Type'] = 'application/x-www-form-urlencoded' - res = call_and_get_response('post', GREMLIN_STATUS_ACTION, host, port, request_param_generator, use_ssl, data, - headers) - content = res.json() - return content - - -def do_gremlin_cancel(host, port, use_ssl, mode, request_param_generator, query_id): - if type(query_id) != str or query_id == '': - raise ValueError("query id must be a non-empty string") - - data = {'cancelQuery': True, 'queryId': query_id} - - headers = {} - if mode == AuthModeEnum.DEFAULT: - """Add correct content-type header for the request. - This is needed because call_and_get_response requires custom headers to be set. - """ - headers['Content-Type'] = 'application/x-www-form-urlencoded' - res = call_and_get_response('post', GREMLIN_STATUS_ACTION, host, port, request_param_generator, use_ssl, data, - headers) - content = res.json() - return content diff --git a/src/graph_notebook/loader/load.py b/src/graph_notebook/loader/load.py deleted file mode 100644 index 3eaa8dbe..00000000 --- a/src/graph_notebook/loader/load.py +++ /dev/null @@ -1,90 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import json -from graph_notebook.request_param_generator.call_and_get_response import call_and_get_response - -FORMAT_CSV = 'csv' -FORMAT_NTRIPLE = 'ntriples' -FORMAT_NQUADS = 'nquads' -FORMAT_RDFXML = 'rdfxml' -FORMAT_TURTLE = 'turtle' - -PARALLELISM_LOW = 'LOW' -PARALLELISM_MEDIUM = 'MEDIUM' -PARALLELISM_HIGH = 'HIGH' -PARALLELISM_OVERSUBSCRIBE = 'OVERSUBSCRIBE' - -MODE_RESUME = 'RESUME' -MODE_NEW = 'NEW' -MODE_AUTO = 'AUTO' - -LOAD_JOB_MODES = [MODE_RESUME, MODE_NEW, MODE_AUTO] -VALID_FORMATS = [FORMAT_CSV, FORMAT_NTRIPLE, FORMAT_NQUADS, FORMAT_RDFXML, FORMAT_TURTLE] -PARALLELISM_OPTIONS = [PARALLELISM_LOW, PARALLELISM_MEDIUM, PARALLELISM_HIGH, PARALLELISM_OVERSUBSCRIBE] -LOADER_ACTION = 'loader' - -FINAL_LOAD_STATUSES = ['LOAD_COMPLETED', - 'LOAD_COMMITTED_W_WRITE_CONFLICTS', - 'LOAD_CANCELLED_BY_USER', - 'LOAD_CANCELLED_DUE_TO_ERRORS', - 'LOAD_FAILED', - 'LOAD_UNEXPECTED_ERROR', - 'LOAD_DATA_DEADLOCK', - 'LOAD_DATA_FAILED_DUE_TO_FEED_MODIFIED_OR_DELETED', - 'LOAD_S3_READ_ERROR', - 'LOAD_S3_ACCESS_DENIED_ERROR', - 'LOAD_IN_QUEUE', - 'LOAD_FAILED_BECAUSE_DEPENDENCY_NOT_SATISFIED', - 'LOAD_FAILED_INVALID_REQUEST', ] - - -def do_load(host, port, load_format, use_ssl, source, region, arn, fail_on_error, request_param_generator, mode="AUTO", - parallelism="HIGH", update_single_cardinality="FALSE", queue_request="FALSE", dependencies=[]): - payload = { - 'source': source, - 'format': load_format, - 'mode': mode, - 'region': region, - 'failOnError': fail_on_error, - 'parallelism': parallelism, - 'updateSingleCardinalityProperties': update_single_cardinality, - 'queueRequest': queue_request - } - - if arn != '': - payload['iamRoleArn'] = arn - - if dependencies: - payload['dependencies'] = json.dumps(dependencies) - - res = call_and_get_response('post', LOADER_ACTION, host, port, request_param_generator, use_ssl, payload) - return res.json() - - -def get_loader_jobs(host, port, use_ssl, request_param_generator): - res = call_and_get_response('get', LOADER_ACTION, host, port, request_param_generator, use_ssl) - return res.json() - - -def get_load_status(host, port, use_ssl, request_param_generator, id, loader_details="FALSE", loader_errors="FALSE", loader_page=1, loader_epp=10): - payload = { - 'loadId': id, - 'details': loader_details, - 'errors': loader_errors, - 'page': loader_page, - 'errorsPerPage': loader_epp - } - res = call_and_get_response('get', LOADER_ACTION, host, port, request_param_generator, use_ssl, payload) - return res.json() - - -def cancel_load(host, port, use_ssl, request_param_generator, load_id): - payload = { - 'loadId': load_id - } - - res = call_and_get_response('get', LOADER_ACTION, host, port, request_param_generator, use_ssl, payload) - return res.status_code == 200 diff --git a/src/graph_notebook/magics/graph_magic.py b/src/graph_notebook/magics/graph_magic.py index 90b30a16..478f0c3e 100644 --- a/src/graph_notebook/magics/graph_magic.py +++ b/src/graph_notebook/magics/graph_magic.py @@ -14,6 +14,8 @@ from enum import Enum import ipywidgets as widgets +from SPARQLWrapper import SPARQLWrapper +from botocore.session import get_session from gremlin_python.driver.protocol import GremlinServerError from IPython.core.display import HTML, display_html, display from IPython.core.magic import (Magics, magics_class, cell_magic, line_magic, line_cell_magic, needs_local_scope) @@ -21,25 +23,18 @@ from requests import HTTPError import graph_notebook -from graph_notebook.configuration.generate_config import generate_default_config, DEFAULT_CONFIG_LOCATION +from graph_notebook.configuration.generate_config import generate_default_config, DEFAULT_CONFIG_LOCATION, AuthModeEnum, \ + Configuration from graph_notebook.decorators.decorators import display_exceptions from graph_notebook.magics.ml import neptune_ml_magic_handler, generate_neptune_ml_parser +from graph_notebook.neptune.client import ClientBuilder, Client, VALID_FORMATS, PARALLELISM_OPTIONS, PARALLELISM_HIGH, \ + LOAD_JOB_MODES, MODE_AUTO, FINAL_LOAD_STATUSES, SPARQL_ACTION from graph_notebook.network import SPARQLNetwork from graph_notebook.network.gremlin.GremlinNetwork import parse_pattern_list_str, GremlinNetwork -from graph_notebook.sparql.table import get_rows_and_columns -from graph_notebook.gremlin.query import do_gremlin_query, do_gremlin_explain, do_gremlin_profile -from graph_notebook.gremlin.status import do_gremlin_status, do_gremlin_cancel -from graph_notebook.sparql.query import get_query_type, do_sparql_query, do_sparql_explain, SPARQL_ACTION -from graph_notebook.sparql.status import do_sparql_status, do_sparql_cancel -from graph_notebook.system.database_reset import perform_database_reset, initiate_database_reset +from graph_notebook.visualization.sparql_rows_and_columns import get_rows_and_columns from graph_notebook.visualization.template_retriever import retrieve_template -from graph_notebook.gremlin.client_provider.factory import create_client_provider -from graph_notebook.request_param_generator.factory import create_request_generator -from graph_notebook.loader.load import do_load, get_loader_jobs, get_load_status, cancel_load, VALID_FORMATS, \ - PARALLELISM_OPTIONS, PARALLELISM_HIGH, FINAL_LOAD_STATUSES, LOAD_JOB_MODES, MODE_AUTO from graph_notebook.configuration.get_config import get_config, get_config_from_dict from graph_notebook.seed.load_query import get_data_sets, get_queries -from graph_notebook.status.get_status import get_status from graph_notebook.widgets import Force from graph_notebook.options import OPTIONS_DEFAULT_DIRECTED, vis_options_merge @@ -91,6 +86,28 @@ def str_to_query_mode(s: str) -> QueryMode: return QueryMode.DEFAULT +ACTION_TO_QUERY_TYPE = { + 'sparql': 'application/sparql-query', + 'sparqlupdate': 'application/sparql-update' +} + + +def get_query_type(query): + s = SPARQLWrapper('') + s.setQuery(query) + return s.queryType + + +def query_type_to_action(query_type): + query_type = query_type.upper() + if query_type in ['SELECT', 'CONSTRUCT', 'ASK', 'DESCRIBE']: + return 'sparql' + else: + # TODO: check explicitly for all query types, raise exception for invalid query + return 'sparqlupdate' + + +# TODO: refactor large magic commands into their own modules like what we do with %neptune_ml # noinspection PyTypeChecker @magics_class class Graph(Magics): @@ -100,7 +117,12 @@ def __init__(self, shell): try: self.config_location = os.getenv('GRAPH_NOTEBOOK_CONFIG', DEFAULT_CONFIG_LOCATION) - self.graph_notebook_config = get_config(self.config_location) + config = get_config(self.config_location) + self.graph_notebook_config = config + + self.client: Client = None + self._generate_client_from_config(config) + except FileNotFoundError: self.graph_notebook_config = generate_default_config() print( @@ -109,6 +131,21 @@ def __init__(self, shell): self.graph_notebook_vis_options = OPTIONS_DEFAULT_DIRECTED logger.setLevel(logging.ERROR) + def _generate_client_from_config(self, config: Configuration): + if self.client: + self.client.close() + + builder = ClientBuilder() \ + .with_host(config.host) \ + .with_port(config.port) \ + .with_region(config.aws_region) \ + .with_tls(config.ssl) \ + .with_sparql_path(config.sparql.path) + if config.auth_mode == AuthModeEnum.IAM: + builder = builder.with_iam(get_session()) + + self.client = builder.build() + @line_cell_magic @display_exceptions def graph_notebook_config(self, line='', cell=''): @@ -116,6 +153,7 @@ def graph_notebook_config(self, line='', cell=''): data = json.loads(cell) config = get_config_from_dict(data) self.graph_notebook_config = config + self._generate_client_from_config(config) print('set notebook config to:') print(json.dumps(self.graph_notebook_config.to_dict(), indent=2)) elif line == 'reset': @@ -143,6 +181,7 @@ def graph_notebook_host(self, line): # TODO: we should attempt to make a status call to this host before we set the config to this value. self.graph_notebook_config.host = line + self._generate_client_from_config(self.graph_notebook_config) print(f'set host to {line}') @cell_magic @@ -155,26 +194,24 @@ def sparql(self, line='', cell='', local_ns: dict = None): parser.add_argument('--path', '-p', default='', help='prefix path to sparql endpoint. For example, if "foo/bar" were specified, the endpoint called would be host:port/foo/bar') parser.add_argument('--expand-all', action='store_true') - - request_generator = create_request_generator(self.graph_notebook_config.auth_mode, - self.graph_notebook_config.iam_credentials_provider_type, - command='sparql') + parser.add_argument('--explain-type', default='dynamic', + help='explain mode to use when using the explain query mode', + choices=['dynamic', 'static', 'details']) + parser.add_argument('--explain-format', default='text/html', help='response format for explain query mode', + choices=['text/csv', 'text/html']) parser.add_argument('--store-to', type=str, default='', help='store query result to this variable') args = parser.parse_args(line.split()) mode = str_to_query_mode(args.query_mode) tab = widgets.Tab() path = args.path if args.path != '' else self.graph_notebook_config.sparql.path - logger.debug(f'using mode={mode}') if mode == QueryMode.EXPLAIN: - res = do_sparql_explain(cell, self.graph_notebook_config.host, self.graph_notebook_config.port, - self.graph_notebook_config.ssl, request_generator, path=path) - store_to_ns(args.store_to, res, local_ns) - if 'error' in res: - html = error_template.render(error=json.dumps(res['error'], indent=2)) - else: - html = sparql_explain_template.render(table=res) + res = self.client.sparql_explain(cell, args.explain_type, args.explain_format, path=path) + res.raise_for_status() + explain = res.content.decode('utf-8') + store_to_ns(args.store_to, explain, local_ns) + html = sparql_explain_template.render(table=explain) explain_output = widgets.Output(layout=DEFAULT_LAYOUT) with explain_output: display(HTML(html)) @@ -186,8 +223,10 @@ def sparql(self, line='', cell='', local_ns: dict = None): query_type = get_query_type(cell) headers = {} if query_type not in ['SELECT', 'CONSTRUCT', 'DESCRIBE'] else { 'Accept': 'application/sparql-results+json'} - res = do_sparql_query(cell, self.graph_notebook_config.host, self.graph_notebook_config.port, - self.graph_notebook_config.ssl, request_generator, headers, path=path) + + query_res = self.client.sparql(cell, path=path, headers=headers) + query_res.raise_for_status() + res = query_res.json() store_to_ns(args.store_to, res, local_ns) titles = [] children = [] @@ -207,8 +246,7 @@ def sparql(self, line='', cell='', local_ns: dict = None): titles.append('Table') children.append(hbox) - expand_all = line == '--expand-all' - sn = SPARQLNetwork(expand_all=expand_all) + sn = SPARQLNetwork(expand_all=args.expand_all) sn.extract_prefix_declarations_from_query(cell) try: sn.add_results(res) @@ -268,20 +306,18 @@ def sparql_status(self, line='', local_ns: dict = None): parser.add_argument('--store-to', type=str, default='', help='store query result to this variable') args = parser.parse_args(line.split()) - request_generator = create_request_generator(self.graph_notebook_config.auth_mode, - self.graph_notebook_config.iam_credentials_provider_type) - if not args.cancelQuery: - res = do_sparql_status(self.graph_notebook_config.host, self.graph_notebook_config.port, - self.graph_notebook_config.ssl, request_generator, args.queryId) - + status_res = self.client.sparql_cancel(args.queryId) + status_res.raise_for_status() + res = status_res.json() else: if args.queryId == '': print(SPARQL_CANCEL_HINT_MSG) return else: - res = do_sparql_cancel(self.graph_notebook_config.host, self.graph_notebook_config.port, - self.graph_notebook_config.ssl, request_generator, args.queryId, args.silent) + cancel_res = self.client.sparql_cancel(args.queryId, args.silent) + cancel_res.raise_for_status() + res = cancel_res.json() store_to_ns(args.store_to, res, local_ns) print(json.dumps(res, indent=2)) @@ -304,13 +340,11 @@ def gremlin(self, line, cell, local_ns: dict = None): tab = widgets.Tab() if mode == QueryMode.EXPLAIN: - request_generator = create_request_generator(self.graph_notebook_config.auth_mode, - self.graph_notebook_config.iam_credentials_provider_type) - - query_res = do_gremlin_explain(cell, self.graph_notebook_config.host, self.graph_notebook_config.port, - self.graph_notebook_config.ssl, request_generator) - if 'explain' in query_res: - html = pre_container_template.render(content=query_res['explain']) + res = self.client.gremlin_explain(cell) + res.raise_for_status() + query_res = res.content.decode('utf-8') + if 'Neptune Gremlin Explain' in query_res: + html = pre_container_template.render(content=query_res) else: html = pre_container_template.render(content='No explain found') @@ -321,13 +355,11 @@ def gremlin(self, line, cell, local_ns: dict = None): tab.set_title(0, 'Explain') display(tab) elif mode == QueryMode.PROFILE: - request_generator = create_request_generator(self.graph_notebook_config.auth_mode, - self.graph_notebook_config.iam_credentials_provider_type) - - query_res = do_gremlin_profile(cell, self.graph_notebook_config.host, self.graph_notebook_config.port, - self.graph_notebook_config.ssl, request_generator) - if 'profile' in query_res: - html = pre_container_template.render(content=query_res['profile']) + res = self.client.gremlin_profile(cell) + res.raise_for_status() + query_res = res.content.decode('utf-8') + if 'Neptune Gremlin Profile' in query_res: + html = pre_container_template.render(content=query_res) else: html = pre_container_template.render(content='No profile found') profile_output = widgets.Output(layout=DEFAULT_LAYOUT) @@ -337,10 +369,7 @@ def gremlin(self, line, cell, local_ns: dict = None): tab.set_title(0, 'Profile') display(tab) else: - client_provider = create_client_provider(self.graph_notebook_config.auth_mode, - self.graph_notebook_config.iam_credentials_provider_type) - query_res = do_gremlin_query(cell, self.graph_notebook_config.host, self.graph_notebook_config.port, - self.graph_notebook_config.ssl, client_provider) + query_res = self.client.gremlin_query(cell) children = [] titles = [] @@ -393,24 +422,18 @@ def gremlin_status(self, line='', local_ns: dict = None): parser.add_argument('--store-to', type=str, default='', help='store query result to this variable') args = parser.parse_args(line.split()) - request_generator = create_request_generator(self.graph_notebook_config.auth_mode, - self.graph_notebook_config.iam_credentials_provider_type) - if not args.cancelQuery: - res = do_gremlin_status(self.graph_notebook_config.host, - self.graph_notebook_config.port, - self.graph_notebook_config.ssl, self.graph_notebook_config.auth_mode, - request_generator, args.queryId, args.includeWaiting) - + status_res = self.client.gremlin_status(args.queryId) + status_res.raise_for_status() + res = status_res.json() else: if args.queryId == '': print(GREMLIN_CANCEL_HINT_MSG) return else: - res = do_gremlin_cancel(self.graph_notebook_config.host, self.graph_notebook_config.port, - self.graph_notebook_config.ssl, self.graph_notebook_config.auth_mode, - request_generator, args.queryId) - + cancel_res = self.client.gremlin_cancel(args.queryId) + cancel_res.raise_for_status() + res = cancel_res.json() print(json.dumps(res, indent=2)) store_to_ns(args.store_to, res, local_ns) @@ -418,41 +441,38 @@ def gremlin_status(self, line='', local_ns: dict = None): @display_exceptions def status(self, line): logger.info(f'calling for status on endpoint {self.graph_notebook_config.host}') - request_generator = create_request_generator(self.graph_notebook_config.auth_mode, - self.graph_notebook_config.iam_credentials_provider_type) logger.info( f'used credentials_provider_mode={self.graph_notebook_config.iam_credentials_provider_type.name} and auth_mode={self.graph_notebook_config.auth_mode.name} to make status request') - - res = get_status(self.graph_notebook_config.host, self.graph_notebook_config.port, - self.graph_notebook_config.ssl, request_generator) + status_res = self.client.status() + status_res.raise_for_status() + res = status_res.json() logger.info(f'got the response {res}') return res @line_magic @display_exceptions def db_reset(self, line): - host = self.graph_notebook_config.host - port = self.graph_notebook_config.port - ssl = self.graph_notebook_config.ssl - - logger.info(f'calling system endpoint {host}') + logger.info(f'calling system endpoint {self.client.host}') parser = argparse.ArgumentParser() parser.add_argument('-g', '--generate-token', action='store_true', help='generate token for database reset') - parser.add_argument('-t', '--token', nargs=1, default='', help='perform database reset with given token') + parser.add_argument('-t', '--token', default='', help='perform database reset with given token') parser.add_argument('-y', '--yes', action='store_true', help='skip the prompt and perform database reset') args = parser.parse_args(line.split()) generate_token = args.generate_token skip_prompt = args.yes - request_generator = create_request_generator(self.graph_notebook_config.auth_mode, - self.graph_notebook_config.iam_credentials_provider_type) logger.info( f'used credentials_provider_mode={self.graph_notebook_config.iam_credentials_provider_type.name} and auth_mode={self.graph_notebook_config.auth_mode.name} to make system request') if generate_token is False and args.token == '': if skip_prompt: - res = initiate_database_reset(host, port, ssl, request_generator) + initiate_res = self.client.initiate_reset() + initiate_res.raise_for_status() + res = initiate_res.json() token = res['payload']['token'] - res = perform_database_reset(token, host, port, ssl, request_generator) + + perform_reset_res = self.client.perform_reset(token) + perform_reset_res.raise_for_status() logger.info(f'got the response {res}') + res = perform_reset_res.json() return res output = widgets.Output() @@ -473,7 +493,9 @@ def db_reset(self, line): display(text_hbox, check_box, button_hbox, output) def on_button_delete_clicked(b): - result = initiate_database_reset(host, port, ssl, request_generator) + initiate_res = self.client.initiate_reset() + initiate_res.raise_for_status() + result = initiate_res.json() text_hbox.close() check_box.close() @@ -492,7 +514,9 @@ def on_button_delete_clicked(b): print(result) return - result = perform_database_reset(token, host, port, ssl, request_generator) + perform_reset_res = self.client.perform_reset(token) + perform_reset_res.raise_for_status() + result = perform_reset_res.json() if 'status' not in result or result['status'] != '200 OK': with output: @@ -522,7 +546,9 @@ def on_button_delete_clicked(b): display_html(HTML(loading_wheel_html)) try: retry -= 1 - interval_check_response = get_status(host, port, ssl, request_generator) + status_res = self.client.status() + status_res.raise_for_status() + interval_check_response = status_res.json() except Exception as e: # Exception is expected when database is resetting, continue waiting with job_status_output: @@ -560,10 +586,14 @@ def on_button_cancel_clicked(b): button_cancel.on_click(on_button_cancel_clicked) return elif generate_token: - res = initiate_database_reset(host, port, ssl, request_generator) + initiate_res = self.client.initiate_reset() + initiate_res.raise_for_status() + res = initiate_res.json() else: # args.token is an array of a single string, e.g., args.token=['ade-23-c23'], use index 0 to take the string - res = perform_database_reset(args.token[0], host, port, ssl, request_generator) + perform_res = self.client.perform_reset(args.token) + perform_res.raise_for_status() + res = perform_res.json() logger.info(f'got the response {res}') return res @@ -572,6 +602,7 @@ def on_button_cancel_clicked(b): @needs_local_scope @display_exceptions def load(self, line='', local_ns: dict = None): + # TODO: change widgets to let any arbitrary inputs be added by users parser = argparse.ArgumentParser() parser.add_argument('-s', '--source', default='s3://') parser.add_argument('-l', '--loader-arn', default=self.graph_notebook_config.load_from_s3_arn) @@ -588,16 +619,7 @@ def load(self, line='', local_ns: dict = None): args = parser.parse_args(line.split()) - # since this can be a long-running task, freezing variables in the case - # that a user alters them in another command. - host = self.graph_notebook_config.host - port = self.graph_notebook_config.port - ssl = self.graph_notebook_config.ssl - - credentials_provider_mode = self.graph_notebook_config.iam_credentials_provider_type - request_generator = create_request_generator(self.graph_notebook_config.auth_mode, credentials_provider_mode) region = self.graph_notebook_config.aws_region - button = widgets.Button(description="Submit") output = widgets.Output() source = widgets.Text( @@ -708,7 +730,8 @@ def on_button_clicked(b): if not len(dependencies_list) < 64: validated = False - dep_validation_label = widgets.HTML('

A maximum of 64 jobs may be queued at once.

') + dep_validation_label = widgets.HTML( + '

A maximum of 64 jobs may be queued at once.

') dep_hbox.children += (dep_validation_label,) if not validated: @@ -718,11 +741,19 @@ def on_button_clicked(b): source.value) # replace any env variables in source.value with their values, can use $foo or ${foo}. Particularly useful for ${AWS_REGION} logger.info(f'using source_exp: {source_exp}') try: - load_result = do_load(host, port, source_format.value, ssl, str(source_exp), region_box.value, - arn.value, fail_on_error.value, request_generator, mode=mode.value, - parallelism=parallelism.value, update_single_cardinality=update_single_cardinality.value, - queue_request=queue_request.value, dependencies=dependencies_list) - + kwargs = { + 'failOnError': fail_on_error.value, + 'parallelism': parallelism.value, + 'updateSingleCardinalityProperties': update_single_cardinality.value, + 'queueRequest': queue_request.value + } + + if dependencies: + kwargs['dependencies'] = dependencies_list + + load_res = self.client.load(source.value, source_format.value, arn.value, region_box.value, **kwargs) + load_res.raise_for_status() + load_result = load_res.json() store_to_ns(args.store_to, load_result, local_ns) source_hbox.close() @@ -767,8 +798,9 @@ def on_button_clicked(b): with job_status_output: display_html(HTML(loading_wheel_html)) try: - interval_check_response = get_load_status(host, port, ssl, request_generator, - load_result['payload']['loadId']) + load_status_res = self.client.load_status(load_result['payload']['loadId']) + load_status_res.raise_for_status() + interval_check_response = load_status_res.json() except Exception as e: logger.error(e) with job_status_output: @@ -806,10 +838,9 @@ def load_ids(self, line, local_ns: dict = None): parser.add_argument('--store-to', type=str, default='') args = parser.parse_args(line.split()) - credentials_provider_mode = self.graph_notebook_config.iam_credentials_provider_type - request_generator = create_request_generator(self.graph_notebook_config.auth_mode, credentials_provider_mode) - res = get_loader_jobs(self.graph_notebook_config.host, self.graph_notebook_config.port, - self.graph_notebook_config.ssl, request_generator) + load_status = self.client.load_status() + load_status.raise_for_status() + res = load_status.json() ids = [] if 'payload' in res and 'loadIds' in res['payload']: ids = res['payload']['loadIds'] @@ -834,15 +865,21 @@ def load_status(self, line, local_ns: dict = None): parser.add_argument('--store-to', type=str, default='') parser.add_argument('--details', action='store_true', default=False) parser.add_argument('--errors', action='store_true', default=False) - parser.add_argument('--page', '-p', default='1', help='The error page number. Only valid when the --errors option is set.') - parser.add_argument('--errorsPerPage', '-e', default='10', help='The number of errors per each page. Only valid when the --errors option is set.') + parser.add_argument('--page', '-p', default='1', + help='The error page number. Only valid when the --errors option is set.') + parser.add_argument('--errorsPerPage', '-e', default='10', + help='The number of errors per each page. Only valid when the --errors option is set.') args = parser.parse_args(line.split()) - credentials_provider_mode = self.graph_notebook_config.iam_credentials_provider_type - request_generator = create_request_generator(self.graph_notebook_config.auth_mode, credentials_provider_mode) - res = get_load_status(self.graph_notebook_config.host, self.graph_notebook_config.port, - self.graph_notebook_config.ssl, request_generator, args.load_id, args.details, args.errors, - args.page, args.errorsPerPage) + payload = { + 'details': args.details, + 'errors': args.errors, + 'page': args.page, + 'errorsPerPage': args.errorsPerPage + } + load_status_res = self.client.load_status(args.load_id, **payload) + load_status_res.raise_for_status() + res = load_status_res.json() print(json.dumps(res, indent=2)) if args.store_to != '' and local_ns is not None: @@ -857,10 +894,9 @@ def cancel_load(self, line, local_ns: dict = None): parser.add_argument('--store-to', type=str, default='') args = parser.parse_args(line.split()) - credentials_provider_mode = self.graph_notebook_config.iam_credentials_provider_type - request_generator = create_request_generator(self.graph_notebook_config.auth_mode, credentials_provider_mode) - res = cancel_load(self.graph_notebook_config.host, self.graph_notebook_config.port, - self.graph_notebook_config.ssl, request_generator, args.load_id) + cancel_res = self.client.cancel_load(args.load_id) + cancel_res.raise_for_status() + res = cancel_res.json() if res: print('Cancelled successfully.') else: @@ -875,7 +911,7 @@ def seed(self, line): parser = argparse.ArgumentParser() parser.add_argument('--language', type=str, default='', choices=SEED_LANGUAGE_OPTIONS) parser.add_argument('--dataset', type=str, default='') - # TODO: Gremlin paths are not yet supported. + # TODO: Gremlin api paths are not yet supported. parser.add_argument('--path', '-p', default=SPARQL_ACTION, help='prefix path to query endpoint. For example, "foo/bar". The queried path would then be host:port/foo/bar for sparql seed commands') parser.add_argument('--run', action='store_true') @@ -938,21 +974,11 @@ def on_button_clicked(b=None): for q in queries: with output: print(f'{progress.value}/{len(queries)}:\t{q["name"]}') - # Just like with the load command, seed is long-running - # as such, we want to obtain the values of host, port, etc. in case they - # change during execution. - host = self.graph_notebook_config.host - port = self.graph_notebook_config.port - auth_mode = self.graph_notebook_config.auth_mode - ssl = self.graph_notebook_config.ssl - if language == 'gremlin': - client_provider = create_client_provider(auth_mode, - self.graph_notebook_config.iam_credentials_provider_type) # IMPORTANT: We treat each line as its own query! for line in q['content'].splitlines(): try: - do_gremlin_query(line, host, port, ssl, client_provider) + self.client.gremlin_query(line) except GremlinServerError as gremlinEx: try: error = json.loads(gremlinEx.args[0][5:]) # remove the leading error code. @@ -975,10 +1001,8 @@ def on_button_clicked(b=None): progress.close() return else: - request_generator = create_request_generator(auth_mode, - self.graph_notebook_config.iam_credentials_provider_type) try: - do_sparql_query(q['content'], host, port, ssl, request_generator, path=args.path) + self.client.sparql(q['content'], path=args.path) except HTTPError as httpEx: # attempt to turn response into json try: @@ -1053,11 +1077,9 @@ def neptune_ml(self, line, cell='', local_ns: dict = None): parser = generate_neptune_ml_parser() args = parser.parse_args(line.split()) logger.info(f'received call to neptune_ml with details: {args.__dict__}, cell={cell}, local_ns={local_ns}') - request_generator = create_request_generator(self.graph_notebook_config.auth_mode, - self.graph_notebook_config.iam_credentials_provider_type) main_output = widgets.Output() display(main_output) - res = neptune_ml_magic_handler(args, request_generator, self.graph_notebook_config, main_output, cell, local_ns) + res = neptune_ml_magic_handler(args, self.client, main_output, cell, local_ns) message = json.dumps(res, indent=2) if type(res) is dict else res store_to_ns(args.store_to, res, local_ns) with main_output: diff --git a/src/graph_notebook/magics/ml.py b/src/graph_notebook/magics/ml.py index b40db040..91e42474 100644 --- a/src/graph_notebook/magics/ml.py +++ b/src/graph_notebook/magics/ml.py @@ -6,12 +6,8 @@ from IPython.core.display import display from ipywidgets import widgets -from graph_notebook.authentication.iam_credentials_provider.credentials_factory import credentials_provider_factory -from graph_notebook.authentication.iam_credentials_provider.credentials_provider import Credentials -from graph_notebook.configuration.generate_config import Configuration, AuthModeEnum from graph_notebook.magics.parsing import str_to_namespace_var -from graph_notebook.ml.sagemaker import start_export, get_export_status, start_processing_job, get_processing_status, \ - start_training, get_training_status, start_create_endpoint, get_endpoint_status, EXPORT_SERVICE_NAME +from graph_notebook.neptune.client import Client logger = logging.getLogger("neptune_ml_magic_handler") @@ -146,17 +142,26 @@ def generate_neptune_ml_parser(): return parser -def neptune_ml_export_start(params, export_url: str, export_ssl: bool = True, creds: Credentials = None): +def neptune_ml_export_start(client: Client, params, export_url: str, export_ssl: bool = True): if type(params) is str: params = json.loads(params) - job = start_export(export_url, params, export_ssl, creds) + export_res = client.export(export_url, params, export_ssl) + export_res.raise_for_status() + job = export_res.json() return job -def wait_for_export(export_url: str, job_id: str, output: widgets.Output, +def neptune_ml_export_status(client: Client, export_url: str, job_id: str, export_ssl: bool = True): + res = client.export_status(export_url, job_id, export_ssl) + res.raise_for_status() + job = res.json() + return job + + +def wait_for_export(client: Client, export_url: str, job_id: str, output: widgets.Output, export_ssl: bool = True, wait_interval: int = DEFAULT_WAIT_INTERVAL, - wait_timeout: int = DEFAULT_WAIT_TIMEOUT, creds: Credentials = None): + wait_timeout: int = DEFAULT_WAIT_TIMEOUT): job_id_output = widgets.Output() update_widget_output = widgets.Output() with output: @@ -170,7 +175,9 @@ def wait_for_export(export_url: str, job_id: str, output: widgets.Output, while datetime.datetime.utcnow() - beginning_time < (datetime.timedelta(seconds=wait_timeout)): update_widget_output.clear_output() print('Checking for latest status...') - export_status = get_export_status(export_url, export_ssl, job_id, creds) + status_res = client.export_status(export_url, job_id, export_ssl) + status_res.raise_for_status() + export_status = status_res.json() if export_status['status'] in ['succeeded', 'failed']: print('Export is finished') return export_status @@ -180,32 +187,30 @@ def wait_for_export(export_url: str, job_id: str, output: widgets.Output, time.sleep(wait_interval) -def neptune_ml_export(args: argparse.Namespace, config: Configuration, output: widgets.Output, cell: str): - auth_mode = AuthModeEnum.IAM if args.export_iam else AuthModeEnum.DEFAULT - creds = None - if auth_mode == AuthModeEnum.IAM: - creds = credentials_provider_factory(config.iam_credentials_provider_type).get_iam_credentials() - +def neptune_ml_export(args: argparse.Namespace, client: Client, output: widgets.Output, + cell: str): export_ssl = not args.export_no_ssl if args.which_sub == 'start': if cell == '': return 'Cell body must have json payload or reference notebook variable using syntax ${payload_var}' - export_job = neptune_ml_export_start(cell, args.export_url, export_ssl, creds) + export_job = neptune_ml_export_start(client, cell, args.export_url, export_ssl) if args.wait: - return wait_for_export(args.export_url, export_job['jobId'], - output, export_ssl, args.wait_interval, args.wait_timeout, creds) + return wait_for_export(client, args.export_url, export_job['jobId'], + output, export_ssl, args.wait_interval, args.wait_timeout) else: return export_job elif args.which_sub == 'status': if args.wait: - status = wait_for_export(args.export_url, args.job_id, output, export_ssl, args.wait_interval, - args.wait_timeout, creds) + status = wait_for_export(client, args.export_url, args.job_id, output, export_ssl, args.wait_interval, + args.wait_timeout) else: - status = get_export_status(args.export_url, export_ssl, args.job_id, creds) + status_res = client.export_status(args.export_url, args.job_id, export_ssl) + status_res.raise_for_status() + status = status_res.json() return status -def wait_for_dataprocessing(job_id: str, config: Configuration, request_param_generator, output: widgets.Output, +def wait_for_dataprocessing(job_id: str, client: Client, output: widgets.Output, wait_interval: int = DEFAULT_WAIT_INTERVAL, wait_timeout: int = DEFAULT_WAIT_TIMEOUT): job_id_output = widgets.Output() update_status_output = widgets.Output() @@ -219,7 +224,9 @@ def wait_for_dataprocessing(job_id: str, config: Configuration, request_param_ge beginning_time = datetime.datetime.utcnow() while datetime.datetime.utcnow() - beginning_time < (datetime.timedelta(seconds=wait_timeout)): update_status_output.clear_output() - status = get_processing_status(config.host, str(config.port), config.ssl, request_param_generator, job_id) + status_res = client.dataprocessing_job_status(job_id) + status_res.raise_for_status() + status = status_res.json() if status['status'] in ['Completed', 'Failed']: print('Data processing is finished') return status @@ -229,37 +236,34 @@ def wait_for_dataprocessing(job_id: str, config: Configuration, request_param_ge time.sleep(wait_interval) -def neptune_ml_dataprocessing(args: argparse.Namespace, request_param_generator, output: widgets.Output, - config: Configuration, params: dict = None): +def neptune_ml_dataprocessing(args: argparse.Namespace, client, output: widgets.Output, params: dict = None): if args.which_sub == 'start': if params is None or params == '' or params == {}: params = { - 'inputDataS3Location': args.s3_input_uri, - 'processedDataS3Location': args.s3_processed_uri, 'id': args.job_id, 'configFileName': args.config_file_name } - processing_job = start_processing_job(config.host, str(config.port), config.ssl, - request_param_generator, params) + processing_job_res = client.dataprocessing_start(args.s3_input_uri, args.s3_processed_uri, **params) + processing_job_res.raise_for_status() + processing_job = processing_job_res.json() job_id = params['id'] if args.wait: - return wait_for_dataprocessing(job_id, config, request_param_generator, - output, args.wait_interval, args.wait_timeout) + return wait_for_dataprocessing(job_id, client, output, args.wait_interval, args.wait_timeout) else: return processing_job elif args.which_sub == 'status': if args.wait: - return wait_for_dataprocessing(args.job_id, config, request_param_generator, output, args.wait_interval, - args.wait_timeout) + return wait_for_dataprocessing(args.job_id, client, output, args.wait_interval, args.wait_timeout) else: - return get_processing_status(config.host, str(config.port), config.ssl, request_param_generator, - args.job_id) + processing_status = client.dataprocessing_job_status(args.job_id) + processing_status.raise_for_status() + return processing_status.json() else: return f'Sub parser "{args.which} {args.which_sub}" was not recognized' -def wait_for_training(job_id: str, config: Configuration, request_param_generator, output: widgets.Output, +def wait_for_training(job_id: str, client: Client, output: widgets.Output, wait_interval: int = DEFAULT_WAIT_INTERVAL, wait_timeout: int = DEFAULT_WAIT_TIMEOUT): job_id_output = widgets.Output() update_status_output = widgets.Output() @@ -273,7 +277,9 @@ def wait_for_training(job_id: str, config: Configuration, request_param_generato beginning_time = datetime.datetime.utcnow() while datetime.datetime.utcnow() - beginning_time < (datetime.timedelta(seconds=wait_timeout)): update_status_output.clear_output() - status = get_training_status(config.host, str(config.port), config.ssl, request_param_generator, job_id) + training_status_res = client.modeltraining_job_status(job_id) + training_status_res.raise_for_status() + status = training_status_res.json() if status['status'] in ['Completed', 'Failed']: print('Training is finished') return status @@ -283,35 +289,34 @@ def wait_for_training(job_id: str, config: Configuration, request_param_generato time.sleep(wait_interval) -def neptune_ml_training(args: argparse.Namespace, request_param_generator, config: Configuration, - output: widgets.Output, params): +def neptune_ml_training(args: argparse.Namespace, client: Client, output: widgets.Output, params): if args.which_sub == 'start': if params is None or params == '' or params == {}: params = { "id": args.job_id, "dataProcessingJobId": args.data_processing_id, "trainingInstanceType": args.instance_type, - "trainModelS3Location": args.s3_output_uri } - training_job = start_training(config.host, str(config.port), config.ssl, request_param_generator, params) + start_training_res = client.modeltraining_start(args.job_id, args.s3_output_uri, **params) + start_training_res.raise_for_status() + training_job = start_training_res.json() if args.wait: - return wait_for_training(training_job['id'], config, request_param_generator, output, args.wait_interval, - args.wait_timeout) + return wait_for_training(training_job['id'], client, output, args.wait_interval, args.wait_timeout) else: return training_job elif args.which_sub == 'status': if args.wait: - return wait_for_training(args.job_id, config, request_param_generator, output, args.wait_interval, - args.wait_timeout) + return wait_for_training(args.job_id, client, output, args.wait_interval, args.wait_timeout) else: - return get_training_status(config.host, str(config.port), config.ssl, request_param_generator, - args.job_id) + training_status_res = client.modeltraining_job_status(args.job_id) + training_status_res.raise_for_status() + return training_status_res.json() else: return f'Sub parser "{args.which} {args.which_sub}" was not recognized' -def wait_for_endpoint(job_id: str, config: Configuration, request_param_generator, output: widgets.Output, +def wait_for_endpoint(job_id: str, client: Client, output: widgets.Output, wait_interval: int = DEFAULT_WAIT_INTERVAL, wait_timeout: int = DEFAULT_WAIT_TIMEOUT): job_id_output = widgets.Output() update_status_output = widgets.Output() @@ -325,7 +330,9 @@ def wait_for_endpoint(job_id: str, config: Configuration, request_param_generato beginning_time = datetime.datetime.utcnow() while datetime.datetime.utcnow() - beginning_time < (datetime.timedelta(seconds=wait_timeout)): update_status_output.clear_output() - status = get_endpoint_status(config.host, str(config.port), config.ssl, request_param_generator, job_id) + endpoint_status_res = client.endpoints_status(job_id) + endpoint_status_res.raise_for_status() + status = endpoint_status_res.json() if status['status'] in ['InService', 'Failed']: print('Endpoint creation is finished') return status @@ -335,47 +342,44 @@ def wait_for_endpoint(job_id: str, config: Configuration, request_param_generato time.sleep(wait_interval) -def neptune_ml_endpoint(args: argparse.Namespace, request_param_generator, - config: Configuration, output: widgets.Output, params): +def neptune_ml_endpoint(args: argparse.Namespace, client: Client, output: widgets.Output, params): if args.which_sub == 'create': if params is None or params == '' or params == {}: params = { "id": args.job_id, - "mlModelTrainingJobId": args.model_job_id, 'instanceType': args.instance_type } - create_endpoint_job = start_create_endpoint(config.host, str(config.port), config.ssl, - request_param_generator, params) - + create_endpoint_res = client.endpoints_create(args.model_job_id, **params) + create_endpoint_res.raise_for_status() + create_endpoint_job = create_endpoint_res.json() if args.wait: - return wait_for_endpoint(create_endpoint_job['id'], config, request_param_generator, output, - args.wait_interval, args.wait_timeout) + return wait_for_endpoint(create_endpoint_job['id'], client, output, args.wait_interval, args.wait_timeout) else: return create_endpoint_job elif args.which_sub == 'status': if args.wait: - return wait_for_endpoint(args.job_id, config, request_param_generator, output, - args.wait_interval, args.wait_timeout) + return wait_for_endpoint(args.job_id, client, output, args.wait_interval, args.wait_timeout) else: - return get_endpoint_status(config.host, str(config.port), config.ssl, request_param_generator, args.job_id) + endpoint_status = client.endpoints_status(args.job_id) + endpoint_status.raise_for_status() + return endpoint_status.json() else: return f'Sub parser "{args.which} {args.which_sub}" was not recognized' -def neptune_ml_magic_handler(args, request_param_generator, config: Configuration, output: widgets.Output, - cell: str = '', local_ns: dict = None) -> any: +def neptune_ml_magic_handler(args, client: Client, output: widgets.Output, cell: str = '', local_ns: dict = None): if local_ns is None: local_ns = {} cell = str_to_namespace_var(cell, local_ns) if args.which == 'export': - return neptune_ml_export(args, config, output, cell) + return neptune_ml_export(args, client, output, cell) elif args.which == 'dataprocessing': - return neptune_ml_dataprocessing(args, request_param_generator, output, config, cell) + return neptune_ml_dataprocessing(args, client, output, cell) elif args.which == 'training': - return neptune_ml_training(args, request_param_generator, config, output, cell) + return neptune_ml_training(args, client, output, cell) elif args.which == 'endpoint': - return neptune_ml_endpoint(args, request_param_generator, config, output, cell) + return neptune_ml_endpoint(args, client, output, cell) else: return f'sub parser {args.which} was not recognized' diff --git a/src/graph_notebook/ml/sagemaker.py b/src/graph_notebook/ml/sagemaker.py deleted file mode 100644 index 71c4a59b..00000000 --- a/src/graph_notebook/ml/sagemaker.py +++ /dev/null @@ -1,86 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import json -import requests -from requests_aws4auth import AWS4Auth - -from graph_notebook.authentication.iam_credentials_provider.credentials_provider import Credentials -from graph_notebook.request_param_generator.call_and_get_response import call_and_get_response - -EXPORT_SERVICE_NAME = 'execute-api' -EXPORT_ACTION = 'neptune-export' -EXTRA_HEADERS = {'content-type': 'application/json'} -UPDATE_DELAY_SECONDS = 60 - - -def start_export(export_host: str, export_params: dict, use_ssl: bool, - creds: Credentials = None) -> dict: - auth = None - if creds is not None: - auth = AWS4Auth(creds.key, creds.secret, creds.region, EXPORT_SERVICE_NAME, - session_token=creds.token) - - protocol = 'https' if use_ssl else 'http' - url = f'{protocol}://{export_host}/{EXPORT_ACTION}' - res = requests.post(url, json=export_params, headers=EXTRA_HEADERS, auth=auth) - res.raise_for_status() - job = res.json() - return job - - -def get_export_status(export_host: str, use_ssl: bool, job_id: str, creds: Credentials = None): - auth = None - if creds is not None: - auth = AWS4Auth(creds.key, creds.secret, creds.region, EXPORT_SERVICE_NAME, - session_token=creds.token) - - protocol = 'https' if use_ssl else 'http' - url = f'{protocol}://{export_host}/{EXPORT_ACTION}/{job_id}' - res = requests.get(url, headers=EXTRA_HEADERS, auth=auth) - res.raise_for_status() - job = res.json() - return job - - -def get_processing_status(host: str, port: str, use_ssl: bool, request_param_generator, job_name: str): - res = call_and_get_response('get', f'ml/dataprocessing/{job_name}', host, port, request_param_generator, - use_ssl, extra_headers=EXTRA_HEADERS) - status = res.json() - return status - - -def start_processing_job(host: str, port: str, use_ssl: bool, request_param_generator, params: dict): - params_raw = json.dumps(params) if type(params) is dict else params - res = call_and_get_response('post', 'ml/dataprocessing', host, port, request_param_generator, use_ssl, params_raw, - EXTRA_HEADERS) - job = res.json() - return job - - -def start_training(host: str, port: str, use_ssl: bool, request_param_generator, params): - params_raw = json.dumps(params) if type(params) is dict else params - res = call_and_get_response('post', 'ml/modeltraining', host, port, request_param_generator, use_ssl, params_raw, - EXTRA_HEADERS) - return res.json() - - -def get_training_status(host: str, port: str, use_ssl: bool, request_param_generator, training_job_name: str): - res = call_and_get_response('get', f'ml/modeltraining/{training_job_name}', host, port, - request_param_generator, use_ssl, extra_headers=EXTRA_HEADERS) - return res.json() - - -def start_create_endpoint(host: str, port: str, use_ssl: bool, request_param_generator, params): - params_raw = json.dumps(params) if type(params) is dict else params - res = call_and_get_response('post', 'ml/endpoints', host, port, request_param_generator, use_ssl, params_raw, - EXTRA_HEADERS) - return res.json() - - -def get_endpoint_status(host: str, port: str, use_ssl: bool, request_param_generator, training_job_name: str): - res = call_and_get_response('get', f'ml/endpoints/{training_job_name}', host, port, request_param_generator, - use_ssl, extra_headers=EXTRA_HEADERS) - return res.json() diff --git a/src/graph_notebook/nbextensions/neptune_menu/static/main.js b/src/graph_notebook/nbextensions/neptune_menu/static/main.js index 52ba48fe..7b08ee84 100644 --- a/src/graph_notebook/nbextensions/neptune_menu/static/main.js +++ b/src/graph_notebook/nbextensions/neptune_menu/static/main.js @@ -105,7 +105,7 @@ define([ let host='', port='', iam='off', ssl='on'; function configCallback(data){ - console.log('neptune menu callback...'); + console.log('ml menu callback...'); const raw = data['content']['text']; const config = JSON.parse(raw); host = config['host']; @@ -127,10 +127,10 @@ define([ const createMenu = (kernel) => { console.log("kernel type is ", kernel.name); if(kernel.name === "gremlin_kernel" || kernel.name === "sparql_kernel"){ - console.log("skipping neptune menu creation"); + console.log("skipping ml menu creation"); return; } - console.log('creating neptune menu from config...'); + console.log('creating ml menu from config...'); kernel.execute('%load_ext graph_notebook.magics'); kernel.execute( "%graph_notebook_config silent", diff --git a/src/graph_notebook/authentication/__init__.py b/src/graph_notebook/neptune/__init__.py similarity index 100% rename from src/graph_notebook/authentication/__init__.py rename to src/graph_notebook/neptune/__init__.py diff --git a/src/graph_notebook/neptune/client.py b/src/graph_notebook/neptune/client.py new file mode 100644 index 00000000..d102476c --- /dev/null +++ b/src/graph_notebook/neptune/client.py @@ -0,0 +1,498 @@ +""" +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +SPDX-License-Identifier: Apache-2.0 +""" + +import json + +import botocore +import requests +from SPARQLWrapper import SPARQLWrapper +from boto3 import Session +from botocore.auth import SigV4Auth +from botocore.awsrequest import AWSRequest +from gremlin_python.driver import client +from tornado import httpclient + +import graph_notebook.neptune.gremlin.graphsonV3d0_MapType_objectify_patch # noqa F401 + +DEFAULT_SPARQL_CONTENT_TYPE = 'application/x-www-form-urlencoded' +DEFAULT_PORT = 8182 +DEFAULT_REGION = 'us-east-1' + +NEPTUNE_SERVICE_NAME = 'neptune-db' + +# TODO: Constants for states of each long-running job +# TODO: add doc links to each command + +FORMAT_CSV = 'csv' +FORMAT_NTRIPLE = 'ntriples' +FORMAT_NQUADS = 'nquads' +FORMAT_RDFXML = 'rdfxml' +FORMAT_TURTLE = 'turtle' + +PARALLELISM_LOW = 'LOW' +PARALLELISM_MEDIUM = 'MEDIUM' +PARALLELISM_HIGH = 'HIGH' +PARALLELISM_OVERSUBSCRIBE = 'OVERSUBSCRIBE' + +MODE_RESUME = 'RESUME' +MODE_NEW = 'NEW' +MODE_AUTO = 'AUTO' + +LOAD_JOB_MODES = [MODE_RESUME, MODE_NEW, MODE_AUTO] +VALID_FORMATS = [FORMAT_CSV, FORMAT_NTRIPLE, FORMAT_NQUADS, FORMAT_RDFXML, FORMAT_TURTLE] +PARALLELISM_OPTIONS = [PARALLELISM_LOW, PARALLELISM_MEDIUM, PARALLELISM_HIGH, PARALLELISM_OVERSUBSCRIBE] +LOADER_ACTION = 'loader' + +FINAL_LOAD_STATUSES = ['LOAD_COMPLETED', + 'LOAD_COMMITTED_W_WRITE_CONFLICTS', + 'LOAD_CANCELLED_BY_USER', + 'LOAD_CANCELLED_DUE_TO_ERRORS', + 'LOAD_FAILED', + 'LOAD_UNEXPECTED_ERROR', + 'LOAD_DATA_DEADLOCK', + 'LOAD_DATA_FAILED_DUE_TO_FEED_MODIFIED_OR_DELETED', + 'LOAD_S3_READ_ERROR', + 'LOAD_S3_ACCESS_DENIED_ERROR', + 'LOAD_IN_QUEUE', + 'LOAD_FAILED_BECAUSE_DEPENDENCY_NOT_SATISFIED', + 'LOAD_FAILED_INVALID_REQUEST', ] + +EXPORT_SERVICE_NAME = 'execute-api' +EXPORT_ACTION = 'ml-export' +EXTRA_HEADERS = {'content-type': 'application/json'} +SPARQL_ACTION = 'sparql' + + +class Client(object): + def __init__(self, host: str, port: int = DEFAULT_PORT, ssl: bool = True, region: str = DEFAULT_REGION, + sparql_path: str = '/sparql', auth=None, session: Session = None): + self.host = host + self.port = port + self.ssl = ssl + self.sparql_path = sparql_path + self.region = region + self._auth = auth + self._session = session + + self._http_protocol = 'https' if self.ssl else 'http' + self._ws_protocol = 'wss' if self.ssl else 'ws' + + self._http_session = None + + def sparql_query(self, query: str, headers=None, explain: str = '', path: str = '') -> requests.Response: + if headers is None: + headers = {} + + data = {'query': query} + return self.do_sparql_request(data, headers, explain, path=path) + + def sparql_update(self, update: str, headers=None, explain: str = '', path: str = '') -> requests.Response: + if headers is None: + headers = {} + + data = {'update': update} + return self.do_sparql_request(data, headers, explain, path=path) + + def do_sparql_request(self, data: dict, headers=None, explain: str = '', path: str = ''): + if 'content-type' not in headers: + headers['content-type'] = DEFAULT_SPARQL_CONTENT_TYPE + + explain = explain.lower() + if explain != '': + if explain not in ['static', 'dynamic', 'details']: + raise ValueError('explain mode not valid, must be one of "static", "dynamic", or "details"') + else: + data['explain'] = explain + + sparql_path = path if path != '' else self.sparql_path + uri = f'{self._http_protocol}://{self.host}:{self.port}/{sparql_path}' + req = self._prepare_request('POST', uri, data=data, headers=headers) + res = self._http_session.send(req) + return res + + def sparql(self, query: str, headers=None, explain: str = '', path: str = '') -> requests.Response: + if headers is None: + headers = {} + + s = SPARQLWrapper('') + s.setQuery(query) + query_type = s.queryType.upper() + if query_type in ['SELECT', 'CONSTRUCT', 'ASK', 'DESCRIBE']: + return self.sparql_query(query, headers, explain, path=path) + else: + return self.sparql_update(query, headers, explain, path=path) + + # TODO: enum/constants for supported types + def sparql_explain(self, query: str, explain: str = 'dynamic', output_format: str = 'text/html', + headers=None, path: str = '') -> requests.Response: + if headers is None: + headers = {} + + if 'Accept' not in headers: + headers['Accept'] = output_format + + return self.sparql(query, headers, explain, path=path) + + def sparql_status(self, query_id: str = ''): + return self._query_status('sparql', query_id=query_id) + + def sparql_cancel(self, query_id: str, silent: bool = False): + if type(query_id) is not str or query_id == '': + raise ValueError('query_id must be a non-empty string') + return self._query_status('sparql', query_id=query_id, silent=silent, cancelQuery=True) + + def get_gremlin_connection(self) -> client.Client: + uri = f'{self._http_protocol}://{self.host}:{self.port}/gremlin' + request = self._prepare_request('GET', uri) + + ws_url = f'{self._ws_protocol}://{self.host}:{self.port}/gremlin' + ws_request = httpclient.HTTPRequest(ws_url, headers=dict(request.headers)) + return client.Client(ws_request, 'g') + + def gremlin_query(self, query, bindings=None): + c = self.get_gremlin_connection() + try: + result = c.submit(query, bindings) + future_results = result.all() + results = future_results.result() + c.close() + return results + except Exception as e: + c.close() + raise e + + def gremlin_http_query(self, query, headers=None) -> requests.Response: + if headers is None: + headers = {} + + uri = f'{self._http_protocol}://{self.host}:{self.port}/gremlin' + data = {'gremlin': query} + req = self._prepare_request('POST', uri, data=json.dumps(data), headers=headers) + res = self._http_session.send(req) + return res + + def gremlin_status(self, query_id: str = '', include_waiting: bool = False): + kwargs = {} + if include_waiting: + kwargs['includeWaiting'] = True + return self._query_status('gremlin', query_id=query_id, **kwargs) + + def gremlin_cancel(self, query_id: str): + if type(query_id) is not str or query_id == '': + raise ValueError('query_id must be a non-empty string') + return self._query_status('gremlin', query_id=query_id, cancelQuery=True) + + def gremlin_explain(self, query: str) -> requests.Response: + url = f'{self._http_protocol}://{self.host}:{self.port}/gremlin/explain' + data = {'gremlin': query} + req = self._prepare_request('POST', url, data=json.dumps(data)) + res = self._http_session.send(req) + return res + + def gremlin_profile(self, query: str) -> requests.Response: + url = f'{self._http_protocol}://{self.host}:{self.port}/gremlin/profile' + data = {'gremlin': query} + req = self._prepare_request('POST', url, data=json.dumps(data)) + res = self._http_session.send(req) + return res + + def status(self) -> requests.Response: + url = f'{self._http_protocol}://{self.host}:{self.port}/status' + req = self._prepare_request('GET', url, data='') + res = self._http_session.send(req) + return res + + def load(self, source: str, source_format: str, iam_role_arn: str, region: str, **kwargs) -> requests.Response: + """ + For a full list of allow parameters, see aws documentation on the Neptune loader + endpoint: https://docs.aws.amazon.com/neptune/latest/userguide/load-api-reference-load.html + """ + payload = { + 'source': source, + 'format': source_format, + 'region': self.region, + 'iamRoleArn': iam_role_arn + } + + for key, value in kwargs.items(): + payload[key] = value + + url = f'{self._http_protocol}://{self.host}:{self.port}/loader' + raw = json.dumps(payload) + req = self._prepare_request('POST', url, data=raw, headers={'content-type': 'application/json'}) + res = self._http_session.send(req) + return res + + def load_status(self, load_id: str = '', **kwargs) -> requests.Response: + params = {} + for k, v in kwargs.items(): + params[k] = v + + if load_id != '': + params['loadId'] = load_id + + url = f'{self._http_protocol}://{self.host}:{self.port}/loader' + req = self._prepare_request('GET', url, params=params) + res = self._http_session.send(req) + return res + + def cancel_load(self, load_id: str) -> requests.Response: + url = f'{self._http_protocol}://{self.host}:{self.port}/loader' + params = {'loadId': load_id} + req = self._prepare_request('DELETE', url, params=params) + res = self._http_session.send(req) + return res + + def initiate_reset(self) -> requests.Response: + data = { + 'action': 'initiateDatabaseReset' + } + url = f'{self._http_protocol}://{self.host}:{self.port}/system' + req = self._prepare_request('POST', url, data=data) + res = self._http_session.send(req) + return res + + def perform_reset(self, token: str) -> requests.Response: + data = { + 'action': 'performDatabaseReset', + 'token': token + } + url = f'{self._http_protocol}://{self.host}:{self.port}/system' + req = self._prepare_request('POST', url, data=data) + res = self._http_session.send(req) + return res + + def dataprocessing_start(self, s3_input_uri: str, s3_output_uri: str, **kwargs) -> requests.Response: + data = { + 'inputDataS3Location': s3_input_uri, + 'processedDataS3Location': s3_output_uri, + } + + for k, v in kwargs.items(): + data[k] = v + + url = f'{self._http_protocol}://{self.host}:{self.port}/ml/dataprocessing' + req = self._prepare_request('POST', url, data=json.dumps(data), headers={'content-type': 'application/json'}) + res = self._http_session.send(req) + return res + + def dataprocessing_job_status(self, job_id: str, neptune_iam_role_arn: str = '') -> requests.Response: + url = f'{self._http_protocol}://{self.host}:{self.port}/ml/dataprocessing/{job_id}' + data = {} + if neptune_iam_role_arn != '': + data['neptuneIamRoleArn'] = neptune_iam_role_arn + req = self._prepare_request('GET', url, params=data) + res = self._http_session.send(req) + return res + + def dataprocessing_status(self, max_items: int = 10, neptune_iam_role_arn: str = '') -> requests.Response: + url = f'{self._http_protocol}://{self.host}:{self.port}/ml/dataprocessing' + data = { + 'maxItems': max_items + } + + if neptune_iam_role_arn != '': + data['neptuneIamRoleArn'] = neptune_iam_role_arn + req = self._prepare_request('GET', url, params=data) + res = self._http_session.send(req) + return res + + def dataprocessing_stop(self, job_id: str, clean=False, neptune_iam_role_arn: str = '') -> requests.Response: + url = f'{self._http_protocol}://{self.host}:{self.port}/ml/dataprocessing/{job_id}' + data = { + 'clean': clean + } + if neptune_iam_role_arn != '': + data['neptuneIamRoleArn'] = neptune_iam_role_arn + + req = self._prepare_request('DELETE', url, params=data) + res = self._http_session.send(req) + return res + + def modeltraining_start(self, data_processing_job_id: str, train_model_s3_location: str, + **kwargs) -> requests.Response: + """ + for a full list of supported parameters, see: + https://docs.aws.amazon.com/neptune/latest/userguide/machine-learning-api-modeltraining.html + """ + data = { + 'dataProcessingJobId': data_processing_job_id, + 'trainModelS3Location': train_model_s3_location + } + + for k, v in kwargs.items(): + data[k] = v + + url = f'{self._http_protocol}://{self.host}:{self.port}/ml/modeltraining' + req = self._prepare_request('POST', url, data=json.dumps(data), headers={'content-type': 'application/json'}) + res = self._http_session.send(req) + return res + + def modeltraining_status(self, max_items: int = 10, neptune_iam_role_arn: str = '') -> requests.Response: + data = { + 'maxItems': max_items + } + + if neptune_iam_role_arn != '': + data['neptuneIamRoleArn'] = neptune_iam_role_arn + + url = f'{self._http_protocol}://{self.host}:{self.port}/ml/modeltraining' + req = self._prepare_request('GET', url, params=data) + res = self._http_session.send(req) + return res + + def modeltraining_job_status(self, training_job_id: str, neptune_iam_role_arn: str = '') -> requests.Response: + data = {} if neptune_iam_role_arn == '' else {'neptuneIamRoleArn': neptune_iam_role_arn} + url = f'{self._http_protocol}://{self.host}:{self.port}/ml/modeltraining/{training_job_id}' + req = self._prepare_request('GET', url, params=data) + res = self._http_session.send(req) + return res + + def modeltraining_stop(self, training_job_id: str, neptune_iam_role_arn: str = '', + clean: bool = False) -> requests.Response: + data = { + 'clean': "TRUE" if clean else "FALSE", + } + + if neptune_iam_role_arn != '': + data['neptuneIamRoleArn'] = neptune_iam_role_arn + + url = f'{self._http_protocol}://{self.host}:{self.port}/ml/modeltraining/{training_job_id}' + req = self._prepare_request('DELETE', url, params=data) + res = self._http_session.send(req) + return res + + def endpoints_create(self, training_job_id: str, **kwargs) -> requests.Response: + data = { + 'mlModelTrainingJobId': training_job_id + } + + for k, v in kwargs.items(): + data[k] = v + + url = f'{self._http_protocol}://{self.host}:{self.port}/ml/endpoints' + req = self._prepare_request('POST', url, data=json.dumps(data), headers={'content-type': 'application/json'}) + res = self._http_session.send(req) + return res + + def endpoints_status(self, endpoint_id: str, neptune_iam_role_arn: str = '') -> requests.Response: + data = {} if neptune_iam_role_arn == '' else {'neptuneIamRoleArn': neptune_iam_role_arn} + url = f'{self._http_protocol}://{self.host}:{self.port}/ml/endpoints/{endpoint_id}' + req = self._prepare_request('GET', url, params=data) + res = self._http_session.send(req) + return res + + def endpoints_delete(self, endpoint_id: str, neptune_iam_role_arn: str = '') -> requests.Response: + data = {} if neptune_iam_role_arn == '' else {'neptuneIamRoleArn': neptune_iam_role_arn} + url = f'{self._http_protocol}://{self.host}:{self.port}/ml/endpoints/{endpoint_id}' + req = self._prepare_request('DELETE', url, params=data) + res = self._http_session.send(req) + return res + + def endpoints(self, max_items: int = 10, neptune_iam_role_arn: str = ''): + data = { + 'maxItems': max_items + } + if neptune_iam_role_arn != '': + data['neptuneIamRoleArn'] = neptune_iam_role_arn + + url = f'{self._http_protocol}://{self.host}:{self.port}/ml/endpoints' + req = self._prepare_request('GET', url, params=data) + res = self._http_session.send(req) + return res + + def export(self, host: str, params: dict, ssl: bool = True) -> requests.Response: + protocol = 'https' if ssl else 'http' + url = f'{protocol}://{host}/{EXPORT_ACTION}' + req = self._prepare_request('POST', url, data=json.dumps(params), service="execute-api") + res = self._http_session.send(req) + return res + + def export_status(self, host, job_id, ssl: bool = True) -> requests.Response: + protocol = 'https' if ssl else 'http' + url = f'{protocol}://{host}/{EXPORT_ACTION}/{job_id}' + req = self._prepare_request('GET', url, service="execute-api") + res = self._http_session.send(req) + return res + + def _query_status(self, language: str, *, query_id: str = '', **kwargs) -> requests.Response: + data = {} + if query_id != '': + data['queryId'] = query_id + + for k, v in kwargs.items(): + data[k] = v + + headers = { + 'Content-Type': 'application/x-www-form-urlencoded' + } + url = f'{self._http_protocol}://{self.host}:{self.port}/{language}/status' + req = self._prepare_request('POST', url, data=data, headers=headers) + res = self._http_session.send(req) + return res + + def _prepare_request(self, method, url, *, data=None, params=None, headers=None, service=NEPTUNE_SERVICE_NAME): + self._ensure_http_session() + request = requests.Request(method=method, url=url, data=data, params=params, headers=headers, auth=self._auth) + if self._session is not None: + credentials = self._session.get_credentials() + frozen_creds = credentials.get_frozen_credentials() + + req = AWSRequest(method=method, url=url, data=data, params=params, headers=headers) + SigV4Auth(frozen_creds, service, self.region).add_auth(req) + prepared_iam_req = req.prepare() + request.headers = dict(prepared_iam_req.headers) + + return request.prepare() + + def _ensure_http_session(self): + if not self._http_session: + self._http_session = requests.Session() + + def set_session(self, session: Session): + self._session = session + + def close(self): + if self._http_session: + self._http_session.close() + self._http_session = None + + @property + def iam_enabled(self): + return type(self._session) is botocore.session.Session + + +class ClientBuilder(object): + def __init__(self, args: dict = None): + if args is None: + args = {} + self.args = args + + def with_host(self, host: str): + self.args['host'] = host + return ClientBuilder(self.args) + + def with_port(self, port: int): + self.args['port'] = port + return ClientBuilder(self.args) + + def with_sparql_path(self, path: str): + self.args['sparql_path'] = path + return ClientBuilder(self.args) + + def with_tls(self, tls: bool): + self.args['ssl'] = tls + return ClientBuilder(self.args) + + def with_region(self, region: str): + self.args['region'] = region + return ClientBuilder(self.args) + + def with_iam(self, session: Session): + self.args['session'] = session + return ClientBuilder(self.args) + + def build(self) -> Client: + return Client(**self.args) diff --git a/src/graph_notebook/ml/__init__.py b/src/graph_notebook/neptune/gremlin/__init__.py similarity index 100% rename from src/graph_notebook/ml/__init__.py rename to src/graph_notebook/neptune/gremlin/__init__.py diff --git a/src/graph_notebook/gremlin/client_provider/graphsonV3d0_MapType_objectify_patch.py b/src/graph_notebook/neptune/gremlin/graphsonV3d0_MapType_objectify_patch.py similarity index 100% rename from src/graph_notebook/gremlin/client_provider/graphsonV3d0_MapType_objectify_patch.py rename to src/graph_notebook/neptune/gremlin/graphsonV3d0_MapType_objectify_patch.py diff --git a/src/graph_notebook/notebooks/03-Sample-Applications/00-Sample-Applications-Overview.ipynb b/src/graph_notebook/notebooks/03-Sample-Applications/00-Sample-Applications-Overview.ipynb index efbd443e..208830e1 100644 --- a/src/graph_notebook/notebooks/03-Sample-Applications/00-Sample-Applications-Overview.ipynb +++ b/src/graph_notebook/notebooks/03-Sample-Applications/00-Sample-Applications-Overview.ipynb @@ -56,4 +56,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/src/graph_notebook/notebooks/04-Machine-Learning/neptune_ml_utils.py b/src/graph_notebook/notebooks/04-Machine-Learning/neptune_ml_utils.py index 7517ab6d..c9b30a3f 100644 --- a/src/graph_notebook/notebooks/04-Machine-Learning/neptune_ml_utils.py +++ b/src/graph_notebook/notebooks/04-Machine-Learning/neptune_ml_utils.py @@ -15,6 +15,8 @@ # How often to check the status UPDATE_DELAY_SECONDS = 15 +HOME_DIRECTORY = os.path.expanduser("~") + def signed_request(method, url, data=None, params=None, headers=None, service=None): creds = boto3.Session().get_credentials().get_frozen_credentials() @@ -24,7 +26,7 @@ def signed_request(method, url, data=None, params=None, headers=None, service=No def load_configuration(): - with open('/home/ec2-user/graph_notebook_config.json') as f: + with open(f'{HOME_DIRECTORY}/graph_notebook_config.json') as f: data = json.load(f) host = data['host'] port = data['port'] @@ -34,62 +36,69 @@ def load_configuration(): iam = False return host, port, iam + def get_host(): host, port, iam = load_configuration() return host + def get_iam(): host, port, iam = load_configuration() return iam + def get_training_job_name(prefix: str): return f'{prefix}-{int(time.time())}' + def check_ml_enabled(): - host, port, use_iam = load_configuration() - response = signed_request("GET", url=f'https://{host}:{port}/ml/modeltraining', service='neptune-db') + host, port, use_iam = load_configuration() + response = signed_request("GET", url=f'https://{host}:{port}/ml/modeltraining', service='ml-db') if response.status_code != 200: print('''This Neptune cluster \033[1mis not\033[0m configured to use Neptune ML. Please configure the cluster according to the Amazpnm Neptune ML documentation before proceeding.''') else: print("This Neptune cluster is configured to use Neptune ML") - + + def get_export_service_host(): - with open('/home/ec2-user/.bashrc') as f: + with open(f'{HOME_DIRECTORY}/.bashrc') as f: data = f.readlines() for d in data: if str.startswith(d, 'export NEPTUNE_EXPORT_API_URI'): parts = d.split('=') - if len(parts)==2: - path=urlparse(parts[1].rstrip()) + if len(parts) == 2: + path = urlparse(parts[1].rstrip()) return path.hostname + "/v1" - logging.error("Unable to determine the Neptune Export Service Endpoint. You will need to enter this assign this manually.") + logging.error( + "Unable to determine the Neptune Export Service Endpoint. You will need to enter this assign this manually.") return None + def delete_pretrained_data(setup_node_classification: bool, setup_node_regression: bool, setup_link_prediction: bool): - host, port, use_iam = load_configuration() if setup_node_classification: - response = signed_request("POST", service='neptune-db', - url=f'https://{host}:{port}/gremlin', - headers={'content-type': 'application/json'}, - data=json.dumps({'gremlin': "g.V('movie_1', 'movie_7', 'movie_15').properties('genre').drop()"})) - + response = signed_request("POST", service='ml-db', + url=f'https://{host}:{port}/gremlin', + headers={'content-type': 'application/json'}, + data=json.dumps( + {'gremlin': "g.V('movie_1', 'movie_7', 'movie_15').properties('genre').drop()"})) + if response.status_code != 200: print(response.content.decode('utf-8')) if setup_node_regression: - response = signed_request("POST", service='neptune-db', - url=f'https://{host}:{port}/gremlin', - headers={'content-type': 'application/json'}, - data=json.dumps({'gremlin': "g.V('user_1').out('wrote').properties('score').drop()"})) + response = signed_request("POST", service='ml-db', + url=f'https://{host}:{port}/gremlin', + headers={'content-type': 'application/json'}, + data=json.dumps({'gremlin': "g.V('user_1').out('wrote').properties('score').drop()"})) if response.status_code != 200: print(response.content.decode('utf-8')) if setup_link_prediction: - response = signed_request("POST", service='neptune-db', - url=f'https://{host}:{port}/gremlin', - headers={'content-type': 'application/json'}, - data=json.dumps({'gremlin': "g.V('user_1').outE('rated').drop()"})) + response = signed_request("POST", service='ml-db', + url=f'https://{host}:{port}/gremlin', + headers={'content-type': 'application/json'}, + data=json.dumps({'gremlin': "g.V('user_1').outE('rated').drop()"})) if response.status_code != 200: print(response.content.decode('utf-8')) @@ -113,8 +122,9 @@ def delete_endpoint(training_job_name: str, neptune_iam_role_arn=None): if neptune_iam_role_arn: query_string = f'?neptuneIamRoleArn={neptune_iam_role_arn}' host, port, use_iam = load_configuration() - response = signed_request("DELETE", service='neptune-db', - url=f'https://{host}:{port}/ml/endpoints/{training_job_name}{query_string}', headers={'content-type': 'application/json'}) + response = signed_request("DELETE", service='ml-db', + url=f'https://{host}:{port}/ml/endpoints/{training_job_name}{query_string}', + headers={'content-type': 'application/json'}) if response.status_code != 200: print(response.content.decode('utf-8')) else: @@ -129,28 +139,28 @@ def prepare_movielens_data(s3_bucket_uri: str): logging.error(e) - def setup_pretrained_endpoints(s3_bucket_uri: str, setup_node_classification: bool, setup_node_regression: bool, setup_link_prediction: bool): delete_pretrained_data(setup_node_classification, setup_node_regression, setup_link_prediction) try: - return PretrainedModels().setup_pretrained_endpoints(s3_bucket_uri, setup_node_classification, setup_node_regression, setup_link_prediction) + return PretrainedModels().setup_pretrained_endpoints(s3_bucket_uri, setup_node_classification, + setup_node_regression, setup_link_prediction) except Exception as e: logging.error(e) class MovieLensProcessor: - raw_directory = r'/home/ec2-user/data/raw' - formatted_directory = r'/home/ec2-user/data/formatted' + raw_directory = fr'{HOME_DIRECTORY}/data/raw' + formatted_directory = fr'{HOME_DIRECTORY}/data/formatted' def __download_and_unzip(self): - if not os.path.exists('/home/ec2-user/data'): - os.makedirs('/home/ec2-user/data') - if not os.path.exists('/home/ec2-user/data/raw'): - os.makedirs('/home/ec2-user/data/raw') - if not os.path.exists('/home/ec2-user/data/formatted'): - os.makedirs('/home/ec2-user/data/formatted') + if not os.path.exists(f'{HOME_DIRECTORY}/data'): + os.makedirs(f'{HOME_DIRECTORY}/data') + if not os.path.exists(f'{HOME_DIRECTORY}/data/raw'): + os.makedirs(f'{HOME_DIRECTORY}/data/raw') + if not os.path.exists(f'{HOME_DIRECTORY}/data/formatted'): + os.makedirs(f'{HOME_DIRECTORY}/data/formatted') # Download the MovieLens dataset url = 'http://files.grouplens.org/datasets/movielens/ml-100k.zip' r = requests.get(url, allow_redirects=True) @@ -163,7 +173,7 @@ def __process_movies_genres(self): # process the movies_vertex.csv print('Processing Movies', end='\r') movies_df = pd.read_csv(os.path.join( - self.raw_directory, 'ml-100k/u.item'), sep='|', encoding='ISO-8859-1', + self.raw_directory, 'ml-100k/u.item'), sep='|', encoding='ISO-8859-1', names=['~id', 'title', 'release_date', 'video_release_date', 'imdb_url', 'unknown', 'Action', 'Adventure', 'Animation', 'Childrens', 'Comedy', 'Crime', 'Documentary', 'Drama', 'Fantasy', 'Film-Noir', 'Horror', 'Musical', @@ -218,7 +228,7 @@ def __process_ratings_users(self): # Create ratings vertices and add edges on both sides print('Processing Ratings', end='\r') ratings_vertices = pd.read_csv(os.path.join( - self.raw_directory, 'ml-100k/u.data'), sep='\t', encoding='ISO-8859-1', + self.raw_directory, 'ml-100k/u.data'), sep='\t', encoding='ISO-8859-1', names=['~from', '~to', 'score:Int', 'timestamp']) ratings_vertices['~from'] = ratings_vertices['~from'].apply( lambda x: f'user_{x}') @@ -232,10 +242,10 @@ def __process_ratings_users(self): dict = {} for index, row in ratings_vertices.iterrows(): - dict[index*2] = {'~id': uuid.uuid4(), '~label': 'wrote', - '~from': row['~from'], '~to': row['~id']} - dict[index*2 + 1] = {'~id': uuid.uuid4(), '~label': 'about', - '~from': row['~id'], '~to': row['~to']} + dict[index * 2] = {'~id': uuid.uuid4(), '~label': 'wrote', + '~from': row['~from'], '~to': row['~id']} + dict[index * 2 + 1] = {'~id': uuid.uuid4(), '~label': 'about', + '~from': row['~id'], '~to': row['~to']} rating_edges_df = pd.DataFrame.from_dict(dict, "index") # Remove the from and to columns and write this out as a vertex now @@ -259,7 +269,7 @@ def __process_users(self): # User Vertices - Load, rename column with type, and save user_df = pd.read_csv(os.path.join( - self.raw_directory, 'ml-100k/u.user'), sep='|', encoding='ISO-8859-1', + self.raw_directory, 'ml-100k/u.user'), sep='|', encoding='ISO-8859-1', names=['~id', 'age:Int', 'gender', 'occupation', 'zip_code']) user_df['~id'] = user_df['~id'].apply( lambda x: f'user_{x}') @@ -279,7 +289,7 @@ def __upload_to_s3(self, bucketname: str): self.formatted_directory, file), bucket, f'{file_path}/{file}') def prepare_movielens_data(self, s3_bucket: str): - bucket_name = f'{s3_bucket}/neptune-formatted/movielens-100k' + bucket_name = f'{s3_bucket}/ml-formatted/movielens-100k' self.__download_and_unzip() self.__process_movies_genres() self.__process_users() @@ -372,12 +382,12 @@ def __create_model(self, name: str, model_s3_location: str): return name def __get_neptune_ml_role(self): - with open('/home/ec2-user/.bashrc') as f: + with open(f'{HOME_DIRECTORY}/.bashrc') as f: data = f.readlines() for d in data: if str.startswith(d, 'export NEPTUNE_ML_ROLE_ARN'): parts = d.split('=') - if len(parts)==2: + if len(parts) == 2: return parts[1].rstrip() logging.error("Unable to determine the Neptune ML IAM Role.") return None diff --git a/src/graph_notebook/request_param_generator/call_and_get_response.py b/src/graph_notebook/request_param_generator/call_and_get_response.py deleted file mode 100644 index f0bc3e84..00000000 --- a/src/graph_notebook/request_param_generator/call_and_get_response.py +++ /dev/null @@ -1,32 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import requests - - -def call_and_get_response(method: str, action: str, host: str, port: str, request_param_generator, use_ssl: bool, - query='', extra_headers=None): - if extra_headers is None: - extra_headers = {} - - method = method.upper() - protocol = 'https' if use_ssl else 'http' - - request_params = request_param_generator.generate_request_params(method=method, action=action, query=query, - host=host, port=port, protocol=protocol, - headers=extra_headers) - headers = request_params['headers'] if request_params['headers'] is not None else {} - - if method == 'GET': - res = requests.get(url=request_params['url'], params=request_params['params'], headers=headers) - elif method == 'DELETE': - res = requests.delete(url=request_params['url'], params=request_params['params'], headers=headers) - elif method == 'POST': - res = requests.post(url=request_params['url'], data=request_params['params'], headers=headers) - else: - raise NotImplementedError(f'Use of method {method} has not been implemented in call_and_get_response') - - res.raise_for_status() - return res diff --git a/src/graph_notebook/request_param_generator/default_request_generator.py b/src/graph_notebook/request_param_generator/default_request_generator.py deleted file mode 100644 index 0fb2b8cf..00000000 --- a/src/graph_notebook/request_param_generator/default_request_generator.py +++ /dev/null @@ -1,18 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - - -class DefaultRequestGenerator(object): - @staticmethod - def generate_request_params(method, action, query, host, port, protocol, headers=None): - url = f'{protocol}://{host}:{port}/{action}' if port != '' else f'{protocol}://{host}/{action}' - params = { - 'method': method, - 'url': url, - 'headers': headers, - 'params': query, - } - - return params diff --git a/src/graph_notebook/request_param_generator/factory.py b/src/graph_notebook/request_param_generator/factory.py deleted file mode 100644 index 2ad79dae..00000000 --- a/src/graph_notebook/request_param_generator/factory.py +++ /dev/null @@ -1,23 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -from graph_notebook.configuration.generate_config import AuthModeEnum -from graph_notebook.request_param_generator.default_request_generator import DefaultRequestGenerator -from graph_notebook.request_param_generator.iam_request_generator import IamRequestGenerator -from graph_notebook.request_param_generator.sparql_request_generator import SPARQLRequestGenerator -from graph_notebook.authentication.iam_credentials_provider.credentials_factory import IAMAuthCredentialsProvider, credentials_provider_factory - - -def create_request_generator(mode: AuthModeEnum, - credentials_provider_mode: IAMAuthCredentialsProvider = IAMAuthCredentialsProvider.ROLE, - command: str = ''): - - if mode == AuthModeEnum.DEFAULT and command == 'sparql': - return SPARQLRequestGenerator() - elif mode == AuthModeEnum.IAM: - credentials_provider_mode = credentials_provider_factory(credentials_provider_mode) - return IamRequestGenerator(credentials_provider_mode) - else: - return DefaultRequestGenerator() diff --git a/src/graph_notebook/request_param_generator/iam_request_generator.py b/src/graph_notebook/request_param_generator/iam_request_generator.py deleted file mode 100644 index fc88e809..00000000 --- a/src/graph_notebook/request_param_generator/iam_request_generator.py +++ /dev/null @@ -1,21 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -from graph_notebook.authentication.iam_headers import make_signed_request - - -class IamRequestGenerator(object): - def __init__(self, credentials_provider): - self.credentials_provider = credentials_provider - - def generate_request_params(self, method, action, query, host, port, protocol, headers=None): - credentials = self.credentials_provider.get_iam_credentials() - if protocol in ['https', 'wss']: - use_ssl = True - else: - use_ssl = False - - return make_signed_request(method, action, query, host, port, credentials.key, credentials.secret, - credentials.region, use_ssl, credentials.token, additional_headers=headers) diff --git a/src/graph_notebook/request_param_generator/sparql_request_generator.py b/src/graph_notebook/request_param_generator/sparql_request_generator.py deleted file mode 100644 index f63ae479..00000000 --- a/src/graph_notebook/request_param_generator/sparql_request_generator.py +++ /dev/null @@ -1,22 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - - -class SPARQLRequestGenerator(object): - @staticmethod - def generate_request_params(method, action, query, host, port, protocol, headers=None): - if headers is None: - headers = {} - - if 'Content-Type' not in headers: - headers['Content-Type'] = "application/x-www-form-urlencoded" - - url = f'{protocol}://{host}:{port}/{action}' - return { - 'method': method, - 'url': url, - 'headers': headers, - 'params': query, - } diff --git a/src/graph_notebook/sparql/query.py b/src/graph_notebook/sparql/query.py deleted file mode 100644 index 6d104600..00000000 --- a/src/graph_notebook/sparql/query.py +++ /dev/null @@ -1,82 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import logging - -from SPARQLWrapper import SPARQLWrapper -from graph_notebook.request_param_generator.call_and_get_response import call_and_get_response - -logging.basicConfig() -logger = logging.getLogger("sparql") - -ACTION_TO_QUERY_TYPE = { - 'sparql': 'application/sparql-query', - 'sparqlupdate': 'application/sparql-update' -} - -SPARQL_ACTION = 'sparql' - - -def get_query_type(query): - s = SPARQLWrapper('') - s.setQuery(query) - return s.queryType - - -def query_type_to_action(query_type): - query_type = query_type.upper() - if query_type in ['SELECT', 'CONSTRUCT', 'ASK', 'DESCRIBE']: - return 'sparql' - else: - # TODO: check explicitly for all query types, raise exception for invalid query - return 'sparqlupdate' - - -def do_sparql_query(query, host, port, use_ssl, request_param_generator, extra_headers=None, path: str = SPARQL_ACTION): - path = SPARQL_ACTION if path == '' else path - - if extra_headers is None: - extra_headers = {} - logger.debug(f'query={query}, endpoint={host}, port={port}') - query_type = get_query_type(query) - action = query_type_to_action(query_type) - - data = {} - if action == 'sparql': - data['query'] = query - elif action == 'sparqlupdate': - data['update'] = query - - res = call_and_get_response('post', path, host, port, request_param_generator, use_ssl, data, extra_headers) - try: - content = res.json() # attempt to return json, otherwise we will return the content string. - except Exception: - content = res.content.decode('utf-8') - return content - - -def do_sparql_explain(query: str, host: str, port: str, use_ssl: bool, request_param_generator, - accept_type='text/html', path: str = ''): - path = SPARQL_ACTION if path == '' else path - - query_type = get_query_type(query) - action = query_type_to_action(query_type) - - data = { - 'explain': 'dynamic', - } - - if action == 'sparql': - data['query'] = query - elif action == 'sparqlupdate': - data['update'] = query - - extra_headers = { - 'Accept': accept_type - } - - res = call_and_get_response('post', path, host, port, request_param_generator, use_ssl, data, - extra_headers) - return res.content.decode('utf-8') diff --git a/src/graph_notebook/sparql/status.py b/src/graph_notebook/sparql/status.py deleted file mode 100644 index 737d61b7..00000000 --- a/src/graph_notebook/sparql/status.py +++ /dev/null @@ -1,49 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -from graph_notebook.request_param_generator.call_and_get_response import call_and_get_response - -SPARQL_STATUS_ACTION = 'sparql/status' - - -def do_sparql_status(host, port, use_ssl, request_param_generator, query_id=None): - data = {} - if query_id != '' and query_id is not None: - data['queryId'] = query_id - - headers = { - 'Content-Type': 'application/x-www-form-urlencoded' - } - res = call_and_get_response('post', SPARQL_STATUS_ACTION, host, port, request_param_generator, use_ssl, data, - headers) - try: - content = res.json() # attempt to return json, otherwise we will return the content string. - except Exception: - """When a invalid UUID is supplied, status servlet returns an empty string. - See https://sim.amazon.com/issues/NEPTUNE-16137 - """ - content = 'UUID is invalid.' - return content - - -def do_sparql_cancel(host, port, use_ssl, request_param_generator, query_id, silent=False): - if type(query_id) is not str or query_id == '': - raise ValueError("query id must be a non-empty string") - - data = {'cancelQuery': True, 'queryId': query_id, 'silent': silent} - - headers = { - 'Content-Type': 'application/x-www-form-urlencoded' - } - res = call_and_get_response('post', SPARQL_STATUS_ACTION, host, port, request_param_generator, use_ssl, data, - headers) - try: - content = res.json() - except Exception: - """When a invalid UUID is supplied, status servlet returns an empty string. - See https://sim.amazon.com/issues/NEPTUNE-16137 - """ - content = 'UUID is invalid.' - return content diff --git a/src/graph_notebook/status/get_status.py b/src/graph_notebook/status/get_status.py deleted file mode 100644 index bd86128c..00000000 --- a/src/graph_notebook/status/get_status.py +++ /dev/null @@ -1,17 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" -from json import JSONDecodeError - -from graph_notebook.request_param_generator.call_and_get_response import call_and_get_response -from graph_notebook.request_param_generator.default_request_generator import DefaultRequestGenerator - - -def get_status(host, port, use_ssl, request_param_generator=DefaultRequestGenerator()): - res = call_and_get_response('get', 'status', host, port, request_param_generator, use_ssl) - try: - js = res.json() - except JSONDecodeError: - js = res.content - return js diff --git a/src/graph_notebook/system/database_reset.py b/src/graph_notebook/system/database_reset.py index 7e6fb46e..e69de29b 100644 --- a/src/graph_notebook/system/database_reset.py +++ b/src/graph_notebook/system/database_reset.py @@ -1,21 +0,0 @@ -from graph_notebook.request_param_generator.call_and_get_response import call_and_get_response -from graph_notebook.request_param_generator.default_request_generator import DefaultRequestGenerator - -SYSTEM_ACTION = 'system' - - -def initiate_database_reset(host, port, use_ssl, request_param_generator=DefaultRequestGenerator()): - data = { - 'action': 'initiateDatabaseReset' - } - res = call_and_get_response('post', SYSTEM_ACTION, host, port, request_param_generator, use_ssl, data) - return res.json() - - -def perform_database_reset(token, host, port, use_ssl, request_param_generator=DefaultRequestGenerator()): - data = { - 'action': 'performDatabaseReset', - 'token': token - } - res = call_and_get_response('post', SYSTEM_ACTION, host, port, request_param_generator, use_ssl, data) - return res.json() diff --git a/src/graph_notebook/sparql/table.py b/src/graph_notebook/visualization/sparql_rows_and_columns.py similarity index 89% rename from src/graph_notebook/sparql/table.py rename to src/graph_notebook/visualization/sparql_rows_and_columns.py index 0029a7ad..14ac1548 100644 --- a/src/graph_notebook/sparql/table.py +++ b/src/graph_notebook/visualization/sparql_rows_and_columns.py @@ -3,11 +3,13 @@ SPDX-License-Identifier: Apache-2.0 """ + def get_rows_and_columns(sparql_results): if type(sparql_results) is not dict: return None - if 'head' in sparql_results and 'vars' in sparql_results['head'] and 'results' in sparql_results and 'bindings' in sparql_results['results']: + if 'head' in sparql_results and 'vars' in sparql_results['head'] and 'results' in sparql_results and 'bindings' in \ + sparql_results['results']: columns = [] for v in sparql_results['head']['vars']: columns.append(v) diff --git a/test/integration/DataDrivenGremlinTest.py b/test/integration/DataDrivenGremlinTest.py index 2a794091..3ada56e5 100644 --- a/test/integration/DataDrivenGremlinTest.py +++ b/test/integration/DataDrivenGremlinTest.py @@ -5,9 +5,7 @@ import logging -from graph_notebook.gremlin.client_provider.factory import create_client_provider from graph_notebook.seed.load_query import get_queries -from graph_notebook.gremlin.query import do_gremlin_query from test.integration import IntegrationTest @@ -16,9 +14,9 @@ class DataDrivenGremlinTest(IntegrationTest): def setUp(self): super().setUp() - self.client_provider = create_client_provider(self.auth_mode, self.iam_credentials_provider_type) + self.client = self.client_builder.build() query_check_for_airports = "g.V('3684').outE().inV().has(id, '3444')" - res = do_gremlin_query(query_check_for_airports, self.host, self.port, self.ssl, self.client_provider) + res = self.client.gremlin_query(query_check_for_airports) if len(res) < 1: logging.info('did not find final airports edge, seeding database now...') airport_queries = get_queries('gremlin', 'airports') @@ -30,7 +28,7 @@ def setUp(self): # we are deciding to try except because we do not know if the database # we are connecting to has a partially complete set of airports data or not. try: - do_gremlin_query(line, self.host, self.port, self.ssl, self.client_provider) + self.client.gremlin_query(line) except Exception as e: logging.error(f'query {q} failed due to {e}') continue diff --git a/test/integration/DataDrivenSparqlTest.py b/test/integration/DataDrivenSparqlTest.py index 5ce77861..a013c01e 100644 --- a/test/integration/DataDrivenSparqlTest.py +++ b/test/integration/DataDrivenSparqlTest.py @@ -5,11 +5,7 @@ import logging -from graph_notebook.authentication.iam_credentials_provider.credentials_factory import IAMAuthCredentialsProvider from graph_notebook.seed.load_query import get_queries -from graph_notebook.request_param_generator.factory import create_request_generator -from graph_notebook.sparql.query import do_sparql_query - from test.integration import IntegrationTest logger = logging.getLogger('DataDrivenSparqlTest') @@ -17,15 +13,14 @@ class DataDrivenSparqlTest(IntegrationTest): - @classmethod - def setUpClass(cls): - super().setUpClass() - cls.request_generator = create_request_generator(cls.auth_mode, IAMAuthCredentialsProvider.ENV) + def setUp(self) -> None: + super().setUp() airport_queries = get_queries('sparql', 'epl') for q in airport_queries: try: # we are deciding to try except because we do not know if the database we are connecting to has a partially complete set of airports data or not. - do_sparql_query(q['content'], cls.host, cls.port, cls.ssl, cls.request_generator) + res = self.client.sparql(q['content'].strip()) + print(res) except Exception as e: logger.error(f'query {q["content"]} failed due to {e}') continue diff --git a/test/integration/IntegrationTest.py b/test/integration/IntegrationTest.py index 49a87ad9..7f04b80d 100644 --- a/test/integration/IntegrationTest.py +++ b/test/integration/IntegrationTest.py @@ -5,19 +5,34 @@ import unittest +from botocore.session import get_session + +from graph_notebook.configuration.generate_config import Configuration, AuthModeEnum from graph_notebook.configuration.get_config import get_config +from graph_notebook.neptune.client import Client, ClientBuilder from test.integration.NeptuneIntegrationWorkflowSteps import TEST_CONFIG_PATH +def setup_client_builder(config: Configuration) -> ClientBuilder: + builder = ClientBuilder() \ + .with_host(config.host) \ + .with_port(config.port) \ + .with_region(config.aws_region) \ + .with_tls(config.ssl) \ + .with_sparql_path(config.sparql.path) + + if config.auth_mode == AuthModeEnum.IAM: + builder = builder.with_iam(get_session()) + + return builder + + class IntegrationTest(unittest.TestCase): @classmethod def setUpClass(cls): super().setUpClass() - config = get_config(TEST_CONFIG_PATH) - cls.config = config - cls.host = config.host - cls.port = config.port - cls.auth_mode = config.auth_mode - cls.ssl = config.ssl - cls.iam_credentials_provider_type = config.iam_credentials_provider_type - cls.load_from_s3_arn = config.load_from_s3_arn + cls.config = get_config(TEST_CONFIG_PATH) + cls.client_builder = setup_client_builder(cls.config) + + def setUp(self) -> None: + self.client = self.client_builder.build() diff --git a/test/integration/NeptuneIntegrationWorkflowSteps.py b/test/integration/NeptuneIntegrationWorkflowSteps.py index 5149330a..a91f8f19 100644 --- a/test/integration/NeptuneIntegrationWorkflowSteps.py +++ b/test/integration/NeptuneIntegrationWorkflowSteps.py @@ -15,7 +15,6 @@ import boto3 as boto3 import requests -from graph_notebook.authentication.iam_credentials_provider.credentials_factory import IAMAuthCredentialsProvider from graph_notebook.configuration.generate_config import AuthModeEnum, Configuration SUBPARSER_CREATE_CFN = 'create-cfn-stack' @@ -63,7 +62,7 @@ def set_iam_auth_on_neptune_cluster(cluster_identifier: str, iam_value: bool, ne response = neptune_client.modify_db_cluster(DBClusterIdentifier=cluster_identifier, EnableIAMDatabaseAuthentication=iam_value, ApplyImmediately=True) - logging.info(f'modified neptune cluster {cluster_identifier} to set iam auth to {iam_value}: {response}') + logging.info(f'modified ml cluster {cluster_identifier} to set iam auth to {iam_value}: {response}') # wait for authentication setting to show as changed: while cluster['IAMDatabaseAuthenticationEnabled'] != iam_value: @@ -177,8 +176,8 @@ def get_stack_details_to_run(stack: dict, region: str = 'us-east-1', timeout_min ip = network_interface['PrivateIpAddresses'][0]['Association']['PublicIp'] logging.info(f'checking if ip {ip} can be used ') + url = f'https://{ip}:80/status' try: - url = f'https://{ip}:80/status' logging.info(f'checking ip address {ip}, url={url}') # hard-coded to port 80 since that's what this CFN stack uses for its load balancer requests.get(url, verify=False, timeout=5) # an exception is thrown if the host cannot be reached. @@ -268,8 +267,7 @@ def generate_config_from_stack(stack: dict, region: str, iam: bool) -> Configura file.writelines(new_lines) auth = AuthModeEnum.IAM if iam else AuthModeEnum.DEFAULT - conf = Configuration(details['endpoint'], 80, auth, IAMAuthCredentialsProvider.ENV, details['loader_arn'], - ssl=True, aws_region=region) + conf = Configuration(details['endpoint'], 80, auth.ENV, details['loader_arn'], ssl=True, aws_region=region) logging.info(f'generated configuration for test run: {conf.to_dict()}') return conf @@ -323,16 +321,6 @@ def main(): delete_parser.add_argument('--cfn-stack-name', type=str, default='') delete_parser.add_argument('--aws-region', type=str, default='us-east-1') - # sub parser for running tests - parser_run_tests = subparsers.add_parser(SUBPARSER_RUN_TESTS, - help='run tests with the pattern *_without_iam.py') - parser_run_tests.add_argument('--pattern', type=str), - parser_run_tests.add_argument('--iam', action='store_true') - parser_run_tests.add_argument('--cfn-stack-name', type=str, default='') - parser_run_tests.add_argument('--aws-region', type=str, default='us-east-1') - parser_run_tests.add_argument('--skip-config-generation', action='store_true', - help=f'skips config generation for testing, using the one found under {TEST_CONFIG_PATH}') - args = parser.parse_args() cfn_client = boto3.client('cloudformation', region_name=args.aws_region) @@ -342,15 +330,6 @@ def main(): handle_create_cfn_stack(stack_name, args.cfn_template_url, args.cfn_s3_bucket, cfn_client, args.cfn_runner_role) elif args.which == SUBPARSER_DELETE_CFN: delete_stack(args.cfn_stack_name, cfn_client) - elif args.which == SUBPARSER_RUN_TESTS: - if not args.skip_config_generation: - loop_until_stack_is_complete(args.cfn_stack_name, cfn_client) - stack = get_cfn_stack_details(args.cfn_stack_name, cfn_client) - cluster_identifier = get_neptune_identifier_from_cfn(args.cfn_stack_name, cfn_client) - set_iam_auth_on_neptune_cluster(cluster_identifier, args.iam, neptune_client) - config = generate_config_from_stack(stack, args.aws_region, args.iam) - config.write_to_file(TEST_CONFIG_PATH) - run_integration_tests(args.pattern) elif args.which == SUBPARSER_ENABLE_IAM: cluster_identifier = get_neptune_identifier_from_cfn(args.cfn_stack_name, cfn_client) set_iam_auth_on_neptune_cluster(cluster_identifier, True, neptune_client) diff --git a/test/integration/__init__.py b/test/integration/__init__.py index d0f15c29..34da10ca 100644 --- a/test/integration/__init__.py +++ b/test/integration/__init__.py @@ -6,4 +6,4 @@ from .IntegrationTest import IntegrationTest # noqa F401 from .DataDrivenGremlinTest import DataDrivenGremlinTest # noqa F401 from .DataDrivenSparqlTest import DataDrivenSparqlTest # noqa F401 -from .NeptuneIntegrationWorkflowSteps import TEST_CONFIG_PATH # noqa F401 +from .NeptuneIntegrationWorkflowSteps import TEST_CONFIG_PATH diff --git a/test/integration/gremlin/__init__.py b/test/integration/gremlin/__init__.py deleted file mode 100644 index 9049dd04..00000000 --- a/test/integration/gremlin/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" diff --git a/test/integration/gremlin/client_provider/__init__.py b/test/integration/gremlin/client_provider/__init__.py deleted file mode 100644 index 9049dd04..00000000 --- a/test/integration/gremlin/client_provider/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" diff --git a/test/integration/gremlin/client_provider/client_provider_factory.py b/test/integration/gremlin/client_provider/client_provider_factory.py deleted file mode 100644 index 5b3d6b31..00000000 --- a/test/integration/gremlin/client_provider/client_provider_factory.py +++ /dev/null @@ -1,22 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import unittest - -from graph_notebook.authentication.iam_credentials_provider.credentials_factory import IAMAuthCredentialsProvider -from graph_notebook.configuration.generate_config import AuthModeEnum -from graph_notebook.gremlin.client_provider.default_client import ClientProvider -from graph_notebook.gremlin.client_provider.factory import create_client_provider -from graph_notebook.gremlin.client_provider.iam_client import IamClientProvider - - -class TestClientProviderFactory(unittest.TestCase): - def test_create_default_client(self): - client_provider = create_client_provider(AuthModeEnum.DEFAULT) - self.assertEqual(ClientProvider, type(client_provider)) - - def test_create_iam_client_from_env(self): - client_provider = create_client_provider(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - self.assertEqual(IamClientProvider, type(client_provider)) diff --git a/test/integration/gremlin/gremlin_query_with_iam.py b/test/integration/gremlin/gremlin_query_with_iam.py deleted file mode 100644 index a3cac15e..00000000 --- a/test/integration/gremlin/gremlin_query_with_iam.py +++ /dev/null @@ -1,37 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -from graph_notebook.authentication.iam_credentials_provider.credentials_factory import IAMAuthCredentialsProvider -from graph_notebook.configuration.generate_config import AuthModeEnum -from graph_notebook.gremlin.query import do_gremlin_query, do_gremlin_explain, do_gremlin_profile -from graph_notebook.gremlin.client_provider.factory import create_client_provider -from graph_notebook.request_param_generator.factory import create_request_generator - -from test.integration import IntegrationTest - - -class TestGremlinWithIam(IntegrationTest): - def test_do_gremlin_query_with_iam(self): - client_provider = create_client_provider(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - query = 'g.V().limit(1)' - results = do_gremlin_query(query, self.host, self.port, self.ssl, client_provider) - - self.assertEqual(type(results), list) - - def test_do_gremlin_explain_with_iam(self): - query = 'g.V().limit(1)' - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - results = do_gremlin_explain(query, self.host, self.port, self.ssl, request_generator) - - self.assertEqual(type(results), dict) - self.assertTrue('explain' in results) - - def test_do_gremlin_profile_with_iam(self): - query = 'g.V().limit(1)' - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - results = do_gremlin_profile(query, self.host, self.port, self.ssl, request_generator) - - self.assertEqual(type(results), dict) - self.assertTrue('profile' in results) diff --git a/test/integration/gremlin/gremlin_query_without_iam.py b/test/integration/gremlin/gremlin_query_without_iam.py deleted file mode 100644 index 7f61d020..00000000 --- a/test/integration/gremlin/gremlin_query_without_iam.py +++ /dev/null @@ -1,35 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -from graph_notebook.gremlin.client_provider.default_client import ClientProvider -from graph_notebook.gremlin.query import do_gremlin_query, do_gremlin_explain, do_gremlin_profile -from graph_notebook.request_param_generator.default_request_generator import DefaultRequestGenerator - -from test.integration import IntegrationTest - - -class TestGremlin(IntegrationTest): - def test_do_gremlin_query(self): - client_provider = ClientProvider() - query = 'g.V().limit(1)' - results = do_gremlin_query(query, self.host, self.port, self.ssl, client_provider) - - self.assertEqual(type(results), list) - - def test_do_gremlin_explain(self): - query = 'g.V().limit(1)' - request_generator = DefaultRequestGenerator() - results = do_gremlin_explain(query, self.host, self.port, self.ssl, request_generator) - - self.assertEqual(type(results), dict) - self.assertTrue('explain' in results) - - def test_do_gremlin_profile(self): - query = 'g.V().limit(1)' - request_generator = DefaultRequestGenerator() - results = do_gremlin_profile(query, self.host, self.port, self.ssl, request_generator) - - self.assertEqual(type(results), dict) - self.assertTrue('profile' in results) diff --git a/test/integration/gremlin/gremlin_status_with_iam.py b/test/integration/gremlin/gremlin_status_with_iam.py deleted file mode 100644 index 0066db90..00000000 --- a/test/integration/gremlin/gremlin_status_with_iam.py +++ /dev/null @@ -1,134 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import threading -import time -import requests -from os import cpu_count - -from graph_notebook.configuration.generate_config import AuthModeEnum -from graph_notebook.authentication.iam_credentials_provider.credentials_factory import IAMAuthCredentialsProvider -from graph_notebook.gremlin.query import do_gremlin_query -from graph_notebook.gremlin.status import do_gremlin_status, do_gremlin_cancel -from graph_notebook.gremlin.client_provider.factory import create_client_provider -from graph_notebook.request_param_generator.factory import create_request_generator -from gremlin_python.driver.protocol import GremlinServerError - -from test.integration import DataDrivenGremlinTest - - -class TestGremlinStatusWithIam(DataDrivenGremlinTest): - def do_gremlin_query_save_results(self, query, res): - client_provider = create_client_provider(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - try: - res['result'] = do_gremlin_query(query, self.host, self.port, self.ssl, client_provider) - except GremlinServerError as exception: - res['error'] = str(exception) - - def test_do_gremlin_status_nonexistent(self): - with self.assertRaises(requests.HTTPError): - query_id = "some-guid-here" - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - try: - do_gremlin_status(self.host, self.port, self.ssl, AuthModeEnum.IAM, request_generator, query_id, False) - except requests.HTTPError as exception: - content = exception.response.json() - self.assertTrue('requestId' in content) - self.assertTrue('code' in content) - self.assertTrue('detailedMessage' in content) - self.assertEqual('InvalidParameterException', content['code']) - raise exception - - def test_do_gremlin_cancel_nonexistent(self): - with self.assertRaises(requests.HTTPError): - query_id = "some-guid-here" - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - try: - do_gremlin_cancel(self.host, self.port, self.ssl, AuthModeEnum.IAM, request_generator, query_id) - except requests.HTTPError as exception: - content = exception.response.json() - self.assertTrue('requestId' in content) - self.assertTrue('code' in content) - self.assertTrue('detailedMessage' in content) - self.assertEqual('InvalidParameterException', content['code']) - raise exception - - def test_do_gremlin_cancel_empty_query_id(self): - with self.assertRaises(ValueError): - query_id = '' - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - do_gremlin_cancel(self.host, self.port, self.ssl, AuthModeEnum.IAM, request_generator, query_id) - - def test_do_gremlin_cancel_non_str_query_id(self): - with self.assertRaises(ValueError): - query_id = 42 - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - do_gremlin_cancel(self.host, self.port, self.ssl, AuthModeEnum.IAM, request_generator, query_id) - - def test_do_gremlin_status_and_cancel(self): - query = "g.V().out().out().out().out()" - query_res = {} - gremlin_query_thread = threading.Thread(target=self.do_gremlin_query_save_results, args=(query, query_res,)) - gremlin_query_thread.start() - time.sleep(3) - - query_id = '' - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - status_res = do_gremlin_status(self.host, self.port, self.ssl, AuthModeEnum.IAM, - request_generator, query_id, False) - self.assertEqual(type(status_res), dict) - self.assertTrue('acceptedQueryCount' in status_res) - self.assertTrue('runningQueryCount' in status_res) - self.assertTrue(status_res['runningQueryCount'] == 1) - self.assertTrue('queries' in status_res) - - query_id = '' - for q in status_res['queries']: - if query in q['queryString']: - query_id = q['queryId'] - - self.assertNotEqual(query_id, '') - - cancel_res = do_gremlin_cancel(self.host, self.port, self.ssl, AuthModeEnum.IAM, request_generator, query_id) - self.assertEqual(type(cancel_res), dict) - self.assertTrue('status' in cancel_res) - self.assertTrue('payload' in cancel_res) - self.assertEqual('200 OK', cancel_res['status']) - - gremlin_query_thread.join() - self.assertFalse('result' in query_res) - self.assertTrue('error' in query_res) - self.assertTrue('code' in query_res['error']) - self.assertTrue('requestId' in query_res['error']) - self.assertTrue('detailedMessage' in query_res['error']) - self.assertTrue('TimeLimitExceededException' in query_res['error']) - - def test_do_gremlin_status_include_waiting(self): - query = "g.V().out().out().out().out()" - num_threads = 4 * cpu_count() - threads = [] - query_results = [] - for x in range(0, num_threads): - query_res = {} - gremlin_query_thread = threading.Thread(target=self.do_gremlin_query_save_results, args=(query, query_res,)) - threads.append(gremlin_query_thread) - query_results.append(query_res) - gremlin_query_thread.start() - - time.sleep(5) - - query_id = '' - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - status_res = do_gremlin_status(self.host, self.port, self.ssl, AuthModeEnum.IAM, - request_generator, query_id, True) - - self.assertEqual(type(status_res), dict) - self.assertTrue('acceptedQueryCount' in status_res) - self.assertTrue('runningQueryCount' in status_res) - self.assertTrue('queries' in status_res) - self.assertEqual(status_res['acceptedQueryCount'], len(status_res['queries'])) - - for gremlin_query_thread in threads: - gremlin_query_thread.join() diff --git a/test/integration/gremlin/gremlin_status_without_iam.py b/test/integration/gremlin/gremlin_status_without_iam.py deleted file mode 100644 index 03f6d7b1..00000000 --- a/test/integration/gremlin/gremlin_status_without_iam.py +++ /dev/null @@ -1,129 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import threading -import time -import requests -from os import cpu_count - -from gremlin_python.driver.protocol import GremlinServerError - -from graph_notebook.configuration.generate_config import AuthModeEnum -from graph_notebook.gremlin.query import do_gremlin_query -from graph_notebook.gremlin.status import do_gremlin_status, do_gremlin_cancel -from graph_notebook.request_param_generator.factory import create_request_generator - -from test.integration import DataDrivenGremlinTest - - -class TestGremlinStatusWithoutIam(DataDrivenGremlinTest): - def do_gremlin_query_save_results(self, query, res): - try: - res['result'] = do_gremlin_query(query, self.host, self.port, self.ssl, self.client_provider) - except GremlinServerError as exception: - res['error'] = str(exception) - - def test_do_gremlin_status_nonexistent(self): - with self.assertRaises(requests.HTTPError): - query_id = "ac7d5a03-00cf-4280-b464-edbcbf51ffce" - request_generator = create_request_generator(AuthModeEnum.DEFAULT) - try: - do_gremlin_status(self.host, self.port, self.ssl, self.auth_mode, request_generator, query_id, False) - except requests.HTTPError as exception: - content = exception.response.json() - self.assertTrue('requestId' in content) - self.assertTrue('code' in content) - self.assertTrue('detailedMessage' in content) - self.assertEqual('InvalidParameterException', content['code']) - raise exception - - def test_do_gremlin_cancel_nonexistent(self): - with self.assertRaises(requests.HTTPError): - query_id = "ac7d5a03-00cf-4280-b464-edbcbf51ffce" - request_generator = create_request_generator(AuthModeEnum.DEFAULT) - try: - do_gremlin_cancel(self.host, self.port, self.ssl, self.auth_mode, request_generator, query_id) - except requests.HTTPError as exception: - content = exception.response.json() - self.assertTrue('requestId' in content) - self.assertTrue('code' in content) - self.assertTrue('detailedMessage' in content) - self.assertEqual('InvalidParameterException', content['code']) - raise exception - - def test_do_gremlin_cancel_empty_query_id(self): - with self.assertRaises(ValueError): - query_id = '' - request_generator = create_request_generator(AuthModeEnum.DEFAULT) - do_gremlin_cancel(self.host, self.port, self.ssl, self.auth_mode, request_generator, query_id) - - def test_do_gremlin_cancel_non_str_query_id(self): - with self.assertRaises(ValueError): - query_id = 42 - request_generator = create_request_generator(AuthModeEnum.DEFAULT) - do_gremlin_cancel(self.host, self.port, self.ssl, self.auth_mode, request_generator, query_id) - - def test_do_gremlin_status_and_cancel(self): - query = "g.V().out().out().out().out()" - query_res = {} - gremlin_query_thread = threading.Thread(target=self.do_gremlin_query_save_results, args=(query, query_res,)) - gremlin_query_thread.start() - time.sleep(3) - - query_id = '' - request_generator = create_request_generator(AuthModeEnum.DEFAULT) - status_res = do_gremlin_status(self.host, self.port, self.ssl, self.auth_mode, - request_generator, query_id, False) - self.assertEqual(type(status_res), dict) - self.assertTrue('acceptedQueryCount' in status_res) - self.assertTrue('runningQueryCount' in status_res) - self.assertTrue(status_res['runningQueryCount'] == 1) - self.assertTrue('queries' in status_res) - - query_id = '' - for q in status_res['queries']: - if query in q['queryString']: - query_id = q['queryId'] - - self.assertNotEqual(query_id, '') - - cancel_res = do_gremlin_cancel(self.host, self.port, self.ssl, self.auth_mode, request_generator, query_id) - self.assertEqual(type(cancel_res), dict) - self.assertTrue('status' in cancel_res) - self.assertTrue('payload' in cancel_res) - self.assertEqual('200 OK', cancel_res['status']) - - gremlin_query_thread.join() - self.assertFalse('result' in query_res) - self.assertTrue('error' in query_res) - self.assertTrue('code' in query_res['error']) - self.assertTrue('requestId' in query_res['error']) - self.assertTrue('detailedMessage' in query_res['error']) - self.assertTrue('TimeLimitExceededException' in query_res['error']) - - def test_do_gremlin_status_include_waiting(self): - query = "g.V().out().out().out().out()" - num_threads = 4 * cpu_count() - threads = [] - for x in range(0, num_threads): - gremlin_query_thread = threading.Thread(target=self.do_gremlin_query_save_results, args=(query, {})) - threads.append(gremlin_query_thread) - gremlin_query_thread.start() - - time.sleep(5) - - query_id = '' - request_generator = create_request_generator(AuthModeEnum.DEFAULT) - status_res = do_gremlin_status(self.host, self.port, self.ssl, self.auth_mode, - request_generator, query_id, True) - - self.assertEqual(type(status_res), dict) - self.assertTrue('acceptedQueryCount' in status_res) - self.assertTrue('runningQueryCount' in status_res) - self.assertTrue('queries' in status_res) - self.assertEqual(status_res['acceptedQueryCount'], len(status_res['queries'])) - - for gremlin_query_thread in threads: - gremlin_query_thread.join() diff --git a/test/integration/network/gremlin/__init__.py b/test/integration/iam/__init__.py similarity index 99% rename from test/integration/network/gremlin/__init__.py rename to test/integration/iam/__init__.py index 9049dd04..023f3c82 100644 --- a/test/integration/network/gremlin/__init__.py +++ b/test/integration/iam/__init__.py @@ -2,3 +2,4 @@ Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. SPDX-License-Identifier: Apache-2.0 """ + diff --git a/src/graph_notebook/gremlin/__init__.py b/test/integration/iam/gremlin/__init__.py similarity index 100% rename from src/graph_notebook/gremlin/__init__.py rename to test/integration/iam/gremlin/__init__.py diff --git a/test/integration/iam/gremlin/test_gremlin_status_with_iam.py b/test/integration/iam/gremlin/test_gremlin_status_with_iam.py new file mode 100644 index 00000000..502b7c9a --- /dev/null +++ b/test/integration/iam/gremlin/test_gremlin_status_with_iam.py @@ -0,0 +1,122 @@ +""" +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +SPDX-License-Identifier: Apache-2.0 +""" +import concurrent.futures +import threading +import time + +import pytest +from os import cpu_count + +from botocore.session import get_session +from gremlin_python.driver.protocol import GremlinServerError + +from graph_notebook.neptune.client import Client +from test.integration import DataDrivenGremlinTest + + +def long_running_gremlin_query(c: Client, query: str): + with pytest.raises(GremlinServerError): + c.gremlin_query(query) + return + + +class TestGremlinStatusWithIam(DataDrivenGremlinTest): + def setUp(self) -> None: + super().setUp() + if not self.client.iam_enabled: + self.client = self.client_builder.with_iam(get_session()).build() + + @pytest.mark.iam + @pytest.mark.neptune + def test_do_gremlin_status_nonexistent(self): + query_id = "some-guid-here" + res = self.client.gremlin_status(query_id) + assert res.status_code == 400 + js = res.json() + assert js['code'] == 'InvalidParameterException' + assert js['detailedMessage'] == f'Supplied queryId {query_id} is invalid' + + @pytest.mark.iam + @pytest.mark.neptune + def test_do_gremlin_cancel_nonexistent(self): + query_id = "some-guid-here" + res = self.client.gremlin_cancel(query_id) + assert res.status_code == 400 + js = res.json() + assert js['code'] == 'InvalidParameterException' + assert js['detailedMessage'] == f'Supplied queryId {query_id} is invalid' + + @pytest.mark.iam + @pytest.mark.neptune + def test_do_gremlin_cancel_empty_query_id(self): + with self.assertRaises(ValueError): + self.client.gremlin_cancel('') + + @pytest.mark.iam + @pytest.mark.neptune + def test_do_gremlin_cancel_non_str_query_id(self): + with self.assertRaises(ValueError): + self.client.gremlin_cancel(42) + + @pytest.mark.iam + @pytest.mark.neptune + def test_do_gremlin_status_and_cancel(self): + long_running_query = "g.V().out().out().out().out().out().out().out().out()" + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(long_running_gremlin_query, self.client, long_running_query) + + time.sleep(1) + status_res = self.client.gremlin_status() + assert status_res.status_code == 200 + + status_js = status_res.json() + query_id = '' + for q in status_js['queries']: + if q['queryString'] == long_running_query: + query_id = q['queryId'] + + assert query_id != '' + + cancel_res = self.client.gremlin_cancel(query_id) + assert cancel_res.status_code == 200 + assert cancel_res.json()['status'] == '200 OK' + + time.sleep(1) + status_after_cancel = self.client.gremlin_status(query_id) + assert status_after_cancel.status_code == 400 # check that the query is no longer valid + assert status_after_cancel.json()['code'] == 'InvalidParameterException' + + future.result() + + @pytest.mark.iam + @pytest.mark.neptune + def test_do_gremlin_status_include_waiting(self): + query = "g.V().out().out().out().out()" + num_threads = cpu_count() * 4 + threads = [] + for x in range(0, num_threads): + thread = threading.Thread(target=long_running_gremlin_query, args=(self.client, query)) + thread.start() + threads.append(thread) + + time.sleep(5) + + res = self.client.gremlin_status(include_waiting=True) + assert res.status_code == 200 + status_res = res.json() + + self.assertEqual(type(status_res), dict) + self.assertTrue('acceptedQueryCount' in status_res) + self.assertTrue('runningQueryCount' in status_res) + self.assertTrue('queries' in status_res) + self.assertEqual(status_res['acceptedQueryCount'], len(status_res['queries'])) + + for q in status_res['queries']: + # cancel all the queries we executed since they can take a very long time. + if q['queryString'] == query: + self.client.gremlin_cancel(q['queryId']) + + for t in threads: + t.join() diff --git a/test/integration/iam/gremlin/test_gremlin_with_iam.py b/test/integration/iam/gremlin/test_gremlin_with_iam.py new file mode 100644 index 00000000..375e2b7d --- /dev/null +++ b/test/integration/iam/gremlin/test_gremlin_with_iam.py @@ -0,0 +1,55 @@ +""" +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +SPDX-License-Identifier: Apache-2.0 +""" +import pytest +from botocore.session import get_session +from gremlin_python.structure.graph import Vertex + +from test.integration import IntegrationTest + + +class TestGremlinWithIam(IntegrationTest): + def setUp(self) -> None: + self.client = self.client_builder.with_iam(get_session()).build() + + @pytest.mark.iam + @pytest.mark.gremlin + def test_do_gremlin_query_with_iam(self): + query = 'g.V().limit(1)' + results = self.client.gremlin_query(query) + assert type(results) is list + for r in results: + assert type(r) is Vertex + + @pytest.mark.iam + @pytest.mark.gremlin + def test_do_gremlin_explain_with_iam(self): + query = 'g.V().limit(1)' + res = self.client.gremlin_explain(query) + assert res.status_code == 200 + results = res.content.decode('utf-8') + self.assertTrue('Explain' in results) + + @pytest.mark.iam + @pytest.mark.gremlin + def test_do_gremlin_profile_with_iam(self): + query = 'g.V().limit(1)' + res = self.client.gremlin_profile(query) + assert res.status_code == 200 + + results = res.content.decode('utf-8') + self.assertTrue('Profile' in results) + + @pytest.mark.iam + @pytest.mark.gremlin + def test_iam_gremlin_http_query(self): + query = 'g.V().limit(1)' + res = self.client.gremlin_http_query(query) + assert res.status_code == 200 + assert 'result' in res.json() + + def test_iam_gremlin_connection(self): + conn = self.client.get_gremlin_connection() + conn.submit('g.V().limit(1)') + print('here') diff --git a/test/integration/system/__init__.py b/test/integration/iam/load/__init__.py similarity index 100% rename from test/integration/system/__init__.py rename to test/integration/iam/load/__init__.py diff --git a/test/integration/iam/load/test_load_with_iam.py b/test/integration/iam/load/test_load_with_iam.py new file mode 100644 index 00000000..b34cbb48 --- /dev/null +++ b/test/integration/iam/load/test_load_with_iam.py @@ -0,0 +1,58 @@ +import time + +import pytest +import unittest + +from botocore.session import get_session + +from test.integration import IntegrationTest + +TEST_BULKLOAD_SOURCE = 's3://aws-ml-customer-samples-%s/bulkload-datasets/%s/airroutes/v01' + + +@unittest.skip +class TestLoadWithIAM(IntegrationTest): + def setUp(self) -> None: + assert self.config.load_from_s3_arn != '' + self.client = self.client_builder.with_iam(get_session()).build() + + def test_iam_load(self): + load_format = 'turtle' + source = TEST_BULKLOAD_SOURCE % (self.config.aws_region, 'turtle') + + # for a full list of options, see https://docs.aws.amazon.com/neptune/latest/userguide/bulk-load-data.html + kwargs = { + 'failOnError': "TRUE", + } + res = self.client.load(source, load_format, self.config.load_from_s3_arn, **kwargs) + assert res.status_code == 200 + + load_js = res.json() + assert 'loadId' in load_js['payload'] + load_id = load_js['payload']['loadId'] + + time.sleep(1) # brief wait to ensure the load job can be obtained + + res = self.client.load_status(load_id, details="TRUE") + assert res.status_code == 200 + + load_status = res.json() + assert 'overallStatus' in load_status['payload'] + status = load_status['payload']['overallStatus'] + assert status['fullUri'] == source + + res = self.client.cancel_load(load_id) + assert res.status_code == 200 + + time.sleep(5) + res = self.client.load_status(load_id, details="TRUE") + cancelled_status = res.json() + assert 'LOAD_CANCELLED_BY_USER' in cancelled_status['payload']['feedCount'][-1] + + def test_iam_load_status(self): + res = self.client.load_status() # This should only give a list of load ids + assert res.status_code == 200 + + js = res.json() + assert 'loadIds' in js['payload'] + assert len(js['payload'].keys()) == 1 diff --git a/test/integration/iam/ml/__init__.py b/test/integration/iam/ml/__init__.py new file mode 100644 index 00000000..70385e35 --- /dev/null +++ b/test/integration/iam/ml/__init__.py @@ -0,0 +1,27 @@ +""" +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +SPDX-License-Identifier: Apache-2.0 +""" + +from botocore.session import get_session + +from graph_notebook.configuration.generate_config import Configuration +from graph_notebook.neptune.client import Client, ClientBuilder + + +def setup_iam_client(config: Configuration) -> Client: + client = ClientBuilder() \ + .with_host(config.host) \ + .with_port(config.port) \ + .with_region(config.aws_region) \ + .with_tls(config.ssl) \ + .with_sparql_path(config.sparql.path) \ + .with_iam(get_session()) \ + .build() + + assert client.host == config.host + assert client.port == config.port + assert client.region == config.aws_region + assert client.sparql_path == config.sparql.path + assert client.ssl is config.ssl + return client diff --git a/test/integration/iam/ml/test_neptune_client_with_iam.py b/test/integration/iam/ml/test_neptune_client_with_iam.py new file mode 100644 index 00000000..26d7d3ae --- /dev/null +++ b/test/integration/iam/ml/test_neptune_client_with_iam.py @@ -0,0 +1,25 @@ +""" +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +SPDX-License-Identifier: Apache-2.0 +""" +import datetime +import concurrent.futures +import json +import os +import time + +import pytest +from botocore.session import get_session +from gremlin_python.structure.graph import Edge + +from graph_notebook.configuration.generate_config import Configuration +from graph_notebook.configuration.get_config import get_config +from graph_notebook.neptune.client import ClientBuilder, Client + +client: Client +config: Configuration + +TEST_BULKLOAD_SOURCE = 's3://aws-ml-customer-samples-%s/bulkload-datasets/%s/airroutes/v01' +GREMLIN_TEST_LABEL = 'graph-notebook-test' +SPARQL_TEST_PREDICATE = '' + diff --git a/test/integration/iam/ml/test_neptune_ml_with_iam.py b/test/integration/iam/ml/test_neptune_ml_with_iam.py new file mode 100644 index 00000000..36d1c05e --- /dev/null +++ b/test/integration/iam/ml/test_neptune_ml_with_iam.py @@ -0,0 +1,179 @@ +""" +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +SPDX-License-Identifier: Apache-2.0 +""" + +import logging +import os +import threading +import time +import unittest + +import pytest +from botocore.session import get_session + +from graph_notebook.configuration.get_config import get_config +from test.integration import IntegrationTest +from test.integration.iam.ml import setup_iam_client + +logger = logging.getLogger() + + +@unittest.skip +class TestNeptuneMLWithIAM(IntegrationTest): + def setUp(self) -> None: + self.client = self.client_builder.with_iam(get_session()).build() + + def tearDown(self) -> None: + endpoint_ids = client.endpoints().json()['ids'] + for endpoint_id in endpoint_ids: + self.client.endpoints_delete(endpoint_id) + + client.close() + + @pytest.mark.neptuneml + @pytest.mark.iam + def test_neptune_ml_e2e(self): + s3_input_uri = os.getenv('NEPTUNE_ML_DATAPROCESSING_S3_INPUT', '') + s3_processed_uri = os.getenv('NEPTUNE_ML_DATAPROCESSING_S3_PROCESSED', '') + train_model_s3_location = os.getenv('NEPTUNE_ML_TRAINING_S3_LOCATION', '') + + assert s3_input_uri != '' + assert s3_processed_uri != '' + assert train_model_s3_location != '' + + logger.info("dataprocessing...") + dataprocessing_job = do_dataprocessing(s3_input_uri, s3_processed_uri) + dataprocessing_id = dataprocessing_job['id'] + + p = threading.Thread(target=wait_for_dataprocessing_complete, args=(dataprocessing_id,)) + p.start() + p.join(3600) + + logger.info("model training...") + training_job = do_modeltraining(dataprocessing_id, train_model_s3_location) + training_job_id = training_job['id'] + + p = threading.Thread(target=wait_for_modeltraining_complete, args=(training_job_id,)) + p.start() + p.join(3600) + + logger.info("endpoint...") + endpoint_job = do_create_endpoint(training_job_id) + endpoint_job_id = endpoint_job['id'] + p = threading.Thread(target=wait_for_endpoint_complete, args=(endpoint_job_id,)) + p.start() + p.join(3600) + + @pytest.mark.neptuneml + @pytest.mark.iam + def test_neptune_ml_dataprocessing_status(self): + status = client.dataprocessing_status() + + assert status.status_code == 200 + assert 'ids' in status.json() + + @pytest.mark.neptuneml + @pytest.mark.iam + def test_neptune_ml_modeltraining_status(self): + status = client.modeltraining_status() + assert status.status_code == 200 + assert 'ids' in status.json() + + @pytest.mark.neptuneml + @pytest.mark.iam + def test_neptune_ml_training(self): + dataprocessing_id = os.getenv('NEPTUNE_ML_DATAPROCESSING_ID', '') + train_model_s3_location = os.getenv('NEPTUNE_ML_TRAINING_S3_LOCATION', '') + + assert dataprocessing_id != '' + assert train_model_s3_location != '' + + dataprocessing_status = client.dataprocessing_job_status(dataprocessing_id) + assert dataprocessing_status.status_code == 200 + + job_start_res = client.modeltraining_start(dataprocessing_id, train_model_s3_location) + assert job_start_res.status_code == 200 + + job_id = job_start_res.json()['id'] + training_status_res = client.modeltraining_job_status(job_id) + assert training_status_res.status_code == 200 + + job_stop_res = client.modeltraining_stop(job_id, clean=True) + assert job_stop_res.status_code == 200 + + +def setup_module(): + global client + client = setup_iam_client(get_config()) + + +def teardown_module(): + endpoint_ids = client.endpoints().json()['ids'] + for endpoint_id in endpoint_ids: + client.endpoints_delete(endpoint_id) + + client.close() + + +def do_dataprocessing(s3_input, s3_processed) -> dict: + logger.info(f"starting dataprocessing job with input={s3_input} and processed={s3_processed}") + dataprocessing_res = client.dataprocessing_start(s3_input, s3_processed) + assert dataprocessing_res.status_code == 200 + return dataprocessing_res.json() + + +def wait_for_dataprocessing_complete(dataprocessing_id: str): + logger.info(f"waiting for dataprocessing job {dataprocessing_id} to complete") + while True: + status = client.dataprocessing_job_status(dataprocessing_id) + assert status.status_code == 200 + raw = status.json() + logger.info(f"status is {raw['status']}") + if raw['status'] != 'InProgress': + assert raw['status'] == 'Completed' + return raw + logger.info("waiting for 10 seconds then checking again") + time.sleep(10) + + +def do_modeltraining(dataprocessing_id, train_model_s3_location): + logger.info( + f"starting training job from dataprocessing_job_id={dataprocessing_id} and training_model_s3_location={train_model_s3_location}") + training_start = client.modeltraining_start(dataprocessing_id, train_model_s3_location) + assert training_start.status_code == 200 + return training_start.json() + + +def wait_for_modeltraining_complete(training_job: str) -> dict: + logger.info(f"waiting for modeltraining job {training_job} to complete") + while True: + status = client.modeltraining_job_status(training_job) + assert status.status_code == 200 + raw = status.json() + logger.info(f"status is {raw['status']}") + if raw['status'] != 'InProgress': + assert raw['status'] == 'Completed' + return raw + logger.info("waiting for 10 seconds then checking again") + time.sleep(10) + + +def do_create_endpoint(training_job_id: str) -> dict: + endpoint_res = client.endpoints_create(training_job_id) + assert endpoint_res.status_code == 200 + return endpoint_res.json() + + +def wait_for_endpoint_complete(endpoint_job_id): + logger.info(f"waiting for endpoint creation job {endpoint_job_id} to complete") + while True: + endpoint_status = client.endpoints_status(endpoint_job_id) + assert endpoint_status.status_code == 200 + raw = endpoint_status.json() + logger.info(f"status is {raw['status']}") + if raw['status'] != 'Creating': + assert raw['status'] == 'InService' + return raw + logger.info("waiting for 10 seconds then checking again") + time.sleep(10) diff --git a/test/integration/iam/sparql/__init__.py b/test/integration/iam/sparql/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/integration/iam/sparql/test_sparql_query_with_iam.py b/test/integration/iam/sparql/test_sparql_query_with_iam.py new file mode 100644 index 00000000..6b4d9c68 --- /dev/null +++ b/test/integration/iam/sparql/test_sparql_query_with_iam.py @@ -0,0 +1,54 @@ +""" +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +SPDX-License-Identifier: Apache-2.0 +""" +from json import JSONDecodeError + +import pytest +from botocore.session import get_session + +from test.integration import DataDrivenSparqlTest + + +class TestSparqlQueryWithIam(DataDrivenSparqlTest): + def setUp(self) -> None: + super().setUp() + self.client = self.client_builder.with_iam(get_session()).build() + + @pytest.mark.iam + @pytest.mark.sparql + def test_do_sparql_query(self): + query = "SELECT * WHERE {?s ?p ?o} LIMIT 1" + query_res = self.client.sparql(query) + assert query_res.status_code == 200 + res = query_res.json() + + self.assertEqual(type(res), dict) + self.assertTrue('s' in res['head']['vars']) + self.assertTrue('p' in res['head']['vars']) + self.assertTrue('o' in res['head']['vars']) + + @pytest.mark.iam + @pytest.mark.sparql + def test_do_sparql_explain(self): + query = "SELECT * WHERE {?s ?p ?o} LIMIT 1" + query_res = self.client.sparql_explain(query) + assert query_res.status_code == 200 + res = query_res.content.decode('utf-8') + self.assertEqual(type(res), str) + self.assertTrue(res.startswith('')) + + @pytest.mark.iam + @pytest.mark.sparql + def test_iam_describe(self): + query = '''PREFIX soccer: + DESCRIBE soccer:Arsenal''' + res = self.client.sparql(query) + assert res.status_code == 200 + + # test that we do not get back json + with pytest.raises(JSONDecodeError): + res.json() + + content = res.content.decode('utf-8') + assert len(content.splitlines()) == 6 diff --git a/test/integration/iam/sparql/test_sparql_status_with_iam.py b/test/integration/iam/sparql/test_sparql_status_with_iam.py new file mode 100644 index 00000000..c6038099 --- /dev/null +++ b/test/integration/iam/sparql/test_sparql_status_with_iam.py @@ -0,0 +1,128 @@ +""" +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +SPDX-License-Identifier: Apache-2.0 +""" + +import concurrent.futures +import time + +import pytest +from botocore.session import get_session + +from graph_notebook.neptune.client import Client + +from test.integration import DataDrivenSparqlTest + + +def long_running_sparql_query(c: Client, query: str): + res = c.sparql(query) + return res + + +class TestSparqlStatusWithIam(DataDrivenSparqlTest): + def setUp(self) -> None: + super().setUp() + self.client = self.client_builder.with_iam(get_session()).build() + + @pytest.mark.iam + @pytest.mark.neptune + def test_do_sparql_status_nonexistent(self): + query_id = "invalid-guid" + status_res = self.client.sparql_status(query_id) + assert status_res.status_code == 200 + assert status_res.content == b'' + + @pytest.mark.iam + @pytest.mark.neptune + def test_do_sparql_cancel_nonexistent(self): + query_id = "invalid-guid" + cancel_res = self.client.sparql_cancel(query_id) + assert cancel_res.status_code == 200 + assert cancel_res.content == b'' + + @pytest.mark.iam + @pytest.mark.neptune + def test_do_sparql_cancel_empty_query_id(self): + with pytest.raises(ValueError): + self.client.sparql_cancel('') + + @pytest.mark.iam + @pytest.mark.neptune + def test_do_sparql_cancel_non_str_query_id(self): + with pytest.raises(ValueError): + self.client.sparql_cancel(42) + + @pytest.mark.iam + @pytest.mark.neptune + def test_do_sparql_status_and_cancel(self): + query = "SELECT * WHERE { ?s ?p ?o . ?s2 ?p2 ?o2 .?s3 ?p3 ?o3 . ?s4 ?s5 ?s6 .} ORDER BY DESC(?s)" + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(long_running_sparql_query, self.client, query) + time.sleep(1) + + status = self.client.sparql_status() + status_res = status.json() + assert 'acceptedQueryCount' in status_res + assert 'runningQueryCount' in status_res + assert 'queries' in status_res + + time.sleep(1) + + query_id = '' + for q in status_res['queries']: + if query in q['queryString']: + query_id = q['queryId'] + + self.assertNotEqual(query_id, '') + + cancel = self.client.sparql_cancel(query_id, False) + cancel_res = cancel.json() + + assert 'acceptedQueryCount' in cancel_res + assert 'acceptedQueryCount' in cancel_res + assert 'runningQueryCount' in cancel_res + assert 'queries' in cancel_res + + res = future.result() + assert res.status_code == 500 + raw = res.json() + assert raw['code'] == 'CancelledByUserException' + assert raw['detailedMessage'] == 'Operation terminated (cancelled by user)' + + @pytest.mark.iam + @pytest.mark.neptune + def test_do_sparql_status_and_cancel_silently(self): + + query = "SELECT * WHERE { ?s ?p ?o . ?s2 ?p2 ?o2 .?s3 ?p3 ?o3 . ?s4 ?s5 ?s6 .} ORDER BY DESC(?s)" + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(long_running_sparql_query, self.client, query) + time.sleep(1) + + status = self.client.sparql_status() + status_res = status.json() + assert 'acceptedQueryCount' in status_res + assert 'runningQueryCount' in status_res + assert 'queries' in status_res + + time.sleep(1) + + query_id = '' + for q in status_res['queries']: + if query in q['queryString']: + query_id = q['queryId'] + + assert query_id != '' + + cancel = self.client.sparql_cancel(query_id, True) + cancel_res = cancel.json() + assert 'acceptedQueryCount' in cancel_res + assert 'runningQueryCount' in cancel_res + assert 'queries' in cancel_res + + res = future.result() + query_res = res.json() + assert type(query_res) is dict + assert 's3' in query_res['head']['vars'] + assert 'p3' in query_res['head']['vars'] + assert 'o3' in query_res['head']['vars'] + assert [] == query_res['results']['bindings'] diff --git a/test/integration/iam/status/__init__.py b/test/integration/iam/status/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/integration/iam/status/test_status_with_iam.py b/test/integration/iam/status/test_status_with_iam.py new file mode 100644 index 00000000..0cd83446 --- /dev/null +++ b/test/integration/iam/status/test_status_with_iam.py @@ -0,0 +1,29 @@ +""" +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +SPDX-License-Identifier: Apache-2.0 +""" +import pytest +from botocore.session import get_session + +from test.integration import IntegrationTest + + +class TestStatusWithIAM(IntegrationTest): + def setUp(self) -> None: + super().setUp() + self.client = self.client_builder.with_iam(get_session()).build() + + @pytest.mark.neptune + @pytest.mark.iam + def test_do_status_with_iam_credentials(self): + res = self.client.status() + assert res.status_code == 200 + status = res.json() + self.assertEqual(status['status'], 'healthy') + + @pytest.mark.neptune + @pytest.mark.iam + def test_do_status_without_iam_credentials(self): + client = self.client_builder.with_iam(None).build() + res = client.status() + assert res.status_code != 200 diff --git a/test/integration/iam/system/__init__.py b/test/integration/iam/system/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/integration/iam/system/test_system_with_iam.py b/test/integration/iam/system/test_system_with_iam.py new file mode 100644 index 00000000..688f5834 --- /dev/null +++ b/test/integration/iam/system/test_system_with_iam.py @@ -0,0 +1,60 @@ +import datetime +import time + +import pytest +from botocore.session import get_session +from test.integration import IntegrationTest + + +class TestStatusWithIAM(IntegrationTest): + def setUp(self) -> None: + self.client = self.client_builder.with_iam(get_session()).build() + + @pytest.mark.iam + @pytest.mark.neptune + def test_do_db_reset_initiate_with_iam_credentials(self): + token = self.client.initiate_reset() + result = token.json() + self.assertNotEqual(result['payload']['token'], '') + + @pytest.mark.iam + @pytest.mark.neptune + def test_do_db_reset_perform_with_wrong_token_with_iam_credentials(self): + res = self.client.perform_reset('invalid') + assert res.status_code == 400 + + expected_message = "System command parameter 'token' : 'invalid' does not match database reset token" + assert expected_message == res.json()['detailedMessage'] + + @pytest.mark.iam + @pytest.mark.neptune + def test_do_db_reset_initiate_without_iam_credentials(self): + client = self.client_builder.with_iam(None).build() + res = client.initiate_reset() + assert res.status_code == 403 + + @pytest.mark.iam + @pytest.mark.neptune + @pytest.mark.reset + def test_iam_fast_reset(self): + initiate_reset_res = self.client.initiate_reset() + assert initiate_reset_res.status_code == 200 + + token = initiate_reset_res.json()['payload']['token'] + reset_res = self.client.perform_reset(token) + assert reset_res.json()['status'] == '200 OK' + + # check for status for 5 minutes while reset is performed + end_time = datetime.datetime.now() + datetime.timedelta(minutes=5) + status = None + while end_time >= datetime.datetime.now(): + try: + status = self.client.status() + if status.status_code != 200: + time.sleep(5) # wait momentarily until we obtain the status again + else: + break + except: + time.sleep(5) + + assert status.status_code == 200 diff --git a/test/integration/network/__init__.py b/test/integration/network/__init__.py deleted file mode 100644 index 9049dd04..00000000 --- a/test/integration/network/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" diff --git a/test/integration/notebook/__init__.py b/test/integration/notebook/__init__.py deleted file mode 100644 index 9049dd04..00000000 --- a/test/integration/notebook/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" diff --git a/test/integration/sparql/__init__.py b/test/integration/sparql/__init__.py deleted file mode 100644 index 9049dd04..00000000 --- a/test/integration/sparql/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" diff --git a/test/integration/sparql/sparql_query_with_iam.py b/test/integration/sparql/sparql_query_with_iam.py deleted file mode 100644 index 8764b182..00000000 --- a/test/integration/sparql/sparql_query_with_iam.py +++ /dev/null @@ -1,30 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -from graph_notebook.configuration.generate_config import AuthModeEnum -from graph_notebook.authentication.iam_credentials_provider.credentials_factory import IAMAuthCredentialsProvider -from graph_notebook.sparql.query import do_sparql_query, do_sparql_explain -from graph_notebook.request_param_generator.factory import create_request_generator - -from test.integration import IntegrationTest - - -class TestSparqlQueryWithIam(IntegrationTest): - def test_do_sparql_query(self): - query = "SELECT * WHERE {?s ?p ?o} LIMIT 1" - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - res = do_sparql_query(query, self.host, self.port, self.ssl, request_generator) - self.assertEqual(type(res), dict) - self.assertTrue('s' in res['head']['vars']) - self.assertTrue('p' in res['head']['vars']) - self.assertTrue('o' in res['head']['vars']) - - def test_do_sparql_explain(self): - query = "SELECT * WHERE {?s ?p ?o} LIMIT 1" - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - - res = do_sparql_explain(query, self.host, self.port, self.ssl, request_generator) - self.assertEqual(type(res), str) - self.assertTrue(res.startswith('')) diff --git a/test/integration/sparql/sparql_status_with_iam.py b/test/integration/sparql/sparql_status_with_iam.py deleted file mode 100644 index 60ee3fb3..00000000 --- a/test/integration/sparql/sparql_status_with_iam.py +++ /dev/null @@ -1,134 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import threading -import time -import requests - -from graph_notebook.configuration.generate_config import AuthModeEnum -from graph_notebook.authentication.iam_credentials_provider.credentials_factory import IAMAuthCredentialsProvider -from graph_notebook.sparql.query import do_sparql_query -from graph_notebook.sparql.status import do_sparql_status, do_sparql_cancel -from graph_notebook.request_param_generator.factory import create_request_generator - -from test.integration import DataDrivenSparqlTest - - -class TestSparqlStatusWithIam(DataDrivenSparqlTest): - def do_sparql_query_save_result(self, query, res): - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - try: - res['result'] = do_sparql_query(query, self.host, self.port, self.ssl, request_generator) - except requests.HTTPError as exception: - res['error'] = exception.response.json() - - def setUp(self) -> None: - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - res = do_sparql_status(self.host, self.port, self.ssl, request_generator) - for q in res['queries']: - do_sparql_cancel(self.host, self.port, self.ssl, request_generator, q['queryId'], False) - - def test_do_sparql_status_nonexistent(self): - query_id = "ac7d5a03-00cf-4280-b464-edbcbf51ffce" - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - res = do_sparql_status(self.host, self.port, self.ssl, request_generator, query_id) - self.assertEqual(type(res), dict) - self.assertTrue('acceptedQueryCount' in res) - self.assertTrue('runningQueryCount' in res) - self.assertTrue('queries' in res) - - def test_do_sparql_cancel_nonexistent(self): - query_id = "ac7d5a03-00cf-4280-b464-edbcbf51ffce" - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - res = do_sparql_cancel(self.host, self.port, self.ssl, request_generator, query_id, False) - self.assertEqual(type(res), dict) - self.assertTrue('acceptedQueryCount' in res) - self.assertTrue('runningQueryCount' in res) - self.assertTrue('queries' in res) - - def test_do_sparql_cancel_empty_query_id(self): - with self.assertRaises(ValueError): - query_id = '' - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - do_sparql_cancel(self.host, self.port, self.ssl, request_generator, query_id, False) - - def test_do_sparql_cancel_non_str_query_id(self): - with self.assertRaises(ValueError): - query_id = 42 - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - do_sparql_cancel(self.host, self.port, self.ssl, request_generator, query_id, False) - - def test_do_sparql_status_and_cancel(self): - query = "SELECT * WHERE { ?s ?p ?o . ?s2 ?p2 ?o2 .?s3 ?p3 ?o3 .} ORDER BY DESC(?s) LIMIT 100" - query_res = {} - sparql_query_thread = threading.Thread(target=self.do_sparql_query_save_result, args=(query, query_res,)) - sparql_query_thread.start() - time.sleep(1) - - query_id = '' - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - status_res = do_sparql_status(self.host, self.port, self.ssl, self.request_generator, query_id) - self.assertEqual(type(status_res), dict) - self.assertTrue('acceptedQueryCount' in status_res) - self.assertTrue('runningQueryCount' in status_res) - self.assertTrue('queries' in status_res) - - time.sleep(1) - - query_id = '' - for q in status_res['queries']: - if query in q['queryString']: - query_id = q['queryId'] - - self.assertNotEqual(query_id, '') - - cancel_res = do_sparql_cancel(self.host, self.port, self.ssl, request_generator, query_id, False) - self.assertEqual(type(cancel_res), dict) - self.assertTrue('acceptedQueryCount' in cancel_res) - self.assertTrue('runningQueryCount' in cancel_res) - self.assertTrue('queries' in cancel_res) - - sparql_query_thread.join() - self.assertFalse('result' in query_res) - self.assertTrue('error' in query_res) - self.assertTrue('code' in query_res['error']) - self.assertTrue('requestId' in query_res['error']) - self.assertTrue('detailedMessage' in query_res['error']) - self.assertEqual('CancelledByUserException', query_res['error']['code']) - - def test_do_sparql_status_and_cancel_silently(self): - query = "SELECT * WHERE { ?s ?p ?o . ?s2 ?p2 ?o2 .?s3 ?p3 ?o3 .} ORDER BY DESC(?s) LIMIT 100" - query_res = {} - sparql_query_thread = threading.Thread(target=self.do_sparql_query_save_result, args=(query, query_res,)) - sparql_query_thread.start() - time.sleep(1) - - query_id = '' - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - status_res = do_sparql_status(self.host, self.port, self.ssl, request_generator, query_id) - self.assertEqual(type(status_res), dict) - self.assertTrue('acceptedQueryCount' in status_res) - self.assertTrue('runningQueryCount' in status_res) - self.assertTrue('queries' in status_res) - - query_id = '' - for q in status_res['queries']: - if query in q['queryString']: - query_id = q['queryId'] - - self.assertNotEqual(query_id, '') - - cancel_res = do_sparql_cancel(self.host, self.port, self.ssl, request_generator, query_id, True) - self.assertEqual(type(cancel_res), dict) - self.assertTrue('acceptedQueryCount' in cancel_res) - self.assertTrue('runningQueryCount' in cancel_res) - self.assertTrue('queries' in cancel_res) - - sparql_query_thread.join() - self.assertEqual(type(query_res['result']), dict) - self.assertTrue('s3' in query_res['result']['head']['vars']) - self.assertTrue('p3' in query_res['result']['head']['vars']) - self.assertTrue('o3' in query_res['result']['head']['vars']) - self.assertEqual([], query_res['result']['results']['bindings']) diff --git a/test/integration/sparql/sparql_status_without_iam.py b/test/integration/sparql/sparql_status_without_iam.py deleted file mode 100644 index 1d600b34..00000000 --- a/test/integration/sparql/sparql_status_without_iam.py +++ /dev/null @@ -1,137 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import threading - -import logging -import time -import requests - -from graph_notebook.configuration.generate_config import AuthModeEnum -from graph_notebook.sparql.query import do_sparql_query -from graph_notebook.sparql.status import do_sparql_status, do_sparql_cancel -from graph_notebook.request_param_generator.factory import create_request_generator -from graph_notebook.request_param_generator.sparql_request_generator import SPARQLRequestGenerator - -from test.integration import DataDrivenSparqlTest - -logger = logging.getLogger('TestSparqlStatusWithoutIam') - - -class TestSparqlStatusWithoutIam(DataDrivenSparqlTest): - def do_sparql_query_save_result(self, query, res): - try: - res['result'] = do_sparql_query(query, self.host, self.port, self.ssl, SPARQLRequestGenerator()) - except requests.HTTPError as exception: - res['error'] = exception.response.json() - - def setUp(self) -> None: - request_generator = create_request_generator(AuthModeEnum.DEFAULT) - res = do_sparql_status(self.host, self.port, self.ssl, request_generator) - for q in res['queries']: - do_sparql_cancel(self.host, self.port, self.ssl, request_generator, q['queryId'], False) - - def test_do_sparql_status_nonexistent(self): - query_id = "ac7d5a03-00cf-4280-b464-edbcbf51ffce" - request_generator = create_request_generator(AuthModeEnum.DEFAULT) - res = do_sparql_status(self.host, self.port, self.ssl, request_generator, query_id) - self.assertEqual(type(res), dict) - self.assertTrue('acceptedQueryCount' in res) - self.assertTrue('runningQueryCount' in res) - self.assertTrue('queries' in res) - - def test_do_sparql_cancel_nonexistent(self): - query_id = "ac7d5a03-00cf-4280-b464-edbcbf51ffce" - request_generator = create_request_generator(AuthModeEnum.DEFAULT) - res = do_sparql_cancel(self.host, self.port, self.ssl, request_generator, query_id, False) - self.assertEqual(type(res), dict) - self.assertTrue('acceptedQueryCount' in res) - self.assertTrue('runningQueryCount' in res) - self.assertTrue('queries' in res) - - def test_do_sparql_cancel_empty_query_id(self): - with self.assertRaises(ValueError): - query_id = '' - request_generator = create_request_generator(AuthModeEnum.DEFAULT) - do_sparql_cancel(query_id, False, self.host, self.port, self.ssl, request_generator) - - def test_do_sparql_cancel_non_str_query_id(self): - with self.assertRaises(ValueError): - query_id = 42 - request_generator = create_request_generator(AuthModeEnum.DEFAULT) - do_sparql_cancel(query_id, False, self.host, self.port, self.ssl, request_generator) - - def test_do_sparql_status_and_cancel(self): - query = "SELECT * WHERE { ?s ?p ?o . ?s2 ?p2 ?o2 .?s3 ?p3 ?o3 .} ORDER BY DESC(?s) LIMIT 100" - query_res = {} - sparql_query_thread = threading.Thread(target=self.do_sparql_query_save_result, args=(query, query_res,)) - sparql_query_thread.start() - time.sleep(3) - - query_id = '' - request_generator = create_request_generator(AuthModeEnum.DEFAULT) - status_res = do_sparql_status(self.host, self.port, self.ssl, request_generator, query_id) - self.assertEqual(type(status_res), dict) - self.assertTrue('acceptedQueryCount' in status_res) - self.assertTrue('runningQueryCount' in status_res) - self.assertEqual(1, status_res['runningQueryCount']) - self.assertTrue('queries' in status_res) - - query_id = '' - for q in status_res['queries']: - if query in q['queryString']: - query_id = q['queryId'] - - self.assertNotEqual(query_id, '') - - cancel_res = do_sparql_cancel(self.host, self.port, self.ssl, request_generator, query_id, False) - self.assertEqual(type(cancel_res), dict) - self.assertTrue('acceptedQueryCount' in cancel_res) - self.assertTrue('runningQueryCount' in cancel_res) - self.assertTrue('queries' in cancel_res) - - sparql_query_thread.join() - self.assertFalse('result' in query_res) - self.assertTrue('error' in query_res) - self.assertTrue('code' in query_res['error']) - self.assertTrue('requestId' in query_res['error']) - self.assertTrue('detailedMessage' in query_res['error']) - self.assertEqual('CancelledByUserException', query_res['error']['code']) - - def test_do_sparql_status_and_cancel_silently(self): - query = "SELECT * WHERE { ?s ?p ?o . ?s2 ?p2 ?o2 .?s3 ?p3 ?o3 .} ORDER BY DESC(?s) LIMIT 100" - query_res = {} - sparql_query_thread = threading.Thread(target=self.do_sparql_query_save_result, args=(query, query_res,)) - sparql_query_thread.start() - time.sleep(3) - - query_id = '' - request_generator = create_request_generator(AuthModeEnum.DEFAULT) - status_res = do_sparql_status(self.host, self.port, self.ssl, request_generator, query_id) - self.assertEqual(type(status_res), dict) - self.assertTrue('acceptedQueryCount' in status_res) - self.assertTrue('runningQueryCount' in status_res) - self.assertEqual(1, status_res['runningQueryCount']) - self.assertTrue('queries' in status_res) - - query_id = '' - for q in status_res['queries']: - if query in q['queryString']: - query_id = q['queryId'] - - self.assertNotEqual(query_id, '') - - cancel_res = do_sparql_cancel(self.host, self.port, self.ssl, request_generator, query_id, True) - self.assertEqual(type(cancel_res), dict) - self.assertTrue('acceptedQueryCount' in cancel_res) - self.assertTrue('runningQueryCount' in cancel_res) - self.assertTrue('queries' in cancel_res) - - sparql_query_thread.join() - self.assertEqual(type(query_res['result']), dict) - self.assertTrue('s3' in query_res['result']['head']['vars']) - self.assertTrue('p3' in query_res['result']['head']['vars']) - self.assertTrue('o3' in query_res['result']['head']['vars']) - self.assertEqual([], query_res['result']['results']['bindings']) diff --git a/test/integration/status/__init__.py b/test/integration/status/__init__.py deleted file mode 100644 index 9049dd04..00000000 --- a/test/integration/status/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" diff --git a/test/integration/status/status_with_iam.py b/test/integration/status/status_with_iam.py deleted file mode 100644 index 5878cede..00000000 --- a/test/integration/status/status_with_iam.py +++ /dev/null @@ -1,24 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -from requests.exceptions import HTTPError - -from graph_notebook.authentication.iam_credentials_provider.credentials_factory import IAMAuthCredentialsProvider -from graph_notebook.configuration.generate_config import AuthModeEnum -from graph_notebook.request_param_generator.factory import create_request_generator -from graph_notebook.status.get_status import get_status - -from test.integration import IntegrationTest - - -class TestStatusWithIAM(IntegrationTest): - def test_do_status_with_iam_credentials(self): - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - status = get_status(self.host, self.port, self.ssl, request_generator) - self.assertEqual(status['status'], 'healthy') - - def test_do_status_without_iam_credentials(self): - with self.assertRaises(HTTPError): - get_status(self.host, self.port, self.ssl) diff --git a/test/integration/system/system_with_iam.py b/test/integration/system/system_with_iam.py deleted file mode 100644 index a3298630..00000000 --- a/test/integration/system/system_with_iam.py +++ /dev/null @@ -1,25 +0,0 @@ -from requests.exceptions import HTTPError - -from graph_notebook.authentication.iam_credentials_provider.credentials_factory import IAMAuthCredentialsProvider -from graph_notebook.configuration.generate_config import AuthModeEnum -from graph_notebook.request_param_generator.factory import create_request_generator -from graph_notebook.system.database_reset import initiate_database_reset, perform_database_reset -from test.integration import IntegrationTest - - -class TestStatusWithIAM(IntegrationTest): - def test_do_db_reset_initiate_with_iam_credentials(self): - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - result = initiate_database_reset(self.host, self.port, self.ssl, request_generator) - self.assertNotEqual(result['payload']['token'], '') - - def test_do_db_reset_perform_with_wrong_token_with_iam_credentials(self): - request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) - with self.assertRaises(HTTPError) as cm: - perform_database_reset('x', self.host, self.port, self.ssl, request_generator) - expected_message = "System command parameter 'token' : 'x' does not match database reset token" - self.assertEqual(expected_message, str(cm.exception.response.json()['detailedMessage'])) - - def test_do_db_reset_initiate_without_iam_credentials(self): - with self.assertRaises(HTTPError): - initiate_database_reset(self.host, self.port, self.ssl) diff --git a/test/integration/system/system_without_iam.py b/test/integration/system/system_without_iam.py deleted file mode 100644 index 566f0a12..00000000 --- a/test/integration/system/system_without_iam.py +++ /dev/null @@ -1,15 +0,0 @@ -from requests.exceptions import HTTPError -from graph_notebook.system.database_reset import initiate_database_reset, perform_database_reset -from test.integration import IntegrationTest - - -class TestStatusWithoutIAM(IntegrationTest): - def test_do_database_reset_initiate(self): - result = initiate_database_reset(self.host, self.port, self.ssl) - self.assertNotEqual(result['payload']['token'], '') - - def test_do_database_reset_perform_with_wrong_token(self): - with self.assertRaises(HTTPError) as cm: - perform_database_reset('x', self.host, self.port, self.ssl) - expected_message = "System command parameter 'token' : 'x' does not match database reset token" - self.assertEqual(expected_message, str(cm.exception.response.json()['detailedMessage'])) diff --git a/test/integration/without_iam/__init__.py b/test/integration/without_iam/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/integration/without_iam/gremlin/__init__.py b/test/integration/without_iam/gremlin/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/integration/gremlin/bug_fix_tests_without_iam.py b/test/integration/without_iam/gremlin/test_gremlin_patching.py similarity index 71% rename from test/integration/gremlin/bug_fix_tests_without_iam.py rename to test/integration/without_iam/gremlin/test_gremlin_patching.py index 5a7c63d4..37557faa 100644 --- a/test/integration/gremlin/bug_fix_tests_without_iam.py +++ b/test/integration/without_iam/gremlin/test_gremlin_patching.py @@ -5,13 +5,11 @@ import logging -from graph_notebook.gremlin.client_provider.default_client import ClientProvider -from graph_notebook.gremlin.query import do_gremlin_query +import pytest from test.integration import IntegrationTest - -logger = logging.getLogger('TestUnhashableTypeDict') +logger = logging.getLogger('test_bug_fixes') class TestBugFixes(IntegrationTest): @@ -21,12 +19,8 @@ class TestBugFixes(IntegrationTest): is not. We ran into this in a Data Lab a while back but we worked around it there by monkey patching the Gremlin Python client. We may want to do the same for the version of Gremlin Python used by the workbench.""" - @classmethod - def setUpClass(cls): - super(TestBugFixes, cls).setUpClass() - - cls.client_provider = ClientProvider() - + def setUp(self) -> None: + self.client = self.client_builder.build() queries = [ "g.addV('Interest').property(id,'i1').property('value', 4)", "g.addV('Priority').property(id, 'p1').property('name', 'P1')", @@ -34,25 +28,19 @@ def setUpClass(cls): "g.V('m1').addE('interested').to(g.V('i1'))", "g.V('m1').addE('prioritized').to(g.V('p1'))" ] - cls.runQueries(queries) + for q in queries: + self.client.gremlin_query(q) - @classmethod - def tearDownClass(cls): + def tearDown(self) -> None: queries = [ "g.V('i1').drop()", "g.V('p1').drop()", "g.V('m1').drop()" ] - cls.runQueries(queries) - - @classmethod - def runQueries(cls, queries): - for query in queries: - try: - do_gremlin_query(query, cls.host, cls.port, cls.ssl, cls.client_provider) - except Exception as e: - logger.error(f'query {query} failed due to {e}') + for q in queries: + self.client.gremlin_query(q) + @pytest.mark.gremlin def test_do_gremlin_query_with_map_as_key(self): query = """ g.V().hasLabel("Interest").as("int") @@ -63,7 +51,7 @@ def test_do_gremlin_query_with_map_as_key(self): .by("name") .groupCount().unfold() """ - results = do_gremlin_query(query, self.host, self.port, self.ssl, self.client_provider) + results = self.client.gremlin_query(query) keys_are_hashable = True for key in results[0].keys(): try: @@ -73,13 +61,14 @@ def test_do_gremlin_query_with_map_as_key(self): break self.assertEqual(keys_are_hashable, True) + @pytest.mark.gremlin def test_do_gremlin_query_with_list_as_key(self): query = """ g.V('m1').group() .by(out().fold()) .by(out().count()) """ - results = do_gremlin_query(query, self.host, self.port, self.ssl, self.client_provider) + results = self.client.gremlin_query(query) keys_are_hashable = True for key in results[0].keys(): try: diff --git a/test/integration/without_iam/gremlin/test_gremlin_query.py b/test/integration/without_iam/gremlin/test_gremlin_query.py new file mode 100644 index 00000000..79ee575b --- /dev/null +++ b/test/integration/without_iam/gremlin/test_gremlin_query.py @@ -0,0 +1,37 @@ +""" +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +SPDX-License-Identifier: Apache-2.0 +""" +import pytest +from gremlin_python.structure.graph import Vertex + +from test.integration import IntegrationTest + + +class TestGremlin(IntegrationTest): + @pytest.mark.gremlin + def test_do_gremlin_query(self): + query = 'g.V().limit(1)' + results = self.client.gremlin_query(query) + assert type(results) is list + for r in results: + assert type(r) is Vertex + + self.assertEqual(type(results), list) + + @pytest.mark.gremlin + def test_do_gremlin_explain(self): + query = 'g.V().limit(1)' + res = self.client.gremlin_explain(query) + assert res.status_code == 200 + results = res.content.decode('utf-8') + self.assertTrue('Explain' in results) + + @pytest.mark.gremlin + def test_do_gremlin_profile(self): + query = 'g.V().limit(1)' + res = self.client.gremlin_profile(query) + assert res.status_code == 200 + + results = res.content.decode('utf-8') + self.assertTrue('Profile' in results) diff --git a/test/integration/without_iam/gremlin/test_gremlin_status_without_iam.py b/test/integration/without_iam/gremlin/test_gremlin_status_without_iam.py new file mode 100644 index 00000000..37337ee0 --- /dev/null +++ b/test/integration/without_iam/gremlin/test_gremlin_status_without_iam.py @@ -0,0 +1,113 @@ +""" +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +SPDX-License-Identifier: Apache-2.0 +""" + +import threading +import time + +import pytest +import concurrent.futures +from os import cpu_count + +from gremlin_python.driver.protocol import GremlinServerError + +from graph_notebook.neptune.client import Client +from test.integration import DataDrivenGremlinTest + + +def long_running_gremlin_query(c: Client, query: str): + res = c.gremlin_query(query) + return res + + +class TestGremlinStatusWithoutIam(DataDrivenGremlinTest): + @pytest.mark.neptune + def test_do_gremlin_status_nonexistent(self): + query_id = "some-guid-here" + res = self.client.gremlin_status(query_id) + assert res.status_code == 400 + js = res.json() + assert js['code'] == 'InvalidParameterException' + assert js['detailedMessage'] == f'Supplied queryId {query_id} is invalid' + + @pytest.mark.neptune + def test_do_gremlin_cancel_nonexistent(self): + query_id = "some-guid-here" + res = self.client.gremlin_cancel(query_id) + assert res.status_code == 400 + js = res.json() + assert js['code'] == 'InvalidParameterException' + assert js['detailedMessage'] == f'Supplied queryId {query_id} is invalid' + + @pytest.mark.neptune + def test_do_gremlin_cancel_empty_query_id(self): + with self.assertRaises(ValueError): + self.client.gremlin_cancel('') + + @pytest.mark.neptune + def test_do_gremlin_cancel_non_str_query_id(self): + with self.assertRaises(ValueError): + self.client.gremlin_cancel(42) + + @pytest.mark.neptune + def test_do_gremlin_status_and_cancel(self): + long_running_query = "g.V().out().out().out().out().out().out().out().out()" + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(long_running_gremlin_query, self.client, long_running_query) + + time.sleep(1) + status_res = self.client.gremlin_status() + assert status_res.status_code == 200 + + status_js = status_res.json() + query_id = '' + for q in status_js['queries']: + if q['queryString'] == long_running_query: + query_id = q['queryId'] + + assert query_id != '' + + cancel_res = self.client.gremlin_cancel(query_id) + assert cancel_res.status_code == 200 + assert cancel_res.json()['status'] == '200 OK' + + time.sleep(1) + status_after_cancel = self.client.gremlin_status(query_id) + assert status_after_cancel.status_code == 400 # check that the query is no longer valid + assert status_after_cancel.json()['code'] == 'InvalidParameterException' + + with pytest.raises(GremlinServerError): + # this result corresponds to the cancel query, so our gremlin client will raise an exception + future.result() + + @pytest.mark.iam + @pytest.mark.neptune + def test_do_gremlin_status_include_waiting(self): + query = "g.V().out().out().out().out()" + num_threads = cpu_count() * 4 + threads = [] + for x in range(0, num_threads): + thread = threading.Thread(target=long_running_gremlin_query, args=(self.client, query)) + thread.start() + threads.append(thread) + + time.sleep(5) + + res = self.client.gremlin_status(include_waiting=True) + assert res.status_code == 200 + status_res = res.json() + + self.assertEqual(type(status_res), dict) + self.assertTrue('acceptedQueryCount' in status_res) + self.assertTrue('runningQueryCount' in status_res) + self.assertTrue('queries' in status_res) + self.assertEqual(status_res['acceptedQueryCount'], len(status_res['queries'])) + + for q in status_res['queries']: + # cancel all the queries we executed since they can take a very long time. + if q['queryString'] == query: + self.client.gremlin_cancel(q['queryId']) + + for t in threads: + t.join() diff --git a/src/graph_notebook/gremlin/client_provider/__init__.py b/test/integration/without_iam/network/__init__.py similarity index 100% rename from src/graph_notebook/gremlin/client_provider/__init__.py rename to test/integration/without_iam/network/__init__.py diff --git a/src/graph_notebook/loader/__init__.py b/test/integration/without_iam/network/gremlin/__init__.py similarity index 100% rename from src/graph_notebook/loader/__init__.py rename to test/integration/without_iam/network/gremlin/__init__.py diff --git a/test/integration/network/gremlin/gremlin_network_from_queries.py b/test/integration/without_iam/network/gremlin/test_gremlin_network_from_queries.py similarity index 81% rename from test/integration/network/gremlin/gremlin_network_from_queries.py rename to test/integration/without_iam/network/gremlin/test_gremlin_network_from_queries.py index e96f08f9..51b81748 100644 --- a/test/integration/network/gremlin/gremlin_network_from_queries.py +++ b/test/integration/without_iam/network/gremlin/test_gremlin_network_from_queries.py @@ -2,18 +2,19 @@ Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. SPDX-License-Identifier: Apache-2.0 """ +import pytest -from graph_notebook.gremlin.query import do_gremlin_query from graph_notebook.network.gremlin.GremlinNetwork import GremlinNetwork from test.integration import DataDrivenGremlinTest class TestGremlinNetwork(DataDrivenGremlinTest): + + @pytest.mark.gremlin def test_add_paths_to_network(self): airports_path_query = "g.V().has('code', 'SEA').outE().inV().path()" - results = do_gremlin_query(airports_path_query, self.host, self.port, self.ssl, self.client_provider) - + results = self.client.gremlin_query(airports_path_query) gremlin_network = GremlinNetwork() gremlin_network.add_results(results) sea_code = '22' @@ -23,10 +24,10 @@ def test_add_paths_to_network(self): actual_label = gremlin_network.graph[sea_code][aus_code][edge_id]['label'] self.assertEqual(expected_label, actual_label) + @pytest.mark.gremlin def test_add_value_map_to_network(self): airports_path_query = "g.V().has('code', 'SEA').outE().inV().path().by(valueMap(true))" - results = do_gremlin_query(airports_path_query, self.host, self.port, self.ssl, self.client_provider) - + results = self.client.gremlin_query(airports_path_query) gremlin_network = GremlinNetwork() gremlin_network.add_results(results) edge_id = '4406' @@ -34,9 +35,10 @@ def test_add_value_map_to_network(self): actual_label = gremlin_network.graph.nodes.get(edge_id)['label'] self.assertEqual(expected_label, actual_label) + @pytest.mark.gremlin def test_add_entire_path(self): sea_to_bmi = "g.V().has('code', 'SEA').outE().inV().has('code', 'ORD').outE().inV().has('code', 'BMI').path()" - results = do_gremlin_query(sea_to_bmi, self.host, self.port, self.ssl, self.client_provider) + results = self.client.gremlin_query(sea_to_bmi) gremlin_network = GremlinNetwork() gremlin_network.add_results(results) @@ -48,9 +50,10 @@ def test_add_entire_path(self): self.assertTrue(gremlin_network.graph.has_edge('22', '18', '4420')) self.assertTrue(gremlin_network.graph.has_edge('18', '359', '7126')) + @pytest.mark.gremlin def test_add_paths_with_bad_pattern(self): query = "g.V().out().out().path().limit(10)" - results = do_gremlin_query(query, self.host, self.port, self.ssl, self.client_provider) + results = self.client.gremlin_query(query) gremlin_network = GremlinNetwork() gremlin_network.add_results(results) @@ -61,17 +64,19 @@ def test_add_paths_with_bad_pattern(self): self.assertEqual('', edge['label']) self.assertFalse(edge['arrows']['to']['enabled']) + @pytest.mark.gremlin def test_add_path_with_repeat(self): query = "g.V().has('airport', 'code', 'ANC').repeat(outE().inV().simplePath()).times(2).path().by('code').by()" - results = do_gremlin_query(query, self.host, self.port, self.ssl, self.client_provider) + results = self.client.gremlin_query(query) gremlin_network = GremlinNetwork() gremlin_network.add_results(results) self.assertEqual('route', gremlin_network.graph.edges[('ANC', 'BLI', '5276')]['label']) + @pytest.mark.gremlin def test_valuemap_without_ids(self): query = "g.V().has('code', 'ANC').out().path().by(valueMap()).limit(10)" - results = do_gremlin_query(query, self.host, self.port, self.ssl, self.client_provider) + results = self.client.gremlin_query(query) gremlin_network = GremlinNetwork() gremlin_network.add_results(results) @@ -79,25 +84,28 @@ def test_valuemap_without_ids(self): node = gremlin_network.graph.nodes.get(n) self.assertEqual(gremlin_network.label_max_length, len(node['label'])) + @pytest.mark.gremlin def test_path_without_by_nodes_have_ids(self): query = "g.V().has('code', 'AUS').outE().inV().outE().inV().has('code', 'SEA').path()" - results = do_gremlin_query(query, self.host, self.port, self.ssl, self.client_provider) + results = self.client.gremlin_query(query) gremlin_network = GremlinNetwork() gremlin_network.add_results(results) node = gremlin_network.graph.nodes.get('9') self.assertIsNotNone(node) + @pytest.mark.gremlin def test_path_without_by_oute_has_arrows(self): query = "g.V().hasLabel('airport').has('code', 'SEA').outE().inV().path()" - results = do_gremlin_query(query, self.host, self.port, self.ssl, self.client_provider) + results = self.client.gremlin_query(query) gremlin_network = GremlinNetwork() gremlin_network.add_results(results) edge = gremlin_network.graph.edges[('22', '151', '7389')] self.assertTrue('arrows' not in edge) + @pytest.mark.gremlin def test_path_without_by_ine_has_arrows(self): query = "g.V().hasLabel('airport').has('code', 'SEA').inE().outV().path()" - results = do_gremlin_query(query, self.host, self.port, self.ssl, self.client_provider) + results = self.client.gremlin_query(query) gremlin_network = GremlinNetwork() gremlin_network.add_results(results) edge = gremlin_network.graph.edges[('3670', '22', '53637')] diff --git a/test/integration/network/gremlin/gremlin_network_with_pattern.py b/test/integration/without_iam/network/gremlin/test_gremlin_network_with_pattern.py similarity index 83% rename from test/integration/network/gremlin/gremlin_network_with_pattern.py rename to test/integration/without_iam/network/gremlin/test_gremlin_network_with_pattern.py index 43f3fce0..1a154f88 100644 --- a/test/integration/network/gremlin/gremlin_network_with_pattern.py +++ b/test/integration/without_iam/network/gremlin/test_gremlin_network_with_pattern.py @@ -3,7 +3,6 @@ SPDX-License-Identifier: Apache-2.0 """ -from graph_notebook.gremlin.query import do_gremlin_query from graph_notebook.network.gremlin.GremlinNetwork import GremlinNetwork, PathPattern from test.integration import DataDrivenGremlinTest @@ -12,7 +11,7 @@ class TestGremlinNetwork(DataDrivenGremlinTest): def test_add_path_with_edge_object(self): query = "g.V().has('airport','code','AUS').outE().inV().path().by('code').by().limit(10)" - results = do_gremlin_query(query, self.host, self.port, self.ssl, self.client_provider) + results = self.client.gremlin_query(query) gn = GremlinNetwork() pattern = [PathPattern.V, PathPattern.OUT_E, PathPattern.IN_V] gn.add_results_with_pattern(results, pattern) @@ -27,7 +26,7 @@ def test_add_path_by_dist(self): path(). by('code'). by('dist')""" - results = do_gremlin_query(query, self.host, self.port, self.ssl, self.client_provider) + results = self.client.gremlin_query(query) gn = GremlinNetwork() pattern = [PathPattern.V, PathPattern.OUT_E, PathPattern.IN_V, PathPattern.OUT_E] gn.add_results_with_pattern(results, pattern) @@ -41,7 +40,7 @@ def test_path_with_dict(self): by(valueMap('code','city','region','desc','lat','lon'). order(local). by(keys))""" - results = do_gremlin_query(query, self.host, self.port, self.ssl, self.client_provider) + results = self.client.gremlin_query(query) gn = GremlinNetwork() pattern = [PathPattern.V, PathPattern.IN_V] gn.add_results_with_pattern(results, pattern) @@ -55,7 +54,7 @@ def test_out_v_unhashable_dict(self): out(). path(). by(valueMap())""" - results = do_gremlin_query(query, self.host, self.port, self.ssl, self.client_provider) + results = self.client.gremlin_query(query) gn = GremlinNetwork() pattern = [PathPattern.V, PathPattern.OUT_V] gn.add_results_with_pattern(results, pattern) diff --git a/test/integration/notebook/GraphNotebookIntegrationTest.py b/test/integration/without_iam/notebook/GraphNotebookIntegrationTest.py similarity index 100% rename from test/integration/notebook/GraphNotebookIntegrationTest.py rename to test/integration/without_iam/notebook/GraphNotebookIntegrationTest.py diff --git a/src/graph_notebook/request_param_generator/__init__.py b/test/integration/without_iam/notebook/__init__.py similarity index 100% rename from src/graph_notebook/request_param_generator/__init__.py rename to test/integration/without_iam/notebook/__init__.py diff --git a/test/integration/notebook/test_gremlin_graph_notebook.py b/test/integration/without_iam/notebook/test_gremlin_graph_notebook.py similarity index 85% rename from test/integration/notebook/test_gremlin_graph_notebook.py rename to test/integration/without_iam/notebook/test_gremlin_graph_notebook.py index c8ed1f92..2cac2548 100644 --- a/test/integration/notebook/test_gremlin_graph_notebook.py +++ b/test/integration/without_iam/notebook/test_gremlin_graph_notebook.py @@ -2,8 +2,9 @@ Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. SPDX-License-Identifier: Apache-2.0 """ +import pytest -from test.integration.notebook.GraphNotebookIntegrationTest import GraphNotebookIntegrationTest +from test.integration.without_iam.notebook import GraphNotebookIntegrationTest class TestGraphMagicGremlin(GraphNotebookIntegrationTest): @@ -11,6 +12,8 @@ def tearDown(self) -> None: delete_query = "g.V('graph-notebook-test').drop()" self.ip.run_cell_magic('gremlin', 'query', delete_query) + @pytest.mark.jupyter + @pytest.mark.gremlin def test_gremlin_query(self): label = 'graph-notebook-test' query = f"g.addV('{label}')" diff --git a/test/integration/notebook/test_sparql_graph_notebook.py b/test/integration/without_iam/notebook/test_sparql_graph_notebook.py similarity index 64% rename from test/integration/notebook/test_sparql_graph_notebook.py rename to test/integration/without_iam/notebook/test_sparql_graph_notebook.py index a71c1935..134115b3 100644 --- a/test/integration/notebook/test_sparql_graph_notebook.py +++ b/test/integration/without_iam/notebook/test_sparql_graph_notebook.py @@ -2,12 +2,26 @@ Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. SPDX-License-Identifier: Apache-2.0 """ +import pytest -from test.integration.notebook.GraphNotebookIntegrationTest import GraphNotebookIntegrationTest +from test.integration.without_iam.notebook import GraphNotebookIntegrationTest class TestGraphMagicGremlin(GraphNotebookIntegrationTest): + + @pytest.mark.jupyter + @pytest.mark.sparql def test_sparql_query(self): + query = 'SELECT * WHERE {?s ?o ?p } LIMIT 1' + store_to_var = 'sparql_res' + self.ip.run_cell_magic('sparql', f'--store-to {store_to_var}', query) + self.assertFalse('graph_notebook_error' in self.ip.user_ns) + sparql_res = self.ip.user_ns[store_to_var] + self.assertEqual(['s', 'o','p'], sparql_res['head']['vars']) + + @pytest.mark.jupyter + @pytest.mark.sparql + def test_sparql_query_explain(self): query = 'SELECT * WHERE {?s ?o ?p } LIMIT 1' store_to_var = 'sparql_res' self.ip.run_cell_magic('sparql', f'explain --store-to {store_to_var}', query) @@ -16,6 +30,7 @@ def test_sparql_query(self): self.assertTrue(sparql_res.startswith('')) self.assertTrue('' in sparql_res) + @pytest.mark.jupyter def test_load_sparql_config(self): config = '''{ "host": "localhost", diff --git a/test/integration/notebook/test_status_graph_notebook.py b/test/integration/without_iam/notebook/test_status_graph_notebook.py similarity index 51% rename from test/integration/notebook/test_status_graph_notebook.py rename to test/integration/without_iam/notebook/test_status_graph_notebook.py index 1b280f64..2af95908 100644 --- a/test/integration/notebook/test_status_graph_notebook.py +++ b/test/integration/without_iam/notebook/test_status_graph_notebook.py @@ -2,11 +2,18 @@ Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. SPDX-License-Identifier: Apache-2.0 """ +import pytest +from botocore.session import get_session -from test.integration.notebook.GraphNotebookIntegrationTest import GraphNotebookIntegrationTest +from test.integration.without_iam.notebook import GraphNotebookIntegrationTest class TestGraphMagicStatus(GraphNotebookIntegrationTest): + def setUp(self) -> None: + self.client = self.client_builder.with_iam(get_session()).build() + + @pytest.mark.jupyter + @pytest.mark.neptune def test_status(self): res = self.ip.run_line_magic('status', '') self.assertEqual('healthy', res['status']) diff --git a/src/graph_notebook/sparql/__init__.py b/test/integration/without_iam/sparql/__init__.py similarity index 100% rename from src/graph_notebook/sparql/__init__.py rename to test/integration/without_iam/sparql/__init__.py diff --git a/test/integration/sparql/sparql_query_without_iam.py b/test/integration/without_iam/sparql/test_sparql_query_without_iam.py similarity index 59% rename from test/integration/sparql/sparql_query_without_iam.py rename to test/integration/without_iam/sparql/test_sparql_query_without_iam.py index c1bccc27..58cf3045 100644 --- a/test/integration/sparql/sparql_query_without_iam.py +++ b/test/integration/without_iam/sparql/test_sparql_query_without_iam.py @@ -2,27 +2,30 @@ Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. SPDX-License-Identifier: Apache-2.0 """ - -from graph_notebook.request_param_generator.sparql_request_generator import SPARQLRequestGenerator -from graph_notebook.sparql.query import do_sparql_query, do_sparql_explain +import pytest from test.integration import IntegrationTest class TestSparqlQuery(IntegrationTest): + @pytest.mark.sparql def test_do_sparql_query(self): query = "SELECT * WHERE {?s ?p ?o} LIMIT 1" - request_generator = SPARQLRequestGenerator() - res = do_sparql_query(query, self.host, self.port, self.ssl, request_generator) + sparql_res = self.client.sparql(query) + assert sparql_res.status_code == 200 + res = sparql_res.json() + self.assertEqual(type(res), dict) self.assertTrue('s' in res['head']['vars']) self.assertTrue('p' in res['head']['vars']) self.assertTrue('o' in res['head']['vars']) + @pytest.mark.sparql def test_do_sparql_explain(self): query = "SELECT * WHERE {?s ?p ?o} LIMIT 1" - request_generator = SPARQLRequestGenerator() - res = do_sparql_explain(query, self.host, self.port, self.ssl, request_generator) + query_res = self.client.sparql_explain(query) + assert query_res.status_code == 200 + res = query_res.content.decode('utf-8') self.assertEqual(type(res), str) self.assertTrue(res.startswith('')) diff --git a/test/integration/without_iam/sparql/test_sparql_status_without_iam.py b/test/integration/without_iam/sparql/test_sparql_status_without_iam.py new file mode 100644 index 00000000..9ee39c99 --- /dev/null +++ b/test/integration/without_iam/sparql/test_sparql_status_without_iam.py @@ -0,0 +1,118 @@ +""" +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +SPDX-License-Identifier: Apache-2.0 +""" + +import concurrent.futures + +import logging +import time +import pytest + +from graph_notebook.neptune.client import Client +from test.integration import DataDrivenSparqlTest + +logger = logging.getLogger('TestSparqlStatusWithoutIam') + + +def long_running_sparql_query(c: Client, query: str): + res = c.sparql(query) + return res + + +class TestSparqlStatusWithoutIam(DataDrivenSparqlTest): + @pytest.mark.neptune + def test_do_sparql_status_nonexistent(self): + query_id = "invalid-guid" + status_res = self.client.sparql_status(query_id) + assert status_res.status_code == 200 + assert status_res.content == b'' + + @pytest.mark.neptune + def test_do_sparql_cancel_nonexistent(self): + query_id = "invalid-guid" + cancel_res = self.client.sparql_cancel(query_id) + assert cancel_res.status_code == 200 + assert cancel_res.content == b'' + + @pytest.mark.neptune + def test_do_sparql_cancel_empty_query_id(self): + with pytest.raises(ValueError): + self.client.sparql_cancel('') + + @pytest.mark.neptune + def test_do_sparql_cancel_non_str_query_id(self): + with pytest.raises(ValueError): + self.client.sparql_cancel(42) + + @pytest.mark.neptune + def test_do_sparql_status_and_cancel(self): + query = "SELECT * WHERE { ?s ?p ?o . ?s2 ?p2 ?o2 .?s3 ?p3 ?o3 . ?s4 ?s5 ?s6 .} ORDER BY DESC(?s)" + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(long_running_sparql_query, self.client, query) + time.sleep(1) + + status = self.client.sparql_status() + status_res = status.json() + assert 'acceptedQueryCount' in status_res + assert 'runningQueryCount' in status_res + assert 'queries' in status_res + + time.sleep(1) + + query_id = '' + for q in status_res['queries']: + if query in q['queryString']: + query_id = q['queryId'] + + self.assertNotEqual(query_id, '') + + cancel = self.client.sparql_cancel(query_id, False) + cancel_res = cancel.json() + + assert 'acceptedQueryCount' in cancel_res + assert 'acceptedQueryCount' in cancel_res + assert 'runningQueryCount' in cancel_res + assert 'queries' in cancel_res + + res = future.result() + assert res.status_code == 500 + raw = res.json() + assert raw['code'] == 'CancelledByUserException' + assert raw['detailedMessage'] == 'Operation terminated (cancelled by user)' + + @pytest.mark.neptune + def test_do_sparql_status_and_cancel_silently(self): + query = "SELECT * WHERE { ?s ?p ?o . ?s2 ?p2 ?o2 .?s3 ?p3 ?o3 . ?s4 ?s5 ?s6 .} ORDER BY DESC(?s)" + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(long_running_sparql_query, self.client, query) + time.sleep(1) + + status = self.client.sparql_status() + status_res = status.json() + assert 'acceptedQueryCount' in status_res + assert 'runningQueryCount' in status_res + assert 'queries' in status_res + + time.sleep(1) + + query_id = '' + for q in status_res['queries']: + if query in q['queryString']: + query_id = q['queryId'] + + assert query_id != '' + + cancel = self.client.sparql_cancel(query_id, True) + cancel_res = cancel.json() + assert 'acceptedQueryCount' in cancel_res + assert 'runningQueryCount' in cancel_res + assert 'queries' in cancel_res + + res = future.result() + query_res = res.json() + assert type(query_res) is dict + assert 's3' in query_res['head']['vars'] + assert 'p3' in query_res['head']['vars'] + assert 'o3' in query_res['head']['vars'] + assert [] == query_res['results']['bindings'] diff --git a/src/graph_notebook/status/__init__.py b/test/integration/without_iam/status/__init__.py similarity index 100% rename from src/graph_notebook/status/__init__.py rename to test/integration/without_iam/status/__init__.py diff --git a/test/integration/status/status_without_iam.py b/test/integration/without_iam/status/test_status_without_iam.py similarity index 77% rename from test/integration/status/status_without_iam.py rename to test/integration/without_iam/status/test_status_without_iam.py index d72ce1d3..58b69c4b 100644 --- a/test/integration/status/status_without_iam.py +++ b/test/integration/without_iam/status/test_status_without_iam.py @@ -2,6 +2,7 @@ Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. SPDX-License-Identifier: Apache-2.0 """ +import pytest from graph_notebook.status.get_status import get_status @@ -9,6 +10,9 @@ class TestStatusWithoutIAM(IntegrationTest): + + @pytest.mark.neptune def test_do_status(self): - status = get_status(self.host, self.port, self.ssl) + res = self.client.status() + status = res.json() self.assertEqual(status['status'], 'healthy') diff --git a/test/integration/without_iam/system/__init__.py b/test/integration/without_iam/system/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/integration/without_iam/system/test_system_without_iam.py b/test/integration/without_iam/system/test_system_without_iam.py new file mode 100644 index 00000000..c171e96a --- /dev/null +++ b/test/integration/without_iam/system/test_system_without_iam.py @@ -0,0 +1,18 @@ +import pytest +from test.integration import IntegrationTest + + +class TestStatusWithoutIAM(IntegrationTest): + + @pytest.mark.neptune + def test_do_database_reset_initiate(self): + res = self.client.initiate_reset() + result = res.json() + self.assertNotEqual(result['payload']['token'], '') + + @pytest.mark.neptune + def test_do_database_reset_perform_with_wrong_token(self): + res = self.client.perform_reset('invalid') + assert res.status_code == 400 + expected_message = "System command parameter 'token' : 'invalid' does not match database reset token" + assert expected_message == res.json()['detailedMessage'] diff --git a/test/unit/configuration/test_configuration.py b/test/unit/configuration/test_configuration.py index a845ce4a..0bbb4d50 100644 --- a/test/unit/configuration/test_configuration.py +++ b/test/unit/configuration/test_configuration.py @@ -7,9 +7,7 @@ import unittest from graph_notebook.configuration.get_config import get_config -from graph_notebook.authentication.iam_credentials_provider.credentials_factory import IAMAuthCredentialsProvider -from graph_notebook.configuration.generate_config import Configuration, DEFAULT_AUTH_MODE, \ - DEFAULT_IAM_CREDENTIALS_PROVIDER, AuthModeEnum, generate_config +from graph_notebook.configuration.generate_config import Configuration, DEFAULT_AUTH_MODE, AuthModeEnum, generate_config class TestGenerateConfiguration(unittest.TestCase): @@ -28,25 +26,21 @@ def test_configuration_default_auth_defaults(self): self.assertEqual(self.host, config.host) self.assertEqual(self.port, config.port) self.assertEqual(DEFAULT_AUTH_MODE, config.auth_mode) - self.assertEqual(DEFAULT_IAM_CREDENTIALS_PROVIDER, config.iam_credentials_provider_type) self.assertEqual(True, config.ssl) self.assertEqual('', config.load_from_s3_arn) def test_configuration_override_defaults(self): auth_mode = AuthModeEnum.IAM - credentials_provider = IAMAuthCredentialsProvider.ENV ssl = False loader_arn = 'foo' - config = Configuration(self.host, self.port, auth_mode, credentials_provider, loader_arn, ssl) + config = Configuration(self.host, self.port, auth_mode, loader_arn, ssl) self.assertEqual(auth_mode, config.auth_mode) - self.assertEqual(credentials_provider, config.iam_credentials_provider_type) self.assertEqual(ssl, config.ssl) self.assertEqual(loader_arn, config.load_from_s3_arn) def test_generate_configuration_with_defaults(self): config = Configuration(self.host, self.port) c = generate_config(config.host, config.port, config.auth_mode, config.ssl, - config.iam_credentials_provider_type, config.load_from_s3_arn, config.aws_region) c.write_to_file(self.test_file_path) config_from_file = get_config(self.test_file_path) @@ -54,14 +48,12 @@ def test_generate_configuration_with_defaults(self): def test_generate_configuration_override_defaults(self): auth_mode = AuthModeEnum.IAM - credentials_provider = IAMAuthCredentialsProvider.ENV ssl = False loader_arn = 'foo' aws_region = 'us-west-2' - config = Configuration(self.host, self.port, auth_mode, credentials_provider, loader_arn, ssl, aws_region) + config = Configuration(self.host, self.port, auth_mode, loader_arn, ssl, aws_region) c = generate_config(config.host, config.port, config.auth_mode, config.ssl, - config.iam_credentials_provider_type, config.load_from_s3_arn, config.aws_region) c.write_to_file(self.test_file_path) config_from_file = get_config(self.test_file_path) diff --git a/test/unit/configuration/test_configuration_from_main.py b/test/unit/configuration/test_configuration_from_main.py index 5b2d6452..ebfd604c 100644 --- a/test/unit/configuration/test_configuration_from_main.py +++ b/test/unit/configuration/test_configuration_from_main.py @@ -8,7 +8,6 @@ from graph_notebook.configuration.generate_config import AuthModeEnum, Configuration from graph_notebook.configuration.get_config import get_config -from graph_notebook.authentication.iam_credentials_provider.credentials_factory import IAMAuthCredentialsProvider class TestGenerateConfigurationMain(unittest.TestCase): @@ -24,11 +23,11 @@ def tearDown(self) -> None: os.remove(self.test_file_path) def test_generate_configuration_main_defaults(self): - expected_config = Configuration(self.host, self.port, AuthModeEnum.DEFAULT, IAMAuthCredentialsProvider.ROLE, '', True) + expected_config = Configuration(self.host, self.port, AuthModeEnum.DEFAULT, '', True) self.generate_config_from_main_and_test(expected_config) def test_generate_configuration_main_override_defaults(self): - expected_config = Configuration(self.host, self.port, AuthModeEnum.IAM, IAMAuthCredentialsProvider.ROLE, 'loader_arn', False) + expected_config = Configuration(self.host, self.port, AuthModeEnum.IAM, 'loader_arn', False) self.generate_config_from_main_and_test(expected_config) def test_generate_configuration_main_empty_args(self): @@ -42,7 +41,7 @@ def generate_config_from_main_and_test(self, source_config: Configuration): # This will run the main method that our install script runs on a Sagemaker notebook. # The return code should be 0, but more importantly, we need to assert that the # Configuration object we get from the resulting file is what we expect. - result = os.system(f'{self.python_cmd} -m graph_notebook.configuration.generate_config --host "{source_config.host}" --port "{source_config.port}" --auth_mode "{source_config.auth_mode.value}" --ssl "{source_config.ssl}" --iam_credentials_provider "{source_config.iam_credentials_provider_type.value}" --load_from_s3_arn "{source_config.load_from_s3_arn}" --config_destination="{self.test_file_path}" ') + result = os.system(f'{self.python_cmd} -m graph_notebook.configuration.generate_config --host "{source_config.host}" --port "{source_config.port}" --auth_mode "{source_config.auth_mode.value}" --ssl "{source_config.ssl}" --load_from_s3_arn "{source_config.load_from_s3_arn}" --config_destination="{self.test_file_path}" ') self.assertEqual(result, 0) config = get_config(self.test_file_path) self.assertEqual(source_config.to_dict(), config.to_dict()) diff --git a/test/unit/gremlin/__init__.py b/test/unit/gremlin/__init__.py deleted file mode 100644 index 9049dd04..00000000 --- a/test/unit/gremlin/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" diff --git a/test/unit/request_param_generator/__init__.py b/test/unit/request_param_generator/__init__.py deleted file mode 100644 index 9049dd04..00000000 --- a/test/unit/request_param_generator/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" diff --git a/test/unit/request_param_generator/test_default_request_generator.py b/test/unit/request_param_generator/test_default_request_generator.py deleted file mode 100644 index e948d2a6..00000000 --- a/test/unit/request_param_generator/test_default_request_generator.py +++ /dev/null @@ -1,32 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import unittest - -from graph_notebook.request_param_generator.default_request_generator import DefaultRequestGenerator - - -class TestDefaultRequestGenerator(unittest.TestCase): - def test_generate_request_params(self): - method = 'post' - action = 'foo' - query = { - 'bar': 'baz' - } - host = 'host_endpoint' - port = 8182 - protocol = 'https' - headers = { - 'header1': 'header_value_1' - } - - rpg = DefaultRequestGenerator() - request_params = rpg.generate_request_params(method, action, query, host, port, protocol, headers) - - expected_url = f'{protocol}://{host}:{port}/{action}' - self.assertEqual(request_params['method'], method) - self.assertEqual(request_params['url'], expected_url) - self.assertEqual(request_params['headers'], headers) - self.assertEqual(request_params['params'], query) diff --git a/test/unit/request_param_generator/test_factory_generator.py b/test/unit/request_param_generator/test_factory_generator.py deleted file mode 100644 index 6878cb6d..00000000 --- a/test/unit/request_param_generator/test_factory_generator.py +++ /dev/null @@ -1,33 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import unittest - -from graph_notebook.configuration.generate_config import AuthModeEnum -from graph_notebook.authentication.iam_credentials_provider.credentials_factory import IAMAuthCredentialsProvider -from graph_notebook.authentication.iam_credentials_provider.env_credentials_provider import EnvCredentialsProvider -from graph_notebook.request_param_generator.default_request_generator import DefaultRequestGenerator -from graph_notebook.request_param_generator.factory import create_request_generator -from graph_notebook.request_param_generator.iam_request_generator import IamRequestGenerator -from graph_notebook.request_param_generator.sparql_request_generator import SPARQLRequestGenerator - - -class TestRequestParamGeneratorFactory(unittest.TestCase): - def test_create_request_generator_sparql(self): - mode = AuthModeEnum.DEFAULT - command = 'sparql' - rpg = create_request_generator(mode, command=command) - self.assertEqual(SPARQLRequestGenerator, type(rpg)) - - def test_create_request_generator_default(self): - mode = AuthModeEnum.DEFAULT - rpg = create_request_generator(mode) - self.assertEqual(DefaultRequestGenerator, type(rpg)) - - def test_create_request_generator_iam_env(self): - mode = AuthModeEnum.IAM - rpg = create_request_generator(mode, IAMAuthCredentialsProvider.ENV) - self.assertEqual(IamRequestGenerator, type(rpg)) - self.assertEqual(EnvCredentialsProvider, type(rpg.credentials_provider)) diff --git a/test/unit/request_param_generator/test_sparql_request_generator.py b/test/unit/request_param_generator/test_sparql_request_generator.py deleted file mode 100644 index 3584a7c5..00000000 --- a/test/unit/request_param_generator/test_sparql_request_generator.py +++ /dev/null @@ -1,58 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import unittest - -from graph_notebook.request_param_generator.sparql_request_generator import SPARQLRequestGenerator - - -class TestSparqlRequestGenerator(unittest.TestCase): - def test_generate_request_params(self): - method = 'post' - action = 'foo' # action is a no-op since we know it is sparql - query = { - 'bar': 'baz' - } - host = 'host_endpoint' - port = 8182 - protocol = 'https' - headers = { - 'header1': 'header_value_1' - } - - rpg = SPARQLRequestGenerator() - request_params = rpg.generate_request_params(method, action, query, host, port, protocol, headers) - expected_headers = { - 'header1': 'header_value_1', - 'Content-Type': 'application/x-www-form-urlencoded' - } - - expected_url = f'{protocol}://{host}:{port}/{action}' - self.assertEqual(request_params['method'], method) - self.assertEqual(request_params['url'], expected_url) - self.assertEqual(request_params['headers'], expected_headers) - self.assertEqual(request_params['params'], query) - - def test_generate_request_params_no_headers(self): - method = 'post' - action = 'foo' # action is a no-op since we know it is sparql - query = { - 'bar': 'baz' - } - host = 'host_endpoint' - port = 8182 - protocol = 'https' - - rpg = SPARQLRequestGenerator() - request_params = rpg.generate_request_params(method, action, query, host, port, protocol, headers=None) - expected_headers = { - 'Content-Type': 'application/x-www-form-urlencoded' - } - - expected_url = f'{protocol}://{host}:{port}/{action}' - self.assertEqual(request_params['method'], method) - self.assertEqual(request_params['url'], expected_url) - self.assertEqual(request_params['headers'], expected_headers) - self.assertEqual(request_params['params'], query) diff --git a/test/unit/sparql/test_sparql.py b/test/unit/sparql/test_sparql.py index 03d2a4b2..a4bb97cc 100644 --- a/test/unit/sparql/test_sparql.py +++ b/test/unit/sparql/test_sparql.py @@ -5,7 +5,7 @@ import unittest -from graph_notebook.sparql.query import get_query_type, query_type_to_action +from graph_notebook.magics.graph_magic import get_query_type, query_type_to_action class TestSparql(unittest.TestCase): From 01ed4213ddcc6393f12290ebf892e259633376c8 Mon Sep 17 00:00:00 2001 From: Kline Date: Wed, 31 Mar 2021 14:09:30 -0700 Subject: [PATCH 02/16] revert change to setup.py email --- setup.py | 4 ++-- test/integration/IntegrationTest.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index ee43ee23..1fd95dda 100644 --- a/setup.py +++ b/setup.py @@ -55,8 +55,8 @@ def get_version(): setup( name='graph-notebook', - author='amazon-ml', - author_email='amazon-ml-pypi@amazon.com', + author='amazon-neptune', + author_email='amazon-neptune-pypi@amazon.com', description='jupyter notebook extension to connect to graph databases', long_description=long_description, long_description_content_type='text/markdown', diff --git a/test/integration/IntegrationTest.py b/test/integration/IntegrationTest.py index 7f04b80d..058c3bc7 100644 --- a/test/integration/IntegrationTest.py +++ b/test/integration/IntegrationTest.py @@ -9,7 +9,7 @@ from graph_notebook.configuration.generate_config import Configuration, AuthModeEnum from graph_notebook.configuration.get_config import get_config -from graph_notebook.neptune.client import Client, ClientBuilder +from graph_notebook.neptune.client import ClientBuilder from test.integration.NeptuneIntegrationWorkflowSteps import TEST_CONFIG_PATH From 72a41ffe6737e7d3d070781562e346a80cfd70aa Mon Sep 17 00:00:00 2001 From: Kline Date: Wed, 31 Mar 2021 14:09:30 -0700 Subject: [PATCH 03/16] revert change to setup.py email --- setup.py | 2 +- src/graph_notebook/nbextensions/neptune_menu/static/main.js | 6 +++--- test/integration/NeptuneIntegrationWorkflowSteps.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/setup.py b/setup.py index 1fd95dda..bb08f9c5 100644 --- a/setup.py +++ b/setup.py @@ -92,7 +92,7 @@ def get_version(): 'Programming Language :: Python :: 3.6', 'License :: OSI Approved :: Apache Software License' ], - keywords='jupyter ml gremlin sparql', + keywords='jupyter neptune gremlin sparql', tests_require=[ 'pytest' ] diff --git a/src/graph_notebook/nbextensions/neptune_menu/static/main.js b/src/graph_notebook/nbextensions/neptune_menu/static/main.js index 7b08ee84..52ba48fe 100644 --- a/src/graph_notebook/nbextensions/neptune_menu/static/main.js +++ b/src/graph_notebook/nbextensions/neptune_menu/static/main.js @@ -105,7 +105,7 @@ define([ let host='', port='', iam='off', ssl='on'; function configCallback(data){ - console.log('ml menu callback...'); + console.log('neptune menu callback...'); const raw = data['content']['text']; const config = JSON.parse(raw); host = config['host']; @@ -127,10 +127,10 @@ define([ const createMenu = (kernel) => { console.log("kernel type is ", kernel.name); if(kernel.name === "gremlin_kernel" || kernel.name === "sparql_kernel"){ - console.log("skipping ml menu creation"); + console.log("skipping neptune menu creation"); return; } - console.log('creating ml menu from config...'); + console.log('creating neptune menu from config...'); kernel.execute('%load_ext graph_notebook.magics'); kernel.execute( "%graph_notebook_config silent", diff --git a/test/integration/NeptuneIntegrationWorkflowSteps.py b/test/integration/NeptuneIntegrationWorkflowSteps.py index a91f8f19..076c49d6 100644 --- a/test/integration/NeptuneIntegrationWorkflowSteps.py +++ b/test/integration/NeptuneIntegrationWorkflowSteps.py @@ -62,7 +62,7 @@ def set_iam_auth_on_neptune_cluster(cluster_identifier: str, iam_value: bool, ne response = neptune_client.modify_db_cluster(DBClusterIdentifier=cluster_identifier, EnableIAMDatabaseAuthentication=iam_value, ApplyImmediately=True) - logging.info(f'modified ml cluster {cluster_identifier} to set iam auth to {iam_value}: {response}') + logging.info(f'modified neptune cluster {cluster_identifier} to set iam auth to {iam_value}: {response}') # wait for authentication setting to show as changed: while cluster['IAMDatabaseAuthenticationEnabled'] != iam_value: From 8f830bb153f4a5bd10ddc2319d92b088626391df Mon Sep 17 00:00:00 2001 From: Kline Date: Thu, 1 Apr 2021 10:38:45 -0700 Subject: [PATCH 04/16] remove references to iam_credentials_provider_type --- src/graph_notebook/magics/graph_magic.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/graph_notebook/magics/graph_magic.py b/src/graph_notebook/magics/graph_magic.py index 478f0c3e..26af3270 100644 --- a/src/graph_notebook/magics/graph_magic.py +++ b/src/graph_notebook/magics/graph_magic.py @@ -441,8 +441,6 @@ def gremlin_status(self, line='', local_ns: dict = None): @display_exceptions def status(self, line): logger.info(f'calling for status on endpoint {self.graph_notebook_config.host}') - logger.info( - f'used credentials_provider_mode={self.graph_notebook_config.iam_credentials_provider_type.name} and auth_mode={self.graph_notebook_config.auth_mode.name} to make status request') status_res = self.client.status() status_res.raise_for_status() res = status_res.json() @@ -460,8 +458,6 @@ def db_reset(self, line): args = parser.parse_args(line.split()) generate_token = args.generate_token skip_prompt = args.yes - logger.info( - f'used credentials_provider_mode={self.graph_notebook_config.iam_credentials_provider_type.name} and auth_mode={self.graph_notebook_config.auth_mode.name} to make system request') if generate_token is False and args.token == '': if skip_prompt: initiate_res = self.client.initiate_reset() From 0a91da60db84b87af1598125851814efdeb59146 Mon Sep 17 00:00:00 2001 From: Kline Date: Thu, 1 Apr 2021 11:03:43 -0700 Subject: [PATCH 05/16] address flake8 --- test/integration/__init__.py | 2 +- test/integration/iam/__init__.py | 1 - test/integration/iam/load/test_load_with_iam.py | 2 ++ .../iam/ml/test_neptune_client_with_iam.py | 13 +------------ test/integration/iam/system/test_system_with_iam.py | 2 +- .../notebook/test_sparql_graph_notebook.py | 2 +- .../without_iam/status/test_status_without_iam.py | 2 -- 7 files changed, 6 insertions(+), 18 deletions(-) diff --git a/test/integration/__init__.py b/test/integration/__init__.py index 34da10ca..d0f15c29 100644 --- a/test/integration/__init__.py +++ b/test/integration/__init__.py @@ -6,4 +6,4 @@ from .IntegrationTest import IntegrationTest # noqa F401 from .DataDrivenGremlinTest import DataDrivenGremlinTest # noqa F401 from .DataDrivenSparqlTest import DataDrivenSparqlTest # noqa F401 -from .NeptuneIntegrationWorkflowSteps import TEST_CONFIG_PATH +from .NeptuneIntegrationWorkflowSteps import TEST_CONFIG_PATH # noqa F401 diff --git a/test/integration/iam/__init__.py b/test/integration/iam/__init__.py index 023f3c82..9049dd04 100644 --- a/test/integration/iam/__init__.py +++ b/test/integration/iam/__init__.py @@ -2,4 +2,3 @@ Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. SPDX-License-Identifier: Apache-2.0 """ - diff --git a/test/integration/iam/load/test_load_with_iam.py b/test/integration/iam/load/test_load_with_iam.py index b34cbb48..102cee66 100644 --- a/test/integration/iam/load/test_load_with_iam.py +++ b/test/integration/iam/load/test_load_with_iam.py @@ -16,6 +16,7 @@ def setUp(self) -> None: assert self.config.load_from_s3_arn != '' self.client = self.client_builder.with_iam(get_session()).build() + @pytest.mark.neptune def test_iam_load(self): load_format = 'turtle' source = TEST_BULKLOAD_SOURCE % (self.config.aws_region, 'turtle') @@ -49,6 +50,7 @@ def test_iam_load(self): cancelled_status = res.json() assert 'LOAD_CANCELLED_BY_USER' in cancelled_status['payload']['feedCount'][-1] + @pytest.mark.neptune def test_iam_load_status(self): res = self.client.load_status() # This should only give a list of load ids assert res.status_code == 200 diff --git a/test/integration/iam/ml/test_neptune_client_with_iam.py b/test/integration/iam/ml/test_neptune_client_with_iam.py index 26d7d3ae..d5af29ce 100644 --- a/test/integration/iam/ml/test_neptune_client_with_iam.py +++ b/test/integration/iam/ml/test_neptune_client_with_iam.py @@ -2,19 +2,9 @@ Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. SPDX-License-Identifier: Apache-2.0 """ -import datetime -import concurrent.futures -import json -import os -import time - -import pytest -from botocore.session import get_session -from gremlin_python.structure.graph import Edge from graph_notebook.configuration.generate_config import Configuration -from graph_notebook.configuration.get_config import get_config -from graph_notebook.neptune.client import ClientBuilder, Client +from graph_notebook.neptune.client import Client client: Client config: Configuration @@ -22,4 +12,3 @@ TEST_BULKLOAD_SOURCE = 's3://aws-ml-customer-samples-%s/bulkload-datasets/%s/airroutes/v01' GREMLIN_TEST_LABEL = 'graph-notebook-test' SPARQL_TEST_PREDICATE = '' - diff --git a/test/integration/iam/system/test_system_with_iam.py b/test/integration/iam/system/test_system_with_iam.py index 688f5834..a176f404 100644 --- a/test/integration/iam/system/test_system_with_iam.py +++ b/test/integration/iam/system/test_system_with_iam.py @@ -54,7 +54,7 @@ def test_iam_fast_reset(self): time.sleep(5) # wait momentarily until we obtain the status again else: break - except: + except Exception: time.sleep(5) assert status.status_code == 200 diff --git a/test/integration/without_iam/notebook/test_sparql_graph_notebook.py b/test/integration/without_iam/notebook/test_sparql_graph_notebook.py index 134115b3..666c701c 100644 --- a/test/integration/without_iam/notebook/test_sparql_graph_notebook.py +++ b/test/integration/without_iam/notebook/test_sparql_graph_notebook.py @@ -17,7 +17,7 @@ def test_sparql_query(self): self.ip.run_cell_magic('sparql', f'--store-to {store_to_var}', query) self.assertFalse('graph_notebook_error' in self.ip.user_ns) sparql_res = self.ip.user_ns[store_to_var] - self.assertEqual(['s', 'o','p'], sparql_res['head']['vars']) + self.assertEqual(['s', 'o', 'p'], sparql_res['head']['vars']) @pytest.mark.jupyter @pytest.mark.sparql diff --git a/test/integration/without_iam/status/test_status_without_iam.py b/test/integration/without_iam/status/test_status_without_iam.py index 58b69c4b..2b20e6de 100644 --- a/test/integration/without_iam/status/test_status_without_iam.py +++ b/test/integration/without_iam/status/test_status_without_iam.py @@ -4,8 +4,6 @@ """ import pytest -from graph_notebook.status.get_status import get_status - from test.integration import IntegrationTest From 349e87d038f7c5cf4ebf7498a4c92b30b5c37d1c Mon Sep 17 00:00:00 2001 From: Kline Date: Thu, 1 Apr 2021 11:16:31 -0700 Subject: [PATCH 06/16] rebase from main, fix one unittest --- src/graph_notebook/magics/graph_magic.py | 10 ++++------ .../gremlin/graphsonV3d0_MapType_objectify_patch.py | 3 ++- .../gremlin}/hashable_dict_patch.py | 0 3 files changed, 6 insertions(+), 7 deletions(-) rename src/graph_notebook/{gremlin/client_provider => neptune/gremlin}/hashable_dict_patch.py (100%) diff --git a/src/graph_notebook/magics/graph_magic.py b/src/graph_notebook/magics/graph_magic.py index 26af3270..12b8a9be 100644 --- a/src/graph_notebook/magics/graph_magic.py +++ b/src/graph_notebook/magics/graph_magic.py @@ -115,20 +115,18 @@ def __init__(self, shell): # You must call the parent constructor super(Graph, self).__init__(shell) + self.graph_notebook_config = generate_default_config() try: self.config_location = os.getenv('GRAPH_NOTEBOOK_CONFIG', DEFAULT_CONFIG_LOCATION) - config = get_config(self.config_location) - self.graph_notebook_config = config - self.client: Client = None - self._generate_client_from_config(config) - + self.graph_notebook_config = get_config(self.config_location) except FileNotFoundError: - self.graph_notebook_config = generate_default_config() print( 'Could not find a valid configuration. Do not forgot to validate your settings using %graph_notebook_config') + self.max_results = DEFAULT_MAX_RESULTS self.graph_notebook_vis_options = OPTIONS_DEFAULT_DIRECTED + self._generate_client_from_config(self.graph_notebook_config) logger.setLevel(logging.ERROR) def _generate_client_from_config(self, config: Configuration): diff --git a/src/graph_notebook/neptune/gremlin/graphsonV3d0_MapType_objectify_patch.py b/src/graph_notebook/neptune/gremlin/graphsonV3d0_MapType_objectify_patch.py index ab0b6896..a28b85a7 100644 --- a/src/graph_notebook/neptune/gremlin/graphsonV3d0_MapType_objectify_patch.py +++ b/src/graph_notebook/neptune/gremlin/graphsonV3d0_MapType_objectify_patch.py @@ -4,7 +4,8 @@ """ from gremlin_python.structure.io.graphsonV3d0 import MapType -from graph_notebook.gremlin.client_provider.hashable_dict_patch import HashableDict +from graph_notebook.neptune.gremlin.hashable_dict_patch import HashableDict + # Original code from Tinkerpop 3.4.1 # diff --git a/src/graph_notebook/gremlin/client_provider/hashable_dict_patch.py b/src/graph_notebook/neptune/gremlin/hashable_dict_patch.py similarity index 100% rename from src/graph_notebook/gremlin/client_provider/hashable_dict_patch.py rename to src/graph_notebook/neptune/gremlin/hashable_dict_patch.py From 859bac52eb43ea9c3c2dec102b6ecc5f4826d73f Mon Sep 17 00:00:00 2001 From: Kline Date: Thu, 1 Apr 2021 12:23:20 -0700 Subject: [PATCH 07/16] upgrade GraphNotebookIntegrationTest due to import conflicts --- .../notebook => }/GraphNotebookIntegrationTest.py | 0 test/integration/__init__.py | 1 + .../without_iam/notebook/test_gremlin_graph_notebook.py | 2 +- .../without_iam/notebook/test_sparql_graph_notebook.py | 2 +- .../without_iam/notebook/test_status_graph_notebook.py | 6 +----- 5 files changed, 4 insertions(+), 7 deletions(-) rename test/integration/{without_iam/notebook => }/GraphNotebookIntegrationTest.py (100%) diff --git a/test/integration/without_iam/notebook/GraphNotebookIntegrationTest.py b/test/integration/GraphNotebookIntegrationTest.py similarity index 100% rename from test/integration/without_iam/notebook/GraphNotebookIntegrationTest.py rename to test/integration/GraphNotebookIntegrationTest.py diff --git a/test/integration/__init__.py b/test/integration/__init__.py index d0f15c29..cf393a60 100644 --- a/test/integration/__init__.py +++ b/test/integration/__init__.py @@ -6,4 +6,5 @@ from .IntegrationTest import IntegrationTest # noqa F401 from .DataDrivenGremlinTest import DataDrivenGremlinTest # noqa F401 from .DataDrivenSparqlTest import DataDrivenSparqlTest # noqa F401 +from .GraphNotebookIntegrationTest import GraphNotebookIntegrationTest # noqa F401 from .NeptuneIntegrationWorkflowSteps import TEST_CONFIG_PATH # noqa F401 diff --git a/test/integration/without_iam/notebook/test_gremlin_graph_notebook.py b/test/integration/without_iam/notebook/test_gremlin_graph_notebook.py index 2cac2548..c456c1b3 100644 --- a/test/integration/without_iam/notebook/test_gremlin_graph_notebook.py +++ b/test/integration/without_iam/notebook/test_gremlin_graph_notebook.py @@ -4,7 +4,7 @@ """ import pytest -from test.integration.without_iam.notebook import GraphNotebookIntegrationTest +from test.integration import GraphNotebookIntegrationTest class TestGraphMagicGremlin(GraphNotebookIntegrationTest): diff --git a/test/integration/without_iam/notebook/test_sparql_graph_notebook.py b/test/integration/without_iam/notebook/test_sparql_graph_notebook.py index 666c701c..69abd498 100644 --- a/test/integration/without_iam/notebook/test_sparql_graph_notebook.py +++ b/test/integration/without_iam/notebook/test_sparql_graph_notebook.py @@ -4,7 +4,7 @@ """ import pytest -from test.integration.without_iam.notebook import GraphNotebookIntegrationTest +from test.integration import GraphNotebookIntegrationTest class TestGraphMagicGremlin(GraphNotebookIntegrationTest): diff --git a/test/integration/without_iam/notebook/test_status_graph_notebook.py b/test/integration/without_iam/notebook/test_status_graph_notebook.py index 2af95908..23387820 100644 --- a/test/integration/without_iam/notebook/test_status_graph_notebook.py +++ b/test/integration/without_iam/notebook/test_status_graph_notebook.py @@ -3,15 +3,11 @@ SPDX-License-Identifier: Apache-2.0 """ import pytest -from botocore.session import get_session -from test.integration.without_iam.notebook import GraphNotebookIntegrationTest +from test.integration import GraphNotebookIntegrationTest class TestGraphMagicStatus(GraphNotebookIntegrationTest): - def setUp(self) -> None: - self.client = self.client_builder.with_iam(get_session()).build() - @pytest.mark.jupyter @pytest.mark.neptune def test_status(self): From d46adc463ff54d0bf2c621a22f9df75760b746d7 Mon Sep 17 00:00:00 2001 From: Kline Date: Thu, 1 Apr 2021 14:59:37 -0700 Subject: [PATCH 08/16] upload/download generated test configuration --- .github/workflows/integration.yml | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 4827bd4a..246757eb 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -53,6 +53,11 @@ jobs: --cfn-s3-bucket ${{ secrets.INTEG_CFN_S3_BUCKET }} \ --cfn-runner-role ${{ secrets.CFN_ROLE }} \ --aws-region ${{ secrets.AWS_REGION }} + - uses: actions/upload-artifact@master + with: + name: graph_notebook_config_integration_test.json + path: /tmp/graph_notebook_config_integration_test.json + retention-days: 1 run-tests: runs-on: ubuntu-latest needs: [generate-stack-name, create-stack] @@ -81,6 +86,10 @@ jobs: - name: Give user /etc/hosts permission run: | sudo chmod 777 /etc/hosts + - uses: actions/download-artifact@master + with: + name: graph_notebook_config_integration_test.json + path: /tmp/graph_notebook_config_integration_test.json - name: Run Basic Tests env: GRAPH_NOTEBOK_CONFIG: /tmp/graph_notebook_config_integration_test.json From c3998cfc52bb703d08a8297df0e72d2d0193abc9 Mon Sep 17 00:00:00 2001 From: Kline Date: Thu, 1 Apr 2021 15:40:15 -0700 Subject: [PATCH 09/16] add section in integ test helper to generate configuration based on a provided stack --- .github/workflows/integration.yml | 15 +++++++++++---- .../NeptuneIntegrationWorkflowSteps.py | 11 +++++++++++ 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 246757eb..ae423180 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -86,15 +86,22 @@ jobs: - name: Give user /etc/hosts permission run: | sudo chmod 777 /etc/hosts - - uses: actions/download-artifact@master - with: - name: graph_notebook_config_integration_test.json - path: /tmp/graph_notebook_config_integration_test.json + - name: Generate test configuration + run: | + python test/integration/NeptuneIntegrationWorkflowSteps.py generate-config \ + --cfn-stack-name ${{ needs.generate-stack-name.outputs.stack-name }} \ + --aws-region ${{ secrets.AWS_REGION }} - name: Run Basic Tests env: GRAPH_NOTEBOK_CONFIG: /tmp/graph_notebook_config_integration_test.json run: | pytest test/integration/without_iam + - name: Generate iam test configuration + run: | + python test/integration/NeptuneIntegrationWorkflowSteps.py generate-config \ + --cfn-stack-name ${{ needs.generate-stack-name.outputs.stack-name }} \ + --aws-region ${{ secrets.AWS_REGION }} \ + --iam - name: Run IAM Tests env: GRAPH_NOTEBOK_CONFIG: /tmp/graph_notebook_config_integration_test.json diff --git a/test/integration/NeptuneIntegrationWorkflowSteps.py b/test/integration/NeptuneIntegrationWorkflowSteps.py index 076c49d6..79d64869 100644 --- a/test/integration/NeptuneIntegrationWorkflowSteps.py +++ b/test/integration/NeptuneIntegrationWorkflowSteps.py @@ -20,6 +20,7 @@ SUBPARSER_CREATE_CFN = 'create-cfn-stack' SUBPARSER_DELETE_CFN = 'delete-cfn-stack' SUBPARSER_RUN_TESTS = 'run-tests' +SUBPARSER_GENERATE_CONFIG = 'generate-config' SUBPARSER_ENABLE_IAM = 'toggle-cluster-iam' sys.path.insert(0, os.path.abspath('..')) @@ -321,6 +322,13 @@ def main(): delete_parser.add_argument('--cfn-stack-name', type=str, default='') delete_parser.add_argument('--aws-region', type=str, default='us-east-1') + # sub parser generate config + config_parser = subparsers.add_parser(SUBPARSER_GENERATE_CONFIG, + help='generate test configuration from supplied cfn stack') + config_parser.add_argument('--cfn-stack-name', type=str, default='') + config_parser.add_argument('--aws-region', type=str, default='us-east-1') + config_parser.add_argument('--iam', action='store_true') + args = parser.parse_args() cfn_client = boto3.client('cloudformation', region_name=args.aws_region) @@ -335,6 +343,9 @@ def main(): set_iam_auth_on_neptune_cluster(cluster_identifier, True, neptune_client) logging.info('waiting for one minute while change is applied...') time.sleep(60) + elif args.which == SUBPARSER_GENERATE_CONFIG: + stack = get_cfn_stack_details(args.cfn_stack_name, cfn_client) + generate_config_from_stack(stack, args.aws_region, args.iam) if __name__ == '__main__': From 46b6fe8bfa9bb175cfca53d88667299109415a62 Mon Sep 17 00:00:00 2001 From: Kline Date: Fri, 2 Apr 2021 08:58:28 -0700 Subject: [PATCH 10/16] correct reference to auth mode when generating test config --- test/integration/NeptuneIntegrationWorkflowSteps.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/integration/NeptuneIntegrationWorkflowSteps.py b/test/integration/NeptuneIntegrationWorkflowSteps.py index 79d64869..50e14eb3 100644 --- a/test/integration/NeptuneIntegrationWorkflowSteps.py +++ b/test/integration/NeptuneIntegrationWorkflowSteps.py @@ -268,7 +268,7 @@ def generate_config_from_stack(stack: dict, region: str, iam: bool) -> Configura file.writelines(new_lines) auth = AuthModeEnum.IAM if iam else AuthModeEnum.DEFAULT - conf = Configuration(details['endpoint'], 80, auth.ENV, details['loader_arn'], ssl=True, aws_region=region) + conf = Configuration(details['endpoint'], 80, auth, details['loader_arn'], ssl=True, aws_region=region) logging.info(f'generated configuration for test run: {conf.to_dict()}') return conf From 7ddb0837848686834267391d60360daa85007f0c Mon Sep 17 00:00:00 2001 From: Kline Date: Fri, 2 Apr 2021 09:46:54 -0700 Subject: [PATCH 11/16] update changelog, persist generated config --- ChangeLog.md | 11 ++++++----- test/integration/NeptuneIntegrationWorkflowSteps.py | 4 +++- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/ChangeLog.md b/ChangeLog.md index 524a81ad..68d89427 100644 --- a/ChangeLog.md +++ b/ChangeLog.md @@ -4,14 +4,15 @@ Starting with v1.31.6, this file will contain a record of major features and upd ## Upcoming -- Add support for Mode, queueRequest, and Dependencies parameters when running %load command -- Add support for list and dict as map keys in Python Gremlin +- Add support for Mode, queueRequest, and Dependencies parameters when running %load command ([Link to PR](https://github.com/aws/graph-notebook/pull/91)) +- Add support for list and dict as map keys in Python Gremlin ([Link to PR](https://github.com/aws/graph-notebook/pull/100)) +- Refactor modules that call to Neptune or other SPARQL/Gremlin endpoints to use a unified client object ([Link to PR](https://github.com/aws/graph-notebook/pull/104)) ## Release 2.0.12 (Mar 25, 2021) - - Add default parameters for `get_load_status` - - Add ipython as a dependency in `setup.py` ([Link to PT](https://github.com/aws/graph-notebook/pull/95)) - - Add parameters in `load_status` for `details`, `errors`, `page`, and `errorsPerPage` + - Add default parameters for `get_load_status` ([Link to PR](https://github.com/aws/graph-notebook/pull/96)) + - Add ipython as a dependency in `setup.py` ([Link to PR](https://github.com/aws/graph-notebook/pull/95)) + - Add parameters in `load_status` for `details`, `errors`, `page`, and `errorsPerPage` ([Link to PR](https://github.com/aws/graph-notebook/pull/88)) ## Release 2.0.10 (Mar 18, 2021) diff --git a/test/integration/NeptuneIntegrationWorkflowSteps.py b/test/integration/NeptuneIntegrationWorkflowSteps.py index 50e14eb3..a0d90958 100644 --- a/test/integration/NeptuneIntegrationWorkflowSteps.py +++ b/test/integration/NeptuneIntegrationWorkflowSteps.py @@ -345,7 +345,9 @@ def main(): time.sleep(60) elif args.which == SUBPARSER_GENERATE_CONFIG: stack = get_cfn_stack_details(args.cfn_stack_name, cfn_client) - generate_config_from_stack(stack, args.aws_region, args.iam) + config = generate_config_from_stack(stack, args.aws_region, args.iam) + config.write_to_file(TEST_CONFIG_PATH) + if __name__ == '__main__': From de11e7168e1891aee32714e832311d09ca645f49 Mon Sep 17 00:00:00 2001 From: Kline Date: Fri, 2 Apr 2021 10:04:25 -0700 Subject: [PATCH 12/16] flake8 --- test/integration/NeptuneIntegrationWorkflowSteps.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/integration/NeptuneIntegrationWorkflowSteps.py b/test/integration/NeptuneIntegrationWorkflowSteps.py index a0d90958..bef3cb0f 100644 --- a/test/integration/NeptuneIntegrationWorkflowSteps.py +++ b/test/integration/NeptuneIntegrationWorkflowSteps.py @@ -349,6 +349,5 @@ def main(): config.write_to_file(TEST_CONFIG_PATH) - if __name__ == '__main__': main() From 1d77a23d8c4eb177e185a60893b55e0cc27c7c6f Mon Sep 17 00:00:00 2001 From: Kline Date: Fri, 2 Apr 2021 10:55:45 -0700 Subject: [PATCH 13/16] ensure cluster is iam enabled for iam tests --- .github/workflows/integration.yml | 5 ----- test/integration/NeptuneIntegrationWorkflowSteps.py | 2 ++ 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index ae423180..b0b0f40d 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -53,11 +53,6 @@ jobs: --cfn-s3-bucket ${{ secrets.INTEG_CFN_S3_BUCKET }} \ --cfn-runner-role ${{ secrets.CFN_ROLE }} \ --aws-region ${{ secrets.AWS_REGION }} - - uses: actions/upload-artifact@master - with: - name: graph_notebook_config_integration_test.json - path: /tmp/graph_notebook_config_integration_test.json - retention-days: 1 run-tests: runs-on: ubuntu-latest needs: [generate-stack-name, create-stack] diff --git a/test/integration/NeptuneIntegrationWorkflowSteps.py b/test/integration/NeptuneIntegrationWorkflowSteps.py index bef3cb0f..36f4ce78 100644 --- a/test/integration/NeptuneIntegrationWorkflowSteps.py +++ b/test/integration/NeptuneIntegrationWorkflowSteps.py @@ -345,6 +345,8 @@ def main(): time.sleep(60) elif args.which == SUBPARSER_GENERATE_CONFIG: stack = get_cfn_stack_details(args.cfn_stack_name, cfn_client) + cluster_identifier = get_neptune_identifier_from_cfn(args.cfn_stack_name, cfn_client) + set_iam_auth_on_neptune_cluster(cluster_identifier, args.iam, neptune_client) config = generate_config_from_stack(stack, args.aws_region, args.iam) config.write_to_file(TEST_CONFIG_PATH) From dd318db91ec0e822abc99f5bcd1e119f4fc73ee6 Mon Sep 17 00:00:00 2001 From: Kline Date: Fri, 2 Apr 2021 12:16:55 -0700 Subject: [PATCH 14/16] found other instances of a rename that needs to be reverted --- src/graph_notebook/neptune/client.py | 2 +- .../04-Machine-Learning/neptune_ml_utils.py | 12 ++++++------ src/graph_notebook/system/__init__.py | 0 src/graph_notebook/system/database_reset.py | 0 4 files changed, 7 insertions(+), 7 deletions(-) delete mode 100644 src/graph_notebook/system/__init__.py delete mode 100644 src/graph_notebook/system/database_reset.py diff --git a/src/graph_notebook/neptune/client.py b/src/graph_notebook/neptune/client.py index d102476c..ffa7c254 100644 --- a/src/graph_notebook/neptune/client.py +++ b/src/graph_notebook/neptune/client.py @@ -60,7 +60,7 @@ 'LOAD_FAILED_INVALID_REQUEST', ] EXPORT_SERVICE_NAME = 'execute-api' -EXPORT_ACTION = 'ml-export' +EXPORT_ACTION = 'neptune-export' EXTRA_HEADERS = {'content-type': 'application/json'} SPARQL_ACTION = 'sparql' diff --git a/src/graph_notebook/notebooks/04-Machine-Learning/neptune_ml_utils.py b/src/graph_notebook/notebooks/04-Machine-Learning/neptune_ml_utils.py index c9b30a3f..14622e9d 100644 --- a/src/graph_notebook/notebooks/04-Machine-Learning/neptune_ml_utils.py +++ b/src/graph_notebook/notebooks/04-Machine-Learning/neptune_ml_utils.py @@ -53,7 +53,7 @@ def get_training_job_name(prefix: str): def check_ml_enabled(): host, port, use_iam = load_configuration() - response = signed_request("GET", url=f'https://{host}:{port}/ml/modeltraining', service='ml-db') + response = signed_request("GET", url=f'https://{host}:{port}/ml/modeltraining', service='neptune-db') if response.status_code != 200: print('''This Neptune cluster \033[1mis not\033[0m configured to use Neptune ML. Please configure the cluster according to the Amazpnm Neptune ML documentation before proceeding.''') @@ -79,7 +79,7 @@ def delete_pretrained_data(setup_node_classification: bool, setup_node_regression: bool, setup_link_prediction: bool): host, port, use_iam = load_configuration() if setup_node_classification: - response = signed_request("POST", service='ml-db', + response = signed_request("POST", service='neptune-db', url=f'https://{host}:{port}/gremlin', headers={'content-type': 'application/json'}, data=json.dumps( @@ -88,14 +88,14 @@ def delete_pretrained_data(setup_node_classification: bool, if response.status_code != 200: print(response.content.decode('utf-8')) if setup_node_regression: - response = signed_request("POST", service='ml-db', + response = signed_request("POST", service='neptune-db', url=f'https://{host}:{port}/gremlin', headers={'content-type': 'application/json'}, data=json.dumps({'gremlin': "g.V('user_1').out('wrote').properties('score').drop()"})) if response.status_code != 200: print(response.content.decode('utf-8')) if setup_link_prediction: - response = signed_request("POST", service='ml-db', + response = signed_request("POST", service='neptune-db', url=f'https://{host}:{port}/gremlin', headers={'content-type': 'application/json'}, data=json.dumps({'gremlin': "g.V('user_1').outE('rated').drop()"})) @@ -122,7 +122,7 @@ def delete_endpoint(training_job_name: str, neptune_iam_role_arn=None): if neptune_iam_role_arn: query_string = f'?neptuneIamRoleArn={neptune_iam_role_arn}' host, port, use_iam = load_configuration() - response = signed_request("DELETE", service='ml-db', + response = signed_request("DELETE", service='neptune-db', url=f'https://{host}:{port}/ml/endpoints/{training_job_name}{query_string}', headers={'content-type': 'application/json'}) if response.status_code != 200: @@ -289,7 +289,7 @@ def __upload_to_s3(self, bucketname: str): self.formatted_directory, file), bucket, f'{file_path}/{file}') def prepare_movielens_data(self, s3_bucket: str): - bucket_name = f'{s3_bucket}/ml-formatted/movielens-100k' + bucket_name = f'{s3_bucket}/neptune-formatted/movielens-100k' self.__download_and_unzip() self.__process_movies_genres() self.__process_users() diff --git a/src/graph_notebook/system/__init__.py b/src/graph_notebook/system/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/graph_notebook/system/database_reset.py b/src/graph_notebook/system/database_reset.py deleted file mode 100644 index e69de29b..00000000 From d7e1d4d3b17ce6d9d3c656ee502b69a25f33d52c Mon Sep 17 00:00:00 2001 From: Kline Date: Fri, 2 Apr 2021 12:29:33 -0700 Subject: [PATCH 15/16] fix a typo found in ml_utils --- .../notebooks/04-Machine-Learning/neptune_ml_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/graph_notebook/notebooks/04-Machine-Learning/neptune_ml_utils.py b/src/graph_notebook/notebooks/04-Machine-Learning/neptune_ml_utils.py index 14622e9d..5541ca68 100644 --- a/src/graph_notebook/notebooks/04-Machine-Learning/neptune_ml_utils.py +++ b/src/graph_notebook/notebooks/04-Machine-Learning/neptune_ml_utils.py @@ -71,7 +71,7 @@ def get_export_service_host(): path = urlparse(parts[1].rstrip()) return path.hostname + "/v1" logging.error( - "Unable to determine the Neptune Export Service Endpoint. You will need to enter this assign this manually.") + "Unable to determine the Neptune Export Service Endpoint. You will need to enter this or assign it manually.") return None From 3ff0dcb46d8daa13720ef5fd56a0fa0810c8de86 Mon Sep 17 00:00:00 2001 From: Kline Date: Mon, 5 Apr 2021 11:59:17 -0700 Subject: [PATCH 16/16] combine gremlin profile/explain into one underlying method --- src/graph_notebook/neptune/client.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/graph_notebook/neptune/client.py b/src/graph_notebook/neptune/client.py index ffa7c254..5bad2701 100644 --- a/src/graph_notebook/neptune/client.py +++ b/src/graph_notebook/neptune/client.py @@ -185,14 +185,13 @@ def gremlin_cancel(self, query_id: str): return self._query_status('gremlin', query_id=query_id, cancelQuery=True) def gremlin_explain(self, query: str) -> requests.Response: - url = f'{self._http_protocol}://{self.host}:{self.port}/gremlin/explain' - data = {'gremlin': query} - req = self._prepare_request('POST', url, data=json.dumps(data)) - res = self._http_session.send(req) - return res + return self._gremlin_query_plan(query, 'explain') def gremlin_profile(self, query: str) -> requests.Response: - url = f'{self._http_protocol}://{self.host}:{self.port}/gremlin/profile' + return self._gremlin_query_plan(query, 'profile') + + def _gremlin_query_plan(self, query: str, plan_type: str, ) -> requests.Response: + url = f'{self._http_protocol}://{self.host}:{self.port}/gremlin/{plan_type}' data = {'gremlin': query} req = self._prepare_request('POST', url, data=json.dumps(data)) res = self._http_session.send(req) @@ -206,7 +205,7 @@ def status(self) -> requests.Response: def load(self, source: str, source_format: str, iam_role_arn: str, region: str, **kwargs) -> requests.Response: """ - For a full list of allow parameters, see aws documentation on the Neptune loader + For a full list of allowed parameters, see aws documentation on the Neptune loader endpoint: https://docs.aws.amazon.com/neptune/latest/userguide/load-api-reference-load.html """ payload = {