From 4a6184e78f5336fe7aa4967f26cf4f36ffd8823e Mon Sep 17 00:00:00 2001 From: AlexCuadron Date: Sat, 22 Feb 2025 19:44:56 +0000 Subject: [PATCH 1/4] docs: Add comprehensive client guide and update server documentation - Add detailed client guide with examples in multiple languages - Add performance considerations and best practices - Add detailed configuration documentation - Add error handling documentation --- sparse_server/README.md | 68 ++++++++- sparse_server/docs/client_guide.md | 216 +++++++++++++++++++++++++++++ 2 files changed, 282 insertions(+), 2 deletions(-) create mode 100644 sparse_server/docs/client_guide.md diff --git a/sparse_server/README.md b/sparse_server/README.md index be1475b..ddbb079 100644 --- a/sparse_server/README.md +++ b/sparse_server/README.md @@ -68,6 +68,8 @@ curl http://localhost:52309/v1/chat/completions \ ## Configuration +### Environment Variables + Server settings can be configured through environment variables or a `.env` file: ```env @@ -76,8 +78,70 @@ API_VERSION="1.0.0" HOST="0.0.0.0" PORT=52309 DEFAULT_MODEL="meta-llama/Llama-2-7b-chat-hf" -HEAVY_CONST=128 -GROUP_FACTOR=4 +HEAVY_CONST=128 # Sparse attention parameter +GROUP_FACTOR=4 # Sparse attention parameter +``` + +### Sparse Attention Parameters + +The server uses DoubleSparse's efficient attention mechanism with two key parameters: + +1. `HEAVY_CONST` (Token Sparsity): + - Controls how many tokens are kept for attention computation + - Higher values = more tokens = more accuracy but slower + - Lower values = fewer tokens = faster but potentially less accurate + - Default: 128 + +2. `GROUP_FACTOR` (Channel Sparsity): + - Controls channel grouping for attention computation + - Higher values = more sparsity = faster but potentially less accurate + - Lower values = less sparsity = more accurate but slower + - Default: 4 + +## Client Integration + +For detailed instructions on how to use the API from different programming languages and frameworks, see our [Client Guide](docs/client_guide.md). The guide includes: + +- Python examples using OpenAI's client library +- Python examples using standard libraries +- JavaScript/TypeScript examples +- cURL examples +- Error handling guidelines +- Best practices + +## Performance Considerations + +1. **Memory Usage**: + - The server uses DoubleSparse's efficient attention mechanism + - Memory usage scales with `HEAVY_CONST` and `GROUP_FACTOR` + - Monitor GPU memory usage to optimize these parameters + +2. **Throughput**: + - Higher `HEAVY_CONST` = lower throughput + - Higher `GROUP_FACTOR` = higher throughput + - Find the right balance for your use case + +3. **Latency**: + - Use streaming for better perceived latency + - First token latency depends on prompt length + - Subsequent tokens benefit from sparse attention + +## Error Handling + +The server implements comprehensive error handling: +- Invalid requests return 400 with details +- Server errors return 500 with stack traces +- Model loading issues return 503 + +Error responses follow this format: +```json +{ + "error": { + "message": "Error description", + "type": "error_type", + "code": 400 + } +} ``` ## License diff --git a/sparse_server/docs/client_guide.md b/sparse_server/docs/client_guide.md new file mode 100644 index 0000000..096c85c --- /dev/null +++ b/sparse_server/docs/client_guide.md @@ -0,0 +1,216 @@ +# DoubleSparse API Client Guide + +This guide explains how to interact with the DoubleSparse API from different clients and programming languages. + +## API Endpoints + +The API follows OpenAI's format and provides these main endpoints: +- `GET /v1/models` - List available models +- `POST /v1/chat/completions` - Create chat completions + +## Python Examples + +### Using OpenAI's Client Library + +```python +from openai import OpenAI + +# Initialize client +client = OpenAI( + base_url="http://localhost:52309/v1", # Your server URL + api_key="not-needed" # API key isn't used but required by the client +) + +# Simple completion +response = client.chat.completions.create( + model="meta-llama/Llama-2-7b-chat-hf", + messages=[ + {"role": "user", "content": "Hello! How are you?"} + ], + temperature=0.7, + max_tokens=100 +) +print(response.choices[0].message.content) + +# Streaming completion +stream = client.chat.completions.create( + model="meta-llama/Llama-2-7b-chat-hf", + messages=[ + {"role": "user", "content": "Write a story about a robot."} + ], + temperature=0.7, + max_tokens=100, + stream=True +) +for chunk in stream: + if chunk.choices[0].delta.content is not None: + print(chunk.choices[0].delta.content, end="") +``` + +### Using Standard Python Libraries + +```python +import requests +import json +import sseclient + +def generate_text(prompt, stream=False): + url = "http://localhost:52309/v1/chat/completions" + headers = {"Content-Type": "application/json"} + data = { + "model": "meta-llama/Llama-2-7b-chat-hf", + "messages": [{"role": "user", "content": prompt}], + "temperature": 0.7, + "max_tokens": 100, + "stream": stream + } + + if not stream: + response = requests.post(url, headers=headers, json=data) + return response.json()["choices"][0]["message"]["content"] + else: + response = requests.post(url, headers=headers, json=data, stream=True) + client = sseclient.SSEClient(response) + for event in client.events(): + if event.data != "[DONE]": + chunk = json.loads(event.data) + if chunk["choices"][0]["delta"].get("content"): + yield chunk["choices"][0]["delta"]["content"] + +# Regular completion +text = generate_text("Hello! How are you?") +print(text) + +# Streaming completion +for chunk in generate_text("Write a story about a robot.", stream=True): + print(chunk, end="") +``` + +## cURL Examples + +### List Available Models +```bash +curl http://localhost:52309/v1/models +``` + +### Generate Completion +```bash +curl http://localhost:52309/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "meta-llama/Llama-2-7b-chat-hf", + "messages": [{"role": "user", "content": "Hello! How are you?"}], + "temperature": 0.7, + "max_tokens": 100 + }' +``` + +### Stream Completion +```bash +curl http://localhost:52309/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "meta-llama/Llama-2-7b-chat-hf", + "messages": [{"role": "user", "content": "Write a story about a robot."}], + "temperature": 0.7, + "max_tokens": 100, + "stream": true + }' +``` + +## JavaScript/TypeScript Example + +```typescript +async function generateText(prompt: string, stream = false) { + const response = await fetch('http://localhost:52309/v1/chat/completions', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + model: 'meta-llama/Llama-2-7b-chat-hf', + messages: [{role: 'user', content: prompt}], + temperature: 0.7, + max_tokens: 100, + stream: stream + }) + }); + + if (!stream) { + const data = await response.json(); + return data.choices[0].message.content; + } else { + const reader = response.body?.getReader(); + const decoder = new TextDecoder(); + + while (reader) { + const {value, done} = await reader.read(); + if (done) break; + + const chunk = decoder.decode(value); + const lines = chunk.split('\n'); + + for (const line of lines) { + if (line.startsWith('data: ')) { + const data = line.slice(6); + if (data === '[DONE]') break; + + try { + const parsed = JSON.parse(data); + const content = parsed.choices[0].delta.content; + if (content) { + yield content; + } + } catch (e) { + console.error('Error parsing chunk:', e); + } + } + } + } + } +} + +// Regular completion +const text = await generateText("Hello! How are you?"); +console.log(text); + +// Streaming completion +for await (const chunk of generateText("Write a story about a robot.", true)) { + process.stdout.write(chunk); +} +``` + +## Error Handling + +The API returns standard HTTP status codes: +- 200: Success +- 400: Bad Request (invalid parameters) +- 500: Server Error +- 503: Service Unavailable (model not loaded) + +Error responses follow this format: +```json +{ + "error": { + "message": "Error description", + "type": "error_type", + "code": 400 + } +} +``` + +## Configuration Parameters + +You can configure these parameters in your requests: +- `temperature` (0-1): Controls randomness. Higher values make output more random. +- `max_tokens` (int): Maximum number of tokens to generate. +- `top_p` (0-1): Nucleus sampling parameter. +- `stream` (boolean): Whether to stream the response. + +## Best Practices + +1. **Error Handling**: Always implement proper error handling in your client code. +2. **Streaming**: Use streaming for long responses to get faster initial output. +3. **Connection Management**: Implement retry logic for failed requests. +4. **Resource Management**: Close connections and streams properly. +5. **Rate Limiting**: Consider implementing rate limiting in your client code. \ No newline at end of file From 2c4362a2711376fa149b7296ea7db9436ba0b5d9 Mon Sep 17 00:00:00 2001 From: AlexCuadron Date: Sat, 22 Feb 2025 19:48:35 +0000 Subject: [PATCH 2/4] feat: Properly implement sparse attention in API server - Replace basic model with proper sparse attention implementation - Add support for different model architectures (LLaMA, Mistral) - Add channel configuration support - Update API to be fully OpenAI-compatible - Add proper token streaming implementation - Add configuration for sparse attention parameters --- sparse_server/app/config.py | 8 +- sparse_server/app/main.py | 289 +++++++++++++++++--------- sparse_server/models/model_manager.py | 165 +++++++++++---- 3 files changed, 329 insertions(+), 133 deletions(-) diff --git a/sparse_server/app/config.py b/sparse_server/app/config.py index 19175f6..ee9ec51 100644 --- a/sparse_server/app/config.py +++ b/sparse_server/app/config.py @@ -20,8 +20,12 @@ class Settings(BaseSettings): # Model Settings DEFAULT_MODEL: str = "meta-llama/Llama-2-7b-chat-hf" - HEAVY_CONST: int = 128 # Sparse attention parameter - GROUP_FACTOR: int = 4 # Sparse attention parameter + MODEL_ARCHITECTURE: str = "llama" # Model architecture (llama, mistral) + + # Sparse Attention Settings + HEAVY_CONST: int = 128 # Number of tokens to keep for attention + GROUP_FACTOR: int = 4 # Channel grouping factor + CHANNEL: str = "qk" # Channel selection (q, k, qk) # Generation Defaults DEFAULT_MAX_TOKENS: int = 100 diff --git a/sparse_server/app/main.py b/sparse_server/app/main.py index 1b89476..9a81683 100644 --- a/sparse_server/app/main.py +++ b/sparse_server/app/main.py @@ -3,6 +3,8 @@ """ import os import sys +import time +import uuid import logging from typing import Optional, List, Dict, Any from pydantic import BaseModel, Field @@ -12,11 +14,8 @@ from fastapi.responses import StreamingResponse, JSONResponse import json -# Add DoubleSparse to path -sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) - -from models.model import Model -from transformers import AutoTokenizer +from .config import settings +from ..models.model_manager import ModelManager # Configure logging logging.basicConfig( @@ -26,22 +25,57 @@ logger = logging.getLogger(__name__) # API Models -class GenerateRequest(BaseModel): - prompt: str = Field(..., description="The input prompt for generation") - max_new_tokens: Optional[int] = Field(100, description="Maximum number of tokens to generate") +class ChatMessage(BaseModel): + role: str = Field(..., description="The role of the message sender") + content: str = Field(..., description="The content of the message") + +class ChatCompletionRequest(BaseModel): + model: str = Field(..., description="Model to use for completion") + messages: List[ChatMessage] = Field(..., description="Messages to generate completions for") temperature: Optional[float] = Field(0.7, description="Sampling temperature") top_p: Optional[float] = Field(0.95, description="Top-p sampling parameter") + max_tokens: Optional[int] = Field(100, description="Maximum number of tokens to generate") stream: Optional[bool] = Field(False, description="Whether to stream the response") -class GenerateResponse(BaseModel): - text: str = Field(..., description="Generated text") - usage: Dict[str, int] = Field(..., description="Token usage statistics") +class CompletionUsage(BaseModel): + prompt_tokens: int = Field(..., description="Number of tokens in the prompt") + completion_tokens: int = Field(..., description="Number of tokens in the completion") + total_tokens: int = Field(..., description="Total number of tokens used") + +class ChatCompletionResponseChoice(BaseModel): + index: int = Field(..., description="Index of the choice") + message: ChatMessage = Field(..., description="The generated message") + finish_reason: Optional[str] = Field(None, description="Reason for finishing") + +class ChatCompletionResponse(BaseModel): + id: str = Field(..., description="Unique identifier for the completion") + object: str = Field("chat.completion", description="Object type") + created: int = Field(..., description="Unix timestamp of creation") + model: str = Field(..., description="Model used for completion") + choices: List[ChatCompletionResponseChoice] = Field(..., description="Generated completions") + usage: CompletionUsage = Field(..., description="Token usage statistics") + +class DeltaMessage(BaseModel): + role: Optional[str] = Field(None, description="Role of the delta message") + content: Optional[str] = Field(None, description="Content of the delta message") + +class ChatCompletionStreamChoice(BaseModel): + index: int = Field(..., description="Index of the choice") + delta: DeltaMessage = Field(..., description="Delta message content") + finish_reason: Optional[str] = Field(None, description="Reason for finishing") + +class ChatCompletionStreamResponse(BaseModel): + id: str = Field(..., description="Unique identifier for the completion") + object: str = Field("chat.completion.chunk", description="Object type") + created: int = Field(..., description="Unix timestamp of creation") + model: str = Field(..., description="Model used for completion") + choices: List[ChatCompletionStreamChoice] = Field(..., description="Generated completion chunks") # Initialize FastAPI app app = FastAPI( - title="DoubleSparse API", - description="REST API for LLM inference using DoubleSparse with sparse attention", - version="1.0.0" + title=settings.API_TITLE, + description=settings.API_DESCRIPTION, + version=settings.API_VERSION ) # Add CORS middleware @@ -53,104 +87,169 @@ class GenerateResponse(BaseModel): allow_headers=["*"], ) -# Global variables for model and tokenizer -model = None -tokenizer = None +# Global model manager +model_manager = None @app.on_event("startup") async def startup_event(): - """Initialize model and tokenizer on startup.""" - global model, tokenizer - + """Initialize model manager on startup.""" + global model_manager try: - logger.info("Loading model and tokenizer...") - model_name = "meta-llama/Llama-2-7b-chat-hf" # Can be made configurable - - # Initialize model with sparse attention - model = Model.from_pretrained( - model_name, - heavy_const=128, # Sparse attention params - group_factor=4, - device="cuda" if torch.cuda.is_available() else "cpu" + model_manager = ModelManager( + model_name=settings.DEFAULT_MODEL, + heavy_const=settings.HEAVY_CONST, + group_factor=settings.GROUP_FACTOR, + channel=settings.CHANNEL, + architecture=settings.MODEL_ARCHITECTURE ) - tokenizer = AutoTokenizer.from_pretrained(model_name) - - logger.info("Model and tokenizer loaded successfully") + model_manager.load_model() + logger.info(f"Model manager initialized with sparse attention (heavy_const={settings.HEAVY_CONST}, group_factor={settings.GROUP_FACTOR})") except Exception as e: - logger.error(f"Error loading model: {str(e)}") - raise RuntimeError(f"Failed to load model: {str(e)}") + logger.error(f"Error initializing model manager: {str(e)}") + raise RuntimeError(f"Failed to initialize model manager: {str(e)}") @app.get("/") async def root(): - """Root endpoint returning API status.""" - return {"status": "ok", "model_loaded": model is not None} - -@app.post("/generate", response_model=GenerateResponse) -async def generate(request: GenerateRequest): - """Generate text from a prompt.""" - if model is None or tokenizer is None: - raise HTTPException(status_code=503, detail="Model not loaded") - + """Root endpoint with API information.""" + return { + "status": "ok", + "version": settings.API_VERSION, + "model": settings.DEFAULT_MODEL, + "sparse_attention": { + "heavy_const": settings.HEAVY_CONST, + "group_factor": settings.GROUP_FACTOR, + "channel": settings.CHANNEL + }, + "model_loaded": model_manager is not None + } + +@app.get("/v1/models") +async def list_models(): + """List available models.""" + return { + "object": "list", + "data": [{ + "id": settings.DEFAULT_MODEL, + "object": "model", + "created": int(time.time()), + "owned_by": "organization" + }] + } + +@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) +async def create_chat_completion(request: ChatCompletionRequest): + """Create a chat completion.""" + if model_manager is None: + raise HTTPException(status_code=503, detail="Model not initialized") + try: - # Tokenize input - input_ids = tokenizer(request.prompt, return_tensors="pt").input_ids - if torch.cuda.is_available(): - input_ids = input_ids.cuda() - - # Generate - outputs = model.generate( - input_ids, - max_new_tokens=request.max_new_tokens, + if request.stream: + return StreamingResponse( + stream_chat_completion(request), + media_type='text/event-stream' + ) + + # Generate completion + completion_text, usage = model_manager.generate( + messages=[msg.dict() for msg in request.messages], + max_tokens=request.max_tokens, temperature=request.temperature, top_p=request.top_p ) - - # Decode output - generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) - - # Calculate usage - usage = { - "prompt_tokens": len(input_ids[0]), - "completion_tokens": len(outputs[0]) - len(input_ids[0]), - "total_tokens": len(outputs[0]) - } - - return GenerateResponse(text=generated_text, usage=usage) - + + response = ChatCompletionResponse( + id=f"chatcmpl-{str(uuid.uuid4())}", + created=int(time.time()), + model=request.model, + choices=[ + ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content=completion_text), + finish_reason="stop" + ) + ], + usage=CompletionUsage(**usage) + ) + + return response + except Exception as e: - logger.error(f"Error during generation: {str(e)}") + logger.error(f"Error during chat completion: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) -@app.post("/generate/stream") -async def generate_stream(request: GenerateRequest): - """Stream generated text token by token.""" - if model is None or tokenizer is None: - raise HTTPException(status_code=503, detail="Model not loaded") - - async def stream_response(): - try: - # Tokenize input - input_ids = tokenizer(request.prompt, return_tensors="pt").input_ids - if torch.cuda.is_available(): - input_ids = input_ids.cuda() - - # Stream generation - for token in model.generate_stream( - input_ids, - max_new_tokens=request.max_new_tokens, - temperature=request.temperature, - top_p=request.top_p - ): - text = tokenizer.decode(token, skip_special_tokens=True) - yield f"data: {json.dumps({'text': text})}\n\n" - - yield "data: [DONE]\n\n" - - except Exception as e: - logger.error(f"Error during streaming: {str(e)}") - yield f"data: {json.dumps({'error': str(e)})}\n\n" +async def stream_chat_completion(request: ChatCompletionRequest): + """Stream chat completion chunks.""" + completion_id = f"chatcmpl-{str(uuid.uuid4())}" - return StreamingResponse(stream_response(), media_type="text/event-stream") + try: + # Start with role + chunk = ChatCompletionStreamResponse( + id=completion_id, + created=int(time.time()), + model=request.model, + choices=[ + ChatCompletionStreamChoice( + index=0, + delta=DeltaMessage(role="assistant"), + finish_reason=None + ) + ] + ) + yield f"data: {chunk.json()}\n\n" + + # Stream the content + for text_chunk in model_manager.generate_stream( + messages=[msg.dict() for msg in request.messages], + max_tokens=request.max_tokens, + temperature=request.temperature, + top_p=request.top_p + ): + chunk = ChatCompletionStreamResponse( + id=completion_id, + created=int(time.time()), + model=request.model, + choices=[ + ChatCompletionStreamChoice( + index=0, + delta=DeltaMessage(content=text_chunk), + finish_reason=None + ) + ] + ) + yield f"data: {chunk.json()}\n\n" + + # Send the final chunk + chunk = ChatCompletionStreamResponse( + id=completion_id, + created=int(time.time()), + model=request.model, + choices=[ + ChatCompletionStreamChoice( + index=0, + delta=DeltaMessage(), + finish_reason="stop" + ) + ] + ) + yield f"data: {chunk.json()}\n\n" + yield "data: [DONE]\n\n" + + except Exception as e: + logger.error(f"Error during streaming: {str(e)}") + error_chunk = ChatCompletionStreamResponse( + id=completion_id, + created=int(time.time()), + model=request.model, + choices=[ + ChatCompletionStreamChoice( + index=0, + delta=DeltaMessage(), + finish_reason="error" + ) + ] + ) + yield f"data: {error_chunk.json()}\n\n" + yield "data: [DONE]\n\n" if __name__ == "__main__": import uvicorn diff --git a/sparse_server/models/model_manager.py b/sparse_server/models/model_manager.py index 363950b..10401ff 100644 --- a/sparse_server/models/model_manager.py +++ b/sparse_server/models/model_manager.py @@ -3,14 +3,16 @@ """ import os import sys +import json import logging from typing import List, Optional, Iterator, Dict import torch -from transformers import AutoTokenizer +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig # Add DoubleSparse to path sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) -from models.model import Model +from evaluation.modify_llama import convert_kvcache_llama_heavy_recent, convert_llama_channel_config +from evaluation.modify_mistral import convert_kvcache_mistral_heavy_recent, convert_mistral_channel_config logger = logging.getLogger(__name__) @@ -20,12 +22,16 @@ def __init__( model_name: str = "meta-llama/Llama-2-7b-chat-hf", heavy_const: int = 128, group_factor: int = 4, - device: str = "cuda" if torch.cuda.is_available() else "cpu" + device: str = "cuda" if torch.cuda.is_available() else "cpu", + channel: str = "qk", + architecture: str = "llama" ): self.model_name = model_name self.device = device self.heavy_const = heavy_const self.group_factor = group_factor + self.channel = channel + self.architecture = architecture self.model = None self.tokenizer = None @@ -35,16 +41,62 @@ def load_model(self) -> None: try: logger.info(f"Loading model {self.model_name}...") - # Initialize model with sparse attention - self.model = Model.from_pretrained( - self.model_name, - heavy_const=self.heavy_const, - group_factor=self.group_factor, - device=self.device - ) + # Load model and config + kwargs = {"torch_dtype": torch.float16, "device_map": "auto"} + self.model = AutoModelForCausalLM.from_pretrained(self.model_name, **kwargs) self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) + config = AutoConfig.from_pretrained(self.model_name) + + # Load channel config + channel_path = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(__file__))), + "config", + f"{self.model_name.split('/')[-1]}.json" + ) + + if not os.path.exists(channel_path): + logger.warning(f"Channel config not found at {channel_path}, using default config") + channel_config = {} + else: + with open(channel_path, "r") as f: + channel_config = json.load(f) - logger.info("Model loaded successfully") + # Convert model to use sparse attention + if self.architecture == "llama": + logger.info("Converting model to use LLaMA sparse attention...") + self.model = convert_kvcache_llama_heavy_recent( + self.model, + config, + self.heavy_const, + self.group_factor + ) + if channel_config: + self.model = convert_llama_channel_config( + self.model, + channel_config, + self.channel + ) + elif self.architecture == "mistral": + logger.info("Converting model to use Mistral sparse attention...") + self.model = convert_kvcache_mistral_heavy_recent( + self.model, + config, + self.heavy_const, + self.group_factor + ) + if channel_config: + self.model = convert_mistral_channel_config( + self.model, + channel_config, + self.channel + ) + else: + raise ValueError(f"Unsupported architecture: {self.architecture}") + + # Set model to evaluation mode + self.model.eval() + + logger.info(f"Model loaded successfully with sparse attention (heavy_const={self.heavy_const}, group_factor={self.group_factor})") except Exception as e: logger.error(f"Error loading model: {str(e)}") raise RuntimeError(f"Failed to load model: {str(e)}") @@ -79,17 +131,23 @@ def generate( prompt = self._prepare_prompt(messages) # Tokenize - input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids - if self.device == "cuda": - input_ids = input_ids.cuda() + inputs = self.tokenizer(prompt, return_tensors="pt", padding=True) + input_ids = inputs.input_ids.to(self.device) + attention_mask = inputs.attention_mask.to(self.device) if hasattr(inputs, 'attention_mask') else None # Generate - outputs = self.model.generate( - input_ids, - max_new_tokens=max_tokens, - temperature=temperature, - top_p=top_p - ) + with torch.no_grad(): + outputs = self.model.generate( + input_ids, + attention_mask=attention_mask, + max_new_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + do_sample=temperature > 0, + pad_token_id=self.tokenizer.pad_token_id, + bos_token_id=self.tokenizer.bos_token_id, + eos_token_id=self.tokenizer.eos_token_id + ) # Get only the new tokens new_tokens = outputs[0][len(input_ids[0]):] @@ -119,21 +177,56 @@ def generate_stream( prompt = self._prepare_prompt(messages) # Tokenize - input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids - if self.device == "cuda": - input_ids = input_ids.cuda() + inputs = self.tokenizer(prompt, return_tensors="pt", padding=True) + input_ids = inputs.input_ids.to(self.device) + attention_mask = inputs.attention_mask.to(self.device) if hasattr(inputs, 'attention_mask') else None # Stream generation - prev_text = "" - for token in self.model.generate_stream( - input_ids, - max_new_tokens=max_tokens, - temperature=temperature, - top_p=top_p - ): - current_text = self.tokenizer.decode(token, skip_special_tokens=True) - # Yield only the new text - new_text = current_text[len(prev_text):] - if new_text: - yield new_text - prev_text = current_text \ No newline at end of file + generated_tokens = [] + past_key_values = None + + with torch.no_grad(): + for _ in range(max_tokens): + outputs = self.model( + input_ids if past_key_values is None else input_ids[:, -1:], + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=True + ) + + next_token_logits = outputs.logits[:, -1, :] + if temperature > 0: + # Apply temperature and top_p sampling + probs = torch.nn.functional.softmax(next_token_logits / temperature, dim=-1) + if top_p < 1.0: + sorted_probs, sorted_indices = torch.sort(probs, descending=True) + cumsum_probs = torch.cumsum(sorted_probs, dim=-1) + sorted_indices_to_remove = cumsum_probs > top_p + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + probs.masked_fill_(indices_to_remove, 0.0) + next_token = torch.multinomial(probs, num_samples=1) + else: + next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) + + generated_tokens.append(next_token[0].item()) + + # Update inputs and attention mask for next iteration + input_ids = torch.cat([input_ids, next_token], dim=-1) + if attention_mask is not None: + attention_mask = torch.cat([ + attention_mask, + attention_mask.new_ones((attention_mask.shape[0], 1)) + ], dim=-1) + + # Update past key values + past_key_values = outputs.past_key_values + + # Decode and yield new token + current_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) + yield current_text + + # Check for EOS token + if next_token[0].item() == self.tokenizer.eos_token_id: + break \ No newline at end of file From 1a615ae344f29902265d9e9a54502c9d5c6a892e Mon Sep 17 00:00:00 2001 From: AlexCuadron Date: Sat, 22 Feb 2025 19:53:28 +0000 Subject: [PATCH 3/4] feat: Add Qwen2 architecture support - Add Qwen2 sparse attention implementation - Update configuration to support Qwen2 - Update documentation with Qwen2 support - Improve architecture selection documentation --- sparse_server/README.md | 2 ++ sparse_server/app/config.py | 2 +- sparse_server/models/model_manager.py | 17 ++++++++++++++++- 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/sparse_server/README.md b/sparse_server/README.md index ddbb079..0da9d05 100644 --- a/sparse_server/README.md +++ b/sparse_server/README.md @@ -78,8 +78,10 @@ API_VERSION="1.0.0" HOST="0.0.0.0" PORT=52309 DEFAULT_MODEL="meta-llama/Llama-2-7b-chat-hf" +MODEL_ARCHITECTURE="llama" # llama, mistral, or qwen2 HEAVY_CONST=128 # Sparse attention parameter GROUP_FACTOR=4 # Sparse attention parameter +CHANNEL="qk" # Channel selection (q, k, qk) ``` ### Sparse Attention Parameters diff --git a/sparse_server/app/config.py b/sparse_server/app/config.py index ee9ec51..f3c57dd 100644 --- a/sparse_server/app/config.py +++ b/sparse_server/app/config.py @@ -20,7 +20,7 @@ class Settings(BaseSettings): # Model Settings DEFAULT_MODEL: str = "meta-llama/Llama-2-7b-chat-hf" - MODEL_ARCHITECTURE: str = "llama" # Model architecture (llama, mistral) + MODEL_ARCHITECTURE: str = "llama" # Model architecture (llama, mistral, qwen2) # Sparse Attention Settings HEAVY_CONST: int = 128 # Number of tokens to keep for attention diff --git a/sparse_server/models/model_manager.py b/sparse_server/models/model_manager.py index 10401ff..c89008c 100644 --- a/sparse_server/models/model_manager.py +++ b/sparse_server/models/model_manager.py @@ -13,6 +13,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) from evaluation.modify_llama import convert_kvcache_llama_heavy_recent, convert_llama_channel_config from evaluation.modify_mistral import convert_kvcache_mistral_heavy_recent, convert_mistral_channel_config +from evaluation.modify_qwen2 import convert_kvcache_qwen2_heavy_recent, convert_qwen2_channel_config logger = logging.getLogger(__name__) @@ -90,8 +91,22 @@ def load_model(self) -> None: channel_config, self.channel ) + elif self.architecture == "qwen2": + logger.info("Converting model to use Qwen2 sparse attention...") + self.model = convert_kvcache_qwen2_heavy_recent( + self.model, + config, + self.heavy_const, + self.group_factor + ) + if channel_config: + self.model = convert_qwen2_channel_config( + self.model, + channel_config, + self.channel + ) else: - raise ValueError(f"Unsupported architecture: {self.architecture}") + raise ValueError(f"Unsupported architecture: {self.architecture}. Supported: llama, mistral, qwen2") # Set model to evaluation mode self.model.eval() From 0d1db06734ac070464c0eee1ca34a75c1eede074 Mon Sep 17 00:00:00 2001 From: AlexCuadron Date: Sat, 22 Feb 2025 20:20:29 +0000 Subject: [PATCH 4/4] refactor: Simplify server to match perplexity_eval.py - Create single server script that matches perplexity_eval.py usage - Automatic architecture detection from model config - Same command-line arguments as perplexity_eval.py - Remove need for manual architecture configuration --- sparse_server.py | 379 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 379 insertions(+) create mode 100644 sparse_server.py diff --git a/sparse_server.py b/sparse_server.py new file mode 100644 index 0000000..57ccd18 --- /dev/null +++ b/sparse_server.py @@ -0,0 +1,379 @@ +""" +OpenAI-compatible server for DoubleSparse with sparse attention. +Uses the same architecture and parameters as perplexity_eval.py. +""" +import os +import sys +import json +import time +import uuid +import logging +import argparse +from typing import List, Dict, Optional +import torch +from fastapi import FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse +from pydantic import BaseModel, Field +from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +# API Models +class ChatMessage(BaseModel): + role: str = Field(..., description="The role of the message sender") + content: str = Field(..., description="The content of the message") + +class ChatCompletionRequest(BaseModel): + model: str = Field(..., description="Model to use for completion") + messages: List[ChatMessage] = Field(..., description="Messages to generate completions for") + temperature: Optional[float] = Field(0.7, description="Sampling temperature") + top_p: Optional[float] = Field(0.95, description="Top-p sampling parameter") + max_tokens: Optional[int] = Field(100, description="Maximum number of tokens to generate") + stream: Optional[bool] = Field(False, description="Whether to stream the response") + +class CompletionUsage(BaseModel): + prompt_tokens: int = Field(..., description="Number of tokens in the prompt") + completion_tokens: int = Field(..., description="Number of tokens in the completion") + total_tokens: int = Field(..., description="Total number of tokens used") + +class ChatCompletionResponseChoice(BaseModel): + index: int = Field(..., description="Index of the choice") + message: ChatMessage = Field(..., description="The generated message") + finish_reason: Optional[str] = Field(None, description="Reason for finishing") + +class ChatCompletionResponse(BaseModel): + id: str = Field(..., description="Unique identifier for the completion") + object: str = Field("chat.completion", description="Object type") + created: int = Field(..., description="Unix timestamp of creation") + model: str = Field(..., description="Model used for completion") + choices: List[ChatCompletionResponseChoice] = Field(..., description="Generated completions") + usage: CompletionUsage = Field(..., description="Token usage statistics") + +class DeltaMessage(BaseModel): + role: Optional[str] = Field(None, description="Role of the delta message") + content: Optional[str] = Field(None, description="Content of the delta message") + +class ChatCompletionStreamChoice(BaseModel): + index: int = Field(..., description="Index of the choice") + delta: DeltaMessage = Field(..., description="Delta message content") + finish_reason: Optional[str] = Field(None, description="Reason for finishing") + +class ChatCompletionStreamResponse(BaseModel): + id: str = Field(..., description="Unique identifier for the completion") + object: str = Field("chat.completion.chunk", description="Object type") + created: int = Field(..., description="Unix timestamp of creation") + model: str = Field(..., description="Model used for completion") + choices: List[ChatCompletionStreamChoice] = Field(..., description="Generated completion chunks") + +def create_app(model_path: str, heavy_const: int, group_factor: int, channel: str = "qk", offloading: bool = False): + """Create FastAPI app with the specified model and parameters.""" + # Initialize FastAPI app + app = FastAPI( + title="DoubleSparse API", + description="OpenAI-compatible REST API for LLM inference using DoubleSparse with sparse attention", + version="1.0.0" + ) + + # Add CORS middleware + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + # Load model and convert to sparse attention + kwargs = {"torch_dtype": torch.float16, "device_map": "auto"} + model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs) + tokenizer = AutoTokenizer.from_pretrained(model_path) + config = AutoConfig.from_pretrained(model_path) + + # Load channel config if available + channel_path = os.path.join("config", model_path.split('/')[-1] + ".json") + channel_config = None + if os.path.exists(channel_path): + with open(channel_path, "r") as f: + channel_config = json.load(f) + + # Detect architecture from config + if "Llama" in config.__class__.__name__: + from evaluation.modify_llama import convert_kvcache_llama_heavy_recent, convert_llama_channel_config + logger.info("Detected LLaMA architecture") + model = convert_kvcache_llama_heavy_recent(model, config, heavy_const, group_factor) + if channel_config: + model = convert_llama_channel_config(model, channel_config, channel) + elif "Mistral" in config.__class__.__name__: + from evaluation.modify_mistral import convert_kvcache_mistral_heavy_recent, convert_mistral_channel_config + logger.info("Detected Mistral architecture") + model = convert_kvcache_mistral_heavy_recent(model, config, heavy_const, group_factor) + if channel_config: + model = convert_mistral_channel_config(model, channel_config, channel) + elif "Qwen2" in config.__class__.__name__: + from evaluation.modify_qwen2 import convert_kvcache_qwen2_heavy_recent, convert_qwen2_channel_config + logger.info("Detected Qwen2 architecture") + model = convert_kvcache_qwen2_heavy_recent(model, config, heavy_const, group_factor) + if channel_config: + model = convert_qwen2_channel_config(model, channel_config, channel) + else: + raise ValueError(f"Unsupported model architecture: {config.__class__.__name__}") + + model.eval() + logger.info(f"Model loaded with sparse attention (heavy_const={heavy_const}, group_factor={group_factor})") + + def _prepare_prompt(messages: List[Dict[str, str]]) -> str: + """Convert chat messages to a single prompt string.""" + prompt = "" + for msg in messages: + role = msg["role"] + content = msg["content"] + if role == "system": + prompt += f"System: {content}\n" + elif role == "user": + prompt += f"User: {content}\n" + elif role == "assistant": + prompt += f"Assistant: {content}\n" + prompt += "Assistant: " + return prompt + + @app.get("/") + async def root(): + """Root endpoint with API information.""" + return { + "status": "ok", + "model": model_path, + "sparse_attention": { + "heavy_const": heavy_const, + "group_factor": group_factor, + "channel": channel + } + } + + @app.get("/v1/models") + async def list_models(): + """List available models.""" + return { + "object": "list", + "data": [{ + "id": model_path, + "object": "model", + "created": int(time.time()), + "owned_by": "organization" + }] + } + + @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) + async def create_chat_completion(request: ChatCompletionRequest): + """Create a chat completion.""" + try: + if request.stream: + return StreamingResponse( + stream_chat_completion(request), + media_type='text/event-stream' + ) + + # Prepare input + prompt = _prepare_prompt([msg.dict() for msg in request.messages]) + inputs = tokenizer(prompt, return_tensors="pt", padding=True) + input_ids = inputs.input_ids.to(model.device) + attention_mask = inputs.attention_mask.to(model.device) if hasattr(inputs, 'attention_mask') else None + + # Generate + with torch.no_grad(): + outputs = model.generate( + input_ids, + attention_mask=attention_mask, + max_new_tokens=request.max_tokens, + temperature=request.temperature, + top_p=request.top_p, + do_sample=request.temperature > 0, + pad_token_id=tokenizer.pad_token_id, + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id + ) + + # Get only the new tokens + new_tokens = outputs[0][len(input_ids[0]):] + completion_text = tokenizer.decode(new_tokens, skip_special_tokens=True) + + # Create response + response = ChatCompletionResponse( + id=f"chatcmpl-{str(uuid.uuid4())}", + created=int(time.time()), + model=request.model, + choices=[ + ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content=completion_text), + finish_reason="stop" + ) + ], + usage=CompletionUsage( + prompt_tokens=len(input_ids[0]), + completion_tokens=len(new_tokens), + total_tokens=len(outputs[0]) + ) + ) + + return response + + except Exception as e: + logger.error(f"Error during chat completion: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + + async def stream_chat_completion(request: ChatCompletionRequest): + """Stream chat completion chunks.""" + completion_id = f"chatcmpl-{str(uuid.uuid4())}" + + try: + # Start with role + chunk = ChatCompletionStreamResponse( + id=completion_id, + created=int(time.time()), + model=request.model, + choices=[ + ChatCompletionStreamChoice( + index=0, + delta=DeltaMessage(role="assistant"), + finish_reason=None + ) + ] + ) + yield f"data: {chunk.json()}\n\n" + + # Prepare input + prompt = _prepare_prompt([msg.dict() for msg in request.messages]) + inputs = tokenizer(prompt, return_tensors="pt", padding=True) + input_ids = inputs.input_ids.to(model.device) + attention_mask = inputs.attention_mask.to(model.device) if hasattr(inputs, 'attention_mask') else None + + # Stream generation + generated_tokens = [] + past_key_values = None + + with torch.no_grad(): + for _ in range(request.max_tokens or 100): + outputs = model( + input_ids if past_key_values is None else input_ids[:, -1:], + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=True + ) + + next_token_logits = outputs.logits[:, -1, :] + if request.temperature > 0: + probs = torch.nn.functional.softmax(next_token_logits / request.temperature, dim=-1) + if request.top_p < 1.0: + sorted_probs, sorted_indices = torch.sort(probs, descending=True) + cumsum_probs = torch.cumsum(sorted_probs, dim=-1) + sorted_indices_to_remove = cumsum_probs > request.top_p + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + probs.masked_fill_(indices_to_remove, 0.0) + next_token = torch.multinomial(probs, num_samples=1) + else: + next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) + + generated_tokens.append(next_token[0].item()) + + # Update inputs and attention mask + input_ids = torch.cat([input_ids, next_token], dim=-1) + if attention_mask is not None: + attention_mask = torch.cat([ + attention_mask, + attention_mask.new_ones((attention_mask.shape[0], 1)) + ], dim=-1) + + # Update past key values + past_key_values = outputs.past_key_values + + # Decode and yield new token + current_text = tokenizer.decode(generated_tokens, skip_special_tokens=True) + chunk = ChatCompletionStreamResponse( + id=completion_id, + created=int(time.time()), + model=request.model, + choices=[ + ChatCompletionStreamChoice( + index=0, + delta=DeltaMessage(content=current_text), + finish_reason=None + ) + ] + ) + yield f"data: {chunk.json()}\n\n" + + # Check for EOS token + if next_token[0].item() == tokenizer.eos_token_id: + break + + # Send the final chunk + chunk = ChatCompletionStreamResponse( + id=completion_id, + created=int(time.time()), + model=request.model, + choices=[ + ChatCompletionStreamChoice( + index=0, + delta=DeltaMessage(), + finish_reason="stop" + ) + ] + ) + yield f"data: {chunk.json()}\n\n" + yield "data: [DONE]\n\n" + + except Exception as e: + logger.error(f"Error during streaming: {str(e)}") + error_chunk = ChatCompletionStreamResponse( + id=completion_id, + created=int(time.time()), + model=request.model, + choices=[ + ChatCompletionStreamChoice( + index=0, + delta=DeltaMessage(), + finish_reason="error" + ) + ] + ) + yield f"data: {error_chunk.json()}\n\n" + yield "data: [DONE]\n\n" + + return app + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Start DoubleSparse API server.') + parser.add_argument('--model_path', type=str, default="meta-llama/Llama-2-7b-hf", help='Selected model') + parser.add_argument('--offloading', action='store_true', help='Whether to use offloading') + parser.add_argument('--channel', type=str, default="qk", choices=["q", "k", "qk"], help='Channel selection') + parser.add_argument('--heavy_const', type=int, default=128, help='Heavy constant') + parser.add_argument('--group_factor', type=int, default=2, help='Group factor') + parser.add_argument('--host', type=str, default="0.0.0.0", help='Server host') + parser.add_argument('--port', type=int, default=52309, help='Server port') + + args = parser.parse_args() + + # Create and start the app + app = create_app( + model_path=args.model_path, + heavy_const=args.heavy_const, + group_factor=args.group_factor, + channel=args.channel, + offloading=args.offloading + ) + + import uvicorn + uvicorn.run( + app, + host=args.host, + port=args.port, + log_level="info" + ) \ No newline at end of file