diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..d7fc12c --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1 @@ +- Name files for their main export, transforming Python names from PascalCase to underscore_case. For example, the file that defines the EAVFilter class should be named eav_filter.py. \ No newline at end of file diff --git a/examples/eav_example.py b/examples/eav_example.py new file mode 100644 index 0000000..1b350c7 --- /dev/null +++ b/examples/eav_example.py @@ -0,0 +1,29 @@ +import os + +from cvec import CVec, EAVFilter + + +def main() -> None: + cvec = CVec( + host=os.environ.get( + "CVEC_HOST", "https://your-subdomain.cvector.dev" + ), # Replace with your API host + api_key=os.environ.get("CVEC_API_KEY", "your-api-key"), + ) + + # Example: Query with numeric range filter + print("\nQuerying with numeric range filter...") + rows = cvec.select_from_eav_id( + table_id="00000000-0000-0000-0000-000000000000", + column_ids=["abcd", "defg", "hijk"], + filters=[ + EAVFilter(column_id="abcd", numeric_min=45992, numeric_max=45993), + ], + ) + print(f"Found {len(rows)} rows with abcd in range [45992, 45993)") + for row in rows: + print(f"- {row}") + + +if __name__ == "__main__": + main() diff --git a/examples/show_eav_schema.py b/examples/show_eav_schema.py new file mode 100755 index 0000000..4277c07 --- /dev/null +++ b/examples/show_eav_schema.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +"""Show EAV tables and columns.""" + +import os + +from cvec import CVec + + +def main() -> None: + cvec = CVec( + host=os.environ.get("CVEC_HOST", ""), + api_key=os.environ.get("CVEC_API_KEY", ""), + ) + + tables = cvec.get_eav_tables() + print(f"Found {len(tables)} EAV tables\n") + + for table in tables: + print(f"{table.name} (id: {table.id})") + columns = cvec.get_eav_columns(table.id) + for column in columns: + print(f" - {column.name} ({column.type}, id: {column.eav_column_id})") + print() + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 8aade12..f6129e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "cvec" -version = "1.2.0" +version = "1.3.0" description = "SDK for CVector Energy" authors = [{ name = "CVector", email = "support@cvector.energy" }] readme = "README.md" diff --git a/src/cvec/__init__.py b/src/cvec/__init__.py index e25fa13..6631527 100644 --- a/src/cvec/__init__.py +++ b/src/cvec/__init__.py @@ -1,3 +1,4 @@ from .cvec import CVec +from .models import EAVColumn, EAVFilter, EAVTable -__all__ = ["CVec"] +__all__ = ["CVec", "EAVColumn", "EAVFilter", "EAVTable"] diff --git a/src/cvec/cvec.py b/src/cvec/cvec.py index eacb6a7..b61306e 100644 --- a/src/cvec/cvec.py +++ b/src/cvec/cvec.py @@ -7,6 +7,9 @@ from urllib.parse import urlencode, urljoin from urllib.request import Request, urlopen +from cvec.models.eav_column import EAVColumn +from cvec.models.eav_filter import EAVFilter +from cvec.models.eav_table import EAVTable from cvec.models.metric import Metric, MetricDataPoint from cvec.models.span import Span from cvec.utils.arrow_converter import ( @@ -30,6 +33,7 @@ class CVec: _refresh_token: Optional[str] _publishable_key: Optional[str] _api_key: Optional[str] + _tenant_id: int def __init__( self, @@ -52,13 +56,17 @@ def __init__( raise ValueError( "CVEC_HOST must be set either as an argument or environment variable" ) + + # Add https:// scheme if not provided + if not self.host.startswith("http://") and not self.host.startswith("https://"): + self.host = f"https://{self.host}" if not self._api_key: raise ValueError( "CVEC_API_KEY must be set either as an argument or environment variable" ) - # Fetch publishable key from host config - self._publishable_key = self._fetch_publishable_key() + # Fetch config (publishable key and tenant ID) + self._publishable_key = self._fetch_config() # Handle authentication email = self._construct_email_from_api_key() @@ -478,15 +486,17 @@ def _refresh_supabase_token(self) -> None: self._access_token = data["access_token"] self._refresh_token = data["refresh_token"] - def _fetch_publishable_key(self) -> str: + def _fetch_config(self) -> str: """ - Fetch the publishable key from the host's config endpoint. + Fetch configuration from the host's config endpoint. + + Sets the tenant_id on the instance and returns the publishable key. Returns: The publishable key from the config response Raises: - ValueError: If the config endpoint is not accessible or doesn't contain the key + ValueError: If the config endpoint is not accessible or doesn't contain required fields """ try: config_url = f"{self.host}/config" @@ -497,13 +507,363 @@ def _fetch_publishable_key(self) -> str: config_data = json.loads(response_data.decode("utf-8")) publishable_key = config_data.get("supabasePublishableKey") + tenant_id = config_data.get("tenantId") if not publishable_key: raise ValueError(f"Configuration fetched from {config_url} is invalid") + if tenant_id is None: + raise ValueError(f"tenantId not found in config from {config_url}") + self._tenant_id = int(tenant_id) return str(publishable_key) except (HTTPError, URLError) as e: raise ValueError(f"Failed to fetch config from {self.host}/config: {e}") except (KeyError, ValueError) as e: raise ValueError(f"Invalid config response: {e}") + + def _call_rpc( + self, + function_name: str, + params: Optional[Dict[str, Any]] = None, + ) -> Any: + """ + Call a Supabase RPC function. + + Args: + function_name: The name of the RPC function to call + params: Optional dictionary of parameters to pass to the function + + Returns: + The response data from the RPC call + """ + if not self._access_token: + raise ValueError("No access token available. Please login first.") + if not self._publishable_key: + raise ValueError("Publishable key not available") + + url = f"{self.host}/supabase/rest/v1/rpc/{function_name}" + + headers = { + "Accept": "application/json", + "Apikey": self._publishable_key, + "Authorization": f"Bearer {self._access_token}", + "Content-Profile": "app_data", + "Content-Type": "application/json", + } + + request_body = json.dumps(params or {}).encode("utf-8") + + def make_rpc_request() -> Any: + """Inner function to make the actual RPC request.""" + req = Request(url, data=request_body, headers=headers, method="POST") + with urlopen(req) as response: + response_data = response.read() + return json.loads(response_data.decode("utf-8")) + + try: + return make_rpc_request() + except HTTPError as e: + # Handle 401 Unauthorized with token refresh + if e.code == 401 and self._access_token and self._refresh_token: + try: + self._refresh_supabase_token() + # Update headers with new token + headers["Authorization"] = f"Bearer {self._access_token}" + + # Retry the request + req = Request( + url, data=request_body, headers=headers, method="POST" + ) + with urlopen(req) as response: + response_data = response.read() + return json.loads(response_data.decode("utf-8")) + except (HTTPError, URLError, ValueError, KeyError) as refresh_error: + logger.warning( + "Token refresh failed, continuing with original request: %s", + refresh_error, + exc_info=True, + ) + # If refresh fails, re-raise the original 401 error + raise e + raise + + def _query_table( + self, + table_name: str, + query_params: Optional[Dict[str, str]] = None, + ) -> Any: + """ + Query a Supabase table via PostgREST. + + Args: + table_name: The name of the table to query + query_params: Optional dict of PostgREST query parameters + (e.g., {"name": "eq.foo", "order": "name"}) + + Returns: + The response data from the query + """ + if not self._access_token: + raise ValueError("No access token available. Please login first.") + if not self._publishable_key: + raise ValueError("Publishable key not available") + + url = f"{self.host}/supabase/rest/v1/{table_name}" + if query_params: + encoded_params = urlencode(query_params) + url = f"{url}?{encoded_params}" + + headers = { + "Accept": "application/json", + "Accept-Profile": "app_data", + "Apikey": self._publishable_key, + "Authorization": f"Bearer {self._access_token}", + } + + def make_query_request() -> Any: + """Inner function to make the actual query request.""" + req = Request(url, headers=headers, method="GET") + with urlopen(req) as response: + response_data = response.read() + return json.loads(response_data.decode("utf-8")) + + try: + return make_query_request() + except HTTPError as e: + # Handle 401 Unauthorized with token refresh + if e.code == 401 and self._access_token and self._refresh_token: + try: + self._refresh_supabase_token() + # Update headers with new token + headers["Authorization"] = f"Bearer {self._access_token}" + + # Retry the request + req = Request(url, headers=headers, method="GET") + with urlopen(req) as response: + response_data = response.read() + return json.loads(response_data.decode("utf-8")) + except (HTTPError, URLError, ValueError, KeyError) as refresh_error: + logger.warning( + "Token refresh failed, continuing with original request: %s", + refresh_error, + exc_info=True, + ) + # If refresh fails, re-raise the original 401 error + raise e + raise + + def get_eav_tables(self) -> List[EAVTable]: + """ + Get all EAV tables for the tenant. + + Returns: + List of EAVTable objects + """ + response_data = self._query_table( + "eav_tables", + {"tenant_id": f"eq.{self._tenant_id}", "order": "name"}, + ) + return [EAVTable.model_validate(table) for table in response_data] + + def get_eav_columns(self, table_id: str) -> List[EAVColumn]: + """ + Get all columns for an EAV table. + + Args: + table_id: The UUID of the EAV table + + Returns: + List of EAVColumn objects + """ + response_data = self._query_table( + "eav_columns", + {"eav_table_id": f"eq.{table_id}", "order": "name"}, + ) + return [EAVColumn.model_validate(column) for column in response_data] + + def select_from_eav_id( + self, + table_id: str, + column_ids: Optional[List[str]] = None, + filters: Optional[List[EAVFilter]] = None, + ) -> List[Dict[str, Any]]: + """ + Query pivoted data from EAV tables using table and column IDs directly. + + This is the lower-level method that works with IDs. For a more user-friendly + interface using names, see select_from_eav(). + + Args: + table_id: The UUID of the EAV table to query + column_ids: Optional list of column IDs to include in the result. + If None, all columns are returned. + filters: Optional list of EAVFilter objects to filter the results. + Each filter must use column_id (not column_name) and can specify: + - column_id: The EAV column ID to filter on (required) + - numeric_min: Minimum numeric value (inclusive) + - numeric_max: Maximum numeric value (exclusive) + - string_value: Exact string value to match + - boolean_value: Boolean value to match + + Returns: + List of dictionaries, each representing a row with column values. + Each row contains an 'id' field plus fields for each column_id + with their corresponding values (number, string, or boolean). + + Example: + >>> filters = [ + ... EAVFilter(column_id="MTnaC", numeric_min=100, numeric_max=200), + ... EAVFilter(column_id="z09PL", string_value="ACTIVE"), + ... ] + >>> rows = client.select_from_eav_id( + ... table_id="550e8400-e29b-41d4-a716-446655440000", + ... column_ids=["MTnaC", "z09PL", "ZNAGI"], + ... filters=filters, + ... ) + """ + # Convert EAVFilter objects to dictionaries + filters_json: List[Dict[str, Any]] = [] + if filters: + for f in filters: + if f.column_id is None: + raise ValueError( + "Filters for select_from_eav_id must use column_id, " + "not column_name" + ) + filter_dict: Dict[str, Any] = {"column_id": f.column_id} + if f.numeric_min is not None: + filter_dict["numeric_min"] = f.numeric_min + if f.numeric_max is not None: + filter_dict["numeric_max"] = f.numeric_max + if f.string_value is not None: + filter_dict["string_value"] = f.string_value + if f.boolean_value is not None: + filter_dict["boolean_value"] = f.boolean_value + filters_json.append(filter_dict) + + params: Dict[str, Any] = { + "tenant_id": self._tenant_id, + "table_id": table_id, + "column_ids": column_ids, + "filters": filters_json, + } + + response_data = self._call_rpc("select_from_eav", params) + return list(response_data) if response_data else [] + + def select_from_eav( + self, + table_name: str, + column_names: Optional[List[str]] = None, + filters: Optional[List[EAVFilter]] = None, + ) -> List[Dict[str, Any]]: + """ + Query pivoted data from EAV tables using human-readable names. + + This method looks up table and column IDs from names, then calls + select_from_eav_id(). For direct ID access, use select_from_eav_id(). + + Args: + table_name: The name of the EAV table to query + column_names: Optional list of column names to include in the result. + If None, all columns are returned. + filters: Optional list of EAVFilter objects to filter the results. + Each filter must use column_name (not column_id) and can specify: + - column_name: The EAV column name to filter on (required) + - numeric_min: Minimum numeric value (inclusive) + - numeric_max: Maximum numeric value (exclusive) + - string_value: Exact string value to match + - boolean_value: Boolean value to match + + Returns: + List of dictionaries, each representing a row with column values. + Each row contains an 'id' field plus fields for each column name + with their corresponding values (number, string, or boolean). + + Example: + >>> filters = [ + ... EAVFilter(column_name="Weight", numeric_min=100, numeric_max=200), + ... EAVFilter(column_name="Status", string_value="ACTIVE"), + ... ] + >>> rows = client.select_from_eav( + ... table_name="BT/Scrap Entry", + ... column_names=["Weight", "Status", "Is Verified"], + ... filters=filters, + ... ) + """ + # Look up the table ID from the table name + tables_response = self._query_table( + "eav_tables", + { + "tenant_id": f"eq.{self._tenant_id}", + "name": f"eq.{table_name}", + "limit": "1", + }, + ) + if not tables_response: + raise ValueError(f"Table '{table_name}' not found") + table_id = tables_response[0]["id"] + + # Get all columns for the table to build name <-> id mappings + columns = self.get_eav_columns(table_id) + column_name_to_id = {col.name: col.eav_column_id for col in columns} + column_id_to_name = {col.eav_column_id: col.name for col in columns} + + # Convert column names to column IDs + column_ids: Optional[List[str]] = None + if column_names: + column_ids = [] + for name in column_names: + if name not in column_name_to_id: + raise ValueError( + f"Column '{name}' not found in table '{table_name}'" + ) + column_ids.append(column_name_to_id[name]) + + # Convert filters with column_name to filters with column_id + id_filters: Optional[List[EAVFilter]] = None + if filters: + id_filters = [] + for f in filters: + if f.column_name is None: + raise ValueError( + "Filters for select_from_eav must use column_name, " + "not column_id" + ) + if f.column_name not in column_name_to_id: + raise ValueError( + f"Filter column '{f.column_name}' not found in table " + f"'{table_name}'" + ) + id_filters.append( + EAVFilter( + column_id=column_name_to_id[f.column_name], + numeric_min=f.numeric_min, + numeric_max=f.numeric_max, + string_value=f.string_value, + boolean_value=f.boolean_value, + ) + ) + + # Call the ID-based method + response_data = self.select_from_eav_id( + table_id=table_id, + column_ids=column_ids, + filters=id_filters, + ) + + # Convert column IDs back to names in the response + result: List[Dict[str, Any]] = [] + for row in response_data: + converted_row: Dict[str, Any] = {} + for key, value in row.items(): + if key == "id": + converted_row[key] = value + elif key in column_id_to_name: + converted_row[column_id_to_name[key]] = value + else: + converted_row[key] = value + result.append(converted_row) + + return result diff --git a/src/cvec/models/__init__.py b/src/cvec/models/__init__.py index e3172fd..dd1a3b0 100644 --- a/src/cvec/models/__init__.py +++ b/src/cvec/models/__init__.py @@ -1,7 +1,13 @@ +from .eav_column import EAVColumn +from .eav_filter import EAVFilter +from .eav_table import EAVTable from .metric import Metric, MetricDataPoint from .span import Span __all__ = [ + "EAVColumn", + "EAVFilter", + "EAVTable", "Metric", "MetricDataPoint", "Span", diff --git a/src/cvec/models/eav_column.py b/src/cvec/models/eav_column.py new file mode 100644 index 0000000..a5222c3 --- /dev/null +++ b/src/cvec/models/eav_column.py @@ -0,0 +1,18 @@ +from datetime import datetime +from typing import Optional + +from pydantic import BaseModel, ConfigDict + + +class EAVColumn(BaseModel): + """ + Represents an EAV column metadata record. + """ + + eav_table_id: str + eav_column_id: str + name: str + type: str + created_at: Optional[datetime] = None + + model_config = ConfigDict(json_encoders={datetime: lambda dt: dt.isoformat()}) diff --git a/src/cvec/models/eav_filter.py b/src/cvec/models/eav_filter.py new file mode 100644 index 0000000..3be49f6 --- /dev/null +++ b/src/cvec/models/eav_filter.py @@ -0,0 +1,33 @@ +from typing import Optional, Union + +from pydantic import BaseModel, model_validator + + +class EAVFilter(BaseModel): + """ + Represents a filter for querying EAV data. + + Filters are used to narrow down results based on column values: + - Use column_name with select_from_eav() for human-readable column names + - Use column_id with select_from_eav_id() for direct column IDs + - Use numeric_min/numeric_max for numeric range filtering (min inclusive, max exclusive) + - Use string_value for exact string matching + - Use boolean_value for boolean matching + + Exactly one of column_name or column_id must be provided. + """ + + column_name: Optional[str] = None + column_id: Optional[str] = None + numeric_min: Optional[Union[int, float]] = None + numeric_max: Optional[Union[int, float]] = None + string_value: Optional[str] = None + boolean_value: Optional[bool] = None + + @model_validator(mode="after") + def check_column_identifier(self) -> "EAVFilter": + if self.column_name is None and self.column_id is None: + raise ValueError("Either column_name or column_id must be provided") + if self.column_name is not None and self.column_id is not None: + raise ValueError("Only one of column_name or column_id should be provided") + return self diff --git a/src/cvec/models/eav_table.py b/src/cvec/models/eav_table.py new file mode 100644 index 0000000..2e7f023 --- /dev/null +++ b/src/cvec/models/eav_table.py @@ -0,0 +1,22 @@ +from datetime import datetime +from typing import Optional + +from pydantic import BaseModel, ConfigDict + + +class EAVTable(BaseModel): + """ + Represents an EAV table metadata record. + """ + + id: str + tenant_id: int + name: str + continuation_token: Optional[str] = None + last_sync_at: Optional[datetime] = None + total_rows_synced: Optional[int] = None + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + last_etag: Optional[str] = None + + model_config = ConfigDict(json_encoders={datetime: lambda dt: dt.isoformat()}) diff --git a/tests/test_cvec.py b/tests/test_cvec.py index 31378bd..9acf00d 100644 --- a/tests/test_cvec.py +++ b/tests/test_cvec.py @@ -1,36 +1,90 @@ -import pytest +import io import os -from unittest.mock import patch from datetime import datetime -from cvec import CVec -from cvec.models.metric import Metric +from typing import Any +from unittest.mock import patch + import pyarrow as pa # type: ignore[import-untyped] import pyarrow.ipc as ipc # type: ignore[import-untyped] -import io -from typing import Any +import pytest + +from cvec import CVec, EAVFilter +from cvec.models.metric import Metric + + +def mock_fetch_config_side_effect(instance: CVec) -> str: + """Side effect for _fetch_config mock that sets tenant_id.""" + instance._tenant_id = 1 + return "test_publishable_key" class TestCVecConstructor: @patch.object(CVec, "_login_with_supabase", return_value=None) - @patch.object(CVec, "_fetch_publishable_key", return_value="test_publishable_key") + @patch.object( + CVec, "_fetch_config", autospec=True, side_effect=mock_fetch_config_side_effect + ) def test_constructor_with_arguments( self, mock_fetch_key: Any, mock_login: Any ) -> None: """Test CVec constructor with all arguments provided.""" client = CVec( - host="test_host", + host="https://test_host", default_start_at=datetime(2023, 1, 1, 0, 0, 0), default_end_at=datetime(2023, 1, 2, 0, 0, 0), api_key="cva_hHs0CbkKALxMnxUdI9hanF0TBPvvvr1HjG6O", ) - assert client.host == "test_host" + assert client.host == "https://test_host" assert client.default_start_at == datetime(2023, 1, 1, 0, 0, 0) assert client.default_end_at == datetime(2023, 1, 2, 0, 0, 0) assert client._publishable_key == "test_publishable_key" assert client._api_key == "cva_hHs0CbkKALxMnxUdI9hanF0TBPvvvr1HjG6O" @patch.object(CVec, "_login_with_supabase", return_value=None) - @patch.object(CVec, "_fetch_publishable_key", return_value="env_publishable_key") + @patch.object( + CVec, "_fetch_config", autospec=True, side_effect=mock_fetch_config_side_effect + ) + def test_constructor_adds_https_scheme( + self, mock_fetch_key: Any, mock_login: Any + ) -> None: + """Test CVec constructor adds https:// scheme if not provided.""" + client = CVec( + host="example.cvector.dev", + api_key="cva_hHs0CbkKALxMnxUdI9hanF0TBPvvvr1HjG6O", + ) + assert client.host == "https://example.cvector.dev" + + @patch.object(CVec, "_login_with_supabase", return_value=None) + @patch.object( + CVec, "_fetch_config", autospec=True, side_effect=mock_fetch_config_side_effect + ) + def test_constructor_preserves_https_scheme( + self, mock_fetch_key: Any, mock_login: Any + ) -> None: + """Test CVec constructor preserves https:// scheme if already provided.""" + client = CVec( + host="https://example.cvector.dev", + api_key="cva_hHs0CbkKALxMnxUdI9hanF0TBPvvvr1HjG6O", + ) + assert client.host == "https://example.cvector.dev" + + @patch.object(CVec, "_login_with_supabase", return_value=None) + @patch.object( + CVec, "_fetch_config", autospec=True, side_effect=mock_fetch_config_side_effect + ) + def test_constructor_preserves_http_scheme( + self, mock_fetch_key: Any, mock_login: Any + ) -> None: + """Test CVec constructor preserves http:// scheme if provided.""" + client = CVec( + host="http://localhost:3000", + api_key="cva_hHs0CbkKALxMnxUdI9hanF0TBPvvvr1HjG6O", + ) + assert client.host == "http://localhost:3000" + + @patch.object(CVec, "_login_with_supabase", return_value=None) + @patch.object( + CVec, "_fetch_config", autospec=True, side_effect=mock_fetch_config_side_effect + ) @patch.dict( os.environ, { @@ -47,14 +101,16 @@ def test_constructor_with_env_vars( default_start_at=datetime(2023, 2, 1, 0, 0, 0), default_end_at=datetime(2023, 2, 2, 0, 0, 0), ) - assert client.host == "env_host" - assert client._publishable_key == "env_publishable_key" + assert client.host == "https://env_host" + assert client._publishable_key == "test_publishable_key" assert client._api_key == "cva_hHs0CbkKALxMnxUdI9hanF0TBPvvvr1HjG6O" assert client.default_start_at == datetime(2023, 2, 1, 0, 0, 0) assert client.default_end_at == datetime(2023, 2, 2, 0, 0, 0) @patch.object(CVec, "_login_with_supabase", return_value=None) - @patch.object(CVec, "_fetch_publishable_key", return_value="test_publishable_key") + @patch.object( + CVec, "_fetch_config", autospec=True, side_effect=mock_fetch_config_side_effect + ) @patch.dict(os.environ, {}, clear=True) def test_constructor_missing_host_raises_value_error( self, mock_fetch_key: Any, mock_login: Any @@ -67,7 +123,9 @@ def test_constructor_missing_host_raises_value_error( CVec(api_key="cva_hHs0CbkKALxMnxUdI9hanF0TBPvvvr1HjG6O") @patch.object(CVec, "_login_with_supabase", return_value=None) - @patch.object(CVec, "_fetch_publishable_key", return_value="test_publishable_key") + @patch.object( + CVec, "_fetch_config", autospec=True, side_effect=mock_fetch_config_side_effect + ) @patch.dict(os.environ, {}, clear=True) def test_constructor_missing_api_key_raises_value_error( self, mock_fetch_key: Any, mock_login: Any @@ -80,7 +138,9 @@ def test_constructor_missing_api_key_raises_value_error( CVec(host="test_host") @patch.object(CVec, "_login_with_supabase", return_value=None) - @patch.object(CVec, "_fetch_publishable_key", return_value="test_publishable_key") + @patch.object( + CVec, "_fetch_config", autospec=True, side_effect=mock_fetch_config_side_effect + ) def test_constructor_args_override_env_vars( self, mock_fetch_key: Any, mock_login: Any ) -> None: @@ -99,13 +159,15 @@ def test_constructor_args_override_env_vars( default_end_at=datetime(2023, 3, 2, 0, 0, 0), api_key="cva_differentKeyKALxMnxUdI9hanF0TBPvvvr1", ) - assert client.host == "arg_host" + assert client.host == "https://arg_host" assert client._api_key == "cva_differentKeyKALxMnxUdI9hanF0TBPvvvr1" assert client.default_start_at == datetime(2023, 3, 1, 0, 0, 0) assert client.default_end_at == datetime(2023, 3, 2, 0, 0, 0) @patch.object(CVec, "_login_with_supabase", return_value=None) - @patch.object(CVec, "_fetch_publishable_key", return_value="test_publishable_key") + @patch.object( + CVec, "_fetch_config", autospec=True, side_effect=mock_fetch_config_side_effect + ) def test_construct_email_from_api_key( self, mock_fetch_key: Any, mock_login: Any ) -> None: @@ -118,7 +180,9 @@ def test_construct_email_from_api_key( assert email == "cva+hHs0@cvector.app" @patch.object(CVec, "_login_with_supabase", return_value=None) - @patch.object(CVec, "_fetch_publishable_key", return_value="test_publishable_key") + @patch.object( + CVec, "_fetch_config", autospec=True, side_effect=mock_fetch_config_side_effect + ) def test_construct_email_from_api_key_invalid_format( self, mock_fetch_key: Any, mock_login: Any ) -> None: @@ -132,7 +196,9 @@ def test_construct_email_from_api_key_invalid_format( client._construct_email_from_api_key() @patch.object(CVec, "_login_with_supabase", return_value=None) - @patch.object(CVec, "_fetch_publishable_key", return_value="test_publishable_key") + @patch.object( + CVec, "_fetch_config", autospec=True, side_effect=mock_fetch_config_side_effect + ) def test_construct_email_from_api_key_invalid_length( self, mock_fetch_key: Any, mock_login: Any ) -> None: @@ -150,7 +216,9 @@ def test_construct_email_from_api_key_invalid_length( class TestCVecGetSpans: @patch.object(CVec, "_login_with_supabase", return_value=None) - @patch.object(CVec, "_fetch_publishable_key", return_value="test_publishable_key") + @patch.object( + CVec, "_fetch_config", autospec=True, side_effect=mock_fetch_config_side_effect + ) def test_get_spans_basic_case(self, mock_fetch_key: Any, mock_login: Any) -> None: # Simulate backend response response_data = [ @@ -190,7 +258,9 @@ def test_get_spans_basic_case(self, mock_fetch_key: Any, mock_login: Any) -> Non class TestCVecGetMetrics: @patch.object(CVec, "_login_with_supabase", return_value=None) - @patch.object(CVec, "_fetch_publishable_key", return_value="test_publishable_key") + @patch.object( + CVec, "_fetch_config", autospec=True, side_effect=mock_fetch_config_side_effect + ) def test_get_metrics_no_interval( self, mock_fetch_key: Any, mock_login: Any ) -> None: @@ -222,7 +292,9 @@ def test_get_metrics_no_interval( assert metrics[1].name == "metric2" @patch.object(CVec, "_login_with_supabase", return_value=None) - @patch.object(CVec, "_fetch_publishable_key", return_value="test_publishable_key") + @patch.object( + CVec, "_fetch_config", autospec=True, side_effect=mock_fetch_config_side_effect + ) def test_get_metrics_with_interval( self, mock_fetch_key: Any, mock_login: Any ) -> None: @@ -247,7 +319,9 @@ def test_get_metrics_with_interval( assert metrics[0].name == "metric_in_interval" @patch.object(CVec, "_login_with_supabase", return_value=None) - @patch.object(CVec, "_fetch_publishable_key", return_value="test_publishable_key") + @patch.object( + CVec, "_fetch_config", autospec=True, side_effect=mock_fetch_config_side_effect + ) def test_get_metrics_no_data_found( self, mock_fetch_key: Any, mock_login: Any ) -> None: @@ -264,7 +338,9 @@ def test_get_metrics_no_data_found( class TestCVecGetMetricData: @patch.object(CVec, "_login_with_supabase", return_value=None) - @patch.object(CVec, "_fetch_publishable_key", return_value="test_publishable_key") + @patch.object( + CVec, "_fetch_config", autospec=True, side_effect=mock_fetch_config_side_effect + ) def test_get_metric_data_basic_case( self, mock_fetch_key: Any, mock_login: Any ) -> None: @@ -299,7 +375,9 @@ def test_get_metric_data_basic_case( assert data_points[2].value_string == "val_str" @patch.object(CVec, "_login_with_supabase", return_value=None) - @patch.object(CVec, "_fetch_publishable_key", return_value="test_publishable_key") + @patch.object( + CVec, "_fetch_config", autospec=True, side_effect=mock_fetch_config_side_effect + ) def test_get_metric_data_no_data_points( self, mock_fetch_key: Any, mock_login: Any ) -> None: @@ -312,7 +390,9 @@ def test_get_metric_data_no_data_points( assert data_points == [] @patch.object(CVec, "_login_with_supabase", return_value=None) - @patch.object(CVec, "_fetch_publishable_key", return_value="test_publishable_key") + @patch.object( + CVec, "_fetch_config", autospec=True, side_effect=mock_fetch_config_side_effect + ) def test_get_metric_arrow_basic_case( self, mock_fetch_key: Any, mock_login: Any ) -> None: @@ -355,7 +435,9 @@ def test_get_metric_arrow_basic_case( ] @patch.object(CVec, "_login_with_supabase", return_value=None) - @patch.object(CVec, "_fetch_publishable_key", return_value="test_publishable_key") + @patch.object( + CVec, "_fetch_config", autospec=True, side_effect=mock_fetch_config_side_effect + ) def test_get_metric_arrow_empty(self, mock_fetch_key: Any, mock_login: Any) -> None: table = pa.table( { @@ -381,3 +463,437 @@ def test_get_metric_arrow_empty(self, mock_fetch_key: Any, mock_login: Any) -> N assert result_table.column("name").to_pylist() == [] assert result_table.column("value_double").to_pylist() == [] assert result_table.column("value_string").to_pylist() == [] + + +class TestEAVFilter: + def test_eav_filter_with_column_name(self) -> None: + """Test EAVFilter with column_name.""" + filter_obj = EAVFilter(column_name="Date") + assert filter_obj.column_name == "Date" + assert filter_obj.column_id is None + + def test_eav_filter_with_column_id(self) -> None: + """Test EAVFilter with column_id.""" + filter_obj = EAVFilter(column_id="MTnaC") + assert filter_obj.column_id == "MTnaC" + assert filter_obj.column_name is None + + def test_eav_filter_numeric_range(self) -> None: + """Test EAVFilter with numeric range.""" + filter_obj = EAVFilter(column_name="Date", numeric_min=100, numeric_max=200) + assert filter_obj.column_name == "Date" + assert filter_obj.numeric_min == 100 + assert filter_obj.numeric_max == 200 + + def test_eav_filter_string_value(self) -> None: + """Test EAVFilter with string value.""" + filter_obj = EAVFilter(column_name="Status", string_value="failure") + assert filter_obj.column_name == "Status" + assert filter_obj.string_value == "failure" + + def test_eav_filter_boolean_value(self) -> None: + """Test EAVFilter with boolean value.""" + filter_obj = EAVFilter(column_name="Is Active", boolean_value=False) + assert filter_obj.column_name == "Is Active" + assert filter_obj.boolean_value is False + + def test_eav_filter_requires_column_identifier(self) -> None: + """Test EAVFilter raises error when neither column_name nor column_id.""" + with pytest.raises(ValueError, match="Either column_name or column_id"): + EAVFilter(numeric_min=100) + + def test_eav_filter_rejects_both_identifiers(self) -> None: + """Test EAVFilter raises error when both column_name and column_id.""" + with pytest.raises(ValueError, match="Only one of column_name or column_id"): + EAVFilter(column_name="Date", column_id="MTnaC") + + +class TestCVecSelectFromEAV: + """Tests for select_from_eav using table_name and column_names.""" + + def _mock_query_table( + self, table_id: str = "7a80f3a2-6fa1-43ce-8483-76bd00dc93c6" + ) -> Any: + """Create a mock for _query_table that returns table and column data.""" + + def mock_query( + table_name: str, query_params: dict[str, str] | None = None + ) -> Any: + if table_name == "eav_tables": + return [{"id": table_id, "tenant_id": 1, "name": "Test Table"}] + elif table_name == "eav_columns": + return [ + { + "eav_table_id": table_id, + "eav_column_id": "col1_id", + "name": "Column 1", + "type": "number", + }, + { + "eav_table_id": table_id, + "eav_column_id": "col2_id", + "name": "Column 2", + "type": "string", + }, + { + "eav_table_id": table_id, + "eav_column_id": "is_active_id", + "name": "Is Active", + "type": "boolean", + }, + ] + return [] + + return mock_query + + @patch.object(CVec, "_login_with_supabase", return_value=None) + @patch.object( + CVec, "_fetch_config", autospec=True, side_effect=mock_fetch_config_side_effect + ) + def test_select_from_eav_basic(self, mock_fetch_key: Any, mock_login: Any) -> None: + """Test select_from_eav with no filters.""" + # Response uses column IDs + rpc_response = [ + {"id": "row1", "col1_id": 100.5, "col2_id": "value1"}, + {"id": "row2", "col1_id": 200.0, "col2_id": "value2"}, + ] + client = CVec( + host="test_host", + api_key="cva_hHs0CbkKALxMnxUdI9hanF0TBPvvvr1HjG6O", + ) + client._query_table = self._mock_query_table() # type: ignore[method-assign] + client._call_rpc = lambda *args, **kwargs: rpc_response # type: ignore[method-assign] + + result = client.select_from_eav( + table_name="Test Table", + ) + + # Result should have column names, not IDs + assert len(result) == 2 + assert result[0]["id"] == "row1" + assert result[0]["Column 1"] == 100.5 + assert result[1]["Column 2"] == "value2" + + @patch.object(CVec, "_login_with_supabase", return_value=None) + @patch.object( + CVec, "_fetch_config", autospec=True, side_effect=mock_fetch_config_side_effect + ) + def test_select_from_eav_with_column_names( + self, mock_fetch_key: Any, mock_login: Any + ) -> None: + """Test select_from_eav with specific column_names.""" + rpc_response = [ + {"id": "row1", "col1_id": 100.5}, + {"id": "row2", "col1_id": 200.0}, + ] + client = CVec( + host="test_host", + api_key="cva_hHs0CbkKALxMnxUdI9hanF0TBPvvvr1HjG6O", + ) + + captured_params: dict[str, Any] = {} + + def mock_call_rpc(name: str, params: Any) -> Any: + captured_params.update(params) + return rpc_response + + client._query_table = self._mock_query_table() # type: ignore[method-assign] + client._call_rpc = mock_call_rpc # type: ignore[assignment, method-assign] + + result = client.select_from_eav( + table_name="Test Table", + column_names=["Column 1"], + ) + + assert len(result) == 2 + # Should translate column name to column ID for the RPC call + assert captured_params["column_ids"] == ["col1_id"] + + @patch.object(CVec, "_login_with_supabase", return_value=None) + @patch.object( + CVec, "_fetch_config", autospec=True, side_effect=mock_fetch_config_side_effect + ) + def test_select_from_eav_with_filters( + self, mock_fetch_key: Any, mock_login: Any + ) -> None: + """Test select_from_eav with filters.""" + rpc_response = [ + {"id": "row1", "col1_id": 150.0, "col2_id": "ACTIVE"}, + ] + client = CVec( + host="test_host", + api_key="cva_hHs0CbkKALxMnxUdI9hanF0TBPvvvr1HjG6O", + ) + + captured_params: dict[str, Any] = {} + + def mock_call_rpc(name: str, params: Any) -> Any: + captured_params.update(params) + return rpc_response + + client._query_table = self._mock_query_table() # type: ignore[method-assign] + client._call_rpc = mock_call_rpc # type: ignore[assignment, method-assign] + + filters = [ + EAVFilter(column_name="Column 1", numeric_min=100, numeric_max=200), + EAVFilter(column_name="Column 2", string_value="ACTIVE"), + ] + + result = client.select_from_eav( + table_name="Test Table", + filters=filters, + ) + + assert len(result) == 1 + assert result[0]["Column 1"] == 150.0 + # Filters should use column IDs in RPC call + assert captured_params["filters"] == [ + {"column_id": "col1_id", "numeric_min": 100, "numeric_max": 200}, + {"column_id": "col2_id", "string_value": "ACTIVE"}, + ] + + @patch.object(CVec, "_login_with_supabase", return_value=None) + @patch.object( + CVec, "_fetch_config", autospec=True, side_effect=mock_fetch_config_side_effect + ) + def test_select_from_eav_with_boolean_filter( + self, mock_fetch_key: Any, mock_login: Any + ) -> None: + """Test select_from_eav with boolean filter.""" + rpc_response = [ + {"id": "row1", "is_active_id": True}, + ] + client = CVec( + host="test_host", + api_key="cva_hHs0CbkKALxMnxUdI9hanF0TBPvvvr1HjG6O", + ) + + captured_params: dict[str, Any] = {} + + def mock_call_rpc(name: str, params: Any) -> Any: + captured_params.update(params) + return rpc_response + + client._query_table = self._mock_query_table() # type: ignore[method-assign] + client._call_rpc = mock_call_rpc # type: ignore[assignment, method-assign] + + filters = [EAVFilter(column_name="Is Active", boolean_value=True)] + + result = client.select_from_eav( + table_name="Test Table", + filters=filters, + ) + + assert len(result) == 1 + assert captured_params["filters"] == [ + {"column_id": "is_active_id", "boolean_value": True} + ] + + @patch.object(CVec, "_login_with_supabase", return_value=None) + @patch.object( + CVec, "_fetch_config", autospec=True, side_effect=mock_fetch_config_side_effect + ) + def test_select_from_eav_empty_result( + self, mock_fetch_key: Any, mock_login: Any + ) -> None: + """Test select_from_eav with empty result.""" + client = CVec( + host="test_host", + api_key="cva_hHs0CbkKALxMnxUdI9hanF0TBPvvvr1HjG6O", + ) + client._query_table = self._mock_query_table() # type: ignore[method-assign] + client._call_rpc = lambda *args, **kwargs: [] # type: ignore[method-assign] + + result = client.select_from_eav( + table_name="Test Table", + ) + + assert result == [] + + @patch.object(CVec, "_login_with_supabase", return_value=None) + @patch.object( + CVec, "_fetch_config", autospec=True, side_effect=mock_fetch_config_side_effect + ) + def test_select_from_eav_table_not_found( + self, mock_fetch_key: Any, mock_login: Any + ) -> None: + """Test select_from_eav raises error when table not found.""" + client = CVec( + host="test_host", + api_key="cva_hHs0CbkKALxMnxUdI9hanF0TBPvvvr1HjG6O", + ) + client._query_table = lambda *args, **kwargs: [] # type: ignore[method-assign] + + with pytest.raises(ValueError, match="Table 'Unknown Table' not found"): + client.select_from_eav( + table_name="Unknown Table", + ) + + @patch.object(CVec, "_login_with_supabase", return_value=None) + @patch.object( + CVec, "_fetch_config", autospec=True, side_effect=mock_fetch_config_side_effect + ) + def test_select_from_eav_column_not_found( + self, mock_fetch_key: Any, mock_login: Any + ) -> None: + """Test select_from_eav raises error when column not found.""" + client = CVec( + host="test_host", + api_key="cva_hHs0CbkKALxMnxUdI9hanF0TBPvvvr1HjG6O", + ) + client._query_table = self._mock_query_table() # type: ignore[method-assign] + + with pytest.raises( + ValueError, match="Column 'Unknown Column' not found in table 'Test Table'" + ): + client.select_from_eav( + table_name="Test Table", + column_names=["Unknown Column"], + ) + + +class TestCVecSelectFromEAVId: + """Tests for select_from_eav_id using table_id and column_ids directly.""" + + @patch.object(CVec, "_login_with_supabase", return_value=None) + @patch.object( + CVec, "_fetch_config", autospec=True, side_effect=mock_fetch_config_side_effect + ) + def test_select_from_eav_id_basic( + self, mock_fetch_key: Any, mock_login: Any + ) -> None: + """Test select_from_eav_id with no filters.""" + rpc_response = [ + {"id": "row1", "col1_id": 100.5, "col2_id": "value1"}, + {"id": "row2", "col1_id": 200.0, "col2_id": "value2"}, + ] + client = CVec( + host="test_host", + api_key="cva_hHs0CbkKALxMnxUdI9hanF0TBPvvvr1HjG6O", + ) + client._call_rpc = lambda *args, **kwargs: rpc_response # type: ignore[method-assign] + + result = client.select_from_eav_id( + table_id="7a80f3a2-6fa1-43ce-8483-76bd00dc93c6", + ) + + # Result keeps column IDs (no name translation) + assert len(result) == 2 + assert result[0]["id"] == "row1" + assert result[0]["col1_id"] == 100.5 + assert result[1]["col2_id"] == "value2" + + @patch.object(CVec, "_login_with_supabase", return_value=None) + @patch.object( + CVec, "_fetch_config", autospec=True, side_effect=mock_fetch_config_side_effect + ) + def test_select_from_eav_id_with_column_ids( + self, mock_fetch_key: Any, mock_login: Any + ) -> None: + """Test select_from_eav_id with specific column_ids.""" + rpc_response = [ + {"id": "row1", "col1_id": 100.5}, + ] + client = CVec( + host="test_host", + api_key="cva_hHs0CbkKALxMnxUdI9hanF0TBPvvvr1HjG6O", + ) + + captured_params: dict[str, Any] = {} + + def mock_call_rpc(name: str, params: Any) -> Any: + captured_params.update(params) + return rpc_response + + client._call_rpc = mock_call_rpc # type: ignore[assignment, method-assign] + + result = client.select_from_eav_id( + table_id="7a80f3a2-6fa1-43ce-8483-76bd00dc93c6", + column_ids=["col1_id"], + ) + + assert len(result) == 1 + assert captured_params["column_ids"] == ["col1_id"] + + @patch.object(CVec, "_login_with_supabase", return_value=None) + @patch.object( + CVec, "_fetch_config", autospec=True, side_effect=mock_fetch_config_side_effect + ) + def test_select_from_eav_id_with_filters( + self, mock_fetch_key: Any, mock_login: Any + ) -> None: + """Test select_from_eav_id with filters using column_id.""" + rpc_response = [ + {"id": "row1", "col1_id": 150.0, "col2_id": "ACTIVE"}, + ] + client = CVec( + host="test_host", + api_key="cva_hHs0CbkKALxMnxUdI9hanF0TBPvvvr1HjG6O", + ) + + captured_params: dict[str, Any] = {} + + def mock_call_rpc(name: str, params: Any) -> Any: + captured_params.update(params) + return rpc_response + + client._call_rpc = mock_call_rpc # type: ignore[assignment, method-assign] + + filters = [ + EAVFilter(column_id="col1_id", numeric_min=100, numeric_max=200), + EAVFilter(column_id="col2_id", string_value="ACTIVE"), + ] + + result = client.select_from_eav_id( + table_id="7a80f3a2-6fa1-43ce-8483-76bd00dc93c6", + filters=filters, + ) + + assert len(result) == 1 + assert captured_params["filters"] == [ + {"column_id": "col1_id", "numeric_min": 100, "numeric_max": 200}, + {"column_id": "col2_id", "string_value": "ACTIVE"}, + ] + + @patch.object(CVec, "_login_with_supabase", return_value=None) + @patch.object( + CVec, "_fetch_config", autospec=True, side_effect=mock_fetch_config_side_effect + ) + def test_select_from_eav_id_rejects_column_name_filter( + self, mock_fetch_key: Any, mock_login: Any + ) -> None: + """Test select_from_eav_id raises error when filter uses column_name.""" + client = CVec( + host="test_host", + api_key="cva_hHs0CbkKALxMnxUdI9hanF0TBPvvvr1HjG6O", + ) + + filters = [EAVFilter(column_name="Column 1", numeric_min=100)] + + with pytest.raises( + ValueError, match="Filters for select_from_eav_id must use column_id" + ): + client.select_from_eav_id( + table_id="7a80f3a2-6fa1-43ce-8483-76bd00dc93c6", + filters=filters, + ) + + @patch.object(CVec, "_login_with_supabase", return_value=None) + @patch.object( + CVec, "_fetch_config", autospec=True, side_effect=mock_fetch_config_side_effect + ) + def test_select_from_eav_id_empty_result( + self, mock_fetch_key: Any, mock_login: Any + ) -> None: + """Test select_from_eav_id with empty result.""" + client = CVec( + host="test_host", + api_key="cva_hHs0CbkKALxMnxUdI9hanF0TBPvvvr1HjG6O", + ) + client._call_rpc = lambda *args, **kwargs: [] # type: ignore[method-assign] + + result = client.select_from_eav_id( + table_id="7a80f3a2-6fa1-43ce-8483-76bd00dc93c6", + ) + + assert result == [] diff --git a/tests/test_modeling.py b/tests/test_modeling.py index 8eea9bd..4924c4e 100644 --- a/tests/test_modeling.py +++ b/tests/test_modeling.py @@ -10,21 +10,24 @@ from cvec.cvec import CVec +def mock_fetch_config_side_effect(instance: CVec) -> str: + """Side effect for _fetch_config mock that sets tenant_id.""" + instance._tenant_id = 1 + return "test_publishable_key" + + class TestModelingMethods: """Test the modeling methods in the CVec class.""" - @patch("cvec.cvec.CVec._fetch_publishable_key") - @patch("cvec.cvec.CVec._login_with_supabase") - @patch("cvec.cvec.CVec._make_request") + @patch.object( + CVec, "_fetch_config", autospec=True, side_effect=mock_fetch_config_side_effect + ) + @patch.object(CVec, "_login_with_supabase", return_value=None) + @patch.object(CVec, "_make_request") def test_get_modeling_metrics( self, mock_make_request: Mock, mock_login: Mock, mock_fetch_key: Mock ) -> None: """Test get_modeling_metrics method.""" - # Mock the publishable key fetch - mock_fetch_key.return_value = "test_publishable_key" - - # Mock the login method - mock_login.return_value = None # Mock the response mock_response = [ @@ -64,18 +67,15 @@ def test_get_modeling_metrics( assert call_args[1]["params"]["start_at"] == "2024-01-01T12:00:00" assert call_args[1]["params"]["end_at"] == "2024-01-01T13:00:00" - @patch("cvec.cvec.CVec._fetch_publishable_key") - @patch("cvec.cvec.CVec._login_with_supabase") - @patch("cvec.cvec.CVec._make_request") + @patch.object( + CVec, "_fetch_config", autospec=True, side_effect=mock_fetch_config_side_effect + ) + @patch.object(CVec, "_login_with_supabase", return_value=None) + @patch.object(CVec, "_make_request") def test_get_modeling_metrics_data( self, mock_make_request: Mock, mock_login: Mock, mock_fetch_key: Mock ) -> None: """Test get_modeling_metrics_data method.""" - # Mock the publishable key fetch - mock_fetch_key.return_value = "test_publishable_key" - - # Mock the login method - mock_login.return_value = None # Mock the response mock_response = [ @@ -117,18 +117,15 @@ def test_get_modeling_metrics_data( assert call_args[1]["params"]["start_at"] == "2024-01-01T12:00:00" assert call_args[1]["params"]["end_at"] == "2024-01-01T13:00:00" - @patch("cvec.cvec.CVec._fetch_publishable_key") - @patch("cvec.cvec.CVec._login_with_supabase") - @patch("cvec.cvec.CVec._make_request") + @patch.object( + CVec, "_fetch_config", autospec=True, side_effect=mock_fetch_config_side_effect + ) + @patch.object(CVec, "_login_with_supabase", return_value=None) + @patch.object(CVec, "_make_request") def test_get_modeling_metrics_data_arrow( self, mock_make_request: Mock, mock_login: Mock, mock_fetch_key: Mock ) -> None: """Test get_modeling_metrics_data_arrow method.""" - # Mock the publishable key fetch - mock_fetch_key.return_value = "test_publishable_key" - - # Mock the login method - mock_login.return_value = None # Mock the response (Arrow data as bytes) mock_response = b"fake_arrow_data" diff --git a/tests/test_token_refresh.py b/tests/test_token_refresh.py index 890c06b..c902678 100644 --- a/tests/test_token_refresh.py +++ b/tests/test_token_refresh.py @@ -9,11 +9,19 @@ from cvec import CVec +def mock_fetch_config_side_effect(instance: CVec) -> str: + """Side effect for _fetch_config mock that sets tenant_id.""" + instance._tenant_id = 1 + return "test_publishable_key" + + class TestTokenRefresh: """Test cases for automatic token refresh functionality.""" @patch.object(CVec, "_login_with_supabase", return_value=None) - @patch.object(CVec, "_fetch_publishable_key", return_value="test_publishable_key") + @patch.object( + CVec, "_fetch_config", autospec=True, side_effect=mock_fetch_config_side_effect + ) @patch("cvec.cvec.urlopen") def test_token_refresh_on_401( self, @@ -67,7 +75,9 @@ def mock_refresh() -> None: assert result == [] @patch.object(CVec, "_login_with_supabase", return_value=None) - @patch.object(CVec, "_fetch_publishable_key", return_value="test_publishable_key") + @patch.object( + CVec, "_fetch_config", autospec=True, side_effect=mock_fetch_config_side_effect + ) @patch("cvec.cvec.urlopen") def test_token_refresh_handles_network_errors_gracefully( self, @@ -110,7 +120,9 @@ def mock_refresh_with_error() -> None: assert exc_info.value.code == 401 @patch.object(CVec, "_login_with_supabase", return_value=None) - @patch.object(CVec, "_fetch_publishable_key", return_value="test_publishable_key") + @patch.object( + CVec, "_fetch_config", autospec=True, side_effect=mock_fetch_config_side_effect + ) @patch("cvec.cvec.urlopen") def test_token_refresh_handles_missing_refresh_token( self,