diff --git a/examples/gds-example.ipynb b/examples/gds-example.ipynb index db71c07..dd74664 100644 --- a/examples/gds-example.ipynb +++ b/examples/gds-example.ipynb @@ -17,11 +17,30 @@ "%pip install matplotlib" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from dotenv import load_dotenv\n", + "\n", + "load_dotenv()" + ] + }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Setup GDS graph" + "## Setup GDS graph\n", + "\n", + "To use GDS, you can either use GDS as a plugin or Aura Graph Analytics.\n", + "In the following, you can choose:\n", + "\n", + " * Provide Aura API credentials and and use Aura Graph Analytics.\n", + " * Use Neo4j + GDS Plugin.\n", + "\n", + "For more information, see the [GDS documentation](https://neo4j.com/docs/graph-data-science/current/installation/)." ] }, { @@ -32,18 +51,33 @@ "source": [ "import os\n", "\n", + "from graphdatascience.session import GdsSessions, DbmsConnectionInfo, AuraAPICredentials, SessionMemory\n", "from graphdatascience import GraphDataScience\n", "\n", "# Get Neo4j DB URI, credentials and name from environment if applicable\n", - "NEO4J_URI = os.environ.get(\"NEO4J_URI\", \"bolt://localhost:7687\")\n", - "NEO4J_AUTH = (\"neo4j\", None)\n", - "NEO4J_DB = os.environ.get(\"NEO4J_DB\", \"neo4j\")\n", - "if os.environ.get(\"NEO4J_USER\") and os.environ.get(\"NEO4J_PASSWORD\"):\n", - " NEO4J_AUTH = (\n", - " os.environ.get(\"NEO4J_USER\"),\n", - " os.environ.get(\"NEO4J_PASSWORD\"),\n", - " )\n", - "gds = GraphDataScience(NEO4J_URI, auth=NEO4J_AUTH, database=NEO4J_DB)" + "db_connection = DbmsConnectionInfo(\n", + " aura_instance_id=os.environ.get(\"AURA_INSTANCEID\"),\n", + " username=os.environ[\"NEO4J_USERNAME\"],\n", + " password=os.environ[\"NEO4J_PASSWORD\"],\n", + " uri=os.environ[\"NEO4J_URI\"],\n", + ")\n", + "\n", + "session_name = \"neo4j-viz-gds-example\"\n", + "if os.environ.get(\"AURA_API_CLIENT_ID\"):\n", + " # Use Aura Graph Analytics\n", + " sessions = GdsSessions(api_credentials=AuraAPICredentials(\n", + " client_id=os.environ[\"AURA_API_CLIENT_ID\"],\n", + " client_secret=os.environ[\"AURA_API_CLIENT_SECRET\"],\n", + " project_id=os.environ.get(\"AURA_API_PROJECT_ID\"),\n", + " ))\n", + " gds = sessions.get_or_create(session_name=session_name, memory=SessionMemory.m_2GB, db_connection=db_connection)\n", + "else:\n", + " # Use GDS Plugin\n", + " sessions = None\n", + " gds = GraphDataScience(\n", + " endpoint=db_connection.get_uri(),\n", + " auth=(db_connection.username, db_connection.password),\n", + " )" ] }, { @@ -55,6 +89,11 @@ "G = gds.graph.load_cora(graph_name=\"cora\")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, { "cell_type": "code", "execution_count": null, @@ -93,139 +132,13 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": { "tags": [ "preserve-output" ] }, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - "
\n", - "
\n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - "
\n", - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "VG.render()" ] @@ -252,139 +165,13 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": { "tags": [ "preserve-output" ] }, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - "
\n", - "
\n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - "
\n", - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "VG.render()" ] @@ -429,139 +216,13 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": { "tags": [ "preserve-output" ] }, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - "
\n", - "
\n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - "
\n", - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "VG.color_nodes(property=\"subject\")\n", "VG.render()" @@ -578,139 +239,13 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": { "tags": [ "preserve-output" ] }, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - "
\n", - "
\n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - "
\n", - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "from neo4j_viz.colors import ColorSpace\n", "\n", @@ -757,139 +292,13 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "metadata": { "tags": [ "preserve-output" ] }, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - "
\n", - "
\n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - "
\n", - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "VG.render()" ] @@ -911,139 +320,13 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "metadata": { "tags": [ "preserve-output" ] }, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - "
\n", - "
\n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - "
\n", - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "from neo4j_viz import Layout\n", "\n", @@ -1093,6 +376,20 @@ "source": [ "gds.graph.drop(\"cora\")" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "teardown" + ] + }, + "outputs": [], + "source": [ + "if sessions:\n", + " sessions.delete(session_name=session_name)" + ] } ], "metadata": { diff --git a/python-wrapper/pyproject.toml b/python-wrapper/pyproject.toml index 911286a..701b7aa 100644 --- a/python-wrapper/pyproject.toml +++ b/python-wrapper/pyproject.toml @@ -33,7 +33,8 @@ dependencies = [ "ipython >=7, <10", "pydantic >=2 , <3", "pydantic-extra-types >=2, <3", - "enum-tools==0.13.0" + "enum-tools==0.13.0", + "python-dotenv>=1.2.1", ] requires-python = ">=3.10" @@ -72,7 +73,7 @@ notebook = [ "palettable>=3.3.3", "matplotlib>=3.9.4", "snowflake-snowpark-python==1.42.0", - "dotenv" + "python-dotenv" ] [project.urls] diff --git a/python-wrapper/tests/conftest.py b/python-wrapper/tests/conftest.py index 40f7f4e..d96d85f 100644 --- a/python-wrapper/tests/conftest.py +++ b/python-wrapper/tests/conftest.py @@ -1,4 +1,5 @@ import os +import random from typing import Any, Generator import pytest @@ -31,45 +32,56 @@ def pytest_collection_modifyitems(config: Any, items: Any) -> None: @pytest.fixture(scope="package") -def aura_ds_instance() -> Generator[Any, None, None]: +def aura_db_instance() -> Generator[Any, None, None]: if os.environ.get("AURA_API_CLIENT_ID", None) is None: yield None return - from tests.gds_helper import aura_api, create_aurads_instance + from tests.gds_helper import aura_api, create_auradb_instance api = aura_api() - id, dbms_connection_info = create_aurads_instance(api) + dbms_connection_info = create_auradb_instance(api) + old_uri = os.environ.get("NEO4J_URI", "") # setting as environment variables to run notebooks with this connection os.environ["NEO4J_URI"] = dbms_connection_info.get_uri() assert isinstance(dbms_connection_info.username, str) os.environ["NEO4J_USER"] = dbms_connection_info.username assert isinstance(dbms_connection_info.password, str) os.environ["NEO4J_PASSWORD"] = dbms_connection_info.password + old_instance = os.environ.get("AURA_INSTANCEID", "") + if dbms_connection_info.aura_instance_id: + os.environ["AURA_INSTANCEID"] = dbms_connection_info.aura_instance_id + yield dbms_connection_info # Clear Neo4j_URI after test (rerun should create a new instance) - os.environ["NEO4J_URI"] = "" - api.delete_instance(id) + os.environ["NEO4J_URI"] = old_uri + os.environ["AURA_INSTANCEID"] = old_instance + assert dbms_connection_info.aura_instance_id is not None + api.delete_instance(dbms_connection_info.aura_instance_id) @pytest.fixture(scope="package") -def gds(aura_ds_instance: Any) -> Generator[Any, None, None]: - from graphdatascience import GraphDataScience +def gds(aura_db_instance: Any) -> Generator[Any, None, None]: + from graphdatascience.session import SessionMemory - from tests.gds_helper import connect_to_plugin_gds + from tests.gds_helper import connect_to_plugin_gds, gds_sessions - if aura_ds_instance: - yield GraphDataScience( - endpoint=aura_ds_instance.uri, - auth=(aura_ds_instance.username, aura_ds_instance.password), - aura_ds=True, - database="neo4j", + if aura_db_instance: + sessions = gds_sessions() + + gds = sessions.get_or_create( + f"neo4j-viz-ci-{os.environ.get('GITHUB_RUN_ID', random.randint(0, 10**6))}", + memory=SessionMemory.m_2GB, + db_connection=aura_db_instance, ) + + yield gds + gds.delete() else: - NEO4J_URI = os.environ.get("NEO4J_URI", "neo4j://localhost:7687") - gds = connect_to_plugin_gds(NEO4J_URI) + NEO4J_URI = os.environ["NEO4J_URI"] + gds = connect_to_plugin_gds(NEO4J_URI) # type: ignore yield gds gds.close() diff --git a/python-wrapper/tests/gds_helper.py b/python-wrapper/tests/gds_helper.py index e5a0d3d..3d6a802 100644 --- a/python-wrapper/tests/gds_helper.py +++ b/python-wrapper/tests/gds_helper.py @@ -1,9 +1,9 @@ import os import re -from graphdatascience import GraphDataScience +from graphdatascience import GdsSessions, GraphDataScience from graphdatascience.semantic_version.semantic_version import SemanticVersion -from graphdatascience.session import DbmsConnectionInfo, SessionMemory +from graphdatascience.session import AuraAPICredentials, DbmsConnectionInfo, SessionMemory from graphdatascience.session.aura_api import AuraApi from graphdatascience.session.aura_api_responses import InstanceCreateDetails from graphdatascience.version import __version__ @@ -49,21 +49,29 @@ def aura_api() -> AuraApi: ) -def create_aurads_instance(api: AuraApi) -> tuple[str, DbmsConnectionInfo]: - # Switch to Sessions once they can be created without a DB +def gds_sessions() -> GdsSessions: + return GdsSessions( + api_credentials=AuraAPICredentials( + client_id=os.environ["AURA_API_CLIENT_ID"], + client_secret=os.environ["AURA_API_CLIENT_SECRET"], + project_id=os.environ.get("AURA_API_TENANT_ID"), + ) + ) + + +def create_auradb_instance(api: AuraApi) -> DbmsConnectionInfo: instance_details: InstanceCreateDetails = api.create_instance( - name="ci-neo4j-viz-session", - memory=SessionMemory.m_8GB.value, + name="ci-neo4j-viz-db", + memory=SessionMemory.m_2GB.value, cloud_provider="gcp", region="europe-west1", + type="enterprise-db", ) wait_result = api.wait_for_instance_running(instance_id=instance_details.id) if wait_result.error: raise Exception(f"Error while waiting for instance to be running: {wait_result.error}") - return instance_details.id, DbmsConnectionInfo( - uri=wait_result.connection_url, - username="neo4j", - password=instance_details.password, + return DbmsConnectionInfo( + username="neo4j", password=instance_details.password, aura_instance_id=instance_details.id )