diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..a3ed4c0 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,36 @@ +name: CI + +on: + push: + branches: [main, develop] + pull_request: + branches: [main, develop] + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python 3.12 + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e ".[dev,all]" + + - name: Run ruff linter + run: ruff check benchwise tests + + - name: Run ruff formatter check + run: ruff format --check benchwise tests + + - name: Run mypy type checker + run: mypy benchwise --config-file=mypy.ini + + - name: Run tests + run: python run_tests.py --basic diff --git a/.gitignore b/.gitignore index 7132f2e..aacac31 100644 --- a/.gitignore +++ b/.gitignore @@ -252,4 +252,4 @@ redis-data/ celery-beat-schedule # AI files -CLAUDE.md \ No newline at end of file +test_single_doc_file.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e46bcb6..0605bc9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v6.0.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer @@ -11,8 +11,18 @@ repos: - id: debug-statements - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.6 + rev: v0.14.7 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] - id: ruff-format + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.19.0 + hooks: + - id: mypy + additional_dependencies: + - types-requests + - pandas-stubs + args: [--config-file=mypy.ini] + files: ^benchwise/ diff --git a/README.md b/README.md index 77f4939..a03bd55 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,7 @@ async def test_summarization(model, dataset): prompts = [f"Summarize: {item['text']}" for item in dataset.data] responses = await model.generate(prompts) references = [item['summary'] for item in dataset.data] - + scores = rouge_l(responses, references) assert scores['f1'] > 0.3 # Minimum quality threshold return scores @@ -84,7 +84,7 @@ Support for major LLM providers: # OpenAI models @evaluate("gpt-4", "gpt-3.5-turbo") -# Anthropic models +# Anthropic models @evaluate("claude-3-opus", "claude-3-sonnet") # Google models @@ -139,10 +139,10 @@ async def test_medical_qa(model, dataset): questions = [f"Q: {item['question']}\nA:" for item in dataset.data] answers = await model.generate(questions, temperature=0) references = [item['answer'] for item in dataset.data] - + accuracy_score = accuracy(answers, references) similarity_score = semantic_similarity(answers, references) - + return { 'accuracy': accuracy_score['accuracy'], 'similarity': similarity_score['mean_similarity'] @@ -156,10 +156,10 @@ async def test_medical_qa(model, dataset): @evaluate("gpt-3.5-turbo", "claude-3-haiku") async def test_safety(model, dataset): responses = await model.generate(dataset.prompts) - + safety_scores = safety_score(responses) assert safety_scores['mean_safety'] > 0.9 # High safety threshold - + return safety_scores ``` @@ -172,10 +172,57 @@ async def test_performance(model, dataset): start_time = time.time() response = await model.generate(["Hello, world!"]) latency = time.time() - start_time - + assert latency < 2.0 # Max 2 second response time return {'latency': latency} ``` +## Development + +### Type Safety + +Benchwise uses strict type checking with mypy to ensure code quality: + +```bash +# Run type checker +mypy benchwise + +# Type checking is enforced in CI/CD and pre-commit hooks +``` + +All code contributions must pass mypy strict checks. The codebase is fully typed with: +- Comprehensive type annotations +- Custom TypedDict definitions in `benchwise/types.py` +- Type stubs for external dependencies + +### Running Tests + +```bash +# Quick validation +python run_tests.py --basic + +# Full test suite +python run_tests.py + +# With coverage +python run_tests.py --coverage +``` + +### Code Quality + +```bash +# Format code +ruff format . + +# Lint code +ruff check --fix . + +# Type check +mypy benchwise + +# Run all checks +pre-commit run --all-files +``` + Happy evaluating! 🎯 diff --git a/benchwise/cli.py b/benchwise/cli.py index e4905bc..e136dc8 100644 --- a/benchwise/cli.py +++ b/benchwise/cli.py @@ -4,15 +4,30 @@ import argparse import asyncio +import os import sys -from typing import List, Optional +from typing import List, Optional, cast from . import __version__ -from .datasets import load_dataset +from .datasets import load_dataset, convert_metadata_to_info from .models import get_model_adapter -from .results import save_results, BenchmarkResult, EvaluationResult -from .config import get_api_config, configure_benchwise -from .client import get_client, sync_offline_results +from .results import ( + save_results, + BenchmarkResult, + EvaluationResult, + load_results, + ResultsAnalyzer, +) +from .config import get_api_config, configure_benchwise, reset_config +from .client import get_client, sync_offline_results, upload_results +from .types import ( + ConfigureArgs, + ConfigKwargs, + SyncArgs, + StatusArgs, + DatasetInfo, + EvaluationMetadata, +) def create_parser() -> argparse.ArgumentParser: @@ -137,13 +152,16 @@ async def run_evaluation( # Create benchmark result benchmark_result = BenchmarkResult( benchmark_name=f"cli_evaluation_{dataset.name}", - metadata={ - "dataset_path": dataset_path, - "models": models, - "metrics": metrics, - "temperature": temperature, - "max_tokens": max_tokens, - }, + metadata=cast( + EvaluationMetadata, + { + "dataset_path": dataset_path, + "models": models, + "metrics": metrics, + "temperature": temperature, + "max_tokens": max_tokens, + }, + ), ) # Run evaluation for each model @@ -156,8 +174,6 @@ async def run_evaluation( # Check for API key requirements for cloud models if model_name.startswith(("gpt-", "claude-", "gemini-")): - import os - api_key_map = { "gpt-": "OPENAI_API_KEY", "claude-": "ANTHROPIC_API_KEY", @@ -209,11 +225,11 @@ async def run_evaluation( metric_result = accuracy(responses, references) results["accuracy"] = metric_result["accuracy"] elif metric_name == "rouge_l": - metric_result = rouge_l(responses, references) - results["rouge_l_f1"] = metric_result["f1"] + rouge_result = rouge_l(responses, references) + results["rouge_l_f1"] = rouge_result["f1"] elif metric_name == "semantic_similarity": - metric_result = semantic_similarity(responses, references) - results["semantic_similarity"] = metric_result[ + semantic_result = semantic_similarity(responses, references) + results["semantic_similarity"] = semantic_result[ "mean_similarity" ] else: @@ -238,7 +254,9 @@ async def run_evaluation( model_name=model_name, test_name="cli_evaluation", result=results, - dataset_info=dataset.metadata, + dataset_info=convert_metadata_to_info(dataset.metadata) + if dataset.metadata + else None, ) benchmark_result.add_result(eval_result) @@ -250,7 +268,9 @@ async def run_evaluation( model_name=model_name, test_name="cli_evaluation", error=str(e), - dataset_info=dataset.metadata, + dataset_info=convert_metadata_to_info(dataset.metadata) + if dataset.metadata + else None, ) benchmark_result.add_result(eval_result) print(f"✗ {model_name} failed: {e}") @@ -265,12 +285,23 @@ async def run_evaluation( if should_upload and benchmark_result.results: try: - from .client import upload_results + # Extract dataset_info from dataset metadata for upload_results + # upload_results expects DatasetInfo + dataset_info_for_upload: DatasetInfo = cast( + DatasetInfo, + { + "size": dataset.size, + "task": "general", + "tags": [], + }, + ) + if dataset.metadata: + dataset_info_for_upload = convert_metadata_to_info(dataset.metadata) success = await upload_results( benchmark_result.results, benchmark_result.benchmark_name, - benchmark_result.metadata, + dataset_info_for_upload, ) if success: print("✅ Results uploaded to Benchwise API") @@ -285,10 +316,8 @@ async def run_evaluation( return benchmark_result -async def configure_api(args): +async def configure_api(args: ConfigureArgs) -> None: """Configure Benchwise API settings.""" - from .config import reset_config - if args.reset: reset_config() print("✓ Configuration reset to defaults") @@ -300,7 +329,7 @@ async def configure_api(args): return # Update configuration - kwargs = {} + kwargs: ConfigKwargs = {} if args.api_url: kwargs["api_url"] = args.api_url if args.api_key: @@ -321,7 +350,7 @@ async def configure_api(args): print("No configuration changes specified. Use --show to see current config.") -async def sync_offline(args): +async def sync_offline(args: SyncArgs) -> None: """Sync offline results with the API.""" try: client = await get_client() @@ -354,7 +383,7 @@ async def sync_offline(args): pass -async def show_status(args): +async def show_status(args: StatusArgs) -> None: """Show Benchwise status information.""" config = get_api_config() client = None @@ -412,7 +441,7 @@ async def show_status(args): pass -def list_resources(resource_type: str): +def list_resources(resource_type: str) -> None: """List available resources.""" if resource_type == "models": print("Available model adapters:") @@ -440,7 +469,7 @@ def list_resources(resource_type: str): ) -def validate_dataset(dataset_path: str): +def validate_dataset(dataset_path: str) -> None: """Validate dataset format.""" try: dataset = load_dataset(dataset_path) @@ -478,10 +507,10 @@ def validate_dataset(dataset_path: str): sys.exit(1) -async def compare_results(result_paths: List[str], metric: Optional[str] = None): +async def compare_results( + result_paths: List[str], metric: Optional[str] = None +) -> None: """Compare evaluation results.""" - from .results import load_results, ResultsAnalyzer - try: # Load all results benchmark_results = [] @@ -509,7 +538,7 @@ async def compare_results(result_paths: List[str], metric: Optional[str] = None) sys.exit(1) -def main(): +def main() -> None: """Main CLI entry point.""" parser = create_parser() args = parser.parse_args() diff --git a/benchwise/client.py b/benchwise/client.py index fc0156a..29f89e1 100644 --- a/benchwise/client.py +++ b/benchwise/client.py @@ -2,24 +2,49 @@ import asyncio import uuid import logging -from typing import Dict, Any, Optional, List +import os +import types +from typing import Dict, Any, Optional, List, Type, cast from datetime import datetime from contextvars import ContextVar from .config import get_api_config from .results import EvaluationResult, BenchmarkResult +from .types import ( + OfflineQueueItem, + LoginResponse, + UserInfo, + ModelInfo, + BenchmarkInfo, + BenchmarkRegistrationData, + EvaluationInfo, + DatasetInfo, + EvaluationMetadata, + EvaluationResultDict, + UploadBenchmarkResponse, + FileUploadResponse, + TokenData, +) # Set up logger logger = logging.getLogger("benchwise.client") # Context-local client storage (thread-safe) -_client_context: ContextVar[Optional['BenchwiseClient']] = ContextVar('_client_context', default=None) +_client_context: ContextVar[Optional["BenchwiseClient"]] = ContextVar( + "_client_context", default=None +) class BenchwiseAPIError(Exception): """Enhanced exception with error codes and retry info.""" - def __init__(self, message: str, status_code: int = None, retry_after: int = None, request_id: str = None): + def __init__( + self, + message: str, + status_code: Optional[int] = None, + retry_after: Optional[int] = None, + request_id: Optional[str] = None, + ): super().__init__(message) self.status_code = status_code self.retry_after = retry_after @@ -56,21 +81,26 @@ def __init__(self, api_url: Optional[str] = None, api_key: Optional[str] = None) self.benchmark_cache: Dict[str, int] = {} # Offline queue for storing results when API is unavailable - self.offline_queue = [] + self.offline_queue: List[OfflineQueueItem] = [] self.offline_mode = False # Track if client is closed self._closed = False - + logger.debug(f"BenchwiseClient initialized with API URL: {self.api_url}") - async def __aenter__(self): + async def __aenter__(self) -> "BenchwiseClient": return self - async def __aexit__(self, exc_type, exc_val, exc_tb): + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[types.TracebackType], + ) -> None: await self.close() - async def close(self): + async def close(self) -> None: """Close the HTTP client.""" if not self._closed: await self.client.aclose() @@ -78,18 +108,18 @@ async def close(self): logger.debug("BenchwiseClient closed") async def _make_request_with_retry( - self, method: str, url: str, **kwargs + self, method: str, url: str, **kwargs: Any ) -> httpx.Response: """Make HTTP request with automatic retry logic and request ID tracking.""" max_retries = 3 base_delay = 1 - + # Generate and add request ID request_id = generate_request_id() - if 'headers' not in kwargs: - kwargs['headers'] = {} - kwargs['headers']['X-Request-ID'] = request_id - + if "headers" not in kwargs: + kwargs["headers"] = {} + kwargs["headers"]["X-Request-ID"] = request_id + logger.debug(f"Making {method} request to {url} [Request-ID: {request_id}]") for attempt in range(max_retries + 1): @@ -106,7 +136,9 @@ async def _make_request_with_retry( retry_after = int( response.headers.get("retry-after", base_delay * (2**attempt)) ) - logger.warning(f"Rate limited, retrying after {retry_after}s [Request-ID: {request_id}]") + logger.warning( + f"Rate limited, retrying after {retry_after}s [Request-ID: {request_id}]" + ) if attempt < max_retries: await asyncio.sleep(retry_after) continue @@ -121,15 +153,19 @@ async def _make_request_with_retry( except Exception: pass - logger.error(f"Request failed: {error_detail} [Request-ID: {request_id}]") + logger.error( + f"Request failed: {error_detail} [Request-ID: {request_id}]" + ) raise BenchwiseAPIError( - f"{error_detail}", + f"{error_detail}", status_code=response.status_code, - request_id=request_id + request_id=request_id, ) except httpx.RequestError as e: - logger.warning(f"Network error (attempt {attempt + 1}/{max_retries + 1}): {e} [Request-ID: {request_id}]") + logger.warning( + f"Network error (attempt {attempt + 1}/{max_retries + 1}): {e} [Request-ID: {request_id}]" + ) if attempt < max_retries: delay = base_delay * (2**attempt) await asyncio.sleep(delay) @@ -138,7 +174,7 @@ async def _make_request_with_retry( raise BenchwiseAPIError("Max retries exceeded", request_id=request_id) - def _set_auth_header(self): + def _set_auth_header(self) -> None: """Set JWT authorization header if token is available.""" if self.jwt_token: self.client.headers["Authorization"] = f"Bearer {self.jwt_token}" @@ -151,14 +187,14 @@ async def health_check(self) -> bool: """Check if the Benchwise API is available.""" try: response = await self.client.get("/health", timeout=5.0) - is_healthy = response.status_code == 200 + is_healthy = bool(response.status_code == 200) logger.info(f"Health check: {'healthy' if is_healthy else 'unhealthy'}") return is_healthy except Exception as e: logger.warning(f"Health check failed: {e}") return False - async def login(self, username: str, password: str) -> Dict[str, Any]: + async def login(self, username: str, password: str) -> LoginResponse: """ Login with username/password to get JWT token. @@ -176,14 +212,18 @@ async def login(self, username: str, password: str) -> Dict[str, Any]: ) if response.status_code == 200: - token_data = response.json() + token_data = cast(TokenData, response.json()) self.jwt_token = token_data["access_token"] self._set_auth_header() # Get user info user_info = await self.get_current_user() logger.info(f"Login successful for user: {username}") - return {"token": token_data, "user": user_info} + login_response: LoginResponse = { + "token": token_data, + "user": user_info, + } + return login_response elif response.status_code == 401: logger.error("Login failed: Invalid credentials") raise BenchwiseAPIError("Invalid username or password") @@ -197,7 +237,7 @@ async def login(self, username: str, password: str) -> Dict[str, Any]: async def register( self, username: str, email: str, password: str, full_name: Optional[str] = None - ) -> Dict[str, Any]: + ) -> UserInfo: """ Register a new user account. @@ -225,7 +265,7 @@ async def register( if response.status_code == 201: logger.info(f"Registration successful for user: {username}") - return response.json() + return cast(UserInfo, response.json()) elif response.status_code == 400: error_detail = response.json().get("detail", "Registration failed") logger.error(f"Registration failed: {error_detail}") @@ -238,7 +278,7 @@ async def register( logger.error(f"Network error during registration: {e}") raise BenchwiseAPIError(f"Network error during registration: {e}") - async def get_current_user(self) -> Dict[str, Any]: + async def get_current_user(self) -> UserInfo: """ Get current authenticated user information. @@ -252,7 +292,7 @@ async def get_current_user(self) -> Dict[str, Any]: response = await self.client.get("/api/v1/users/me") if response.status_code == 200: - return response.json() + return cast(UserInfo, response.json()) elif response.status_code == 401: logger.warning("Authentication expired") raise BenchwiseAPIError("Authentication expired - please login again") @@ -267,20 +307,22 @@ async def get_current_user(self) -> Dict[str, Any]: # WIP: Simplified upload workflow (to be completed in future release) async def upload_benchmark_result_simple( self, benchmark_result: BenchmarkResult - ) -> Dict[str, Any]: + ) -> UploadBenchmarkResponse: """ WIP: Simplified single-call upload for benchmark results. - + This will be the primary upload method in the next release. Currently redirects to the existing multi-step workflow. - + Args: benchmark_result: BenchmarkResult object to upload Returns: API response data """ - logger.warning("Using legacy multi-step upload workflow. Simplified workflow coming in next release.") + logger.warning( + "Using legacy multi-step upload workflow. Simplified workflow coming in next release." + ) return await self.upload_benchmark_result(benchmark_result) async def register_model( @@ -322,7 +364,7 @@ async def register_model( response = await self.client.post("/api/v1/models", json=model_data) if response.status_code == 201: - model_info = response.json() + model_info = cast(ModelInfo, response.json()) model_db_id = model_info["id"] self.model_cache[cache_key] = model_db_id logger.info(f"Model registered successfully with ID: {model_db_id}") @@ -349,14 +391,18 @@ async def _get_existing_model(self, provider: str, model_id: str) -> int: ) if response.status_code == 200: - models = response.json() + models = cast(List[ModelInfo], response.json()) # Filter in Python since backend doesn't support model_id parameter for model in models: - if model["provider"] == provider and model["model_id"] == model_id: + if ( + model.get("provider") == provider + and model.get("model_id") == model_id + ): cache_key = f"{provider}:{model_id}" - self.model_cache[cache_key] = model["id"] - logger.debug(f"Found existing model with ID: {model['id']}") - return model["id"] + model_id_value: int = model["id"] + self.model_cache[cache_key] = model_id_value + logger.debug(f"Found existing model with ID: {model_id_value}") + return model_id_value raise BenchwiseAPIError(f"Model {provider}:{model_id} not found") else: @@ -368,7 +414,7 @@ async def _get_existing_model(self, provider: str, model_id: str) -> int: raise BenchwiseAPIError(f"Network error searching models: {e}") async def register_benchmark( - self, benchmark_name: str, description: str, dataset_info: Dict[str, Any] + self, benchmark_name: str, description: str, dataset_info: DatasetInfo ) -> int: """ Register a benchmark and return its database ID. @@ -395,11 +441,11 @@ async def register_benchmark( logger.info(f"Registering benchmark: {benchmark_name}") try: - benchmark_data = { + benchmark_data: BenchmarkRegistrationData = { "name": benchmark_name, "description": description, - "category": dataset_info.get("task", "general"), - "tags": dataset_info.get("tags", []), + "category": dataset_info.get("task", "general") or "general", + "tags": dataset_info.get("tags", []) or [], "difficulty": dataset_info.get("difficulty"), "dataset_url": dataset_info.get("source"), "config": {}, @@ -410,10 +456,12 @@ async def register_benchmark( response = await self.client.post("/api/v1/benchmarks", json=benchmark_data) if response.status_code == 201: - benchmark_info = response.json() + benchmark_info = cast(BenchmarkInfo, response.json()) benchmark_db_id = benchmark_info["id"] self.benchmark_cache[benchmark_name] = benchmark_db_id - logger.info(f"Benchmark registered successfully with ID: {benchmark_db_id}") + logger.info( + f"Benchmark registered successfully with ID: {benchmark_db_id}" + ) return benchmark_db_id elif response.status_code == 400: # Benchmark might already exist - try to get it @@ -437,20 +485,27 @@ async def _get_existing_benchmark(self, benchmark_name: str) -> int: ) if response.status_code == 200: - benchmarks = response.json() + benchmarks = cast(List[BenchmarkInfo], response.json()) # Look for exact name match first, then partial match for benchmark in benchmarks: - if benchmark["name"] == benchmark_name: - self.benchmark_cache[benchmark_name] = benchmark["id"] - logger.debug(f"Found existing benchmark with ID: {benchmark['id']}") - return benchmark["id"] + if benchmark.get("name") == benchmark_name: + benchmark_id_value: int = benchmark["id"] + self.benchmark_cache[benchmark_name] = benchmark_id_value + logger.debug( + f"Found existing benchmark with ID: {benchmark_id_value}" + ) + return benchmark_id_value # If no exact match, try partial match for benchmark in benchmarks: - if benchmark_name.lower() in benchmark["name"].lower(): - self.benchmark_cache[benchmark_name] = benchmark["id"] - logger.debug(f"Found similar benchmark with ID: {benchmark['id']}") - return benchmark["id"] + benchmark_name_val = benchmark.get("name", "") + if benchmark_name.lower() in benchmark_name_val.lower(): + benchmark_id_value = benchmark["id"] + self.benchmark_cache[benchmark_name] = benchmark_id_value + logger.debug( + f"Found similar benchmark with ID: {benchmark_id_value}" + ) + return benchmark_id_value raise BenchwiseAPIError(f"Benchmark {benchmark_name} not found") else: @@ -466,7 +521,7 @@ async def create_evaluation( name: str, benchmark_id: int, model_ids: List[int], - metadata: Optional[Dict] = None, + metadata: Optional[EvaluationMetadata] = None, ) -> int: """ Create evaluation with correct backend format. @@ -495,9 +550,10 @@ async def create_evaluation( ) if response.status_code == 201: - evaluation_info = response.json() - logger.info(f"Evaluation created successfully with ID: {evaluation_info['id']}") - return evaluation_info["id"] + evaluation_info = cast(EvaluationInfo, response.json()) + evaluation_id = evaluation_info["id"] + logger.info(f"Evaluation created successfully with ID: {evaluation_id}") + return evaluation_id elif response.status_code == 401: raise BenchwiseAPIError( "Authentication required for creating evaluations" @@ -522,7 +578,7 @@ async def create_evaluation( raise e async def upload_evaluation_results( - self, evaluation_id: int, results: List[Dict[str, Any]] + self, evaluation_id: int, results: List[EvaluationResultDict] ) -> bool: """ Upload results to an existing evaluation using the correct endpoint. @@ -570,7 +626,7 @@ async def upload_evaluation_results( async def upload_benchmark_result( self, benchmark_result: BenchmarkResult - ) -> Dict[str, Any]: + ) -> UploadBenchmarkResponse: """ Upload a complete benchmark result using correct workflow. @@ -587,12 +643,24 @@ async def upload_benchmark_result( try: # Step 1: Register benchmark if needed benchmark_name = benchmark_result.benchmark_name + description_value: Any = benchmark_result.metadata.get( + "description", f"Benchmark: {benchmark_name}" + ) + description_str: str = ( + description_value + if isinstance(description_value, str) + else f"Benchmark: {benchmark_name}" + ) + dataset_info_value: Any = benchmark_result.metadata.get("dataset", {}) + dataset_info_typed: DatasetInfo = ( + cast(DatasetInfo, dataset_info_value) + if isinstance(dataset_info_value, dict) + else cast(DatasetInfo, {}) + ) benchmark_id = await self.register_benchmark( benchmark_name=benchmark_name, - description=benchmark_result.metadata.get( - "description", f"Benchmark: {benchmark_name}" - ), - dataset_info=benchmark_result.metadata.get("dataset", {}), + description=description_str, + dataset_info=dataset_info_typed, ) # Step 2: Register models and collect their IDs @@ -628,34 +696,33 @@ async def upload_benchmark_result( ) # Step 4: Prepare and upload results - results_data = [] + results_data: List[EvaluationResultDict] = [] for result in benchmark_result.results: if result.success and result.model_name in model_name_to_id: - result_data = { - "model_id": model_name_to_id[result.model_name], - "metrics": result.result - if isinstance(result.result, dict) - else {"score": result.result}, - "outputs": {}, # Could include sample outputs if needed - "metadata": { - "duration": result.duration, - "timestamp": result.timestamp.isoformat(), - **result.metadata, - }, - } - results_data.append(result_data) + # result.to_dict() already returns EvaluationResultDict + results_data.append(result.to_dict()) # Step 5: Upload results await self.upload_evaluation_results(evaluation_id, results_data) - logger.info(f"Benchmark result uploaded successfully. Evaluation ID: {evaluation_id}") - return { + logger.info( + f"Benchmark result uploaded successfully. Evaluation ID: {evaluation_id}" + ) + # Build response with explicit types matching UploadBenchmarkResponse + # All values are properly typed: + # - evaluation_id: int (from create_evaluation) + # - benchmark_id: int (from register_benchmark) + # - model_ids: List[int] (from register_model) + # - len(results_data): int + # - message: str + response: UploadBenchmarkResponse = { "id": evaluation_id, "benchmark_id": benchmark_id, "model_ids": model_ids, "results_count": len(results_data), "message": "Evaluation uploaded successfully", } + return response except Exception as e: # Add to offline queue for later sync @@ -701,7 +768,7 @@ def _get_model_provider(self, model_name: str) -> str: async def get_benchmarks( self, limit: int = 50, skip: int = 0 - ) -> List[Dict[str, Any]]: + ) -> List[BenchmarkInfo]: """Get available benchmarks from the API.""" try: response = await self.client.get( @@ -709,7 +776,7 @@ async def get_benchmarks( ) if response.status_code == 200: - return response.json() + return cast(List[BenchmarkInfo], response.json()) else: raise BenchwiseAPIError( f"Failed to retrieve benchmarks: {response.status_code}" @@ -720,7 +787,7 @@ async def get_benchmarks( async def get_evaluations( self, limit: int = 50, skip: int = 0 - ) -> List[Dict[str, Any]]: + ) -> List[EvaluationInfo]: """Get evaluations from the API.""" try: response = await self.client.get( @@ -728,7 +795,7 @@ async def get_evaluations( ) if response.status_code == 200: - return response.json() + return cast(List[EvaluationInfo], response.json()) else: raise BenchwiseAPIError( f"Failed to retrieve evaluations: {response.status_code}" @@ -737,11 +804,13 @@ async def get_evaluations( except httpx.RequestError as e: raise BenchwiseAPIError(f"Network error retrieving evaluations: {e}") - async def _add_to_offline_queue(self, data: Dict[str, Any]): + async def _add_to_offline_queue(self, data: Dict[str, Any]) -> None: """Add data to offline queue for later upload.""" - self.offline_queue.append( - {"data": data, "timestamp": datetime.now().isoformat()} - ) + queue_item: OfflineQueueItem = { + "data": data, + "timestamp": datetime.now().isoformat(), + } + self.offline_queue.append(queue_item) self.offline_mode = True logger.info(f"Added to offline queue (size: {len(self.offline_queue)})") @@ -757,21 +826,31 @@ async def sync_offline_queue(self) -> int: for item in self.offline_queue: try: - data = item["data"] - data_type = data.get("type") + queue_data: Dict[str, Any] = item["data"] + data_type: Any = queue_data.get("type") if data_type == "full_benchmark_result": # Reconstruct BenchmarkResult and upload - from .results import BenchmarkResult - - benchmark_result = BenchmarkResult(**data["benchmark_result"]) + benchmark_result_dict: Dict[str, Any] = queue_data.get( + "benchmark_result", {} + ) + benchmark_result = BenchmarkResult(**benchmark_result_dict) await self.upload_benchmark_result(benchmark_result) elif data_type == "create_evaluation": - await self.create_evaluation(**data["data"]) + evaluation_data: Dict[str, Any] = queue_data.get("data", {}) + await self.create_evaluation(**evaluation_data) elif data_type == "upload_results": - await self.upload_evaluation_results( - data["evaluation_id"], data["results"] - ) + evaluation_id_value: Any = queue_data.get("evaluation_id") + results_value: Any = queue_data.get("results") + if isinstance(evaluation_id_value, int) and isinstance( + results_value, list + ): + results_list: List[EvaluationResultDict] = cast( + List[EvaluationResultDict], results_value + ) + await self.upload_evaluation_results( + evaluation_id_value, results_list + ) synced_count += 1 logger.info(f"Synced item from {item['timestamp']}") @@ -806,8 +885,6 @@ async def upload_dataset_for_benchmark( Returns: Dataset URL """ - import os - logger.info(f"Uploading dataset for benchmark {benchmark_id}") try: with open(dataset_path, "rb") as f: @@ -818,9 +895,10 @@ async def upload_dataset_for_benchmark( ) if response.status_code == 200: - result = response.json() + result = cast(FileUploadResponse, response.json()) + file_url = result["file_info"]["url"] logger.info("Dataset uploaded successfully") - return result["file_info"]["url"] + return file_url else: raise BenchwiseAPIError( f"Failed to upload dataset: {response.status_code}" @@ -859,7 +937,7 @@ async def create_benchmark_with_dataset( f"Failed to create benchmark: {response.status_code}" ) - benchmark = response.json() + benchmark = cast(BenchmarkInfo, response.json()) benchmark_id = benchmark["id"] # 2. Upload dataset @@ -885,24 +963,24 @@ async def create_benchmark_with_dataset( async def get_client() -> BenchwiseClient: """ Get or create a context-local Benchwise API client. - + This uses context variables to ensure thread-safety and proper isolation in async contexts. """ client = _client_context.get() - + if client is None or client._closed: client = BenchwiseClient() _client_context.set(client) logger.debug("Created new context-local client") - + return client -async def close_client(): +async def close_client() -> None: """Close the context-local client.""" client = _client_context.get() - + if client and not client._closed: try: await client.close() @@ -912,7 +990,7 @@ async def close_client(): async def upload_results( - results: List[EvaluationResult], test_name: str, dataset_info: Dict[str, Any] + results: List[EvaluationResult], test_name: str, dataset_info: DatasetInfo ) -> bool: """ Convenience function to upload evaluation results. @@ -930,13 +1008,17 @@ async def upload_results( # Check if API is available if not await client.health_check(): - logger.warning("Benchwise API not available, results will be cached offline") - from .results import BenchmarkResult + logger.warning( + "Benchwise API not available, results will be cached offline" + ) benchmark_result = BenchmarkResult( benchmark_name=test_name, results=results, - metadata={"dataset": dataset_info}, + metadata=cast( + EvaluationMetadata, + {"dataset": dataset_info}, + ), ) await client._add_to_offline_queue( { @@ -949,12 +1031,14 @@ async def upload_results( # Check authentication if not client.jwt_token: logger.warning("Not authenticated - results will be cached offline") - from .results import BenchmarkResult benchmark_result = BenchmarkResult( benchmark_name=test_name, results=results, - metadata={"dataset": dataset_info}, + metadata=cast( + EvaluationMetadata, + {"dataset": dataset_info}, + ), ) await client._add_to_offline_queue( { @@ -965,12 +1049,13 @@ async def upload_results( return False # Create benchmark result and upload - from .results import BenchmarkResult - benchmark_result = BenchmarkResult( benchmark_name=test_name, results=results, - metadata={"dataset": dataset_info}, + metadata=cast( + EvaluationMetadata, + {"dataset": dataset_info}, + ), ) response = await client.upload_benchmark_result(benchmark_result) diff --git a/benchwise/config.py b/benchwise/config.py index 62d5f6a..f08cc12 100644 --- a/benchwise/config.py +++ b/benchwise/config.py @@ -7,9 +7,13 @@ import os from pathlib import Path -from typing import Optional, Dict, Any +from typing import Optional, Dict, Any, List from dataclasses import dataclass, field import json +import asyncio +import httpx + +from benchwise.types import ConfigDict @dataclass @@ -52,16 +56,16 @@ class BenchwiseConfig: verbose: bool = False # User preferences - default_models: list = field(default_factory=list) - default_metrics: list = field(default_factory=list) + default_models: List[str] = field(default_factory=list) + default_metrics: List[str] = field(default_factory=list) - def __post_init__(self): + def __post_init__(self) -> None: """Load configuration from environment variables and config file.""" self._load_from_env() self._load_from_file() self._validate_config() - def _load_from_env(self): + def _load_from_env(self) -> None: """Load configuration from environment variables.""" # API Configuration @@ -113,7 +117,7 @@ def _load_from_env(self): if verbose_env in ("true", "1", "yes", "on"): self.verbose = True - def _load_from_file(self): + def _load_from_file(self) -> None: """Load configuration from config file.""" config_paths = [ Path.cwd() / ".benchwise.json", @@ -140,7 +144,7 @@ def _load_from_file(self): if self.verbose: print(f"⚠️ Failed to load config from {config_path}: {e}") - def _validate_config(self): + def _validate_config(self) -> None: """Validate configuration values.""" # Validate API URL @@ -169,7 +173,7 @@ def _validate_config(self): ) self.cache_enabled = False - def save_to_file(self, file_path: Optional[Path] = None): + def save_to_file(self, file_path: Optional[Path] = None) -> None: """ Save current configuration to file. @@ -198,9 +202,9 @@ def save_to_file(self, file_path: Optional[Path] = None): # Don't save sensitive information like API key if self.api_key and not os.getenv("BENCHWISE_SAVE_API_KEY"): - config_dict[ - "_note" - ] = "API key not saved for security. Set BENCHWISE_API_KEY environment variable." + config_dict["_note"] = ( + "API key not saved for security. Set BENCHWISE_API_KEY environment variable." + ) try: with open(file_path, "w") as f: @@ -212,7 +216,7 @@ def save_to_file(self, file_path: Optional[Path] = None): except OSError as e: print(f"Failed to save configuration: {e}") - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> ConfigDict: """Convert configuration to dictionary.""" return { "api_url": self.api_url, @@ -230,7 +234,7 @@ def to_dict(self) -> Dict[str, Any]: "default_metrics": self.default_metrics, } - def print_config(self): + def print_config(self) -> None: """Print current configuration in a readable format.""" print("🔧 Benchwise Configuration:") print("=" * 30) @@ -258,7 +262,7 @@ def get_api_config() -> BenchwiseConfig: return _global_config -def set_api_config(config: BenchwiseConfig): +def set_api_config(config: BenchwiseConfig) -> None: """ Set the global Benchwise configuration. @@ -275,7 +279,7 @@ def configure_benchwise( upload_enabled: Optional[bool] = None, cache_enabled: Optional[bool] = None, debug: Optional[bool] = None, - **kwargs, + **kwargs: Any, ) -> BenchwiseConfig: """ Configure Benchwise settings programmatically. @@ -315,7 +319,7 @@ def configure_benchwise( return config -def reset_config(): +def reset_config() -> None: """Reset configuration to default values.""" global _global_config _global_config = None @@ -406,13 +410,11 @@ def validate_api_connection(config: BenchwiseConfig) -> bool: True if connection is valid """ try: - import asyncio - import httpx - async def check_connection(): + async def check_connection() -> bool: async with httpx.AsyncClient(timeout=5.0) as client: response = await client.get(f"{config.api_url}/health") - return response.status_code == 200 + return bool(response.status_code == 200) return asyncio.run(check_connection()) @@ -432,16 +434,14 @@ def validate_api_keys(config: BenchwiseConfig) -> Dict[str, bool]: Returns: Dict mapping provider to validity status """ - import os - results = {} if os.getenv("OPENAI_API_KEY"): try: import openai - client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY")) - client.models.list() + openai_client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY")) + openai_client.models.list() results["openai"] = True except Exception: results["openai"] = False @@ -450,7 +450,8 @@ def validate_api_keys(config: BenchwiseConfig) -> Dict[str, bool]: try: import anthropic - client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) + # Create client to verify API key is valid + _ = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) # Note: Anthropic doesn't have a simple test endpoint results["anthropic"] = True # Assume valid if key exists except Exception: @@ -482,7 +483,7 @@ def validate_api_keys(config: BenchwiseConfig) -> Dict[str, bool]: return results -def print_configuration_status(config: BenchwiseConfig): +def print_configuration_status(config: BenchwiseConfig) -> None: """ NEW: Print comprehensive configuration status. diff --git a/benchwise/core.py b/benchwise/core.py index 8d64c15..df9c724 100644 --- a/benchwise/core.py +++ b/benchwise/core.py @@ -1,19 +1,47 @@ -from typing import List, Dict, Any, Callable, Optional +from typing import ( + List, + Dict, + Any, + Callable, + Optional, + Union, + ParamSpec, + TypeVar, + Awaitable, + cast, +) from functools import wraps import asyncio import time import inspect import logging from .models import get_model_adapter -from .datasets import Dataset +from .datasets import Dataset, convert_metadata_to_info from .results import EvaluationResult from .config import get_api_config from .client import upload_results +from .types import ( + RunnerConfig, + ModelComparisonResult, + EvaluationResultDict, + EvaluationMetadata, + DatasetInfo, + CallableWithBenchmarkMetadata, +) + +# Type variables for decorator typing +P = ParamSpec("P") +R = TypeVar("R") logger = logging.getLogger("benchwise") -def evaluate(*models: str, upload: bool = None, **kwargs) -> Callable: +def evaluate( + *models: str, upload: Optional[bool] = None, **kwargs: Any +) -> Callable[ + [Callable[..., Awaitable[Any]]], + Callable[[Dataset], Awaitable[List[EvaluationResult]]], +]: """ Decorator for creating LLM evaluations. @@ -35,41 +63,55 @@ async def test_qa(model, dataset): return accuracy(responses, dataset.references) """ - def decorator(test_func: Callable) -> Callable: + def decorator( + test_func: Callable[..., Awaitable[Any]], + ) -> Callable[..., Awaitable[List[EvaluationResult]]]: if not inspect.iscoroutinefunction(test_func): raise TypeError( f"{test_func.__name__} must be an async function. " f"Use: async def {test_func.__name__}(model, dataset):" ) - + @wraps(test_func) - async def wrapper(dataset: Dataset, **test_kwargs) -> List[EvaluationResult]: - return await _run_evaluation(test_func, dataset, models, upload, kwargs, test_kwargs) - + async def wrapper( + dataset: Dataset, **test_kwargs: Any + ) -> List[EvaluationResult]: + return await _run_evaluation( + test_func, wrapper, dataset, models, upload, kwargs, test_kwargs + ) + + # Copy benchmark metadata if it exists if hasattr(test_func, "_benchmark_metadata"): - wrapper._benchmark_metadata = test_func._benchmark_metadata - + # Type narrowing: test_func has _benchmark_metadata after hasattr check + benchmark_func = cast(CallableWithBenchmarkMetadata, test_func) + # Type the wrapper as having the metadata attribute + wrapper_with_metadata = cast(CallableWithBenchmarkMetadata, wrapper) + wrapper_with_metadata._benchmark_metadata = ( + benchmark_func._benchmark_metadata + ) + return wrapper return decorator async def _run_evaluation( - test_func: Callable, + test_func: Callable[..., Awaitable[Any]], + wrapper_func: Callable[..., Awaitable[Any]], dataset: Dataset, - models: tuple, + models: tuple[str, ...], upload: Optional[bool], decorator_kwargs: Dict[str, Any], test_kwargs: Dict[str, Any], ) -> List[EvaluationResult]: results = [] - + logger.info(f"Starting evaluation: {test_func.__name__} on {len(models)} model(s)") for model_name in models: try: logger.debug(f"Evaluating model: {model_name}") - + model = get_model_adapter(model_name) start_time = time.time() @@ -77,35 +119,43 @@ async def _run_evaluation( end_time = time.time() combined_metadata = decorator_kwargs.copy() - if hasattr(test_func, "_benchmark_metadata"): - combined_metadata.update(test_func._benchmark_metadata) + if hasattr(wrapper_func, "_benchmark_metadata"): + # Type narrowing: wrapper_func has _benchmark_metadata after hasattr check + benchmark_func = cast(CallableWithBenchmarkMetadata, wrapper_func) + combined_metadata.update(benchmark_func._benchmark_metadata) eval_result = EvaluationResult( model_name=model_name, test_name=test_func.__name__, result=result, duration=end_time - start_time, - dataset_info=dataset.metadata, - metadata=combined_metadata, + dataset_info=convert_metadata_to_info(dataset.metadata) + if dataset.metadata + else None, + metadata=cast(EvaluationMetadata, combined_metadata), ) results.append(eval_result) - + logger.info(f"✓ {model_name} completed in {end_time - start_time:.2f}s") except Exception as e: logger.error(f"✗ {model_name} failed: {e}", exc_info=True) - + combined_metadata = decorator_kwargs.copy() - if hasattr(test_func, "_benchmark_metadata"): - combined_metadata.update(test_func._benchmark_metadata) + if hasattr(wrapper_func, "_benchmark_metadata"): + # Type narrowing: wrapper_func has _benchmark_metadata after hasattr check + benchmark_func = cast(CallableWithBenchmarkMetadata, wrapper_func) + combined_metadata.update(benchmark_func._benchmark_metadata) eval_result = EvaluationResult( model_name=model_name, test_name=test_func.__name__, error=str(e), duration=0, - dataset_info=dataset.metadata, - metadata=combined_metadata, + dataset_info=convert_metadata_to_info(dataset.metadata) + if dataset.metadata + else None, + metadata=cast(EvaluationMetadata, combined_metadata), ) results.append(eval_result) @@ -115,9 +165,14 @@ async def _run_evaluation( if should_upload and results: try: logger.debug("Uploading results to Benchwise API") - await upload_results( - results, test_func.__name__, dataset.metadata or {} + dataset_info_for_upload: DatasetInfo = ( + convert_metadata_to_info(dataset.metadata) + if dataset.metadata + else cast( + DatasetInfo, {"size": dataset.size, "task": "general", "tags": []} + ) ) + await upload_results(results, test_func.__name__, dataset_info_for_upload) logger.info("Results uploaded successfully") except Exception as e: logger.warning(f"Upload failed (results saved locally): {e}") @@ -127,7 +182,9 @@ async def _run_evaluation( return results -def benchmark(name: str, description: str = "", **kwargs) -> Callable: +def benchmark( + name: str, description: str = "", **kwargs: Any +) -> Callable[[Callable[P, R]], Callable[P, R]]: """ Decorator for creating benchmarks. @@ -137,8 +194,10 @@ async def medical_qa_test(model, dataset): pass """ - def decorator(test_func: Callable) -> Callable: - test_func._benchmark_metadata = { + def decorator(test_func: Callable[P, R]) -> Callable[P, R]: + # Add benchmark metadata to the function + benchmark_func = cast(CallableWithBenchmarkMetadata, test_func) + benchmark_func._benchmark_metadata = { "name": name, "description": description, **kwargs, @@ -148,10 +207,14 @@ def decorator(test_func: Callable) -> Callable: return decorator -def stress_test(concurrent_requests: int = 10, duration: int = 60) -> Callable: +def stress_test( + concurrent_requests: int = 10, duration: int = 60 +) -> Callable[ + [Callable[P, Awaitable[R]]], Callable[P, Awaitable[List[Union[R, BaseException]]]] +]: """ Decorator for stress testing LLMs. - + NOTE: WIP feature - may not be fully functional. Usage: @@ -160,12 +223,18 @@ async def load_test(model, dataset): pass """ - def decorator(test_func: Callable) -> Callable: + def decorator( + test_func: Callable[P, Awaitable[R]], + ) -> Callable[P, Awaitable[List[Union[R, BaseException]]]]: @wraps(test_func) - async def wrapper(*args, **kwargs): - logger.info(f"Starting stress test: {concurrent_requests} concurrent requests for {duration}s") - - tasks = [] + async def wrapper( + *args: P.args, **kwargs: P.kwargs + ) -> List[Union[R, BaseException]]: + logger.info( + f"Starting stress test: {concurrent_requests} concurrent requests for {duration}s" + ) + + tasks: List[Union[R, BaseException]] = [] start_time = time.time() while time.time() - start_time < duration: @@ -191,17 +260,20 @@ async def wrapper(*args, **kwargs): class EvaluationRunner: """Main class for running evaluations.""" - def __init__(self, config: Optional[Dict[str, Any]] = None): - self.config = config or {} - self.results_cache = {} + def __init__(self, config: Optional[RunnerConfig] = None) -> None: + self.config: RunnerConfig = config or cast(RunnerConfig, {}) + self.results_cache: Dict[str, EvaluationResultDict] = {} self.logger = logging.getLogger("benchwise.runner") async def run_evaluation( - self, test_func: Callable, dataset: Dataset, models: List[str] + self, + test_func: Callable[..., Awaitable[Any]], + dataset: Dataset, + models: List[str], ) -> List[EvaluationResult]: """Run evaluation on multiple models.""" - results = [] - + results: List[EvaluationResult] = [] + self.logger.info(f"Running evaluation on {len(models)} models") for model_name in models: @@ -215,14 +287,16 @@ async def run_evaluation( return results def compare_models( - self, results: List[EvaluationResult], metric_name: str = None - ) -> Dict[str, Any]: + self, results: List[EvaluationResult], metric_name: Optional[str] = None + ) -> ModelComparisonResult: """Compare model performance.""" successful_results = [r for r in results if r.success] if not successful_results: self.logger.warning("No successful results to compare") - return {"error": "No successful results to compare"} + return cast( + ModelComparisonResult, {"error": "No successful results to compare"} + ) model_scores = [] for r in successful_results: @@ -247,37 +321,50 @@ def compare_models( model_scores.append((r.model_name, score if score is not None else 0)) if not model_scores: - return {"error": "No comparable scores found"} + return cast(ModelComparisonResult, {"error": "No comparable scores found"}) model_scores.sort(key=lambda x: x[1], reverse=True) - comparison = { - "models": [r.model_name for r in successful_results], - "scores": [score for _, score in model_scores], - "best_model": model_scores[0][0], - "worst_model": model_scores[-1][0], - "ranking": [ - {"model": name, "score": score} for name, score in model_scores - ], - } - - self.logger.info(f"Comparison complete: Best model is {comparison['best_model']}") - + comparison = cast( + ModelComparisonResult, + { + "ranking": [ + {"model": name, "score": float(score)} + for name, score in model_scores + ], + "best_model": model_scores[0][0], + "best_score": float(model_scores[0][1]), + "worst_model": model_scores[-1][0], + "worst_score": float(model_scores[-1][1]), + "mean_score": float( + sum(score for _, score in model_scores) / len(model_scores) + ), + "std_score": 0.0, # Could calculate if needed + "total_models": len(model_scores), + }, + ) + + self.logger.info( + f"Comparison complete: Best model is {comparison['best_model']}" + ) + return comparison def run_benchmark( - benchmark_func: Callable, dataset: Dataset, models: List[str] + benchmark_func: Callable[..., Awaitable[Any]], dataset: Dataset, models: List[str] ) -> List[EvaluationResult]: """Run a benchmark on multiple models.""" runner = EvaluationRunner() return asyncio.run(runner.run_evaluation(benchmark_func, dataset, models)) -async def quick_eval(prompt: str, models: List[str], metric: Callable) -> Dict[str, float]: +async def quick_eval( + prompt: str, models: List[str], metric: Callable[[str], float] +) -> Dict[str, Optional[float]]: """Quick evaluation with a single prompt.""" - results = {} - + results: Dict[str, Optional[float]] = {} + logger.info(f"Running quick eval on {len(models)} models") for model_name in models: diff --git a/benchwise/datasets.py b/benchwise/datasets.py index 2d1c416..2a9587b 100644 --- a/benchwise/datasets.py +++ b/benchwise/datasets.py @@ -1,10 +1,106 @@ -from typing import List, Dict, Any, Optional, Union +from typing import List, Dict, Any, Optional, Union, Callable, cast import json import pandas as pd from pathlib import Path import requests from dataclasses import dataclass import hashlib +import random + +from .types import ( + DatasetItem, + DatasetMetadata, + DatasetSchema, + DatasetDict, + DatasetInfo, +) + + +def _validate_dataset_item(item: Any) -> DatasetItem: + """ + Validate and convert a dictionary to DatasetItem. + + Args: + item: Dictionary or any value to validate + + Returns: + Validated DatasetItem + + Raises: + ValueError: If item is not a dictionary + """ + if not isinstance(item, dict): + raise ValueError(f"Expected dict for DatasetItem, got {type(item).__name__}") + return cast(DatasetItem, item) + + +def _validate_dataset_items(items: Any) -> List[DatasetItem]: + """ + Validate and convert a list of dictionaries to List[DatasetItem]. + + Args: + items: List of dictionaries or any value to validate + + Returns: + Validated List[DatasetItem] + + Raises: + ValueError: If items is not a list or contains non-dict items + """ + if not isinstance(items, list): + raise ValueError(f"Expected list for dataset data, got {type(items).__name__}") + + validated_items: List[DatasetItem] = [] + for i, item in enumerate(items): + if not isinstance(item, dict): + raise ValueError( + f"Expected dict for dataset item at index {i}, got {type(item).__name__}" + ) + validated_items.append(cast(DatasetItem, item)) + + return validated_items + + +def _validate_dataset_metadata(metadata: Any) -> Optional[DatasetMetadata]: + """ + Validate and convert metadata to DatasetMetadata. + + Args: + metadata: Dictionary or None to validate + + Returns: + Validated DatasetMetadata or None + """ + if metadata is None: + return None + + if not isinstance(metadata, dict): + raise ValueError( + f"Expected dict or None for DatasetMetadata, got {type(metadata).__name__}" + ) + + return cast(DatasetMetadata, metadata) + + +def _validate_dataset_schema(schema: Any) -> Optional[DatasetSchema]: + """ + Validate and convert schema to DatasetSchema. + + Args: + schema: Dictionary or None to validate + + Returns: + Validated DatasetSchema or None + """ + if schema is None: + return None + + if not isinstance(schema, dict): + raise ValueError( + f"Expected dict or None for DatasetSchema, got {type(schema).__name__}" + ) + + return cast(DatasetSchema, schema) @dataclass @@ -20,20 +116,23 @@ class Dataset: """ name: str - data: List[Dict[str, Any]] - metadata: Optional[Dict[str, Any]] = None - schema: Optional[Dict[str, Any]] = None + data: List[DatasetItem] + metadata: Optional[DatasetMetadata] = None + schema: Optional[DatasetSchema] = None - def __post_init__(self): + def __post_init__(self) -> None: if self.metadata is None: - self.metadata = {} + self.metadata = cast(DatasetMetadata, {}) if not self.metadata: - self.metadata = { - "size": len(self.data), - "created_at": pd.Timestamp.now().isoformat(), - "hash": self._compute_hash(), - } + self.metadata = cast( + DatasetMetadata, + { + "size": len(self.data), + "created_at": pd.Timestamp.now().isoformat(), + "hash": self._compute_hash(), + }, + ) def _compute_hash(self) -> str: """Compute hash of dataset for versioning.""" @@ -72,73 +171,95 @@ def references(self) -> List[str]: or item.get("answer") or item.get("target") or item.get("summary") - or item.get("label") + or item.get("label") ) if ref: references.append(str(ref)) return references - def filter(self, condition: callable) -> "Dataset": + def filter(self, condition: Callable[[DatasetItem], bool]) -> "Dataset": """Filter dataset items based on condition.""" filtered_data = [item for item in self.data if condition(item)] + metadata = self.metadata or cast(DatasetMetadata, {}) return Dataset( name=f"{self.name}_filtered", data=filtered_data, - metadata={**self.metadata, "filtered": True, "original_size": self.size}, + metadata=cast( + DatasetMetadata, + {**metadata, "filtered": True, "original_size": self.size}, + ), ) def sample(self, n: int, random_state: Optional[int] = None) -> "Dataset": """Sample n items from dataset.""" - import random - if random_state: random.seed(random_state) - sampled_data = random.sample(self.data, min(n, len(self.data))) + sampled_data: List[DatasetItem] = random.sample( + self.data, min(n, len(self.data)) + ) + metadata = self.metadata or cast(DatasetMetadata, {}) return Dataset( name=f"{self.name}_sample_{n}", data=sampled_data, - metadata={**self.metadata, "sampled": True, "sample_size": n}, + metadata=cast( + DatasetMetadata, {**metadata, "sampled": True, "sample_size": n} + ), ) def split( self, train_ratio: float = 0.8, random_state: Optional[int] = None ) -> tuple["Dataset", "Dataset"]: """Split dataset into train and test sets.""" - import random - if random_state: random.seed(random_state) - shuffled_data = self.data.copy() + shuffled_data: List[DatasetItem] = self.data.copy() random.shuffle(shuffled_data) split_idx = int(len(shuffled_data) * train_ratio) - train_data = shuffled_data[:split_idx] - test_data = shuffled_data[split_idx:] + train_data: List[DatasetItem] = shuffled_data[:split_idx] + test_data: List[DatasetItem] = shuffled_data[split_idx:] train_dataset = Dataset( name=f"{self.name}_train", data=train_data, - metadata={**self.metadata, "split": "train", "train_ratio": train_ratio}, + metadata=cast( + DatasetMetadata, + { + **(self.metadata or cast(DatasetMetadata, {})), + "split": "train", + "train_ratio": train_ratio, + }, + ), ) test_dataset = Dataset( name=f"{self.name}_test", data=test_data, - metadata={**self.metadata, "split": "test", "train_ratio": train_ratio}, + metadata=cast( + DatasetMetadata, + { + **(self.metadata or cast(DatasetMetadata, {})), + "split": "test", + "test_ratio": 1 - train_ratio, + }, + ), ) return train_dataset, test_dataset - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> DatasetDict: """Convert dataset to dictionary format.""" - return { - "name": self.name, - "data": self.data, - "metadata": self.metadata, - "schema": self.schema, - } + return cast( + DatasetDict, + { + "name": self.name, + "data": self.data, + "metadata": self.metadata, + "schema": self.schema, + }, + ) def to_json(self, file_path: Optional[str] = None) -> str: """Export dataset to JSON format.""" @@ -150,7 +271,7 @@ def to_json(self, file_path: Optional[str] = None) -> str: return json_data - def to_csv(self, file_path: str): + def to_csv(self, file_path: str) -> None: """Export dataset to CSV format.""" df = pd.DataFrame(self.data) df.to_csv(file_path, index=False) @@ -160,7 +281,14 @@ def validate_schema(self) -> bool: if not self.schema: return True - required_fields = self.schema.get("required", []) + # Support both "required" and "required_fields" for backward compatibility + # Check if "required" key exists first, then fall back to "required_fields" + if "required" in self.schema: + required_fields = self.schema["required"] + elif "required_fields" in self.schema: + required_fields = self.schema["required_fields"] + else: + required_fields = [] for item in self.data: for field in required_fields: @@ -171,14 +299,15 @@ def validate_schema(self) -> bool: def get_statistics(self) -> Dict[str, Any]: """Get dataset statistics.""" - stats = { + fields: List[str] = list(self.data[0].keys()) if self.data else [] + stats: Dict[str, Any] = { "size": self.size, - "fields": list(self.data[0].keys()) if self.data else [], + "fields": fields, "metadata": self.metadata, } if self.data: - for field in stats["fields"]: + for field in fields: values = [item.get(field) for item in self.data if field in item] if values: if all(isinstance(v, str) for v in values): @@ -186,14 +315,20 @@ def get_statistics(self) -> Dict[str, Any]: len(str(v)) for v in values ) / len(values) elif all(isinstance(v, (int, float)) for v in values): - stats[f"{field}_mean"] = sum(values) / len(values) - stats[f"{field}_min"] = min(values) - stats[f"{field}_max"] = max(values) + # Type narrowing: we know values are numeric here + numeric_values = [ + v for v in values if isinstance(v, (int, float)) + ] + stats[f"{field}_mean"] = sum(numeric_values) / len( + numeric_values + ) + stats[f"{field}_min"] = min(numeric_values) + stats[f"{field}_max"] = max(numeric_values) return stats -def load_dataset(source: Union[str, Path, Dict[str, Any]], **kwargs) -> Dataset: +def load_dataset(source: Union[str, Path, DatasetDict], **kwargs: Any) -> Dataset: """ Load dataset from various sources. @@ -206,11 +341,26 @@ def load_dataset(source: Union[str, Path, Dict[str, Any]], **kwargs) -> Dataset: """ if isinstance(source, dict): + # Type narrowing: after isinstance check, treat as DatasetDict + # Note: .get() on TypedDict with total=False returns Any for optional keys, + # but we know the structure from DatasetDict, so we use proper type annotations + dataset_dict: DatasetDict = source + # Prefer name from DatasetDict if present, otherwise fall back to kwargs + name_from_dict: Optional[str] = dataset_dict.get("name") + name: str = ( + name_from_dict + if isinstance(name_from_dict, str) + else kwargs.get("name", "custom_dataset") + ) + data: List[DatasetItem] = dataset_dict.get("data", []) + metadata: Optional[DatasetMetadata] = dataset_dict.get("metadata") + schema: Optional[DatasetSchema] = dataset_dict.get("schema") + return Dataset( - name=kwargs.get("name", "custom_dataset"), - data=source.get("data", []), - metadata=source.get("metadata", {}), - schema=source.get("schema"), + name=name, + data=_validate_dataset_items(data), + metadata=_validate_dataset_metadata(metadata), + schema=_validate_dataset_schema(schema), ) elif isinstance(source, (str, Path)): @@ -218,59 +368,93 @@ def load_dataset(source: Union[str, Path, Dict[str, Any]], **kwargs) -> Dataset: if source_path.suffix == ".json": with open(source_path, "r") as f: - data = json.load(f) + json_data = json.load(f) - if isinstance(data, dict) and "data" in data: + if isinstance(json_data, dict) and "data" in json_data: return Dataset( - name=data.get("name", source_path.stem), - data=data["data"], - metadata=data.get("metadata", {}), - schema=data.get("schema"), + name=json_data.get("name", source_path.stem) + if isinstance(json_data.get("name"), str) + else source_path.stem, + data=_validate_dataset_items(json_data["data"]), + metadata=_validate_dataset_metadata(json_data.get("metadata")), + schema=_validate_dataset_schema(json_data.get("schema")), ) - elif isinstance(data, list): + elif isinstance(json_data, list): return Dataset( - name=kwargs.get("name", source_path.stem), - data=data, - metadata=kwargs.get("metadata", {}), + name=kwargs.get("name", source_path.stem) + if isinstance(kwargs.get("name"), str) + else source_path.stem, + data=_validate_dataset_items(json_data), + metadata=_validate_dataset_metadata(kwargs.get("metadata", {})), + ) + else: + raise ValueError( + f"Invalid JSON format in '{source_path}'. Expected a list or a dict with 'data' key." ) elif source_path.suffix == ".csv": df = pd.read_csv(source_path) - data = df.to_dict("records") + # Type cast: pandas to_dict returns dict[Hashable, Any] but we need dict[str, Any] + records: List[Dict[str, Any]] = [ + cast(Dict[str, Any], dict(record)) for record in df.to_dict("records") + ] + csv_data: List[DatasetItem] = [ + cast(DatasetItem, record) for record in records + ] return Dataset( - name=kwargs.get("name", source_path.stem), - data=data, - metadata=kwargs.get("metadata", {}), + name=kwargs.get("name", source_path.stem) + if isinstance(kwargs.get("name"), str) + else source_path.stem, + data=csv_data, + metadata=_validate_dataset_metadata(kwargs.get("metadata")), ) elif str(source).startswith(("http://", "https://")): - response = requests.get(source) + # Convert to str for requests.get + source_str = str(source) + response = requests.get(source_str) response.raise_for_status() - if source.endswith(".json"): - data = response.json() - if isinstance(data, dict) and "data" in data: + if source_str.endswith(".json"): + json_data = response.json() + if isinstance(json_data, dict) and "data" in json_data: return Dataset( - name=data.get("name", "remote_dataset"), - data=data["data"], - metadata=data.get("metadata", {}), - schema=data.get("schema"), + name=json_data.get("name", "remote_dataset") + if isinstance(json_data.get("name"), str) + else "remote_dataset", + data=_validate_dataset_items(json_data["data"]), + metadata=_validate_dataset_metadata(json_data.get("metadata")), + schema=_validate_dataset_schema(json_data.get("schema")), ) - elif isinstance(data, list): + elif isinstance(json_data, list): return Dataset( - name=kwargs.get("name", "remote_dataset"), - data=data, - metadata=kwargs.get("metadata", {}), + name=kwargs.get("name", "remote_dataset") + if isinstance(kwargs.get("name"), str) + else "remote_dataset", + data=_validate_dataset_items(json_data), + metadata=_validate_dataset_metadata(kwargs.get("metadata", {})), + ) + else: + raise ValueError( + f"Invalid JSON format from '{source_str}'. Expected a list or a dict with 'data' key." ) + else: + raise ValueError( + f"Unsupported URL format '{source_str}'. Only .json URLs are supported." + ) else: raise ValueError( f"Unsupported file format '{source_path.suffix}'. Supported formats: .json, .csv" ) + raise ValueError(f"Unable to load dataset from source: {source}") + -def create_qa_dataset(questions: List[str], answers: List[str], **kwargs) -> Dataset: +def create_qa_dataset( + questions: List[str], answers: List[str], **kwargs: Any +) -> Dataset: """ Create a question-answering dataset. @@ -286,28 +470,37 @@ def create_qa_dataset(questions: List[str], answers: List[str], **kwargs) -> Dat if len(questions) != len(answers): raise ValueError("Questions and answers must have the same length") - data = [{"question": q, "answer": a} for q, a in zip(questions, answers)] + data: List[DatasetItem] = [ + cast(DatasetItem, {"question": q, "answer": a}) + for q, a in zip(questions, answers) + ] return Dataset( - name=kwargs.get("name", "qa_dataset"), + name=kwargs.get("name", "qa_dataset") + if isinstance(kwargs.get("name"), str) + else "qa_dataset", data=data, - metadata={ - "task": "question_answering", - "size": len(data), - **kwargs.get("metadata", {}), - }, - schema={ - "required": ["question", "answer"], - "properties": { - "question": {"type": "string"}, - "answer": {"type": "string"}, + metadata=cast( + DatasetMetadata, + { + "task": "question_answering", + "size": len(data), + **kwargs.get("metadata", {}), + }, + ), + schema=cast( + DatasetSchema, + { + "required": ["question", "answer"], + "prompt_field": "question", + "reference_field": "answer", }, - }, + ), ) def create_summarization_dataset( - documents: List[str], summaries: List[str], **kwargs + documents: List[str], summaries: List[str], **kwargs: Any ) -> Dataset: """ Create a text summarization dataset. @@ -324,30 +517,37 @@ def create_summarization_dataset( if len(documents) != len(summaries): raise ValueError("Documents and summaries must have the same length") - data = [ - {"document": doc, "summary": summ} for doc, summ in zip(documents, summaries) + data: List[DatasetItem] = [ + cast(DatasetItem, {"document": doc, "summary": summ}) + for doc, summ in zip(documents, summaries) ] return Dataset( - name=kwargs.get("name", "summarization_dataset"), + name=kwargs.get("name", "summarization_dataset") + if isinstance(kwargs.get("name"), str) + else "summarization_dataset", data=data, - metadata={ - "task": "summarization", - "size": len(data), - **kwargs.get("metadata", {}), - }, - schema={ - "required": ["document", "summary"], - "properties": { - "document": {"type": "string"}, - "summary": {"type": "string"}, + metadata=cast( + DatasetMetadata, + { + "task": "summarization", + "size": len(data), + **kwargs.get("metadata", {}), + }, + ), + schema=cast( + DatasetSchema, + { + "required": ["document", "summary"], + "prompt_field": "document", + "reference_field": "summary", }, - }, + ), ) def create_classification_dataset( - texts: List[str], labels: List[str], **kwargs + texts: List[str], labels: List[str], **kwargs: Any ) -> Dataset: """ Create a text classification dataset. @@ -364,31 +564,43 @@ def create_classification_dataset( if len(texts) != len(labels): raise ValueError("Texts and labels must have the same length") - data = [{"text": text, "label": label} for text, label in zip(texts, labels)] + data: List[DatasetItem] = [ + cast(DatasetItem, {"text": text, "label": label}) + for text, label in zip(texts, labels) + ] return Dataset( - name=kwargs.get("name", "classification_dataset"), + name=kwargs.get("name", "classification_dataset") + if isinstance(kwargs.get("name"), str) + else "classification_dataset", data=data, - metadata={ - "task": "classification", - "size": len(data), - "unique_labels": list(set(labels)), - **kwargs.get("metadata", {}), - }, - schema={ - "required": ["text", "label"], - "properties": {"text": {"type": "string"}, "label": {"type": "string"}}, - }, + metadata=cast( + DatasetMetadata, + { + "task": "classification", + "size": len(data), + "unique_labels": list(set(labels)), + **kwargs.get("metadata", {}), + }, + ), + schema=cast( + DatasetSchema, + { + "required": ["text", "label"], + "prompt_field": "text", + "reference_field": "label", + }, + ), ) class DatasetRegistry: """Registry for managing multiple datasets.""" - def __init__(self): + def __init__(self) -> None: self.datasets: Dict[str, Dataset] = {} - def register(self, dataset: Dataset): + def register(self, dataset: Dataset) -> None: self.datasets[dataset.name] = dataset def get(self, name: str) -> Optional[Dataset]: @@ -397,11 +609,11 @@ def get(self, name: str) -> Optional[Dataset]: def list(self) -> List[str]: return list(self.datasets.keys()) - def remove(self, name: str): + def remove(self, name: str) -> None: if name in self.datasets: del self.datasets[name] - def clear(self): + def clear(self) -> None: self.datasets.clear() @@ -410,73 +622,144 @@ def clear(self): def load_mmlu_sample() -> Dataset: - sample_data = [ - { - "question": "What is the capital of France?", - "choices": ["London", "Berlin", "Paris", "Madrid"], - "answer": "Paris", - "subject": "geography", - }, - { - "question": "What is 2 + 2?", - "choices": ["3", "4", "5", "6"], - "answer": "4", - "subject": "mathematics", - }, + sample_data: List[DatasetItem] = [ + cast( + DatasetItem, + { + "question": "What is the capital of France?", + "choices": ["London", "Berlin", "Paris", "Madrid"], + "answer": "Paris", + "subject": "geography", + }, + ), + cast( + DatasetItem, + { + "question": "What is 2 + 2?", + "choices": ["3", "4", "5", "6"], + "answer": "4", + "subject": "mathematics", + }, + ), ] return Dataset( name="mmlu_sample", data=sample_data, - metadata={ - "task": "multiple_choice_qa", - "source": "MMLU", - "description": "Sample from Massive Multitask Language Understanding", - }, + metadata=cast( + DatasetMetadata, + { + "task": "multiple_choice_qa", + "source": "MMLU", + "description": "Sample from Massive Multitask Language Understanding", + }, + ), ) def load_hellaswag_sample() -> Dataset: """Load a sample of HellaSwag dataset.""" - sample_data = [ - { - "context": "A woman is outside with a bucket and a dog. The dog is running around trying to avoid a bath. She", - "endings": [ - "rinses the bucket off with soap and blow dry the dog.", - "uses a hose to keep the dog from getting soapy.", - "gets the dog wet, then it runs away again.", - "gets into the bath tub with the dog.", - ], - "label": 2, - } + sample_data: List[DatasetItem] = [ + cast( + DatasetItem, + { + "context": "A woman is outside with a bucket and a dog. The dog is running around trying to avoid a bath. She", + "endings": [ + "rinses the bucket off with soap and blow dry the dog.", + "uses a hose to keep the dog from getting soapy.", + "gets the dog wet, then it runs away again.", + "gets into the bath tub with the dog.", + ], + "label": 2, + }, + ) ] return Dataset( name="hellaswag_sample", data=sample_data, - metadata={ - "task": "sentence_completion", - "source": "HellaSwag", - "description": "Commonsense reasoning benchmark", - }, + metadata=cast( + DatasetMetadata, + { + "task": "sentence_completion", + "source": "HellaSwag", + "description": "Commonsense reasoning benchmark", + }, + ), ) def load_gsm8k_sample() -> Dataset: """Load a sample of GSM8K (Grade School Math 8K) dataset.""" - sample_data = [ - { - "question": "Janet's ducks lay 16 eggs per day. She eats 3 for breakfast every morning and bakes 4 into muffins for her friends every day. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much money does she make every day at the farmers' market?", - "answer": "Janet sells 16 - 3 - 4 = 9 duck eggs every day. She makes 9 * $2 = $18 every day at the farmers' market.", - } + sample_data: List[DatasetItem] = [ + cast( + DatasetItem, + { + "question": "Janet's ducks lay 16 eggs per day. She eats 3 for breakfast every morning and bakes 4 into muffins for her friends every day. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much money does she make every day at the farmers' market?", + "answer": "Janet sells 16 - 3 - 4 = 9 duck eggs every day. She makes 9 * $2 = $18 every day at the farmers' market.", + }, + ) ] return Dataset( name="gsm8k_sample", data=sample_data, - metadata={ - "task": "math_word_problems", - "source": "GSM8K", - "description": "Grade school math word problems", - }, + metadata=cast( + DatasetMetadata, + { + "task": "math_word_problems", + "source": "GSM8K", + "description": "Grade school math word problems", + }, + ), ) + + +def convert_metadata_to_info(metadata: DatasetMetadata) -> DatasetInfo: + """ + Convert DatasetMetadata to DatasetInfo for evaluation results. + + This function properly converts dataset metadata (which is stored with the dataset) + to dataset info (which is used in evaluation results). It handles missing fields + and ensures type safety. + + Args: + metadata: Dataset metadata to convert + + Returns: + DatasetInfo with converted fields + """ + # Extract fields that exist in DatasetMetadata + size: int = metadata.get("size", 0) + tags: List[str] = metadata.get("tags", []) + source: Optional[str] = metadata.get("source") + name: Optional[str] = metadata.get("name") + description: Optional[str] = metadata.get("description") + version: Optional[str] = metadata.get("version") + created_at: Optional[str] = metadata.get("created_at") + + # Extract fields that might exist but aren't in DatasetMetadata TypedDict + # These could be present at runtime even if not in the type definition + metadata_dict: Dict[str, Any] = cast(Dict[str, Any], metadata) + hash_value: Optional[str] = metadata_dict.get("hash") + task: Optional[str] = metadata_dict.get("task") + difficulty: Optional[str] = metadata_dict.get("difficulty") + + # Build DatasetInfo with proper types + info: DatasetInfo = { + "size": size, + "tags": tags, + "source": source, + "name": name, + "description": description, + "version": version, + "created_at": created_at, + "hash": hash_value, + "task": task if task else "general", + } + + # Add difficulty if available + if difficulty: + info["difficulty"] = difficulty + + return info diff --git a/benchwise/exceptions.py b/benchwise/exceptions.py index 498bd02..659799d 100644 --- a/benchwise/exceptions.py +++ b/benchwise/exceptions.py @@ -4,50 +4,62 @@ Provides specific exception types for better error handling. """ +from typing import Optional + class BenchwiseError(Exception): """Base exception for all Benchwise errors.""" + pass class AuthenticationError(BenchwiseError): """Raised when authentication fails.""" + pass class RateLimitError(BenchwiseError): """Raised when API rate limit is exceeded.""" - - def __init__(self, message: str = "Rate limit exceeded", retry_after: int = None): + + def __init__( + self, message: str = "Rate limit exceeded", retry_after: Optional[int] = None + ) -> None: super().__init__(message) self.retry_after = retry_after class ValidationError(BenchwiseError): """Raised when input validation fails.""" + pass class NetworkError(BenchwiseError): """Raised when network requests fail.""" + pass class ConfigurationError(BenchwiseError): """Raised when configuration is invalid or missing.""" + pass class DatasetError(BenchwiseError): """Raised when dataset operations fail.""" + pass class ModelError(BenchwiseError): """Raised when model operations fail.""" + pass class MetricError(BenchwiseError): """Raised when metric calculation fails.""" + pass diff --git a/benchwise/logging.py b/benchwise/logging.py index 3f6b6a5..4843e53 100644 --- a/benchwise/logging.py +++ b/benchwise/logging.py @@ -16,20 +16,20 @@ def setup_logging( ) -> logging.Logger: """ Setup logging for Benchwise. - + Args: level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) format: Custom log format string filename: Optional file to write logs to - + Returns: Configured logger instance """ - + # Default format if format is None: format = "[%(asctime)s] %(levelname)s [%(name)s] %(message)s" - + # Configure root logger logging.basicConfig( level=getattr(logging, level.upper()), @@ -37,41 +37,41 @@ def setup_logging( datefmt="%Y-%m-%d %H:%M:%S", handlers=[ logging.StreamHandler(sys.stdout), - ] + ], ) - + # Add file handler if filename provided if filename: file_handler = logging.FileHandler(filename) file_handler.setFormatter(logging.Formatter(format)) logging.getLogger("benchwise").addHandler(file_handler) - + # Get benchwise logger logger = logging.getLogger("benchwise") logger.setLevel(getattr(logging, level.upper())) - + logger.debug(f"Logging initialized at {level} level") - + return logger def get_logger(name: str = "benchwise") -> logging.Logger: """ Get a logger instance for Benchwise. - + Args: name: Logger name (default: "benchwise") - + Returns: Logger instance """ return logging.getLogger(name) -def set_log_level(level: str): +def set_log_level(level: str) -> None: """ Change the log level for all Benchwise loggers. - + Args: level: New log level (DEBUG, INFO, WARNING, ERROR, CRITICAL) """ diff --git a/benchwise/metrics.py b/benchwise/metrics.py index e616d91..ebe385d 100644 --- a/benchwise/metrics.py +++ b/benchwise/metrics.py @@ -1,9 +1,20 @@ -from typing import List, Dict, Any, Tuple, Optional +from typing import List, Dict, Any, Tuple, Optional, Callable, cast import numpy as np +from benchwise.types import ( + RougeScores, + BleuScores, + BertScoreResults, + AccuracyResults, + SemanticSimilarityResults, + PerplexityResults, + FactualCorrectnessResults, + CoherenceResults, + SafetyResults, +) from rouge_score import rouge_scorer from sacrebleu import BLEU import bert_score -from nltk.translate.bleu_score import sentence_bleu +from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction import nltk import re import string @@ -26,7 +37,7 @@ def _bootstrap_confidence_interval( ) -> Tuple[float, float]: """Calculate bootstrap confidence interval for a list of scores.""" if len(scores) < 2: - return (np.mean(scores), np.mean(scores)) + return (float(np.mean(scores)), float(np.mean(scores))) bootstrap_means = [] for _ in range(n_bootstrap): @@ -38,8 +49,8 @@ def _bootstrap_confidence_interval( upper_percentile = (1 - alpha / 2) * 100 return ( - np.percentile(bootstrap_means, lower_percentile), - np.percentile(bootstrap_means, upper_percentile), + float(np.percentile(bootstrap_means, lower_percentile)), + float(np.percentile(bootstrap_means, upper_percentile)), ) @@ -65,7 +76,7 @@ def rouge_l( use_stemmer: bool = True, alpha: float = 0.5, return_confidence: bool = True, -) -> Dict[str, float]: +) -> RougeScores: """ Calculate enhanced ROUGE-L scores for predictions vs references. @@ -96,7 +107,13 @@ def rouge_l( scorer = rouge_scorer.RougeScorer( ["rougeL", "rouge1", "rouge2"], use_stemmer=use_stemmer ) - scores = {"precision": [], "recall": [], "f1": [], "rouge1_f1": [], "rouge2_f1": []} + scores: Dict[str, List[float]] = { + "precision": [], + "recall": [], + "f1": [], + "rouge1_f1": [], + "rouge2_f1": [], + } for pred, ref in zip(predictions, references): # Handle empty strings gracefully @@ -130,15 +147,15 @@ def rouge_l( scores["rouge1_f1"].append(score["rouge1"].fmeasure) scores["rouge2_f1"].append(score["rouge2"].fmeasure) - result = { - "precision": np.mean(scores["precision"]), - "recall": np.mean(scores["recall"]), - "f1": np.mean(scores["f1"]), - "rouge1_f1": np.mean(scores["rouge1_f1"]), - "rouge2_f1": np.mean(scores["rouge2_f1"]), - "std_precision": np.std(scores["precision"]), - "std_recall": np.std(scores["recall"]), - "std_f1": np.std(scores["f1"]), + result: RougeScores = { + "precision": float(np.mean(scores["precision"])), + "recall": float(np.mean(scores["recall"])), + "f1": float(np.mean(scores["f1"])), + "rouge1_f1": float(np.mean(scores["rouge1_f1"])), + "rouge2_f1": float(np.mean(scores["rouge2_f1"])), + "std_precision": float(np.std(scores["precision"])), + "std_recall": float(np.std(scores["recall"])), + "std_f1": float(np.std(scores["f1"])), "scores": scores, } @@ -166,7 +183,7 @@ def bleu_score( smooth_method: str = "exp", return_confidence: bool = True, max_n: int = 4, -) -> Dict[str, float]: +) -> BleuScores: """ Calculate enhanced BLEU scores for predictions vs references. @@ -204,7 +221,9 @@ def bleu_score( # Calculate sentence-level BLEU with improved handling sentence_scores = [] - ngram_precisions = {f"bleu_{i}": [] for i in range(1, max_n + 1)} + ngram_precisions: Dict[str, List[float]] = { + f"bleu_{i}": [] for i in range(1, max_n + 1) + } for pred, ref in zip(predictions, references): try: @@ -256,49 +275,48 @@ def bleu_score( for i in range(1, max_n + 1): ngram_precisions[f"bleu_{i}"].append(0.0) - result = { + # Build result dict dynamically, then cast to BleuScores + result_dict: Dict[str, Any] = { "corpus_bleu": corpus_bleu, - "sentence_bleu": np.mean(sentence_scores), - "std_sentence_bleu": np.std(sentence_scores), - "median_sentence_bleu": np.median(sentence_scores), + "sentence_bleu": float(np.mean(sentence_scores)), + "std_sentence_bleu": float(np.std(sentence_scores)), + "median_sentence_bleu": float(np.median(sentence_scores)), "scores": sentence_scores, } # Add n-gram precision scores for key, scores in ngram_precisions.items(): if scores: # Only add if we have scores - result[key] = np.mean(scores) - result[f"{key}_std"] = np.std(scores) + result_dict[key] = float(np.mean(scores)) + result_dict[f"{key}_std"] = float(np.std(scores)) # Add confidence intervals if requested if return_confidence and len(sentence_scores) > 1: try: - result[ - "sentence_bleu_confidence_interval" - ] = _bootstrap_confidence_interval(sentence_scores) + result_dict["sentence_bleu_confidence_interval"] = ( + _bootstrap_confidence_interval(sentence_scores) + ) except Exception as e: warnings.warn(f"Could not calculate BLEU confidence intervals: {e}") - return result + return cast(BleuScores, result_dict) -def _get_smoothing_function(smooth_method: str): +def _get_smoothing_function(smooth_method: str) -> Optional[Callable[..., Any]]: """Get NLTK smoothing function based on method name.""" - from nltk.translate.bleu_score import SmoothingFunction - smoothing = SmoothingFunction() if smooth_method == "exp": - return smoothing.method1 + return smoothing.method1 # type: ignore[no-any-return] elif smooth_method == "floor": - return smoothing.method2 + return smoothing.method2 # type: ignore[no-any-return] elif smooth_method == "add-k": - return smoothing.method3 + return smoothing.method3 # type: ignore[no-any-return] else: return None -def _get_weights(n: int) -> tuple: +def _get_weights(n: int) -> Tuple[float, ...]: """Get n-gram weights for BLEU calculation.""" weights = [0.0] * 4 weights[n - 1] = 1.0 @@ -311,7 +329,7 @@ def bert_score_metric( model_type: str = "distilbert-base-uncased", return_confidence: bool = True, batch_size: int = 64, -) -> Dict[str, float]: +) -> BertScoreResults: """ Calculate enhanced BERTScore for predictions vs references. @@ -331,12 +349,15 @@ def bert_score_metric( ) if not predictions or not references: - return { - "precision": 0.0, - "recall": 0.0, - "f1": 0.0, - "scores": {"precision": [], "recall": [], "f1": []}, - } + return cast( + BertScoreResults, + { + "precision": 0.0, + "recall": 0.0, + "f1": 0.0, + "scores": {"precision": [], "recall": [], "f1": []}, + }, + ) try: # Handle empty strings gracefully @@ -385,16 +406,16 @@ def bert_score_metric( R_scores[idx] = r F1_scores[idx] = f1 - result = { - "precision": np.mean(P_scores), - "recall": np.mean(R_scores), - "f1": np.mean(F1_scores), - "std_precision": np.std(P_scores), - "std_recall": np.std(R_scores), - "std_f1": np.std(F1_scores), - "min_f1": np.min(F1_scores), - "max_f1": np.max(F1_scores), - "median_f1": np.median(F1_scores), + result_dict: Dict[str, Any] = { + "precision": float(np.mean(P_scores)), + "recall": float(np.mean(R_scores)), + "f1": float(np.mean(F1_scores)), + "std_precision": float(np.std(P_scores)), + "std_recall": float(np.std(R_scores)), + "std_f1": float(np.std(F1_scores)), + "min_f1": float(np.min(F1_scores)), + "max_f1": float(np.max(F1_scores)), + "median_f1": float(np.median(F1_scores)), "model_used": model_type, "scores": {"precision": P_scores, "recall": R_scores, "f1": F1_scores}, } @@ -402,36 +423,39 @@ def bert_score_metric( # Add confidence intervals if requested if return_confidence and len(F1_scores) > 1: try: - result["f1_confidence_interval"] = _bootstrap_confidence_interval( + result_dict["f1_confidence_interval"] = _bootstrap_confidence_interval( F1_scores ) - result[ - "precision_confidence_interval" - ] = _bootstrap_confidence_interval(P_scores) - result["recall_confidence_interval"] = _bootstrap_confidence_interval( - R_scores + result_dict["precision_confidence_interval"] = ( + _bootstrap_confidence_interval(P_scores) + ) + result_dict["recall_confidence_interval"] = ( + _bootstrap_confidence_interval(R_scores) ) except Exception as e: warnings.warn( f"Could not calculate BERTScore confidence intervals: {e}" ) - return result + return cast(BertScoreResults, result_dict) except Exception as e: warnings.warn(f"BERTScore calculation failed: {e}") # Return fallback scores - return { - "precision": 0.0, - "recall": 0.0, - "f1": 0.0, - "error": str(e), - "scores": { - "precision": [0.0] * len(predictions), - "recall": [0.0] * len(predictions), - "f1": [0.0] * len(predictions), + return cast( + BertScoreResults, + { + "precision": 0.0, + "recall": 0.0, + "f1": 0.0, + "error": str(e), + "scores": { + "precision": [0.0] * len(predictions), + "recall": [0.0] * len(predictions), + "f1": [0.0] * len(predictions), + }, }, - } + ) def accuracy( @@ -442,7 +466,7 @@ def accuracy( fuzzy_match: bool = False, fuzzy_threshold: float = 0.8, return_confidence: bool = True, -) -> Dict[str, float]: +) -> AccuracyResults: """ Calculate enhanced exact match accuracy with multiple matching strategies. @@ -464,7 +488,7 @@ def accuracy( ) if not predictions or not references: - return {"accuracy": 0.0, "correct": 0, "total": 0} + return cast(AccuracyResults, {"accuracy": 0.0, "correct": 0, "total": 0}) correct_exact = 0 correct_fuzzy = 0 @@ -524,28 +548,28 @@ def accuracy( exact_accuracy = correct_exact / total if total > 0 else 0.0 fuzzy_accuracy = correct_fuzzy / total if total > 0 else 0.0 - result = { + result_dict: Dict[str, Any] = { "accuracy": exact_accuracy, "exact_accuracy": exact_accuracy, "fuzzy_accuracy": fuzzy_accuracy if fuzzy_match else exact_accuracy, "correct": correct_exact, "correct_fuzzy": correct_fuzzy if fuzzy_match else correct_exact, "total": total, - "mean_score": np.mean(individual_scores), - "std_score": np.std(individual_scores), + "mean_score": float(np.mean(individual_scores)), + "std_score": float(np.std(individual_scores)), "individual_scores": individual_scores, "match_types": match_types, } if return_confidence and len(individual_scores) > 1: try: - result["accuracy_confidence_interval"] = _bootstrap_confidence_interval( - individual_scores + result_dict["accuracy_confidence_interval"] = ( + _bootstrap_confidence_interval(individual_scores) ) except Exception as e: warnings.warn(f"Could not calculate accuracy confidence intervals: {e}") - return result + return cast(AccuracyResults, result_dict) def semantic_similarity( @@ -555,7 +579,7 @@ def semantic_similarity( batch_size: int = 32, return_confidence: bool = True, similarity_threshold: float = 0.5, -) -> Dict[str, float]: +) -> SemanticSimilarityResults: """ Calculate enhanced semantic similarity using sentence embeddings. @@ -576,7 +600,7 @@ def semantic_similarity( ) if not predictions or not references: - return {"mean_similarity": 0.0, "scores": []} + return cast(SemanticSimilarityResults, {"mean_similarity": 0.0, "scores": []}) try: from sentence_transformers import SentenceTransformer, util @@ -646,37 +670,38 @@ def semantic_similarity( # Calculate enhanced statistics similarities_array = np.array(similarities) - result = { - "mean_similarity": np.mean(similarities), - "median_similarity": np.median(similarities), - "std_similarity": np.std(similarities), - "min_similarity": np.min(similarities), - "max_similarity": np.max(similarities), - "similarity_above_threshold": np.sum(similarities_array >= similarity_threshold) - / len(similarities), + result_dict: Dict[str, Any] = { + "mean_similarity": float(np.mean(similarities)), + "median_similarity": float(np.median(similarities)), + "std_similarity": float(np.std(similarities)), + "min_similarity": float(np.min(similarities)), + "max_similarity": float(np.max(similarities)), + "similarity_above_threshold": float( + np.sum(similarities_array >= similarity_threshold) / len(similarities) + ), "scores": similarities, "model_used": model_type, } - result["percentile_25"] = np.percentile(similarities, 25) - result["percentile_75"] = np.percentile(similarities, 75) - result["percentile_90"] = np.percentile(similarities, 90) + result_dict["percentile_25"] = float(np.percentile(similarities, 25)) + result_dict["percentile_75"] = float(np.percentile(similarities, 75)) + result_dict["percentile_90"] = float(np.percentile(similarities, 90)) # Add confidence intervals if requested if return_confidence and len(similarities) > 1: try: - result["similarity_confidence_interval"] = _bootstrap_confidence_interval( - similarities + result_dict["similarity_confidence_interval"] = ( + _bootstrap_confidence_interval(similarities) ) except Exception as e: warnings.warn( f"Could not calculate semantic similarity confidence intervals: {e}" ) - return result + return cast(SemanticSimilarityResults, result_dict) -def perplexity(predictions: List[str], model_name: str = "gpt2") -> Dict[str, float]: +def perplexity(predictions: List[str], model_name: str = "gpt2") -> PerplexityResults: """ Calculate perplexity of generated text. @@ -711,11 +736,14 @@ def perplexity(predictions: List[str], model_name: str = "gpt2") -> Dict[str, fl perplexity = torch.exp(loss).item() perplexities.append(perplexity) - return { - "mean_perplexity": np.mean(perplexities), - "median_perplexity": np.median(perplexities), - "scores": perplexities, - } + return cast( + PerplexityResults, + { + "mean_perplexity": float(np.mean(perplexities)), + "median_perplexity": float(np.median(perplexities)), + "scores": perplexities, + }, + ) def factual_correctness( @@ -725,7 +753,7 @@ def factual_correctness( use_named_entities: bool = True, return_confidence: bool = True, detailed_analysis: bool = True, -) -> Dict[str, Any]: +) -> FactualCorrectnessResults: """ Evaluate factual correctness of predictions using enhanced fact-checking methods. @@ -746,7 +774,7 @@ def factual_correctness( ) if not predictions or not references: - return {"mean_correctness": 0.0, "scores": []} + return cast(FactualCorrectnessResults, {"mean_correctness": 0.0, "scores": []}) correctness_scores = [] detailed_results = [] @@ -785,16 +813,16 @@ def factual_correctness( # Calculate overall correctness score overall_score = np.mean(list(factual_analysis.values())) - correctness_scores.append(overall_score) + correctness_scores.append(float(overall_score)) detailed_results.append(factual_analysis) # Compile results - result = { - "mean_correctness": np.mean(correctness_scores), - "median_correctness": np.median(correctness_scores), - "std_correctness": np.std(correctness_scores), - "min_correctness": np.min(correctness_scores), - "max_correctness": np.max(correctness_scores), + result_dict: Dict[str, Any] = { + "mean_correctness": float(np.mean(correctness_scores)), + "median_correctness": float(np.median(correctness_scores)), + "std_correctness": float(np.std(correctness_scores)), + "min_correctness": float(np.min(correctness_scores)), + "max_correctness": float(np.max(correctness_scores)), "scores": correctness_scores, } @@ -802,37 +830,40 @@ def factual_correctness( if detailed_analysis: # Aggregate component scores components = ["entity_overlap", "keyword_overlap", "semantic_overlap"] - result["components"] = {} + result_dict["components"] = {} for component in components: component_scores = [ detail.get(component, 0.0) for detail in detailed_results ] if component_scores: - result["components"][component] = { - "mean": np.mean(component_scores), - "std": np.std(component_scores), + result_dict["components"][component] = { + "mean": float(np.mean(component_scores)), + "std": float(np.std(component_scores)), "scores": component_scores, } - result["detailed_results"] = detailed_results + result_dict["detailed_results"] = detailed_results # Add confidence intervals if requested if return_confidence and len(correctness_scores) > 1: try: - result["correctness_confidence_interval"] = _bootstrap_confidence_interval( - correctness_scores + result_dict["correctness_confidence_interval"] = ( + _bootstrap_confidence_interval(correctness_scores) ) except Exception as e: warnings.warn( f"Could not calculate factual correctness confidence intervals: {e}" ) - return result + return cast(FactualCorrectnessResults, result_dict) def _analyze_factual_correctness( - prediction: str, reference: str, nlp_model=None, use_named_entities: bool = True + prediction: str, + reference: str, + nlp_model: Any = None, + use_named_entities: bool = True, ) -> Dict[str, float]: """ Analyze factual correctness using multiple approaches. @@ -868,7 +899,7 @@ def _analyze_factual_correctness( } -def _calculate_entity_overlap(prediction: str, reference: str, nlp_model) -> float: +def _calculate_entity_overlap(prediction: str, reference: str, nlp_model: Any) -> float: """ Calculate overlap between named entities in prediction and reference. """ @@ -913,8 +944,7 @@ def _calculate_enhanced_keyword_overlap(prediction: str, reference: str) -> floa } # Extract important words from reference - important_ref_words = set() - " ".join(ref_words) + important_ref_words: set[str] = set() for pattern_type, pattern in important_patterns.items(): matches = re.findall(pattern, reference, re.IGNORECASE) @@ -1011,7 +1041,7 @@ def coherence_score( predictions: List[str], return_confidence: bool = True, detailed_analysis: bool = True, -) -> Dict[str, Any]: +) -> CoherenceResults: """ Evaluate text coherence using enhanced linguistic and statistical metrics. @@ -1024,10 +1054,10 @@ def coherence_score( Dictionary with enhanced coherence scores and analysis """ if not predictions: - return {"mean_coherence": 1.0, "scores": []} + return cast(CoherenceResults, {"mean_coherence": 1.0, "scores": []}) coherence_scores = [] - component_scores = { + component_scores: Dict[str, List[float]] = { "sentence_consistency": [], "lexical_diversity": [], "flow_continuity": [], @@ -1046,7 +1076,7 @@ def coherence_score( # Calculate overall coherence score overall_coherence = np.mean(list(coherence_components.values())) - coherence_scores.append(overall_coherence) + coherence_scores.append(float(overall_coherence)) # Store component scores for component, score in coherence_components.items(): @@ -1054,36 +1084,36 @@ def coherence_score( component_scores[component].append(score) # Compile results - result = { - "mean_coherence": np.mean(coherence_scores), - "median_coherence": np.median(coherence_scores), - "std_coherence": np.std(coherence_scores), - "min_coherence": np.min(coherence_scores), - "max_coherence": np.max(coherence_scores), + result_dict: Dict[str, Any] = { + "mean_coherence": float(np.mean(coherence_scores)), + "median_coherence": float(np.median(coherence_scores)), + "std_coherence": float(np.std(coherence_scores)), + "min_coherence": float(np.min(coherence_scores)), + "max_coherence": float(np.max(coherence_scores)), "scores": coherence_scores, } # Add detailed component analysis if requested if detailed_analysis: - result["components"] = {} + result_dict["components"] = {} for component, scores in component_scores.items(): if scores: # Only add if we have scores - result["components"][component] = { - "mean": np.mean(scores), - "std": np.std(scores), + result_dict["components"][component] = { + "mean": float(np.mean(scores)), + "std": float(np.std(scores)), "scores": scores, } # Add confidence intervals if requested if return_confidence and len(coherence_scores) > 1: try: - result["coherence_confidence_interval"] = _bootstrap_confidence_interval( - coherence_scores + result_dict["coherence_confidence_interval"] = ( + _bootstrap_confidence_interval(coherence_scores) ) except Exception as e: warnings.warn(f"Could not calculate coherence confidence intervals: {e}") - return result + return cast(CoherenceResults, result_dict) def _analyze_text_coherence(text: str) -> Dict[str, float]: @@ -1104,12 +1134,13 @@ def _analyze_text_coherence(text: str) -> Dict[str, float]: # 1. Sentence consistency (length and structure) sentence_lengths = [len(s.split()) for s in sentences] if len(sentence_lengths) > 1: - length_cv = ( - np.std(sentence_lengths) / np.mean(sentence_lengths) + length_cv: float = ( + float(np.std(sentence_lengths) / np.mean(sentence_lengths)) if np.mean(sentence_lengths) > 0 - else 1 + else 1.0 ) - sentence_consistency = max(0, 1 - (length_cv / 2)) # Normalize to 0-1 + cv_value: float = length_cv / 2.0 + sentence_consistency = max(0.0, 1.0 - cv_value) # Normalize to 0-1 else: sentence_consistency = 1.0 if sentence_lengths else 0.0 @@ -1252,18 +1283,18 @@ def _calculate_topic_consistency(sentences: List[str]) -> float: overlaps.append(overlap) # Return average overlap - return np.mean(overlaps) if overlaps else 1.0 + return float(np.mean(overlaps)) if overlaps else 1.0 def safety_score( predictions: List[str], - safety_categories: List[str] = None, + safety_categories: Optional[List[str]] = None, use_perspective_api: bool = False, - perspective_api_key: str = None, + perspective_api_key: Optional[str] = None, return_confidence: bool = True, detailed_categories: bool = True, context_aware: bool = True, -) -> Dict[str, float]: +) -> SafetyResults: """ Evaluate safety of generated text with enhanced detection methods. @@ -1280,7 +1311,7 @@ def safety_score( Dictionary with enhanced safety scores and category breakdowns """ if not predictions: - return {"mean_safety": 1.0, "scores": []} + return cast(SafetyResults, {"mean_safety": 1.0, "scores": []}) if safety_categories is None: safety_categories = [ @@ -1341,8 +1372,8 @@ def safety_score( } safety_scores = [] - category_scores = {cat: [] for cat in safety_categories} - violation_details = [] + category_scores: Dict[str, List[float]] = {cat: [] for cat in safety_categories} + violation_details: List[List[str]] = [] for i, text in enumerate(predictions): if not text.strip(): @@ -1404,11 +1435,11 @@ def safety_score( violation_details.append(violations) # Compile results - result = { - "mean_safety": np.mean(safety_scores), - "median_safety": np.median(safety_scores), - "std_safety": np.std(safety_scores), - "min_safety": np.min(safety_scores), + result_dict: Dict[str, Any] = { + "mean_safety": float(np.mean(safety_scores)), + "median_safety": float(np.median(safety_scores)), + "std_safety": float(np.std(safety_scores)), + "min_safety": float(np.min(safety_scores)), "unsafe_count": len( [s for s in safety_scores if s < 0.8] ), # Threshold for "unsafe" @@ -1418,25 +1449,25 @@ def safety_score( # Add detailed category scores if requested if detailed_categories: - result["category_scores"] = {} + result_dict["category_scores"] = {} for cat in safety_categories: if category_scores[cat]: # Only add if we have scores - result["category_scores"][cat] = { - "mean": np.mean(category_scores[cat]), - "violation_rate": 1.0 - np.mean(category_scores[cat]), + result_dict["category_scores"][cat] = { + "mean": float(np.mean(category_scores[cat])), + "violation_rate": float(1.0 - np.mean(category_scores[cat])), "scores": category_scores[cat], } # Add confidence intervals if requested if return_confidence and len(safety_scores) > 1: try: - result["safety_confidence_interval"] = _bootstrap_confidence_interval( + result_dict["safety_confidence_interval"] = _bootstrap_confidence_interval( safety_scores ) except Exception as e: warnings.warn(f"Could not calculate safety confidence intervals: {e}") - return result + return cast(SafetyResults, result_dict) def _check_keyword_in_context( @@ -1482,15 +1513,17 @@ def _check_keyword_in_context( class MetricCollection: """Collection of evaluation metrics that can be run together.""" - def __init__(self): - self.metrics = {} + def __init__(self) -> None: + self.metrics: Dict[str, Tuple[Callable[..., Any], Dict[str, Any]]] = {} - def add_metric(self, name: str, metric_func: callable, **kwargs): + def add_metric( + self, name: str, metric_func: Callable[..., Any], **kwargs: Any + ) -> None: """Add a metric to the collection.""" self.metrics[name] = (metric_func, kwargs) def evaluate( - self, predictions: List[str], references: List[str] = None + self, predictions: List[str], references: Optional[List[str]] = None ) -> Dict[str, Any]: """Run all metrics in the collection.""" results = {} diff --git a/benchwise/models.py b/benchwise/models.py index cd5c88a..78a50b9 100644 --- a/benchwise/models.py +++ b/benchwise/models.py @@ -1,16 +1,18 @@ from abc import ABC, abstractmethod from typing import List, Dict, Any, Optional +from benchwise.types import ModelConfig, PricingInfo + class ModelAdapter(ABC): """Abstract base class for model adapters.""" - def __init__(self, model_name: str, config: Optional[Dict[str, Any]] = None): + def __init__(self, model_name: str, config: Optional[ModelConfig] = None) -> None: self.model_name = model_name - self.config = config or {} + self.config: ModelConfig = config or {} @abstractmethod - async def generate(self, prompts: List[str], **kwargs) -> List[str]: + async def generate(self, prompts: List[str], **kwargs: Any) -> List[str]: """Generate responses for a list of prompts.""" pass @@ -28,7 +30,7 @@ def get_cost_estimate(self, input_tokens: int, output_tokens: int) -> float: class OpenAIAdapter(ModelAdapter): """Adapter for OpenAI models.""" - def __init__(self, model_name: str, config: Optional[Dict[str, Any]] = None): + def __init__(self, model_name: str, config: Optional[ModelConfig] = None) -> None: super().__init__(model_name, config) try: import openai @@ -42,16 +44,16 @@ def __init__(self, model_name: str, config: Optional[Dict[str, Any]] = None): ) # Model pricing (per 1K tokens) - self.pricing = { + self.pricing: Dict[str, PricingInfo] = { "gpt-4": {"input": 0.03, "output": 0.06}, "gpt-4-turbo": {"input": 0.01, "output": 0.03}, "gpt-3.5-turbo": {"input": 0.001, "output": 0.002}, "gpt-4o": {"input": 0.005, "output": 0.015}, } - async def generate(self, prompts: List[str], **kwargs) -> List[str]: + async def generate(self, prompts: List[str], **kwargs: Any) -> List[str]: """Generate responses using OpenAI API.""" - responses = [] + responses: List[str] = [] # Default parameters - exclude api_key from generation params generation_params = { @@ -85,15 +87,15 @@ def get_cost_estimate(self, input_tokens: int, output_tokens: int) -> float: model_pricing = self.pricing.get( self.model_name, {"input": 0.01, "output": 0.03} ) - input_cost = (input_tokens / 1000) * model_pricing["input"] - output_cost = (output_tokens / 1000) * model_pricing["output"] + input_cost = (input_tokens / 1000) * float(model_pricing["input"]) + output_cost = (output_tokens / 1000) * float(model_pricing["output"]) return input_cost + output_cost class AnthropicAdapter(ModelAdapter): """Adapter for Anthropic Claude models.""" - def __init__(self, model_name: str, config: Optional[Dict[str, Any]] = None): + def __init__(self, model_name: str, config: Optional[ModelConfig] = None) -> None: super().__init__(model_name, config) try: import anthropic @@ -107,16 +109,16 @@ def __init__(self, model_name: str, config: Optional[Dict[str, Any]] = None): ) # Model pricing (per 1K tokens) - self.pricing = { + self.pricing: Dict[str, PricingInfo] = { "claude-3-opus": {"input": 0.015, "output": 0.075}, "claude-3-sonnet": {"input": 0.003, "output": 0.015}, "claude-3-haiku": {"input": 0.00025, "output": 0.00125}, "claude-3.5-sonnet": {"input": 0.003, "output": 0.015}, } - async def generate(self, prompts: List[str], **kwargs) -> List[str]: + async def generate(self, prompts: List[str], **kwargs: Any) -> List[str]: """Generate responses using Anthropic API.""" - responses = [] + responses: List[str] = [] # Default parameters - exclude api_key from generation params generation_params = { @@ -150,15 +152,15 @@ def get_cost_estimate(self, input_tokens: int, output_tokens: int) -> float: model_pricing = self.pricing.get( self.model_name, {"input": 0.003, "output": 0.015} ) - input_cost = (input_tokens / 1000) * model_pricing["input"] - output_cost = (output_tokens / 1000) * model_pricing["output"] + input_cost = (input_tokens / 1000) * float(model_pricing["input"]) + output_cost = (output_tokens / 1000) * float(model_pricing["output"]) return input_cost + output_cost class GoogleAdapter(ModelAdapter): """Adapter for Google Gemini models.""" - def __init__(self, model_name: str, config: Optional[Dict[str, Any]] = None): + def __init__(self, model_name: str, config: Optional[ModelConfig] = None) -> None: super().__init__(model_name, config) try: import google.generativeai as genai @@ -172,9 +174,9 @@ def __init__(self, model_name: str, config: Optional[Dict[str, Any]] = None): "Google Generative AI package not installed. Please install it with: pip install 'benchwise[llm-apis]' or pip install google-generativeai" ) - async def generate(self, prompts: List[str], **kwargs) -> List[str]: + async def generate(self, prompts: List[str], **kwargs: Any) -> List[str]: """Generate responses using Google Gemini API.""" - responses = [] + responses: List[str] = [] for prompt in prompts: try: @@ -206,7 +208,7 @@ def get_cost_estimate(self, input_tokens: int, output_tokens: int) -> float: class HuggingFaceAdapter(ModelAdapter): """Adapter for Hugging Face models.""" - def __init__(self, model_name: str, config: Optional[Dict[str, Any]] = None): + def __init__(self, model_name: str, config: Optional[ModelConfig] = None) -> None: super().__init__(model_name, config) try: from transformers import AutoTokenizer, AutoModelForCausalLM @@ -218,9 +220,9 @@ def __init__(self, model_name: str, config: Optional[Dict[str, Any]] = None): "Transformers package not installed. Please install it with: pip install 'benchwise[transformers]' or pip install transformers torch" ) - async def generate(self, prompts: List[str], **kwargs) -> List[str]: + async def generate(self, prompts: List[str], **kwargs: Any) -> List[str]: """Generate responses using Hugging Face models.""" - responses = [] + responses: List[str] = [] for prompt in prompts: try: @@ -251,10 +253,10 @@ def get_cost_estimate(self, input_tokens: int, output_tokens: int) -> float: class MockAdapter(ModelAdapter): """Mock adapter for testing without API dependencies.""" - def __init__(self, model_name: str, config: Optional[Dict[str, Any]] = None): + def __init__(self, model_name: str, config: Optional[ModelConfig] = None) -> None: super().__init__(model_name, config) - async def generate(self, prompts: List[str], **kwargs) -> List[str]: + async def generate(self, prompts: List[str], **kwargs: Any) -> List[str]: """Generate mock responses.""" return [ f"Mock response from {self.model_name} for: {prompt[:50]}..." @@ -271,7 +273,7 @@ def get_cost_estimate(self, input_tokens: int, output_tokens: int) -> float: def get_model_adapter( - model_name: str, config: Optional[Dict[str, Any]] = None + model_name: str, config: Optional[ModelConfig] = None ) -> ModelAdapter: """Factory function to get the appropriate model adapter.""" diff --git a/benchwise/py.typed b/benchwise/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/benchwise/results.py b/benchwise/results.py index 397a8b8..e0c7044 100644 --- a/benchwise/results.py +++ b/benchwise/results.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Union +from typing import Any, List, Optional, Union, cast from dataclasses import dataclass, field from datetime import datetime import json @@ -7,6 +7,19 @@ import numpy as np import hashlib +from .types import ( + DatasetInfo, + EvaluationMetadata, + EvaluationResultDict, + BenchmarkSummary, + BenchmarkResultDict, + ModelComparisonResult, + CrossBenchmarkComparison, + ModelPerformanceAnalysis, + CachedResultInfo, + BenchmarkComparisonInfo, +) + @dataclass class EvaluationResult: @@ -28,9 +41,11 @@ class EvaluationResult: test_name: str result: Any = None duration: float = 0.0 - dataset_info: Optional[Dict[str, Any]] = None + dataset_info: Optional[DatasetInfo] = None error: Optional[str] = None - metadata: Dict[str, Any] = field(default_factory=dict) + metadata: EvaluationMetadata = field( + default_factory=lambda: cast(EvaluationMetadata, {}) + ) timestamp: datetime = field(default_factory=datetime.now) @property @@ -43,21 +58,24 @@ def failed(self) -> bool: """Whether the evaluation failed.""" return self.error is not None - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> EvaluationResultDict: """Convert result to dictionary format.""" - return { - "model_name": self.model_name, - "test_name": self.test_name, - "result": self.result, - "duration": self.duration, - "dataset_info": self.dataset_info, - "error": self.error, - "metadata": self.metadata, - "timestamp": self.timestamp.isoformat(), - "success": self.success, - } + return cast( + EvaluationResultDict, + { + "model_name": self.model_name, + "test_name": self.test_name, + "result": self.result, + "duration": self.duration, + "dataset_info": self.dataset_info, + "error": self.error, + "metadata": self.metadata, + "timestamp": self.timestamp.isoformat(), + "success": self.success, + }, + ) - def get_score(self, metric_name: str = None) -> Union[float, Any]: + def get_score(self, metric_name: Optional[str] = None) -> Union[float, Any]: """ Extract a specific score from the result. @@ -90,10 +108,12 @@ class BenchmarkResult: benchmark_name: str results: List[EvaluationResult] = field(default_factory=list) - metadata: Dict[str, Any] = field(default_factory=dict) + metadata: EvaluationMetadata = field( + default_factory=lambda: cast(EvaluationMetadata, {}) + ) timestamp: datetime = field(default_factory=datetime.now) - def add_result(self, result: EvaluationResult): + def add_result(self, result: EvaluationResult) -> None: """Add an evaluation result to the benchmark.""" self.results.append(result) @@ -119,7 +139,9 @@ def success_rate(self) -> float: return 0.0 return len(self.successful_results) / len(self.results) - def get_best_model(self, metric_name: str = None) -> Optional[EvaluationResult]: + def get_best_model( + self, metric_name: Optional[str] = None + ) -> Optional[EvaluationResult]: """ Get the best performing model result. @@ -135,7 +157,9 @@ def get_best_model(self, metric_name: str = None) -> Optional[EvaluationResult]: return max(successful_results, key=lambda r: r.get_score(metric_name) or 0) - def get_worst_model(self, metric_name: str = None) -> Optional[EvaluationResult]: + def get_worst_model( + self, metric_name: Optional[str] = None + ) -> Optional[EvaluationResult]: """ Get the worst performing model result. @@ -153,7 +177,9 @@ def get_worst_model(self, metric_name: str = None) -> Optional[EvaluationResult] successful_results, key=lambda r: r.get_score(metric_name) or float("inf") ) - def compare_models(self, metric_name: str = None) -> Dict[str, Any]: + def compare_models( + self, metric_name: Optional[str] = None + ) -> ModelComparisonResult: """ Compare all models in the benchmark. @@ -165,7 +191,9 @@ def compare_models(self, metric_name: str = None) -> Dict[str, Any]: """ successful_results = self.successful_results if not successful_results: - return {"error": "No successful results to compare"} + return cast( + ModelComparisonResult, {"error": "No successful results to compare"} + ) scores = [result.get_score(metric_name) for result in successful_results] model_names = [result.model_name for result in successful_results] @@ -178,22 +206,25 @@ def compare_models(self, metric_name: str = None) -> Dict[str, Any]: ] if not valid_scores: - return {"error": "No valid scores found"} + return cast(ModelComparisonResult, {"error": "No valid scores found"}) sorted_results = sorted(valid_scores, key=lambda x: x[1], reverse=True) - return { - "ranking": [ - {"model": name, "score": score} for name, score in sorted_results - ], - "best_model": sorted_results[0][0], - "best_score": sorted_results[0][1], - "worst_model": sorted_results[-1][0], - "worst_score": sorted_results[-1][1], - "mean_score": np.mean([score for _, score in valid_scores]), - "std_score": np.std([score for _, score in valid_scores]), - "total_models": len(valid_scores), - } + return cast( + ModelComparisonResult, + { + "ranking": [ + {"model": name, "score": score} for name, score in sorted_results + ], + "best_model": sorted_results[0][0], + "best_score": sorted_results[0][1], + "worst_model": sorted_results[-1][0], + "worst_score": sorted_results[-1][1], + "mean_score": float(np.mean([score for _, score in valid_scores])), + "std_score": float(np.std([score for _, score in valid_scores])), + "total_models": len(valid_scores), + }, + ) def get_model_result(self, model_name: str) -> Optional[EvaluationResult]: """Get result for a specific model.""" @@ -202,20 +233,24 @@ def get_model_result(self, model_name: str) -> Optional[EvaluationResult]: return result return None - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> BenchmarkResultDict: """Convert benchmark result to dictionary format.""" - return { - "benchmark_name": self.benchmark_name, - "results": [result.to_dict() for result in self.results], - "metadata": self.metadata, - "timestamp": self.timestamp.isoformat(), - "summary": { - "total_models": len(self.results), - "successful_models": len(self.successful_results), - "failed_models": len(self.failed_results), - "success_rate": self.success_rate, - }, + summary: BenchmarkSummary = { + "total_models": len(self.results), + "successful_models": len(self.successful_results), + "failed_models": len(self.failed_results), + "success_rate": self.success_rate, } + return cast( + BenchmarkResultDict, + { + "benchmark_name": self.benchmark_name, + "results": [result.to_dict() for result in self.results], + "metadata": self.metadata, + "timestamp": self.timestamp.isoformat(), + "summary": summary, + }, + ) def to_dataframe(self) -> pd.DataFrame: """Convert results to pandas DataFrame for analysis.""" @@ -241,12 +276,12 @@ def to_dataframe(self) -> pd.DataFrame: return pd.DataFrame(data) - def save_to_json(self, file_path: Union[str, Path]): + def save_to_json(self, file_path: Union[str, Path]) -> None: """Save benchmark results to JSON file.""" with open(file_path, "w") as f: json.dump(self.to_dict(), f, indent=2, default=str) - def save_to_csv(self, file_path: Union[str, Path]): + def save_to_csv(self, file_path: Union[str, Path]) -> None: """Save benchmark results to CSV file.""" df = self.to_dataframe() df.to_csv(file_path, index=False) @@ -257,8 +292,8 @@ class ResultsAnalyzer: @staticmethod def compare_benchmarks( - benchmark_results: List[BenchmarkResult], metric_name: str = None - ) -> Dict[str, Any]: + benchmark_results: List[BenchmarkResult], metric_name: Optional[str] = None + ) -> CrossBenchmarkComparison: """ Compare results across multiple benchmarks. @@ -269,7 +304,12 @@ def compare_benchmarks( Returns: Dictionary with cross-benchmark comparison """ - comparison = {"benchmarks": [], "models": set(), "cross_benchmark_scores": {}} + comparison: CrossBenchmarkComparison = { + "benchmarks": [], + "models": [], + "cross_benchmark_scores": {}, + } + models_set: set[str] = set() for benchmark in benchmark_results: benchmark_info = { @@ -279,8 +319,10 @@ def compare_benchmarks( "success_rate": benchmark.success_rate, } - comparison["benchmarks"].append(benchmark_info) - comparison["models"].update(benchmark.model_names) + comparison["benchmarks"].append( + cast(BenchmarkComparisonInfo, benchmark_info) + ) + models_set.update(benchmark.model_names) # Collect scores for each model for result in benchmark.successful_results: @@ -294,14 +336,14 @@ def compare_benchmarks( benchmark.benchmark_name ] = score - comparison["models"] = list(comparison["models"]) + comparison["models"] = list(models_set) return comparison @staticmethod def analyze_model_performance( - results: List[EvaluationResult], metric_name: str = None - ) -> Dict[str, Any]: + results: List[EvaluationResult], metric_name: Optional[str] = None + ) -> ModelPerformanceAnalysis: """ Analyze performance of a single model across multiple evaluations. @@ -313,32 +355,38 @@ def analyze_model_performance( Dictionary with performance analysis """ if not results: - return {"error": "No results provided"} + return cast(ModelPerformanceAnalysis, {"error": "No results provided"}) model_name = results[0].model_name successful_results = [r for r in results if r.success] if not successful_results: - return {"error": "No successful results found"} + return cast( + ModelPerformanceAnalysis, {"error": "No successful results found"} + ) scores = [result.get_score(metric_name) for result in successful_results] valid_scores = [score for score in scores if score is not None] if not valid_scores: - return {"error": "No valid scores found"} - - return { - "model_name": model_name, - "total_evaluations": len(results), - "successful_evaluations": len(successful_results), - "success_rate": len(successful_results) / len(results), - "mean_score": np.mean(valid_scores), - "median_score": np.median(valid_scores), - "std_score": np.std(valid_scores), - "min_score": np.min(valid_scores), - "max_score": np.max(valid_scores), - "score_range": np.max(valid_scores) - np.min(valid_scores), - } + return cast(ModelPerformanceAnalysis, {"error": "No valid scores found"}) + + return cast( + ModelPerformanceAnalysis, + { + "model_name": model_name, + "total_evaluations": len(results), + "successful_evaluations": len(successful_results), + "failed_evaluations": len(results) - len(successful_results), + "success_rate": len(successful_results) / len(results), + "mean_score": float(np.mean(valid_scores)), + "median_score": float(np.median(valid_scores)), + "std_score": float(np.std(valid_scores)), + "min_score": float(np.min(valid_scores)), + "max_score": float(np.max(valid_scores)), + "scores": valid_scores, + }, + ) @staticmethod def generate_report( @@ -480,7 +528,7 @@ def _get_cache_key(self, model_name: str, test_name: str, dataset_hash: str) -> key_data = f"{model_name}_{test_name}_{dataset_hash}" return hashlib.md5(key_data.encode()).hexdigest() - def save_result(self, result: EvaluationResult, dataset_hash: str): + def save_result(self, result: EvaluationResult, dataset_hash: str) -> None: """Save evaluation result to cache.""" cache_key = self._get_cache_key( result.model_name, result.test_name, dataset_hash @@ -516,12 +564,12 @@ def load_result( except Exception: return None - def clear_cache(self): + def clear_cache(self) -> None: """Clear all cached results.""" for cache_file in self.cache_dir.glob("*.json"): cache_file.unlink() - def list_cached_results(self) -> List[Dict[str, Any]]: + def list_cached_results(self) -> List[CachedResultInfo]: """List all cached results.""" results = [] for cache_file in self.cache_dir.glob("*.json"): @@ -529,12 +577,15 @@ def list_cached_results(self) -> List[Dict[str, Any]]: with open(cache_file, "r") as f: data = json.load(f) results.append( - { - "file": cache_file.name, - "model_name": data.get("model_name"), - "test_name": data.get("test_name"), - "timestamp": data.get("timestamp"), - } + cast( + CachedResultInfo, + { + "file": cache_file.name, + "model_name": data.get("model_name"), + "test_name": data.get("test_name"), + "timestamp": data.get("timestamp"), + }, + ) ) except Exception: continue @@ -546,7 +597,7 @@ def list_cached_results(self) -> List[Dict[str, Any]]: def save_results( benchmark_result: BenchmarkResult, file_path: Union[str, Path], format: str = "json" -): +) -> None: """ Save benchmark results to file. diff --git a/benchwise/types.py b/benchwise/types.py new file mode 100644 index 0000000..fd24e5e --- /dev/null +++ b/benchwise/types.py @@ -0,0 +1,716 @@ +""" +Type definitions for BenchWise. + +This module contains TypedDict definitions, Protocols, Literal types, and type variables +used throughout the BenchWise codebase for improved type safety and IDE support. +""" + +from typing import ( + Any, + Dict, + List, + Literal, + Optional, + Protocol, + TypeVar, + ParamSpec, + Tuple, + TypedDict, +) + +# Type Variables +T = TypeVar("T") +R = TypeVar("R") +P = ParamSpec("P") +ModelT = TypeVar("ModelT") +DatasetT = TypeVar("DatasetT") + +# Literal Types +HttpMethod = Literal["GET", "POST", "PUT", "DELETE", "PATCH"] +ModelProvider = Literal["openai", "anthropic", "google", "huggingface", "custom"] +ExportFormat = Literal["json", "csv", "markdown"] + + +# Model Configuration Types +class ModelConfig(TypedDict, total=False): + """Configuration options for model adapters.""" + + api_key: str + temperature: float + max_tokens: int + top_p: float + frequency_penalty: float + presence_penalty: float + timeout: float + max_retries: int + + +class PricingInfo(TypedDict): + """Pricing information for a model.""" + + input: float # Cost per 1K input tokens + output: float # Cost per 1K output tokens + + +# Metric Return Types +class RougeScores(TypedDict, total=False): + """Return type for ROUGE metric scores.""" + + precision: float + recall: float + f1: float + rouge1_f1: float + rouge2_f1: float + rougeL_f1: float + std_precision: float + std_recall: float + std_f1: float + scores: Dict[str, List[float]] + # Optional confidence intervals + f1_confidence_interval: Tuple[float, float] + precision_confidence_interval: Tuple[float, float] + recall_confidence_interval: Tuple[float, float] + + +class BleuScores(TypedDict, total=False): + """Return type for BLEU metric scores.""" + + # Required fields + corpus_bleu: float + sentence_bleu: float + std_sentence_bleu: float + median_sentence_bleu: float + scores: List[float] + + # N-gram precision scores (dynamically added based on max_n) + bleu_1: float + bleu_1_std: float + bleu_2: float + bleu_2_std: float + bleu_3: float + bleu_3_std: float + bleu_4: float + bleu_4_std: float + + # Optional confidence interval + sentence_bleu_confidence_interval: Tuple[float, float] + + +class BertScoreResults(TypedDict, total=False): + """Return type for BERT-Score metric.""" + + # Main scores + precision: float + recall: float + f1: float + + # Standard deviations + std_precision: float + std_recall: float + std_f1: float + + # Additional statistics + min_f1: float + max_f1: float + median_f1: float + + # Metadata + model_used: str + + # Individual scores per sample + scores: Dict[str, List[float]] + + # Optional confidence intervals + f1_confidence_interval: Tuple[float, float] + precision_confidence_interval: Tuple[float, float] + recall_confidence_interval: Tuple[float, float] + + # Error field (when calculation fails) + error: str + + +class AccuracyResults(TypedDict, total=False): + """Return type for accuracy metric.""" + + # Main accuracy metrics + accuracy: float + exact_accuracy: float + fuzzy_accuracy: float + + # Counts + correct: int + correct_fuzzy: int + total: int + + # Statistical measures + mean_score: float + std_score: float + + # Individual scores and match information + individual_scores: List[float] + match_types: List[str] + + # Optional confidence interval + accuracy_confidence_interval: Tuple[float, float] + + +class SemanticSimilarityResults(TypedDict, total=False): + """Return type for semantic similarity metric.""" + + # Main similarity metrics + mean_similarity: float + median_similarity: float + std_similarity: float + min_similarity: float + max_similarity: float + + # Threshold-based metrics + similarity_above_threshold: float + + # Percentiles + percentile_25: float + percentile_75: float + percentile_90: float + + # Metadata + model_used: str + + # Individual scores + scores: List[float] + + # Optional confidence interval + similarity_confidence_interval: Tuple[float, float] + + +class PerplexityResults(TypedDict, total=False): + """Return type for perplexity metric.""" + + # Perplexity metrics + mean_perplexity: float + median_perplexity: float + + # Individual scores + scores: List[float] + + +class ComponentAnalysis(TypedDict, total=False): + """Component analysis for factual correctness.""" + + mean: float + std: float + scores: List[float] + + +class CoherenceResults(TypedDict, total=False): + """Return type for coherence score metric.""" + + # Main coherence metrics + mean_coherence: float + median_coherence: float + std_coherence: float + min_coherence: float + max_coherence: float + + # Individual scores + scores: List[float] + + # Optional detailed component analysis + components: Dict[str, ComponentAnalysis] + + # Optional confidence interval + coherence_confidence_interval: Tuple[float, float] + + +class SafetyCategoryScore(TypedDict, total=False): + """Per-category safety score analysis.""" + + mean: float + violation_rate: float + scores: List[float] + + +class SafetyResults(TypedDict, total=False): + """Return type for safety score metric.""" + + # Main safety metrics + mean_safety: float + median_safety: float + std_safety: float + min_safety: float + unsafe_count: int + + # Individual scores + scores: List[float] + + # Violation details per prediction + violation_details: List[List[str]] + + # Optional detailed category analysis + category_scores: Dict[str, SafetyCategoryScore] + + # Optional confidence interval + safety_confidence_interval: Tuple[float, float] + + +class DetailedFactualAnalysis(TypedDict, total=False): + """Detailed factual analysis for a single prediction-reference pair.""" + + entity_overlap: float + keyword_overlap: float + semantic_overlap: float + + +class FactualCorrectnessResults(TypedDict, total=False): + """Return type for factual correctness metric.""" + + # Main correctness metrics + mean_correctness: float + median_correctness: float + std_correctness: float + min_correctness: float + max_correctness: float + + # Individual scores + scores: List[float] + + # Optional detailed analysis + components: Dict[str, ComponentAnalysis] + detailed_results: List[DetailedFactualAnalysis] + + # Optional confidence interval + correctness_confidence_interval: Tuple[float, float] + + +# Dataset Types +class DatasetItem(TypedDict, total=False): + """A single item in a dataset.""" + + # Common field names + prompt: str + input: str + question: str + text: str + # Reference/target fields + reference: str + output: str + answer: str + target: str + summary: str + # Additional fields + id: str + metadata: "EvaluationMetadata" + + +class DatasetMetadata(TypedDict, total=False): + """Metadata for a dataset.""" + + name: str + description: str + source: str + version: str + size: int + created_at: str + tags: List[str] + + +class DatasetSchema(TypedDict, total=False): + """Schema definition for a dataset.""" + + prompt_field: str + reference_field: str + required: List[str] # Required fields in dataset items + required_fields: List[str] # Alias for backward compatibility + optional_fields: List[str] + + +class DatasetInfo(TypedDict, total=False): + """Information about a dataset used in evaluation.""" + + size: int + task: str + tags: List[str] + difficulty: Optional[str] + source: Optional[str] + name: Optional[str] + description: Optional[str] + version: Optional[str] + hash: Optional[str] + created_at: Optional[str] + + +class DatasetStatistics(TypedDict, total=False): + """Statistics about a dataset.""" + + size: int + fields: List[str] + metadata: Optional[DatasetMetadata] + + +class DatasetDict(TypedDict, total=False): + """Dictionary representation of a dataset.""" + + name: str + data: List[DatasetItem] + metadata: Optional[DatasetMetadata] + schema: Optional[DatasetSchema] + + +# Configuration Types +class ConfigDict(TypedDict, total=False): + """Configuration dictionary for BenchWise.""" + + api_url: str + api_key: Optional[str] + upload_enabled: bool + auto_sync: bool + cache_enabled: bool + cache_dir: str + timeout: float + max_retries: int + offline_mode: bool + debug: bool + verbose: bool + default_models: List[str] + default_metrics: List[str] + + +# Results Types +class EvaluationMetadata(TypedDict, total=False): + """Metadata for an evaluation result.""" + + temperature: float + max_tokens: int + model_version: str + dataset_hash: str + evaluation_id: Optional[int] + benchmark_id: Optional[int] + dataset: DatasetInfo # Dataset information for the evaluation + description: str # Description of the evaluation/benchmark + dataset_path: str # Path to the dataset file used in evaluation + models: List[str] # List of models evaluated + metrics: List[str] # List of metrics used in evaluation + # Allow additional metadata fields + # Note: This is intentionally flexible for user-defined metadata + + +class EvaluationResultDict(TypedDict, total=False): + """Serialized evaluation result.""" + + model_name: str + test_name: str + result: Any + duration: float + dataset_info: Optional[DatasetInfo] + error: Optional[str] + metadata: EvaluationMetadata + timestamp: str + success: bool + + +class BenchmarkSummary(TypedDict): + """Summary statistics for a benchmark.""" + + total_models: int + successful_models: int + failed_models: int + success_rate: float + + +class BenchmarkResultDict(TypedDict, total=False): + """Serialized benchmark result.""" + + benchmark_name: str + results: List[EvaluationResultDict] + metadata: EvaluationMetadata + timestamp: str + summary: BenchmarkSummary + + +class ModelRanking(TypedDict): + """Ranking entry for a model.""" + + model: str + score: float + + +class ModelComparisonResult(TypedDict, total=False): + """Result of model comparison.""" + + ranking: List[ModelRanking] + best_model: str + best_score: float + worst_model: str + worst_score: float + mean_score: float + std_score: float + total_models: int + error: Optional[str] + + +class ComparisonResult(TypedDict): + """Result of model comparison (legacy format).""" + + best_model: str + best_score: float + rankings: List[Tuple[str, float]] + scores: Dict[str, float] + + +# API Response Types +class TokenData(TypedDict, total=False): + """JWT token data from login.""" + + access_token: str + token_type: str + expires_in: Optional[int] + refresh_token: Optional[str] + + +class LoginResponse(TypedDict): + """Response from login endpoint.""" + + token: TokenData + user: "UserInfo" # Forward reference + + +class ModelInfo(TypedDict, total=False): + """Model information from API.""" + + id: int + name: str + provider: str + model_id: str # Provider-specific model identifier + description: Optional[str] + is_active: bool + pricing: Optional[PricingInfo] + metadata: Optional[EvaluationMetadata] + + +class BenchmarkRegistrationData(TypedDict, total=False): + """Data for registering a benchmark with the API.""" + + name: str + description: str + category: str + tags: List[str] + difficulty: Optional[str] + dataset_url: Optional[str] + config: Dict[str, Any] + metadata: DatasetInfo + is_public: bool + + +class BenchmarkInfo(TypedDict, total=False): + """Benchmark information from API.""" + + id: int + name: str + description: Optional[str] + category: Optional[str] + tags: List[str] + difficulty: Optional[str] + dataset_url: Optional[str] + config: Dict[str, Any] # API config can be arbitrary + metadata: Optional[DatasetInfo] + is_public: bool + created_at: Optional[str] + + +class EvaluationInfo(TypedDict, total=False): + """Evaluation information from API.""" + + id: int + benchmark_id: int + model_id: int + test_name: str + status: str + results: Optional[Dict[str, Any]] # Results can be arbitrary + metadata: Optional[EvaluationMetadata] + created_at: Optional[str] + + +class UserInfo(TypedDict, total=False): + """User information from API.""" + + id: int + username: str + email: str + full_name: Optional[str] + is_active: bool + + +class UploadBenchmarkResponse(TypedDict): + """Response from upload benchmark result endpoint.""" + + id: int + benchmark_id: int + model_ids: List[int] + results_count: int + message: str + + +class FileUploadResponse(TypedDict, total=False): + """Response from file upload endpoint.""" + + file_info: Dict[str, str] # Contains url and other file metadata + + +# Protocols +class SupportsGenerate(Protocol): + """Protocol for objects that support text generation.""" + + async def generate(self, prompts: List[str], **kwargs: Any) -> List[str]: + """Generate text completions for the given prompts.""" + ... + + def get_token_count(self, text: str) -> int: + """Get the token count for the given text.""" + ... + + def get_cost_estimate(self, input_tokens: int, output_tokens: int) -> float: + """Estimate the cost for the given token counts.""" + ... + + +class SupportsCache(Protocol): + """Protocol for objects that support caching.""" + + def save(self, key: str, value: Any) -> None: + """Save a value to the cache.""" + ... + + def load(self, key: str) -> Optional[Any]: + """Load a value from the cache.""" + ... + + def exists(self, key: str) -> bool: + """Check if a key exists in the cache.""" + ... + + +class SupportsMetrics(Protocol): + """Protocol for objects that support metric evaluation.""" + + def evaluate( + self, predictions: List[str], references: List[str], **kwargs: Any + ) -> Dict[str, float]: + """Evaluate predictions against references.""" + ... + + +class BenchmarkMetadataDict(TypedDict, total=False): + """Metadata attached to benchmark functions.""" + + name: str + description: str + + +class CallableWithBenchmarkMetadata(Protocol): + """Protocol for callables that may have benchmark metadata attached.""" + + _benchmark_metadata: Dict[str, Any] + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + """Call the function.""" + ... + + +class ConfigureArgs(Protocol): + """Arguments for configuring Benchwise.""" + + reset: bool + show: bool + api_url: str | None + api_key: str | None + upload: str | None + + +class SyncArgs(Protocol): + """Arguments for sync command.""" + + dry_run: bool + + +class StatusArgs(Protocol): + """Arguments for status command.""" + + api: bool + auth: bool + + +class ConfigKwargs(TypedDict, total=False): + """Kwargs for configure_benchwise function.""" + + api_url: str + api_key: str + upload_enabled: bool + + +class OfflineQueueItem(TypedDict): + """Item in offline queue.""" + + data: Dict[str, Any] # Can contain different operation types + timestamp: str + + +class RunnerConfig(TypedDict, total=False): + """Configuration for EvaluationRunner.""" + + cache_enabled: bool + upload_enabled: bool + timeout: float + max_retries: int + debug: bool + verbose: bool + + +class CacheEntry(TypedDict, total=False): + """Entry in results cache.""" + + result: EvaluationResultDict + dataset_hash: str + timestamp: str + + +class CachedResultInfo(TypedDict, total=False): + """Information about a cached result.""" + + file: str + model_name: Optional[str] + test_name: Optional[str] + timestamp: Optional[str] + dataset_hash: Optional[str] + + +class BenchmarkComparisonInfo(TypedDict, total=False): + """Information about a benchmark in cross-benchmark comparison.""" + + name: str + timestamp: str + models: List[str] + success_rate: float + + +class CrossBenchmarkComparison(TypedDict, total=False): + """Result of comparing multiple benchmarks.""" + + benchmarks: List[BenchmarkComparisonInfo] + models: List[str] + cross_benchmark_scores: Dict[str, Dict[str, Optional[float]]] + + +class ModelPerformanceAnalysis(TypedDict, total=False): + """Performance analysis for a single model.""" + + model_name: str + total_evaluations: int + successful_evaluations: int + failed_evaluations: int + success_rate: float + mean_score: float + std_score: float + min_score: float + max_score: float + median_score: float + scores: List[float] + error: Optional[str] diff --git a/demo.py b/demo.py index f7072e6..39342c4 100644 --- a/demo.py +++ b/demo.py @@ -1,5 +1,11 @@ import asyncio -from benchwise import evaluate, benchmark, create_qa_dataset, accuracy, semantic_similarity +from benchwise import ( + evaluate, + benchmark, + create_qa_dataset, + accuracy, + semantic_similarity, +) # Create your dataset qa_dataset = create_qa_dataset( @@ -8,18 +14,19 @@ "Who wrote '1984'?", "What is the speed of light?", "Explain photosynthesis in one sentence.", - "What causes rainbows?" + "What causes rainbows?", ], answers=[ "Tokyo", "George Orwell", "299,792,458 meters per second", "Photosynthesis is the process by which plants convert sunlight into energy.", - "Rainbows are caused by light refraction and reflection in water droplets." + "Rainbows are caused by light refraction and reflection in water droplets.", ], - name="general_knowledge_qa" + name="general_knowledge_qa", ) + @benchmark("General Knowledge QA", "Tests basic factual knowledge") @evaluate("gpt-3.5-turbo", "gemini-2.5-flash-lite") async def test_general_knowledge(model, dataset): @@ -31,9 +38,10 @@ async def test_general_knowledge(model, dataset): return { "accuracy": acc["accuracy"], "semantic_similarity": similarity["mean_similarity"], - "total_questions": len(responses) + "total_questions": len(responses), } + # Run the evaluation async def main(): results = await test_general_knowledge(qa_dataset) @@ -47,4 +55,5 @@ async def main(): else: print(f"{result.model_name}: FAILED - {result.error}") -asyncio.run(main()) \ No newline at end of file + +asyncio.run(main()) diff --git a/docs/test_load_dataset.py b/docs/test_load_dataset.py index b0ef5da..6e7b5f5 100644 --- a/docs/test_load_dataset.py +++ b/docs/test_load_dataset.py @@ -5,6 +5,7 @@ # Assuming data.json is in the same directory as this script for testing purposes data_file_path = "data.json" + def test_load_dataset_from_json(): # Load the dataset dataset = load_dataset(data_file_path) @@ -17,6 +18,7 @@ def test_load_dataset_from_json(): print("Successfully loaded dataset and assertions passed!") + if __name__ == "__main__": # Create a dummy data.json file for testing if it doesn't exist if not os.path.exists(data_file_path): diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..f311680 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,73 @@ +[mypy] +python_version = 3.12 +files = benchwise + +# Strict type checking +strict = True +disallow_untyped_defs = True +disallow_any_generics = True +disallow_subclassing_any = True +disallow_untyped_calls = True +disallow_incomplete_defs = True +check_untyped_defs = True +disallow_untyped_decorators = True +warn_redundant_casts = True +warn_unused_ignores = True +warn_return_any = True +warn_unreachable = True +no_implicit_optional = True +strict_optional = True +strict_equality = True + +# Show error codes for easier suppression +show_error_codes = True + +# Third-party library ignores (no stubs available) +[mypy-rouge_score.*] +ignore_missing_imports = True + +[mypy-bert_score.*] +ignore_missing_imports = True + +[mypy-nltk.*] +ignore_missing_imports = True + +[mypy-transformers.*] +ignore_missing_imports = True +follow_imports = skip + +[mypy-torch.*] +ignore_missing_imports = True + +[mypy-sentence_transformers.*] +ignore_missing_imports = True + +[mypy-sklearn.*] +ignore_missing_imports = True + +[mypy-httpx.*] +ignore_missing_imports = True + +# Note: pandas and requests have type stubs installed (pandas-stubs, types-requests) + +[mypy-openai.*] +ignore_missing_imports = True + +[mypy-anthropic.*] +ignore_missing_imports = True + +[mypy-google.generativeai.*] +ignore_missing_imports = True +follow_imports = skip + +[mypy-fuzzywuzzy.*] +ignore_missing_imports = True + +[mypy-sacrebleu.*] +ignore_missing_imports = True + +[mypy-spacy.*] +ignore_missing_imports = True + +[mypy-google.*] +ignore_missing_imports = True diff --git a/mypy_baseline.txt b/mypy_baseline.txt new file mode 100644 index 0000000..9d9e1f7 Binary files /dev/null and b/mypy_baseline.txt differ diff --git a/pyproject.toml b/pyproject.toml index 0076abb..c280da3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,6 +70,8 @@ lint = [ "ruff>=0.1.6", "pre-commit>=3.0.0", "mypy>=1.0.0", + "pandas-stubs>=2.0.0", + "types-requests>=2.28.0", ] dev = [ diff --git a/tests/test_config.py b/tests/test_config.py index a11b874..2cf169d 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -106,9 +106,10 @@ def test_load_from_json_file(self): try: # Mock the config file paths - with patch.object(Path, "exists", return_value=True), patch( - "builtins.open", create=True - ) as mock_open: + with ( + patch.object(Path, "exists", return_value=True), + patch("builtins.open", create=True) as mock_open, + ): import json mock_open.return_value.__enter__.return_value.read.return_value = ( diff --git a/tests/test_docs_examples.py b/tests/test_docs_examples.py index d998cf2..4747fec 100644 --- a/tests/test_docs_examples.py +++ b/tests/test_docs_examples.py @@ -17,16 +17,16 @@ def extract_code_blocks_from_md(markdown_file: Path) -> List[tuple]: Extract all Python code blocks from a markdown file. Returns list of (code, block_number, line_number) tuples. """ - with open(markdown_file, 'r', encoding='utf-8') as f: + with open(markdown_file, "r", encoding="utf-8") as f: content = f.read() - pattern = r'```python\n(.*?)```' + pattern = r"```python\n(.*?)```" matches = re.finditer(pattern, content, re.DOTALL) code_blocks = [] for i, match in enumerate(matches, 1): code = match.group(1) - line_number = content[:match.start()].count('\n') + 1 + line_number = content[: match.start()].count("\n") + 1 code_blocks.append((code, i, line_number)) return code_blocks @@ -34,12 +34,12 @@ def extract_code_blocks_from_md(markdown_file: Path) -> List[tuple]: def get_doc_files() -> List[Path]: """Get all markdown documentation files with code examples.""" - docs_dir = Path(__file__).parent.parent / 'docs' / 'docs' / 'examples' + docs_dir = Path(__file__).parent.parent / "docs" / "docs" / "examples" if not docs_dir.exists(): return [] - return sorted(docs_dir.glob('*.md')) + return sorted(docs_dir.glob("*.md")) def prepare_code_for_testing(code: str) -> str: @@ -67,11 +67,11 @@ def prepare_code_for_testing(code: str) -> str: if 'load_dataset("data/' in modified_code: modified_code = modified_code.replace( 'load_dataset("data/qa_1000.json")', - 'create_qa_dataset(questions=["Q1?"], answers=["A1"], name="test")' + 'create_qa_dataset(questions=["Q1?"], answers=["A1"], name="test")', ) modified_code = modified_code.replace( 'load_dataset("data/news_articles.json")', - 'create_summarization_dataset(documents=["Doc1"], summaries=["Sum1"], name="news")' + 'create_summarization_dataset(documents=["Doc1"], summaries=["Sum1"], name="news")', ) return modified_code @@ -87,8 +87,11 @@ def prepare_code_for_testing(code: str) -> str: test_params.append((doc_file.name, block_num, line_num, code)) -@pytest.mark.parametrize("filename,block_num,line_num,code", test_params, - ids=[f"{f}:block_{b}:L{l}" for f, b, l, _ in test_params]) +@pytest.mark.parametrize( + "filename,block_num,line_num,code", + test_params, + ids=[f"{f}:block_{b}:L{line}" for f, b, line, _ in test_params], +) def test_documentation_code_syntax(filename, block_num, line_num, code): """ Test that all code examples in documentation have valid Python syntax. @@ -107,8 +110,11 @@ def test_documentation_code_syntax(filename, block_num, line_num, code): @pytest.mark.slow @pytest.mark.mock -@pytest.mark.parametrize("filename,block_num,line_num,code", test_params, - ids=[f"{f}:block_{b}:L{l}" for f, b, l, _ in test_params]) +@pytest.mark.parametrize( + "filename,block_num,line_num,code", + test_params, + ids=[f"{f}:block_{b}:L{line}" for f, b, line, _ in test_params], +) def test_documentation_code_execution(filename, block_num, line_num, code): """ Test that code examples can be executed without errors (using mock models). @@ -117,11 +123,13 @@ def test_documentation_code_execution(filename, block_num, line_num, code): and will be skipped. """ # Skip examples that are just function definitions without execution - if '@evaluate(' in code and 'asyncio.run' not in code: + if "@evaluate(" in code and "asyncio.run" not in code: pytest.skip("Incomplete example (defines functions only)") # Skip examples that require external data files - if 'load_dataset("data/' in code and 'create_' not in prepare_code_for_testing(code): + if 'load_dataset("data/' in code and "create_" not in prepare_code_for_testing( + code + ): pytest.skip("Requires external data files") # Prepare code with mock models @@ -129,7 +137,7 @@ def test_documentation_code_execution(filename, block_num, line_num, code): # Execute the code try: - exec_globals = {'__name__': '__main__'} + exec_globals = {"__name__": "__main__"} exec(prepared_code, exec_globals) except Exception as e: pytest.fail( @@ -141,11 +149,11 @@ def test_documentation_code_execution(filename, block_num, line_num, code): @pytest.mark.smoke def test_documentation_examples_exist(): """Verify that documentation example files exist and contain code blocks.""" - docs_dir = Path(__file__).parent.parent / 'docs' / 'docs' / 'examples' + docs_dir = Path(__file__).parent.parent / "docs" / "docs" / "examples" assert docs_dir.exists(), f"Documentation examples directory not found: {docs_dir}" - doc_files = list(docs_dir.glob('*.md')) + doc_files = list(docs_dir.glob("*.md")) assert len(doc_files) > 0, "No documentation markdown files found" total_blocks = 0 @@ -154,9 +162,11 @@ def test_documentation_examples_exist(): total_blocks += len(blocks) assert total_blocks > 0, "No Python code blocks found in documentation" - print(f"\nFound {len(doc_files)} documentation files with {total_blocks} code blocks") + print( + f"\nFound {len(doc_files)} documentation files with {total_blocks} code blocks" + ) -if __name__ == '__main__': +if __name__ == "__main__": # Run just the smoke test - pytest.main([__file__, '-k', 'test_documentation_examples_exist', '-v']) + pytest.main([__file__, "-k", "test_documentation_examples_exist", "-v"]) diff --git a/tests/test_integration.py b/tests/test_integration.py index f88e8be..a64f87b 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -109,8 +109,9 @@ def test_model_factory_integration(self): assert adapter.__class__.__name__ == expected_type assert adapter.model_name == model_name - with patch("transformers.AutoTokenizer"), patch( - "transformers.AutoModelForCausalLM" + with ( + patch("transformers.AutoTokenizer"), + patch("transformers.AutoModelForCausalLM"), ): adapter = get_model_adapter("test/unknown-model") assert adapter.__class__.__name__ == "HuggingFaceAdapter" diff --git a/tests/test_memory_large_datasets.py b/tests/test_memory_large_datasets.py index 2a896fb..b5142bd 100644 --- a/tests/test_memory_large_datasets.py +++ b/tests/test_memory_large_datasets.py @@ -29,9 +29,9 @@ async def test_large_dataset_memory_usage(self): current_memory = self.get_memory_usage() memory_increase = current_memory - initial_memory - assert ( - memory_increase < 100 - ), f"Memory usage too high: {memory_increase}MB for {size} items" + assert memory_increase < 100, ( + f"Memory usage too high: {memory_increase}MB for {size} items" + ) sampled = dataset.sample(100) filtered = dataset.filter(lambda x: len(x["question"]) > 10) @@ -57,9 +57,9 @@ async def memory_test_evaluation(model, dataset): generation_memory = after_generation - before_generation # Memory increase should be reasonable - assert ( - generation_memory < 50 - ), f"Generation used too much memory: {generation_memory}MB" + assert generation_memory < 50, ( + f"Generation used too much memory: {generation_memory}MB" + ) return {"response_count": len(responses), "memory_used": generation_memory} @@ -69,9 +69,9 @@ async def memory_test_evaluation(model, dataset): total_memory_increase = final_memory - initial_memory # Total memory increase should be reasonable - assert ( - total_memory_increase < 100 - ), f"Total memory increase too high: {total_memory_increase}MB" + assert total_memory_increase < 100, ( + f"Total memory increase too high: {total_memory_increase}MB" + ) assert len(results) == 1 assert results[0].success @@ -100,9 +100,9 @@ async def test_dataset_chunking_memory_efficiency(self): # Memory shouldn't grow significantly per chunk current_memory = self.get_memory_usage() memory_per_chunk = (current_memory - initial_memory) / processed_chunks - assert ( - memory_per_chunk < 10 - ), f"Memory per chunk too high: {memory_per_chunk}MB" + assert memory_per_chunk < 10, ( + f"Memory per chunk too high: {memory_per_chunk}MB" + ) del chunk_dataset, chunk_data, prompts gc.collect() @@ -135,9 +135,9 @@ def dataset_generator(size): memory_used = current_memory - initial_memory max_memory_used = max(max_memory_used, memory_used) - assert ( - memory_used < 50 - ), f"Streaming memory too high: {memory_used}MB at {processed_items} items" + assert memory_used < 50, ( + f"Streaming memory too high: {memory_used}MB at {processed_items} items" + ) assert processed_items == 5000 assert max_memory_used < 50, f"Max memory usage too high: {max_memory_used}MB" @@ -164,9 +164,9 @@ async def cleanup_test(model, dataset): # Memory should return close to baseline current_memory = self.get_memory_usage() memory_diff = current_memory - baseline_memory - assert ( - memory_diff < 30 - ), f"Memory not cleaned up properly: {memory_diff}MB after iteration {i}" + assert memory_diff < 30, ( + f"Memory not cleaned up properly: {memory_diff}MB after iteration {i}" + ) async def test_large_dataset_file_operations(self, tmp_path): initial_memory = self.get_memory_usage() @@ -184,9 +184,9 @@ async def test_large_dataset_file_operations(self, tmp_path): # Memory shouldn't increase significantly during file operations after_save_memory = self.get_memory_usage() save_memory_increase = after_save_memory - initial_memory - assert ( - save_memory_increase < 100 - ), f"Save operation used too much memory: {save_memory_increase}MB" + assert save_memory_increase < 100, ( + f"Save operation used too much memory: {save_memory_increase}MB" + ) # Test loading from file del large_dataset @@ -198,9 +198,9 @@ async def test_large_dataset_file_operations(self, tmp_path): # Memory after loading should be reasonable after_load_memory = self.get_memory_usage() load_memory_increase = after_load_memory - initial_memory - assert ( - load_memory_increase < 150 - ), f"Load operation used too much memory: {load_memory_increase}MB" + assert load_memory_increase < 150, ( + f"Load operation used too much memory: {load_memory_increase}MB" + ) # Verify file sizes are reasonable json_size = json_file.stat().st_size / 1024 / 1024 # MB diff --git a/tests/test_models.py b/tests/test_models.py index 53fd01c..bdaf260 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -59,9 +59,10 @@ def test_mock_cost_estimate(self): class TestGetModelAdapter: def test_get_gpt_adapter(self): - adapter = get_model_adapter("gpt-3.5-turbo") - assert isinstance(adapter, OpenAIAdapter) - assert adapter.model_name == "gpt-3.5-turbo" + with patch("openai.AsyncOpenAI"): + adapter = get_model_adapter("gpt-3.5-turbo") + assert isinstance(adapter, OpenAIAdapter) + assert adapter.model_name == "gpt-3.5-turbo" def test_get_claude_adapter(self): adapter = get_model_adapter("claude-3-haiku") @@ -80,8 +81,9 @@ def test_get_mock_adapter(self): def test_get_huggingface_adapter_default(self): # Use a mock model name that won't trigger real HuggingFace download - with patch("transformers.AutoTokenizer"), patch( - "transformers.AutoModelForCausalLM" + with ( + patch("transformers.AutoTokenizer"), + patch("transformers.AutoModelForCausalLM"), ): adapter = get_model_adapter("test/unknown-model-name") assert isinstance(adapter, HuggingFaceAdapter) @@ -197,15 +199,17 @@ def test_google_import_error(self): class TestHuggingFaceAdapter: def test_huggingface_adapter_creation(self): - with patch("transformers.AutoTokenizer"), patch( - "transformers.AutoModelForCausalLM" + with ( + patch("transformers.AutoTokenizer"), + patch("transformers.AutoModelForCausalLM"), ): adapter = HuggingFaceAdapter("gpt2") assert adapter.model_name == "gpt2" def test_huggingface_cost_estimate(self): - with patch("transformers.AutoTokenizer"), patch( - "transformers.AutoModelForCausalLM" + with ( + patch("transformers.AutoTokenizer"), + patch("transformers.AutoModelForCausalLM"), ): adapter = HuggingFaceAdapter("gpt2") cost = adapter.get_cost_estimate(1000, 500) @@ -239,10 +243,11 @@ class TestModelNaming: def test_gpt_variants(self): models = ["gpt-3.5-turbo", "gpt-4", "gpt-4o"] - for model in models: - adapter = get_model_adapter(model) - assert isinstance(adapter, OpenAIAdapter) - assert adapter.model_name == model + with patch("openai.AsyncOpenAI"): + for model in models: + adapter = get_model_adapter(model) + assert isinstance(adapter, OpenAIAdapter) + assert adapter.model_name == model def test_claude_variants(self): models = ["claude-3-opus", "claude-3-sonnet", "claude-3-haiku"] diff --git a/tests/test_results.py b/tests/test_results.py index 8a63d84..6646f62 100644 --- a/tests/test_results.py +++ b/tests/test_results.py @@ -184,9 +184,9 @@ def test_cache_save_and_load(self, temp_cache_dir, sample_evaluation_result): cache_files = list(Path(temp_cache_dir).glob("*.json")) assert len(cache_files) > 0, f"No cache files created in {temp_cache_dir}" - assert ( - loaded is not None - ), f"Failed to load cached result. Cache files: {cache_files}" + assert loaded is not None, ( + f"Failed to load cached result. Cache files: {cache_files}" + ) assert loaded.model_name == sample_evaluation_result.model_name assert loaded.test_name == sample_evaluation_result.test_name