diff --git a/pyproject.toml b/pyproject.toml index 4a20928..05859b0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,8 @@ requires-python = ">=3.10" dependencies = [ "httpx", "pydantic>1", + "neo4j", + "pinecone-client", ] dynamic = ["version"] readme = "README.md" diff --git a/src/whyhow/apis/graph.py b/src/whyhow/apis/graph.py index 273eeaa..29a5fa1 100644 --- a/src/whyhow/apis/graph.py +++ b/src/whyhow/apis/graph.py @@ -55,10 +55,8 @@ def add_documents(self, namespace: str, documents: list[str]) -> str: ) if len(document_paths) > 3: - raise ValueError( - """Too many documents - please limit uploads to 3 files during the beta.""" - ) + raise ValueError("""Too many documents + please limit uploads to 3 files during the beta.""") files = [ ( diff --git a/src/whyhow/client.py b/src/whyhow/client.py index 3df4897..acc2baf 100644 --- a/src/whyhow/client.py +++ b/src/whyhow/client.py @@ -6,6 +6,9 @@ from httpx import AsyncClient, Auth, Client, Request, Response from whyhow.apis.graph import GraphAPI +from whyhow.validations import VerifyConnectivity + +BASE_URL = "https://43nq5c1b4c.execute-api.us-east-2.amazonaws.com" class APIKeyAuth(Auth): @@ -70,8 +73,7 @@ def __init__( neo4j_url: str | None = None, neo4j_user: str | None = None, neo4j_password: str | None = None, - base_url: - str = "https://43nq5c1b4c.execute-api.us-east-2.amazonaws.com", + base_url: str = BASE_URL, httpx_kwargs: dict[str, Any] | None = None, ) -> None: """Initialize the client.""" @@ -114,6 +116,10 @@ def __init__( if neo4j_url is None: raise ValueError("NEO4J_URL must be set.") + VerifyConnectivity( + neo4j_url, neo4j_user, neo4j_password, pinecone_api_key + ) + auth = APIKeyAuth( api_key, openai_api_key, @@ -166,8 +172,7 @@ def __init__( neo4j_user: str | None = None, neo4j_password: str | None = None, neo4j_url: str | None = None, - base_url: - str = "https://43nq5c1b4c.execute-api.us-east-2.amazonaws.com", + base_url: str = BASE_URL, httpx_kwargs: dict[str, Any] | None = None, ) -> None: """Initialize the client.""" diff --git a/src/whyhow/validations.py b/src/whyhow/validations.py new file mode 100644 index 0000000..d04abf9 --- /dev/null +++ b/src/whyhow/validations.py @@ -0,0 +1,35 @@ +"""Custom validators for whyhow sdk.""" + +from dataclasses import dataclass + +import pinecone +from neo4j import GraphDatabase + + +@dataclass +class VerifyConnectivity: + """This class will verify the connectivity with databases.""" + + neo4j_url: str + neo4j_user: str + neo4j_password: str + pinecone_api_key: str + + def __post_init__(self): + """ + Verify neo4j and pinecone connectivity. + + :return: None + """ + self._verify_neo4j_connectivity() + self._verify_pinecone_connectivity() + + def _verify_neo4j_connectivity(self): + auth = (self.neo4j_user, self.neo4j_password) + + with GraphDatabase.driver(self.neo4j_url, auth=auth) as driver: + driver.verify_connectivity() + + def _verify_pinecone_connectivity(self): + pc_client = pinecone.Pinecone(api_key=self.pinecone_api_key) + pc_client.list_indexes() diff --git a/tests/apis/test_graph.py b/tests/apis/test_graph.py index 0af2594..0619611 100644 --- a/tests/apis/test_graph.py +++ b/tests/apis/test_graph.py @@ -1,6 +1,7 @@ """Tests focused on the graph API.""" import os +from unittest.mock import Mock import pytest @@ -11,6 +12,7 @@ QueryGraphResponse, QueryGraphReturn, ) +from whyhow.validations import VerifyConnectivity # Set fake environment variables os.environ["WHYHOW_API_KEY"] = "fake_api_key" @@ -35,8 +37,14 @@ class TestGraphAPIQuery: """Tests for the query_graph method.""" - def test_query_graph(self, httpx_mock): + def test_query_graph(self, httpx_mock, monkeypatch): """Test querying the graph.""" + connectivity_client = Mock(spec=VerifyConnectivity) + connectivity_client.return_value = None + monkeypatch.setattr( + "whyhow.client.VerifyConnectivity", connectivity_client + ) + client = WhyHow() query = "What friends does Alice have?" @@ -69,8 +77,14 @@ def test_query_graph(self, httpx_mock): class TestGraphAPIAddDocuments: """Tests for the add_documents method.""" - def test_errors(self, httpx_mock, tmp_path): + def test_errors(self, monkeypatch, tmp_path): """Test error handling.""" + connectivity_client = Mock(spec=VerifyConnectivity) + connectivity_client.return_value = None + monkeypatch.setattr( + "whyhow.client.VerifyConnectivity", connectivity_client + ) + client = WhyHow() with pytest.raises(ValueError, match="No documents provided"): diff --git a/tests/test_client.py b/tests/test_client.py index 41568b3..318a37a 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -4,8 +4,10 @@ import pytest from httpx import Client +from neo4j.exceptions import ConfigurationError from whyhow.client import WhyHow +from whyhow.validations import VerifyConnectivity class TestWhyHow: @@ -23,6 +25,13 @@ def test_httpx_kwargs(self, monkeypatch): fake_httpx_client_class = Mock(return_value=fake_httpx_client_inst) monkeypatch.setattr("whyhow.client.Client", fake_httpx_client_class) + + connectivity_client = Mock(spec=VerifyConnectivity) + connectivity_client.return_value = None + monkeypatch.setattr( + "whyhow.client.VerifyConnectivity", connectivity_client + ) + httpx_kwargs = {"verify": False} client = WhyHow( api_key="key", @@ -41,8 +50,35 @@ def test_httpx_kwargs(self, monkeypatch): assert client.httpx_client is fake_httpx_client_class.return_value - def test_base_url_twice(self): + def test_credentials_verification(self, monkeypatch): + """Test connectivity with databases.""" + connectivity_client = Mock(spec=VerifyConnectivity) + connectivity_client.side_effect = ConfigurationError( + "Invalid credentials" + ) + monkeypatch.setattr( + "whyhow.client.VerifyConnectivity", connectivity_client + ) + with pytest.raises(ConfigurationError) as exc_info: + WhyHow( + api_key="mock_api_key", + neo4j_user="mock_neo4j_user", + neo4j_url="mock_neo4j_url", + neo4j_password="mock_neo4j_password", + pinecone_api_key="mock_pinecone_api_key", + ) + assert "Invalid credentials" in str( + exc_info.value + ), "The exception message is not as expected" + + def test_base_url_twice(self, monkeypatch): """Test setting base_url in httpx_kwargs.""" + connectivity_client = Mock(spec=VerifyConnectivity) + connectivity_client.return_value = None + monkeypatch.setattr( + "whyhow.client.VerifyConnectivity", connectivity_client + ) + with pytest.raises( ValueError, match="base_url cannot be set in httpx_kwargs." ):