From 5a37f92b498997442496635ba80a35713c77b856 Mon Sep 17 00:00:00 2001 From: Gerard Simons Date: Sun, 22 Feb 2026 10:50:26 +1300 Subject: [PATCH] fix(vertex): resolve ZeroShotVertexClassifier ImportError and update default to gemini-2.5-flash - Fixed ImportError by re-adding missing TEXT_BISON_MODEL to model_constants.py - Updated ZeroShotVertexClassifier and MultiLabelZeroShotVertexClassifier to use gemini-2.5-flash as the default model - Added mocked unit tests for Vertex zero-shot classifiers - Added live integration test with optional .env support for Vertex AI - Updated dev dependencies with pytest and python-dotenv --- pyproject.toml | 3 +- requirements-dev.txt | 3 + skllm/model_constants.py | 3 +- .../models/vertex/classification/zero_shot.py | 6 +- tests/llm/vertex/test_vertex_live.py | 60 +++++++++++++++++++ tests/llm/vertex/test_vertex_zero_shot.py | 56 +++++++++++++++++ 6 files changed, 126 insertions(+), 5 deletions(-) create mode 100644 tests/llm/vertex/test_vertex_live.py create mode 100644 tests/llm/vertex/test_vertex_zero_shot.py diff --git a/pyproject.toml b/pyproject.toml index e3dedc7..a4151d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,8 @@ dependencies = [ "pandas>=1.5.0,<3.0.0", "openai>=1.2.0,<2.0.0", "tqdm>=4.60.0,<5.0.0", - "google-cloud-aiplatform[pipelines]>=1.27.0,<2.0.0" + "google-cloud-aiplatform[pipelines]>=1.27.0,<2.0.0", + "anthropic>=0.83.0", ] name = "scikit-llm" version = "1.4.3" diff --git a/requirements-dev.txt b/requirements-dev.txt index c2de4e7..aaa26ab 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -6,3 +6,6 @@ docformatter interrogate numpy pandas +pytest>=8.4.2 +pytest-mock>=3.15.1 +python-dotenv>=1.2.1 diff --git a/skllm/model_constants.py b/skllm/model_constants.py index 91c408e..506d2e5 100644 --- a/skllm/model_constants.py +++ b/skllm/model_constants.py @@ -7,4 +7,5 @@ ANTHROPIC_CLAUDE_MODEL = "claude-3-haiku-20240307" # Vertex AI models -VERTEX_DEFAULT_MODEL = "text-bison@002" +TEXT_BISON_MODEL = "text-bison@002" +VERTEX_GEMINI_MODEL = "gemini-2.5-flash" diff --git a/skllm/models/vertex/classification/zero_shot.py b/skllm/models/vertex/classification/zero_shot.py index 59843c8..c44f3d4 100644 --- a/skllm/models/vertex/classification/zero_shot.py +++ b/skllm/models/vertex/classification/zero_shot.py @@ -5,7 +5,7 @@ MultiLabelMixin as _MultiLabelMixin, ) from typing import Optional -from skllm.model_constants import TEXT_BISON_MODEL +from skllm.model_constants import TEXT_BISON_MODEL, VERTEX_GEMINI_MODEL class ZeroShotVertexClassifier( @@ -13,7 +13,7 @@ class ZeroShotVertexClassifier( ): def __init__( self, - model: str = TEXT_BISON_MODEL, + model: str = VERTEX_GEMINI_MODEL, default_label: str = "Random", prompt_template: Optional[str] = None, **kwargs, @@ -43,7 +43,7 @@ class MultiLabelZeroShotVertexClassifier( ): def __init__( self, - model: str = TEXT_BISON_MODEL, + model: str = VERTEX_GEMINI_MODEL, default_label: str = "Random", prompt_template: Optional[str] = None, max_labels: Optional[int] = 5, diff --git a/tests/llm/vertex/test_vertex_live.py b/tests/llm/vertex/test_vertex_live.py new file mode 100644 index 0000000..96ac50e --- /dev/null +++ b/tests/llm/vertex/test_vertex_live.py @@ -0,0 +1,60 @@ +import unittest +import os +from skllm.models.vertex.classification.zero_shot import ( + ZeroShotVertexClassifier, + MultiLabelZeroShotVertexClassifier +) +from skllm.config import SKLLMConfig + +try: + from dotenv import load_dotenv + load_dotenv() +except ImportError: + pass + +# Run this with: +# SKLLM_RUN_LIVE_TESTS=True GOOGLE_CLOUD_PROJECT=your-project uv run pytest tests/llm/vertex/test_vertex_live.py +# Or use a .env file in project root + +@unittest.skipIf(os.environ.get("SKLLM_RUN_LIVE_TESTS") != "True", "Skipping live API test") +class TestVertexLive(unittest.TestCase): + """ + Live tests for Vertex AI Gemini. + """ + + def setUp(self): + project = os.environ.get("GOOGLE_CLOUD_PROJECT") + if project: + SKLLMConfig.set_google_project(project) + + def test_zero_shot_predict_live(self): + """Test single-label zero-shot classification with real Gemini API.""" + X = ["This is a fantastic product!"] + y = ["positive", "negative"] + + clf = ZeroShotVertexClassifier() # Uses default Gemini 2.5 flash + clf.fit(None, y) + + labels = clf.predict(X) + self.assertEqual(len(labels), 1) + self.assertIn(labels[0], y) + print(f"\n[Live Test] Single-label prediction: {labels[0]}") + + def test_multi_label_predict_live(self): + """Test multi-label zero-shot classification with real Gemini API.""" + X = ["The new smartphone has a great camera and long battery life."] + y = ["camera", "battery", "display", "price"] + + # We expect at least 'camera' and 'battery' + clf = MultiLabelZeroShotVertexClassifier(max_labels=2) + clf.fit(None, y) + + labels = clf.predict(X) + self.assertEqual(len(labels), 1) + # The mixin returns a list padded to max_labels + self.assertIn("camera", labels[0]) + self.assertIn("battery", labels[0]) + print(f"\n[Live Test] Multi-label prediction: {labels[0]}") + +if __name__ == "__main__": + unittest.main() diff --git a/tests/llm/vertex/test_vertex_zero_shot.py b/tests/llm/vertex/test_vertex_zero_shot.py new file mode 100644 index 0000000..134b3da --- /dev/null +++ b/tests/llm/vertex/test_vertex_zero_shot.py @@ -0,0 +1,56 @@ +import unittest +from unittest.mock import patch +from skllm.models.vertex.classification.zero_shot import ( + ZeroShotVertexClassifier, + MultiLabelZeroShotVertexClassifier +) +from skllm.model_constants import VERTEX_GEMINI_MODEL + +class TestZeroShotVertexClassifier(unittest.TestCase): + + def test_initialization_default(self): + """Test if the classifier initializes with the new default Gemini model.""" + clf = ZeroShotVertexClassifier() + self.assertEqual(clf.model, VERTEX_GEMINI_MODEL) + + @patch("skllm.llm.vertex.mixin.get_completion_chat_gemini") + def test_single_label_predict(self, mock_gemini): + """Test single-label fit and predict with mocked Gemini completion.""" + mock_gemini.return_value = '{"label": "positive"}' + + X = ["I love this!", "This is bad."] + y = ["positive", "negative"] + + clf = ZeroShotVertexClassifier() + clf.fit(X, y) + + predictions = clf.predict(["I am happy"]) + + self.assertEqual(predictions[0], "positive") + mock_gemini.assert_called() + + # Verify it uses the correct model + args, _ = mock_gemini.call_args + self.assertEqual(args[0], VERTEX_GEMINI_MODEL) + + @patch("skllm.llm.vertex.mixin.get_completion_chat_gemini") + def test_multi_label_predict(self, mock_gemini): + """Test multi-label fit and predict with mocked Gemini completion.""" + mock_gemini.return_value = '{"label": ["tech", "science"]}' + + X = ["New discovery in AI.", "Space exploration."] + y = [["tech", "science"], ["science"]] + + clf = MultiLabelZeroShotVertexClassifier() + clf.fit(X, y) + + predictions = clf.predict(["AI in space"]) + + # The mixin pads the result to max_labels (default 5) with empty strings + self.assertIn("tech", predictions[0]) + self.assertIn("science", predictions[0]) + self.assertEqual(len(predictions[0]), 5) + mock_gemini.assert_called() + +if __name__ == "__main__": + unittest.main()