Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ dependencies = [
"pandas>=1.5.0,<3.0.0",
"openai>=1.2.0,<2.0.0",
"tqdm>=4.60.0,<5.0.0",
"google-cloud-aiplatform[pipelines]>=1.27.0,<2.0.0"
"google-cloud-aiplatform[pipelines]>=1.27.0,<2.0.0",
"anthropic>=0.83.0",
]
name = "scikit-llm"
version = "1.4.3"
Expand Down
3 changes: 3 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,6 @@ docformatter
interrogate
numpy
pandas
pytest>=8.4.2
pytest-mock>=3.15.1
python-dotenv>=1.2.1
3 changes: 2 additions & 1 deletion skllm/model_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@
ANTHROPIC_CLAUDE_MODEL = "claude-3-haiku-20240307"

# Vertex AI models
VERTEX_DEFAULT_MODEL = "text-bison@002"
TEXT_BISON_MODEL = "text-bison@002"
VERTEX_GEMINI_MODEL = "gemini-2.5-flash"
6 changes: 3 additions & 3 deletions skllm/models/vertex/classification/zero_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
MultiLabelMixin as _MultiLabelMixin,
)
from typing import Optional
from skllm.model_constants import TEXT_BISON_MODEL
from skllm.model_constants import TEXT_BISON_MODEL, VERTEX_GEMINI_MODEL


class ZeroShotVertexClassifier(
_BaseZeroShotClassifier, _SingleLabelMixin, _VertexClassifierMixin
):
def __init__(
self,
model: str = TEXT_BISON_MODEL,
model: str = VERTEX_GEMINI_MODEL,
default_label: str = "Random",
prompt_template: Optional[str] = None,
**kwargs,
Expand Down Expand Up @@ -43,7 +43,7 @@ class MultiLabelZeroShotVertexClassifier(
):
def __init__(
self,
model: str = TEXT_BISON_MODEL,
model: str = VERTEX_GEMINI_MODEL,
default_label: str = "Random",
prompt_template: Optional[str] = None,
max_labels: Optional[int] = 5,
Expand Down
60 changes: 60 additions & 0 deletions tests/llm/vertex/test_vertex_live.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import unittest
import os
from skllm.models.vertex.classification.zero_shot import (
ZeroShotVertexClassifier,
MultiLabelZeroShotVertexClassifier
)
from skllm.config import SKLLMConfig

try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
pass

# Run this with:
# SKLLM_RUN_LIVE_TESTS=True GOOGLE_CLOUD_PROJECT=your-project uv run pytest tests/llm/vertex/test_vertex_live.py
# Or use a .env file in project root

@unittest.skipIf(os.environ.get("SKLLM_RUN_LIVE_TESTS") != "True", "Skipping live API test")
class TestVertexLive(unittest.TestCase):
"""
Live tests for Vertex AI Gemini.
"""

def setUp(self):
project = os.environ.get("GOOGLE_CLOUD_PROJECT")
if project:
SKLLMConfig.set_google_project(project)

def test_zero_shot_predict_live(self):
"""Test single-label zero-shot classification with real Gemini API."""
X = ["This is a fantastic product!"]
y = ["positive", "negative"]

clf = ZeroShotVertexClassifier() # Uses default Gemini 2.5 flash
clf.fit(None, y)

labels = clf.predict(X)
self.assertEqual(len(labels), 1)
self.assertIn(labels[0], y)
print(f"\n[Live Test] Single-label prediction: {labels[0]}")

def test_multi_label_predict_live(self):
"""Test multi-label zero-shot classification with real Gemini API."""
X = ["The new smartphone has a great camera and long battery life."]
y = ["camera", "battery", "display", "price"]

# We expect at least 'camera' and 'battery'
clf = MultiLabelZeroShotVertexClassifier(max_labels=2)
clf.fit(None, y)

labels = clf.predict(X)
self.assertEqual(len(labels), 1)
# The mixin returns a list padded to max_labels
self.assertIn("camera", labels[0])
self.assertIn("battery", labels[0])
print(f"\n[Live Test] Multi-label prediction: {labels[0]}")

if __name__ == "__main__":
unittest.main()
56 changes: 56 additions & 0 deletions tests/llm/vertex/test_vertex_zero_shot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import unittest
from unittest.mock import patch
from skllm.models.vertex.classification.zero_shot import (
ZeroShotVertexClassifier,
MultiLabelZeroShotVertexClassifier
)
from skllm.model_constants import VERTEX_GEMINI_MODEL

class TestZeroShotVertexClassifier(unittest.TestCase):

def test_initialization_default(self):
"""Test if the classifier initializes with the new default Gemini model."""
clf = ZeroShotVertexClassifier()
self.assertEqual(clf.model, VERTEX_GEMINI_MODEL)

@patch("skllm.llm.vertex.mixin.get_completion_chat_gemini")
def test_single_label_predict(self, mock_gemini):
"""Test single-label fit and predict with mocked Gemini completion."""
mock_gemini.return_value = '{"label": "positive"}'

X = ["I love this!", "This is bad."]
y = ["positive", "negative"]

clf = ZeroShotVertexClassifier()
clf.fit(X, y)

predictions = clf.predict(["I am happy"])

self.assertEqual(predictions[0], "positive")
mock_gemini.assert_called()

# Verify it uses the correct model
args, _ = mock_gemini.call_args
self.assertEqual(args[0], VERTEX_GEMINI_MODEL)

@patch("skllm.llm.vertex.mixin.get_completion_chat_gemini")
def test_multi_label_predict(self, mock_gemini):
"""Test multi-label fit and predict with mocked Gemini completion."""
mock_gemini.return_value = '{"label": ["tech", "science"]}'

X = ["New discovery in AI.", "Space exploration."]
y = [["tech", "science"], ["science"]]

clf = MultiLabelZeroShotVertexClassifier()
clf.fit(X, y)

predictions = clf.predict(["AI in space"])

# The mixin pads the result to max_labels (default 5) with empty strings
self.assertIn("tech", predictions[0])
self.assertIn("science", predictions[0])
self.assertEqual(len(predictions[0]), 5)
mock_gemini.assert_called()

if __name__ == "__main__":
unittest.main()