Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ jobs:

- name: Run linting
run: |
uv run black --check src/
uv run isort --check-only src/
uv run ruff check src/

build:
runs-on: ubuntu-latest
Expand Down Expand Up @@ -79,4 +78,5 @@ jobs:
path: dist/

- name: Publish to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
uses: pypa/gh-action-pypi-publish@release/v1

2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,5 @@ wheels/
.venv

.DS_Store

uv.lock
12 changes: 4 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ dependencies = [
"numpy>=1.20.0",
"pydantic>=2.0.0",
"tqdm>=4.0.0",
"pytest-asyncio>=0.24.0",
"ruff>=0.14.0",
]

[pyproject.optional-dependencies]
Expand Down Expand Up @@ -65,8 +67,6 @@ all = [
[dependency-groups]
dev = [
"pytest>=7.0",
"black>=22.0",
"isort>=5.0",
"mypy>=1.0",
"pytest-cov>=4.0",
"sphinx>=7.1.2",
Expand All @@ -92,13 +92,9 @@ path = "src/chatan/__init__.py"
[tool.hatch.build.targets.wheel]
packages = ["src/chatan"]

[tool.black]
[tool.ruff]
line-length = 88
target-version = ['py38']

[tool.isort]
profile = "black"
multi_line_output = 3
target-version = "py38"

[tool.mypy]
python_version = "3.8"
Expand Down
12 changes: 10 additions & 2 deletions src/chatan/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,16 @@

from .dataset import dataset
from .evaluate import eval, evaluate
from .generator import generator
from .generator import generator, async_generator
from .sampler import sample
from .viewer import generate_with_viewer

__all__ = ["dataset", "generator", "sample", "generate_with_viewer", "evaluate", "eval"]
__all__ = [
"dataset",
"generator",
"async_generator",
"sample",
"generate_with_viewer",
"evaluate",
"eval",
]
2 changes: 1 addition & 1 deletion src/chatan/dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Dataset creation and manipulation."""

from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union

import pandas as pd
from datasets import Dataset as HFDataset
Expand Down
198 changes: 194 additions & 4 deletions src/chatan/generator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""LLM generators with CPU fallback and aggressive memory management."""
"""LLM generators with CPU fallback, async support, and aggressive memory management."""

import asyncio
import gc
import os
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, Iterable, Optional

import anthropic
import openai
Expand All @@ -26,6 +26,15 @@ def generate(self, prompt: str, **kwargs) -> str:
pass


class AsyncBaseGenerator(ABC):
"""Base class for async LLM generators."""

@abstractmethod
async def generate(self, prompt: str, **kwargs) -> str:
"""Asynchronously generate content from a prompt."""
pass


class OpenAIGenerator(BaseGenerator):
"""OpenAI GPT generator."""

Expand All @@ -46,6 +55,33 @@ def generate(self, prompt: str, **kwargs) -> str:
return response.choices[0].message.content.strip()


class AsyncOpenAIGenerator(AsyncBaseGenerator):
"""Async OpenAI GPT generator."""

def __init__(self, api_key: str, model: str = "gpt-3.5-turbo", **kwargs):
async_client_cls = getattr(openai, "AsyncOpenAI", None)
if async_client_cls is None:
raise ImportError(
"Async OpenAI client is not available. Upgrade the `openai` package "
"to a version that provides `AsyncOpenAI`."
)

self.client = async_client_cls(api_key=api_key)
self.model = model
self.default_kwargs = kwargs

async def generate(self, prompt: str, **kwargs) -> str:
"""Generate content using OpenAI API asynchronously."""
merged_kwargs = {**self.default_kwargs, **kwargs}

response = await self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
**merged_kwargs,
)
return response.choices[0].message.content.strip()


class AnthropicGenerator(BaseGenerator):
"""Anthropic Claude generator."""

Expand All @@ -67,6 +103,39 @@ def generate(self, prompt: str, **kwargs) -> str:
return response.content[0].text.strip()


class AsyncAnthropicGenerator(AsyncBaseGenerator):
"""Async Anthropic Claude generator."""

def __init__(
self,
api_key: str,
model: str = "claude-3-sonnet-20240229",
**kwargs,
):
async_client_cls = getattr(anthropic, "AsyncAnthropic", None)
if async_client_cls is None:
raise ImportError(
"Async Anthropic client is not available. Upgrade the `anthropic` package "
"to a version that provides `AsyncAnthropic`."
)

self.client = async_client_cls(api_key=api_key)
self.model = model
self.default_kwargs = kwargs

async def generate(self, prompt: str, **kwargs) -> str:
"""Generate content using Anthropic API asynchronously."""
merged_kwargs = {**self.default_kwargs, **kwargs}

response = await self.client.messages.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
max_tokens=merged_kwargs.pop("max_tokens", 1000),
**merged_kwargs,
)
return response.content[0].text.strip()


class TransformersGenerator(BaseGenerator):
"""Local HuggingFace/transformers generator with aggressive memory management."""

Expand Down Expand Up @@ -174,7 +243,8 @@ def __del__(self):
if hasattr(self, "tokenizer") and self.tokenizer is not None:
del self.tokenizer
self._clear_cache()
except:
except Exception as e:
print(f"Cleanup failed with exception: {e}")
pass


Expand Down Expand Up @@ -202,6 +272,89 @@ def __call__(self, context: Dict[str, Any]) -> str:
return result.strip() if isinstance(result, str) else result


class AsyncGeneratorFunction:
"""Callable async generator function with high concurrency support."""

def __init__(
self,
generator: AsyncBaseGenerator,
prompt_template: str,
variables: Optional[Dict[str, Any]] = None,
):
self.generator = generator
self.prompt_template = prompt_template
self.variables = variables or {}

async def __call__(self, context: Dict[str, Any], **kwargs) -> str:
"""Generate content with context substitution asynchronously."""
merged = dict(context)
for key, value in self.variables.items():
merged[key] = value(context) if callable(value) else value

prompt = self.prompt_template.format(**merged)
result = await self.generator.generate(prompt, **kwargs)
return result.strip() if isinstance(result, str) else result

async def stream(
self,
contexts: Iterable[Dict[str, Any]],
*,
concurrency: int = 5,
return_exceptions: bool = False,
**kwargs,
):
"""Asynchronously yield results for many contexts with bounded concurrency."""

if concurrency < 1:
raise ValueError("concurrency must be at least 1")

contexts_list = list(contexts)
if not contexts_list:
return

semaphore = asyncio.Semaphore(concurrency)

async def worker(index: int, ctx: Dict[str, Any]):
async with semaphore:
try:
result = await self(ctx, **kwargs)
return index, result, None
except Exception as exc: # pragma: no cover - exercised via return_exceptions
return index, None, exc

tasks = [
asyncio.create_task(worker(index, ctx))
for index, ctx in enumerate(contexts_list)
]

next_index = 0
buffer: Dict[int, Any] = {}

try:
for coro in asyncio.as_completed(tasks):
index, value, error = await coro

if error is not None and not return_exceptions:
for task in tasks:
if not task.done():
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise error

buffer[index] = error if error is not None else value

while next_index in buffer:
item = buffer.pop(next_index)
next_index += 1
yield item

finally:
for task in tasks:
if not task.done():
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)


class GeneratorClient:
"""Main interface for creating generators."""

Expand Down Expand Up @@ -251,6 +404,34 @@ def __call__(self, prompt_template: str, **variables) -> GeneratorFunction:
return GeneratorFunction(self._generator, prompt_template, variables)


class AsyncGeneratorClient:
"""Async interface for creating generators with concurrent execution."""

def __init__(self, provider: str, api_key: Optional[str] = None, **kwargs):
provider_lower = provider.lower()
try:
if provider_lower == "openai":
if api_key is None:
raise ValueError("API key is required for OpenAI")
self._generator = AsyncOpenAIGenerator(api_key, **kwargs)
elif provider_lower == "anthropic":
if api_key is None:
raise ValueError("API key is required for Anthropic")
self._generator = AsyncAnthropicGenerator(api_key, **kwargs)
else:
raise ValueError(f"Unsupported provider for async generator: {provider}")

except Exception as e:
raise ValueError(
f"Failed to initialize async generator for provider '{provider}'. "
f"Check your configuration and try again. Original error: {str(e)}"
) from e

def __call__(self, prompt_template: str, **variables) -> AsyncGeneratorFunction:
"""Create an async generator function."""
return AsyncGeneratorFunction(self._generator, prompt_template, variables)


# Factory function
def generator(
provider: str = "openai", api_key: Optional[str] = None, **kwargs
Expand All @@ -259,3 +440,12 @@ def generator(
if provider.lower() in {"openai", "anthropic"} and api_key is None:
raise ValueError("API key is required")
return GeneratorClient(provider, api_key, **kwargs)


def async_generator(
provider: str = "openai", api_key: Optional[str] = None, **kwargs
) -> AsyncGeneratorClient:
"""Create an async generator client."""
if provider.lower() in {"openai", "anthropic"} and api_key is None:
raise ValueError("API key is required")
return AsyncGeneratorClient(provider, api_key, **kwargs)
2 changes: 1 addition & 1 deletion src/chatan/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import random
import uuid
from datetime import datetime, timedelta
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union

import pandas as pd
from datasets import Dataset as HFDataset
Expand Down
12 changes: 8 additions & 4 deletions src/chatan/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def add_row(self, row: Dict[str, Any]):
try:
with open(self.data_file, "r") as f:
data = json.load(f)
except:
except Exception as e:
print(f"Add row exception: {e}")
data = {"rows": [], "completed": False, "current_row": None}

data["rows"].append(row)
Expand All @@ -79,7 +80,8 @@ def start_row(self, row_index: int):
try:
with open(self.data_file, "r") as f:
data = json.load(f)
except:
except Exception as e:
print(f"Start row exception: {e}")
data = {"rows": [], "completed": False, "current_row": None}

data["current_row"] = {"index": row_index, "cells": {}}
Expand All @@ -95,7 +97,8 @@ def update_cell(self, column: str, value: Any):
try:
with open(self.data_file, "r") as f:
data = json.load(f)
except:
except Exception as e:
print(f"Update cell exception: {e}")
data = {"rows": [], "completed": False, "current_row": None}

if data.get("current_row"):
Expand All @@ -112,7 +115,8 @@ def complete(self):
try:
with open(self.data_file, "r") as f:
data = json.load(f)
except:
except Exception as e:
print(f"Complete exception: {e}")
data = {"rows": [], "completed": False, "current_row": None}

data["completed"] = True
Expand Down
2 changes: 1 addition & 1 deletion tests/test_dataset_comprehensive.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pandas as pd
import tempfile
import os
from unittest.mock import Mock, patch
from unittest.mock import Mock
from datasets import Dataset as HFDataset

from chatan.dataset import Dataset, dataset
Expand Down
Loading