diff --git a/examples/gds-example.ipynb b/examples/gds-example.ipynb
index db71c07..dd74664 100644
--- a/examples/gds-example.ipynb
+++ b/examples/gds-example.ipynb
@@ -17,11 +17,30 @@
"%pip install matplotlib"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from dotenv import load_dotenv\n",
+ "\n",
+ "load_dotenv()"
+ ]
+ },
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "## Setup GDS graph"
+ "## Setup GDS graph\n",
+ "\n",
+ "To use GDS, you can either use GDS as a plugin or Aura Graph Analytics.\n",
+ "In the following, you can choose:\n",
+ "\n",
+ " * Provide Aura API credentials and and use Aura Graph Analytics.\n",
+ " * Use Neo4j + GDS Plugin.\n",
+ "\n",
+ "For more information, see the [GDS documentation](https://neo4j.com/docs/graph-data-science/current/installation/)."
]
},
{
@@ -32,18 +51,33 @@
"source": [
"import os\n",
"\n",
+ "from graphdatascience.session import GdsSessions, DbmsConnectionInfo, AuraAPICredentials, SessionMemory\n",
"from graphdatascience import GraphDataScience\n",
"\n",
"# Get Neo4j DB URI, credentials and name from environment if applicable\n",
- "NEO4J_URI = os.environ.get(\"NEO4J_URI\", \"bolt://localhost:7687\")\n",
- "NEO4J_AUTH = (\"neo4j\", None)\n",
- "NEO4J_DB = os.environ.get(\"NEO4J_DB\", \"neo4j\")\n",
- "if os.environ.get(\"NEO4J_USER\") and os.environ.get(\"NEO4J_PASSWORD\"):\n",
- " NEO4J_AUTH = (\n",
- " os.environ.get(\"NEO4J_USER\"),\n",
- " os.environ.get(\"NEO4J_PASSWORD\"),\n",
- " )\n",
- "gds = GraphDataScience(NEO4J_URI, auth=NEO4J_AUTH, database=NEO4J_DB)"
+ "db_connection = DbmsConnectionInfo(\n",
+ " aura_instance_id=os.environ.get(\"AURA_INSTANCEID\"),\n",
+ " username=os.environ[\"NEO4J_USERNAME\"],\n",
+ " password=os.environ[\"NEO4J_PASSWORD\"],\n",
+ " uri=os.environ[\"NEO4J_URI\"],\n",
+ ")\n",
+ "\n",
+ "session_name = \"neo4j-viz-gds-example\"\n",
+ "if os.environ.get(\"AURA_API_CLIENT_ID\"):\n",
+ " # Use Aura Graph Analytics\n",
+ " sessions = GdsSessions(api_credentials=AuraAPICredentials(\n",
+ " client_id=os.environ[\"AURA_API_CLIENT_ID\"],\n",
+ " client_secret=os.environ[\"AURA_API_CLIENT_SECRET\"],\n",
+ " project_id=os.environ.get(\"AURA_API_PROJECT_ID\"),\n",
+ " ))\n",
+ " gds = sessions.get_or_create(session_name=session_name, memory=SessionMemory.m_2GB, db_connection=db_connection)\n",
+ "else:\n",
+ " # Use GDS Plugin\n",
+ " sessions = None\n",
+ " gds = GraphDataScience(\n",
+ " endpoint=db_connection.get_uri(),\n",
+ " auth=(db_connection.username, db_connection.password),\n",
+ " )"
]
},
{
@@ -55,6 +89,11 @@
"G = gds.graph.load_cora(graph_name=\"cora\")"
]
},
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": []
+ },
{
"cell_type": "code",
"execution_count": null,
@@ -93,139 +132,13 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": null,
"metadata": {
"tags": [
"preserve-output"
]
},
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- " \n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "\n",
- " \n",
- " "
- ],
- "text/plain": [
- ""
- ]
- },
- "execution_count": 6,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"VG.render()"
]
@@ -252,139 +165,13 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": null,
"metadata": {
"tags": [
"preserve-output"
]
},
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- " \n",
- " \n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "\n",
- " \n",
- " "
- ],
- "text/plain": [
- ""
- ]
- },
- "execution_count": 8,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"VG.render()"
]
@@ -429,139 +216,13 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": null,
"metadata": {
"tags": [
"preserve-output"
]
},
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- " \n",
- " \n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "\n",
- " \n",
- " "
- ],
- "text/plain": [
- ""
- ]
- },
- "execution_count": 10,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"VG.color_nodes(property=\"subject\")\n",
"VG.render()"
@@ -578,139 +239,13 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": null,
"metadata": {
"tags": [
"preserve-output"
]
},
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- " \n",
- " \n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "\n",
- " \n",
- " "
- ],
- "text/plain": [
- ""
- ]
- },
- "execution_count": 11,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"from neo4j_viz.colors import ColorSpace\n",
"\n",
@@ -757,139 +292,13 @@
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": null,
"metadata": {
"tags": [
"preserve-output"
]
},
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- " \n",
- " \n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "\n",
- " \n",
- " "
- ],
- "text/plain": [
- ""
- ]
- },
- "execution_count": 14,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"VG.render()"
]
@@ -911,139 +320,13 @@
},
{
"cell_type": "code",
- "execution_count": 15,
+ "execution_count": null,
"metadata": {
"tags": [
"preserve-output"
]
},
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- " \n",
- " \n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "\n",
- " \n",
- " "
- ],
- "text/plain": [
- ""
- ]
- },
- "execution_count": 15,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"from neo4j_viz import Layout\n",
"\n",
@@ -1093,6 +376,20 @@
"source": [
"gds.graph.drop(\"cora\")"
]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "tags": [
+ "teardown"
+ ]
+ },
+ "outputs": [],
+ "source": [
+ "if sessions:\n",
+ " sessions.delete(session_name=session_name)"
+ ]
}
],
"metadata": {
diff --git a/python-wrapper/pyproject.toml b/python-wrapper/pyproject.toml
index 911286a..701b7aa 100644
--- a/python-wrapper/pyproject.toml
+++ b/python-wrapper/pyproject.toml
@@ -33,7 +33,8 @@ dependencies = [
"ipython >=7, <10",
"pydantic >=2 , <3",
"pydantic-extra-types >=2, <3",
- "enum-tools==0.13.0"
+ "enum-tools==0.13.0",
+ "python-dotenv>=1.2.1",
]
requires-python = ">=3.10"
@@ -72,7 +73,7 @@ notebook = [
"palettable>=3.3.3",
"matplotlib>=3.9.4",
"snowflake-snowpark-python==1.42.0",
- "dotenv"
+ "python-dotenv"
]
[project.urls]
diff --git a/python-wrapper/tests/conftest.py b/python-wrapper/tests/conftest.py
index 40f7f4e..d96d85f 100644
--- a/python-wrapper/tests/conftest.py
+++ b/python-wrapper/tests/conftest.py
@@ -1,4 +1,5 @@
import os
+import random
from typing import Any, Generator
import pytest
@@ -31,45 +32,56 @@ def pytest_collection_modifyitems(config: Any, items: Any) -> None:
@pytest.fixture(scope="package")
-def aura_ds_instance() -> Generator[Any, None, None]:
+def aura_db_instance() -> Generator[Any, None, None]:
if os.environ.get("AURA_API_CLIENT_ID", None) is None:
yield None
return
- from tests.gds_helper import aura_api, create_aurads_instance
+ from tests.gds_helper import aura_api, create_auradb_instance
api = aura_api()
- id, dbms_connection_info = create_aurads_instance(api)
+ dbms_connection_info = create_auradb_instance(api)
+ old_uri = os.environ.get("NEO4J_URI", "")
# setting as environment variables to run notebooks with this connection
os.environ["NEO4J_URI"] = dbms_connection_info.get_uri()
assert isinstance(dbms_connection_info.username, str)
os.environ["NEO4J_USER"] = dbms_connection_info.username
assert isinstance(dbms_connection_info.password, str)
os.environ["NEO4J_PASSWORD"] = dbms_connection_info.password
+ old_instance = os.environ.get("AURA_INSTANCEID", "")
+ if dbms_connection_info.aura_instance_id:
+ os.environ["AURA_INSTANCEID"] = dbms_connection_info.aura_instance_id
+
yield dbms_connection_info
# Clear Neo4j_URI after test (rerun should create a new instance)
- os.environ["NEO4J_URI"] = ""
- api.delete_instance(id)
+ os.environ["NEO4J_URI"] = old_uri
+ os.environ["AURA_INSTANCEID"] = old_instance
+ assert dbms_connection_info.aura_instance_id is not None
+ api.delete_instance(dbms_connection_info.aura_instance_id)
@pytest.fixture(scope="package")
-def gds(aura_ds_instance: Any) -> Generator[Any, None, None]:
- from graphdatascience import GraphDataScience
+def gds(aura_db_instance: Any) -> Generator[Any, None, None]:
+ from graphdatascience.session import SessionMemory
- from tests.gds_helper import connect_to_plugin_gds
+ from tests.gds_helper import connect_to_plugin_gds, gds_sessions
- if aura_ds_instance:
- yield GraphDataScience(
- endpoint=aura_ds_instance.uri,
- auth=(aura_ds_instance.username, aura_ds_instance.password),
- aura_ds=True,
- database="neo4j",
+ if aura_db_instance:
+ sessions = gds_sessions()
+
+ gds = sessions.get_or_create(
+ f"neo4j-viz-ci-{os.environ.get('GITHUB_RUN_ID', random.randint(0, 10**6))}",
+ memory=SessionMemory.m_2GB,
+ db_connection=aura_db_instance,
)
+
+ yield gds
+ gds.delete()
else:
- NEO4J_URI = os.environ.get("NEO4J_URI", "neo4j://localhost:7687")
- gds = connect_to_plugin_gds(NEO4J_URI)
+ NEO4J_URI = os.environ["NEO4J_URI"]
+ gds = connect_to_plugin_gds(NEO4J_URI) # type: ignore
yield gds
gds.close()
diff --git a/python-wrapper/tests/gds_helper.py b/python-wrapper/tests/gds_helper.py
index e5a0d3d..3d6a802 100644
--- a/python-wrapper/tests/gds_helper.py
+++ b/python-wrapper/tests/gds_helper.py
@@ -1,9 +1,9 @@
import os
import re
-from graphdatascience import GraphDataScience
+from graphdatascience import GdsSessions, GraphDataScience
from graphdatascience.semantic_version.semantic_version import SemanticVersion
-from graphdatascience.session import DbmsConnectionInfo, SessionMemory
+from graphdatascience.session import AuraAPICredentials, DbmsConnectionInfo, SessionMemory
from graphdatascience.session.aura_api import AuraApi
from graphdatascience.session.aura_api_responses import InstanceCreateDetails
from graphdatascience.version import __version__
@@ -49,21 +49,29 @@ def aura_api() -> AuraApi:
)
-def create_aurads_instance(api: AuraApi) -> tuple[str, DbmsConnectionInfo]:
- # Switch to Sessions once they can be created without a DB
+def gds_sessions() -> GdsSessions:
+ return GdsSessions(
+ api_credentials=AuraAPICredentials(
+ client_id=os.environ["AURA_API_CLIENT_ID"],
+ client_secret=os.environ["AURA_API_CLIENT_SECRET"],
+ project_id=os.environ.get("AURA_API_TENANT_ID"),
+ )
+ )
+
+
+def create_auradb_instance(api: AuraApi) -> DbmsConnectionInfo:
instance_details: InstanceCreateDetails = api.create_instance(
- name="ci-neo4j-viz-session",
- memory=SessionMemory.m_8GB.value,
+ name="ci-neo4j-viz-db",
+ memory=SessionMemory.m_2GB.value,
cloud_provider="gcp",
region="europe-west1",
+ type="enterprise-db",
)
wait_result = api.wait_for_instance_running(instance_id=instance_details.id)
if wait_result.error:
raise Exception(f"Error while waiting for instance to be running: {wait_result.error}")
- return instance_details.id, DbmsConnectionInfo(
- uri=wait_result.connection_url,
- username="neo4j",
- password=instance_details.password,
+ return DbmsConnectionInfo(
+ username="neo4j", password=instance_details.password, aura_instance_id=instance_details.id
)