diff --git a/backend/.env.example b/backend/.env.example new file mode 100644 index 0000000..1904eaf --- /dev/null +++ b/backend/.env.example @@ -0,0 +1,2 @@ +OPENAI_API_KEY=your-openai-api-key-here +GOOGLE_API_KEY=your-google-api-key-here diff --git a/backend/embedding.py b/backend/embedding.py new file mode 100644 index 0000000..e7b5596 --- /dev/null +++ b/backend/embedding.py @@ -0,0 +1,128 @@ +""" +Embedding generation module for distributed file transfer system. +Generates vector embeddings from file content or raw text using sentence-transformers. +NO API KEY REQUIRED - runs locally! +""" + +import numpy as np +import os +from typing import Union +from pathlib import Path + +# Sentence transformer model (lazy loaded) +_embedding_model = None + + +def get_embeddings_model(): + """ + Get or initialize the sentence transformer model. + No API key needed - runs locally! + """ + global _embedding_model + if _embedding_model is None: + try: + from sentence_transformers import SentenceTransformer + # Use a small, fast model that runs locally + _embedding_model = SentenceTransformer('all-MiniLM-L6-v2') + print("Loaded local sentence transformer model: all-MiniLM-L6-v2 (no API key needed)") + except ImportError: + raise ImportError( + "sentence-transformers package is required. Install with: pip install sentence-transformers" + ) + return _embedding_model + +import numpy as np +import os +from typing import Union +from pathlib import Path + +# Gemini client (lazy loaded) +_genai = None + + +def get_genai_client(): + """ + Get or initialize the Google Generative AI client. + Requires GOOGLE_API_KEY environment variable to be set. + """ + global _genai + if _genai is None: + try: + import google.generativeai as genai + api_key = os.getenv('GOOGLE_API_KEY') + if not api_key: + raise ValueError( + "GOOGLE_API_KEY environment variable not set. " + "Set it with: export GOOGLE_API_KEY='your-key-here'" + ) + genai.configure(api_key=api_key) + _genai = genai + print("Google Generative AI client initialized successfully") + except ImportError: + raise ImportError( + "google-generativeai package is required. Install with: pip install google-generativeai" + ) + return _genai + + +def generate_embedding(input_data: Union[str, Path], is_file: bool = None) -> np.ndarray: + """ + Generate embedding vector from file content or text string. + Uses local sentence transformer model - no API key needed! + + Args: + input_data: Either a file path (str/Path) or text content (str) + is_file: If True, treat input_data as file path. If False, treat as text. + If None, auto-detect based on input type and file existence. + + Returns: + numpy array of shape (384,) containing the embedding vector + + Example: + # From file + embedding = generate_embedding("document.txt", is_file=True) + + # From text + embedding = generate_embedding("This is my text content", is_file=False) + + # Auto-detect + embedding = generate_embedding(Path("document.txt")) + """ + model = get_embeddings_model() + + # Auto-detect if is_file not specified + if is_file is None: + if isinstance(input_data, Path): + is_file = True + elif isinstance(input_data, str): + # Check if it's a valid file path + is_file = os.path.isfile(input_data) + else: + raise ValueError("input_data must be a string or Path object") + + # Read file content if needed + if is_file: + file_path = Path(input_data) + if not file_path.exists(): + raise FileNotFoundError(f"File not found: {file_path}") + + try: + # Try reading as text + with open(file_path, 'r', encoding='utf-8') as f: + text_content = f.read() + except UnicodeDecodeError: + # If binary file, convert to string representation + with open(file_path, 'rb') as f: + # For binary files, use filename + size as content + text_content = f"{file_path.name} (binary file, size: {file_path.stat().st_size} bytes)" + else: + text_content = str(input_data) + + if not text_content or not text_content.strip(): + raise ValueError("Cannot generate embedding from empty content") + + # Generate embedding using local model + embedding = model.encode(text_content, convert_to_numpy=True) + + # Ensure it's a 1D numpy array of float32 + return embedding.astype(np.float32) diff --git a/backend/main.py b/backend/main.py deleted file mode 100644 index aa63a42..0000000 --- a/backend/main.py +++ /dev/null @@ -1,203 +0,0 @@ -from fastapi import FastAPI, Depends, HTTPException, Query, Body -from sqlmodel import Session, select, or_ -from typing import List, Optional -from pathlib import Path -from datetime import datetime -from pydantic import BaseModel - -from models import FileRecord, FileSearchResponse -from database import create_db_and_tables, get_session - -app = FastAPI(title="File Sync API", version="1.0.0") - - -@app.on_event("startup") -def on_startup(): - """ - Initialize database on startup - """ - create_db_and_tables() - - -@app.get("/") -def read_root(): - """ - Health check endpoint - """ - return {"status": "File Sync API is running"} - - -@app.get("/files/search", response_model=List[FileSearchResponse]) -def search_files( - query: str = Query(..., description="Search query for file name (supports * wildcard)"), - session: Session = Depends(get_session) -): - """ - GET endpoint: Fuzzy search for files by name with wildcard support - - Returns ONLY metadata - no actual files are stored on the Pi. - Files will be retrieved via SCP from source devices when user confirms download. - - Parameters: - - query: The search term (supports * wildcard for fuzzy matching) - Examples: "*.txt", "report*", "*2024*", "document.pdf" - - Returns: - - List of matching file metadata records - """ - # Convert wildcard pattern to SQL LIKE pattern - # * becomes % in SQL LIKE syntax - like_pattern = query.replace("*", "%") - - # Fuzzy search on file_name (case-insensitive with LIKE) - statement = select(FileRecord).where(FileRecord.file_name.like(f"%{like_pattern}%")) - - results = session.exec(statement).all() - - if not results: - raise HTTPException(status_code=404, detail=f"No files found matching '{query}'") - - return results - - -@app.get("/files/{file_id}", response_model=FileSearchResponse) -def get_file_metadata(file_id: int, session: Session = Depends(get_session)): - """ - Get metadata for a specific file by its ID - - This returns the file location info so the CLI can initiate SCP transfer. - - Parameters: - - file_id: The database ID of the file - - Returns: - - File metadata including device IP and path for SCP retrieval - """ - statement = select(FileRecord).where(FileRecord.id == file_id) - file_record = session.exec(statement).first() - - if not file_record: - raise HTTPException(status_code=404, detail=f"File with ID {file_id} not found") - - return file_record - - -class FileMetadata(BaseModel): - """Request body for registering file metadata""" - file_name: str - absolute_path: str - device: str - device_ip: str - device_user: str - last_modified_time: datetime - size: int - file_type: str - - -@app.post("/files/register") -def register_file( - file_metadata: FileMetadata, - session: Session = Depends(get_session) -): - """ - POST endpoint: Register file metadata (NO FILE UPLOAD) - - The Pi does NOT receive the actual file. It only stores metadata about where - the file exists on the network. When a user wants to download, the CLI will - use SCP to retrieve it from the source device. - - If a file with the same absolute_path and device already exists, it will be - updated (only latest version is kept). - - Parameters: - - file_metadata: JSON body containing all file metadata - - Returns: - - Success message with file ID - """ - try: - # Check if file already exists (same absolute_path + device) - statement = select(FileRecord).where( - FileRecord.absolute_path == file_metadata.absolute_path, - FileRecord.device == file_metadata.device - ) - existing_file = session.exec(statement).first() - - if existing_file: - # Update existing record (keep only latest version) - existing_file.file_name = file_metadata.file_name - existing_file.device_ip = file_metadata.device_ip - existing_file.device_user = file_metadata.device_user - existing_file.last_modified_time = file_metadata.last_modified_time - existing_file.size = file_metadata.size - existing_file.file_type = file_metadata.file_type - existing_file.created_time = datetime.utcnow() # Update timestamp - - session.add(existing_file) - session.commit() - session.refresh(existing_file) - - return { - "message": "File metadata updated successfully", - "file_id": existing_file.id, - "file_name": existing_file.file_name, - "action": "updated" - } - else: - # Create new record - file_record = FileRecord( - file_name=file_metadata.file_name, - absolute_path=file_metadata.absolute_path, - device=file_metadata.device, - device_ip=file_metadata.device_ip, - device_user=file_metadata.device_user, - last_modified_time=file_metadata.last_modified_time, - size=file_metadata.size, - file_type=file_metadata.file_type - ) - - session.add(file_record) - session.commit() - session.refresh(file_record) - - return { - "message": "File metadata registered successfully", - "file_id": file_record.id, - "file_name": file_record.file_name, - "action": "created" - } - - except Exception as e: - raise HTTPException(status_code=500, detail=f"Error registering file metadata: {str(e)}") - - -@app.delete("/files/{file_id}") -def delete_file_metadata(file_id: int, session: Session = Depends(get_session)): - """ - DELETE endpoint: Delete file metadata record - - This only removes the metadata from the Pi's database. - The actual file remains on the source device. - - Parameters: - - file_id: The database ID of the file metadata to delete - - Returns: - - Success message - """ - statement = select(FileRecord).where(FileRecord.id == file_id) - file_record = session.exec(statement).first() - - if not file_record: - raise HTTPException(status_code=404, detail=f"File with ID {file_id} not found") - - # Delete only the database record (no actual file to delete on Pi) - session.delete(file_record) - session.commit() - - return { - "message": "File metadata deleted successfully", - "file_id": file_id, - "file_name": file_record.file_name, - "device": file_record.device - } diff --git a/backend/requirements.txt b/backend/requirements.txt index 7d8b841..676e828 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -14,3 +14,6 @@ typing-inspection==0.4.2 typing_extensions==4.15.0 uvicorn==0.38.0 SQLAlchemy>=2.0.36 +numpy>=1.24.2 +google-generativeai>=0.8.0 +python-dotenv>=1.0.0 diff --git a/backend/test_api.py b/backend/test_api.py new file mode 100644 index 0000000..e83f6ed --- /dev/null +++ b/backend/test_api.py @@ -0,0 +1,189 @@ +#!/usr/bin/env python3 +""" +Test script for the Unified Server with Colinear Query Expansion +""" + +import requests +import json +from pathlib import Path + +BASE_URL = "http://localhost:8000" + +def print_section(title): + print("\n" + "="*70) + print(f" {title}") + print("="*70 + "\n") + +def test_health_check(): + print_section("1. Health Check") + response = requests.get(f"{BASE_URL}/") + print(json.dumps(response.json(), indent=2)) + +def test_upload_file(): + print_section("2. Upload Test File with Embedding") + + # Create a test file + test_content = """ + Machine Learning Tutorial + + This document covers the basics of machine learning, including: + - Supervised learning algorithms + - Neural networks and deep learning + - Data preprocessing techniques + - Model evaluation metrics + """ + + test_file = Path("/tmp/ml_tutorial.txt") + test_file.write_text(test_content) + + # Upload the file + with open(test_file, 'rb') as f: + files = {'file': ('ml_tutorial.txt', f, 'text/plain')} + params = { + 'device': 'laptop', + 'device_ip': '192.168.1.100', + 'device_user': 'testuser', + 'absolute_path': '/home/testuser/docs/ml_tutorial.txt' + } + response = requests.post(f"{BASE_URL}/files/upload", files=files, params=params) + + print(f"Status: {response.status_code}") + print(json.dumps(response.json(), indent=2)) + + # Upload another file + test_content2 = """ + Python Programming Guide + + Learn Python programming from scratch: + - Variables and data types + - Functions and classes + - File I/O operations + - Error handling + """ + + test_file2 = Path("/tmp/python_guide.txt") + test_file2.write_text(test_content2) + + with open(test_file2, 'rb') as f: + files = {'file': ('python_guide.txt', f, 'text/plain')} + params = { + 'device': 'desktop', + 'device_ip': '192.168.1.101', + 'device_user': 'testuser', + 'absolute_path': '/home/testuser/docs/python_guide.txt' + } + response = requests.post(f"{BASE_URL}/files/upload", files=files, params=params) + + print(f"\nSecond file - Status: {response.status_code}") + print(json.dumps(response.json(), indent=2)) + +def test_semantic_search_without_expansion(): + print_section("3. Semantic Search WITHOUT Query Expansion") + + query_data = { + "query": "learning tutorial", + "top_k": 5, + "use_query_expansion": False + } + + response = requests.post(f"{BASE_URL}/search/semantic", json=query_data) + print(f"Status: {response.status_code}") + + if response.status_code == 200: + results = response.json() + print(f"Found {len(results)} results:\n") + for i, result in enumerate(results, 1): + print(f"{i}. {result['filename']}") + print(f" Similarity: {result['similarity_score']:.4f}") + print(f" Matched via: {result['matched_via']}") + print(f" Device: {result['device']} ({result['device_ip']})") + print() + else: + print(response.text) + +def test_semantic_search_with_expansion(): + print_section("4. Semantic Search WITH Query Expansion (Colinear)") + + query_data = { + "query": "learning tutorial", + "top_k": 5, + "use_query_expansion": True, + "expansion_count": 3 + } + + print("Query: 'learning tutorial'") + print("Expansion enabled: True") + print("Expansion count: 3") + print("\nExpected query variants:") + print(" 0. learning tutorial (original)") + print(" 1. document about learning tutorial") + print(" 2. file containing learning tutorial") + print(" 3. information regarding learning tutorial") + print("\n" + "-"*70 + "\n") + + response = requests.post(f"{BASE_URL}/search/semantic", json=query_data) + print(f"Status: {response.status_code}") + + if response.status_code == 200: + results = response.json() + print(f"Found {len(results)} results:\n") + for i, result in enumerate(results, 1): + print(f"{i}. {result['filename']}") + print(f" Similarity: {result['similarity_score']:.4f}") + print(f" Matched via: {result['matched_via']} ⭐") + print(f" Device: {result['device']} ({result['device_ip']})") + print(f" Path: {result['path']}") + print() + else: + print(response.text) + +def test_keyword_search(): + print_section("5. Traditional Keyword Search") + + response = requests.get(f"{BASE_URL}/search/keyword?query=*.txt") + print(f"Status: {response.status_code}") + + if response.status_code == 200: + results = response.json() + print(f"Found {len(results)} .txt files:\n") + for i, result in enumerate(results, 1): + print(f"{i}. {result['file_name']}") + print(f" Device: {result['device']}") + print() + else: + print(response.text) + +def test_stats(): + print_section("6. System Statistics") + + response = requests.get(f"{BASE_URL}/stats") + print(json.dumps(response.json(), indent=2)) + +def main(): + print("\n" + "šŸš€ " * 25) + print(" Testing Unified Server with Colinear Query Expansion") + print("šŸš€ " * 25) + + try: + test_health_check() + test_upload_file() + test_semantic_search_without_expansion() + test_semantic_search_with_expansion() + test_keyword_search() + test_stats() + + print_section("āœ… All Tests Complete!") + print("Key Observations:") + print("- Semantic search WITH expansion may return more/better results") + print("- 'matched_via' shows which query variant found the match") + print("- Similarity scores range from 0-1 (higher = better)") + print("- Query expansion helps find relevant files with different wording") + + except requests.exceptions.ConnectionError: + print("\nāŒ Error: Could not connect to server at", BASE_URL) + print("Make sure the server is running: python unified_server.py") + except Exception as e: + print(f"\nāŒ Error: {e}") + +if __name__ == "__main__": + main() diff --git a/backend/test_embedding.py b/backend/test_embedding.py new file mode 100644 index 0000000..6cdab12 --- /dev/null +++ b/backend/test_embedding.py @@ -0,0 +1,101 @@ +""" +Test script for the embedding generation function. +Demonstrates usage with both files and raw text using Google's Gemini API. +""" + +import os +import numpy as np +from embedding import generate_embedding +from dotenv import load_dotenv +load_dotenv() + + +def test_embedding(): + """Test the embedding generation function.""" + + # Check if API key is set + if not os.getenv('GOOGLE_API_KEY'): + print("ERROR: GOOGLE_API_KEY environment variable not set!") + print("Set it with: export GOOGLE_API_KEY='your-key-here'") + print("Get your free key from: https://makersuite.google.com/app/apikey") + return + + print("=" * 60) + print("Testing Embedding Generation with Google Gemini API") + print("=" * 60) + print() + + # Test 1: Generate embedding from raw text + print("1. Testing with raw text...") + text1 = "This is a sample document about machine learning." + embedding1 = generate_embedding(text1, is_file=False) + print(f"Generated embedding: shape={embedding1.shape}, dtype={embedding1.dtype}") + print() + + # Test 2: Generate embedding from the same text (should be identical) + print("2. Testing with identical text...") + text2 = "This is a sample document about machine learning." + embedding2 = generate_embedding(text2, is_file=False) + + # Check if embeddings are the same + are_same = np.allclose(embedding1, embedding2) + print(f"Embeddings are identical: {are_same}") + print(f"Distance between embeddings: {np.linalg.norm(embedding1 - embedding2):.10f}") + print() + + # Test 3: Generate embedding from different text + print("3. Testing with different text...") + text3 = "A completely different topic about cooking recipes." + embedding3 = generate_embedding(text3, is_file=False) + + distance_similar = np.linalg.norm(embedding1 - embedding2) + distance_different = np.linalg.norm(embedding1 - embedding3) + print(f"Distance between same texts: {distance_similar:.6f}") + print(f"Distance between different texts: {distance_different:.6f}") + print() + + # Test 4: Create a test file and generate embedding from it + print("4. Testing with file...") + test_file = "test_content.txt" + test_content = "This is a sample document about machine learning." + + with open(test_file, 'w', encoding='utf-8') as f: + f.write(test_content) + + embedding4 = generate_embedding(test_file, is_file=True) + + # Check if file embedding matches text embedding + file_matches_text = np.allclose(embedding1, embedding4) + print(f"File embedding matches text embedding: {file_matches_text}") + print(f"Distance: {np.linalg.norm(embedding1 - embedding4):.10f}") + + # Clean up + os.remove(test_file) + print() + + # Test 5: Auto-detect mode + print("5. Testing auto-detect mode...") + + # Create test file + with open(test_file, 'w', encoding='utf-8') as f: + f.write("Auto-detect test content") + + # Should detect as file + embedding5 = generate_embedding(test_file) # Auto-detect + print(f"Auto-detected as file: shape={embedding5.shape}") + + # Should detect as text + embedding6 = generate_embedding("This is obviously text, not a file path") + print(f"Auto-detected as text: shape={embedding6.shape}") + + # Clean up + os.remove(test_file) + print() + + print("=" * 60) + print("All tests completed successfully!") + print("=" * 60) + + +if __name__ == "__main__": + test_embedding() diff --git a/backend/unified_server.py b/backend/unified_server.py new file mode 100644 index 0000000..cb11f17 --- /dev/null +++ b/backend/unified_server.py @@ -0,0 +1,572 @@ +""" +Unified File Sync & RAG Server +Combines file metadata management with semantic search using vector embeddings. +Implements Colinear Query Expansion for improved search results. +""" + +from fastapi import FastAPI, UploadFile, File, HTTPException, Query, Depends, Body +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel +from sqlmodel import Session, select, or_ +from typing import List, Optional +from contextlib import asynccontextmanager +import os +import tempfile +from pathlib import Path +from datetime import datetime +import numpy as np +from dotenv import load_dotenv + +# Load environment variables +load_dotenv() + +from embedding import generate_embedding +from vector_db import VectorDatabase, init_vector_db +from database import create_db_and_tables, get_session, engine +from models import FileRecord, FileSearchResponse + +# Global instances +vector_db: Optional[VectorDatabase] = None +EMBEDDING_DIM = 384 # all-MiniLM-L6-v2 dimension (local model, no API key needed!) + + +# ==================== Lifespan Context Manager ==================== + +@asynccontextmanager +async def lifespan(app: FastAPI): + """ + Lifespan context manager for startup and shutdown events. + Replaces deprecated @app.on_event decorators. + """ + global vector_db + + # Startup + print("šŸš€ Starting Unified File Sync & RAG Server...") + create_db_and_tables() + vector_db = init_vector_db(dimension=EMBEDDING_DIM) + print("āœ“ Unified File Sync & RAG Server initialized successfully") + + yield # Server runs here + + # Shutdown + print("šŸ›‘ Shutting down server...") + + +# Initialize FastAPI app with lifespan +app = FastAPI( + title="Unified File Sync & RAG Server", + version="2.0.0", + description="File metadata management with semantic search and query expansion", + lifespan=lifespan +) + +# CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Global instances +vector_db: Optional[VectorDatabase] = None +EMBEDDING_DIM = 384 # all-MiniLM-L6-v2 dimension (local model, no API key needed!) + +# ==================== Request/Response Models ==================== + +class FileMetadata(BaseModel): + """Request body for registering file metadata""" + file_name: str + absolute_path: str + device: str + device_ip: str + device_user: str + last_modified_time: datetime + size: int + file_type: str + + +class QueryRequest(BaseModel): + """Semantic search query with query expansion support""" + query: str + top_k: int = 5 + use_query_expansion: bool = True # Enable Colinear Query Expansion by default + expansion_count: int = 3 # Number of expanded queries to generate + + +class SearchResult(BaseModel): + """Enhanced search result with similarity score""" + file_id: int + filename: str + path: str + device: str + device_ip: str + device_user: str + size: int + file_type: str + similarity_score: float # 0-1, higher is better (converted from distance) + last_modified_time: datetime + matched_via: str # "original_query" or "expanded_query_N" + + +class UploadResponse(BaseModel): + file_id: int + filename: str + message: str + action: str # "created" or "updated" + + +# ==================== CORS Middleware ==================== + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # Configure appropriately for production + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +# ==================== Health Check ==================== + +@app.get("/") +async def root(): + """Health check endpoint.""" + return { + "status": "running", + "service": "Unified File Sync & RAG Server", + "version": "2.0.0", + "features": ["metadata_search", "semantic_search", "query_expansion"] + } + + +# ==================== File Registration (Metadata Only) ==================== + +@app.post("/files/register", response_model=UploadResponse) +def register_file( + file_metadata: FileMetadata, + session: Session = Depends(get_session) +): + """ + Register file metadata WITHOUT uploading the actual file. + + The Pi does NOT receive the actual file. It only stores metadata about where + the file exists on the network. When a user wants to download, the CLI will + use SCP to retrieve it from the source device. + + If a file with the same absolute_path and device already exists, it will be updated. + """ + try: + # Check if file already exists (same absolute_path + device) + statement = select(FileRecord).where( + FileRecord.absolute_path == file_metadata.absolute_path, + FileRecord.device == file_metadata.device + ) + existing_file = session.exec(statement).first() + + if existing_file: + # Update existing record (keep only latest version) + existing_file.file_name = file_metadata.file_name + existing_file.device_ip = file_metadata.device_ip + existing_file.device_user = file_metadata.device_user + existing_file.last_modified_time = file_metadata.last_modified_time + existing_file.size = file_metadata.size + existing_file.file_type = file_metadata.file_type + existing_file.created_time = datetime.utcnow() + + session.add(existing_file) + session.commit() + session.refresh(existing_file) + + return UploadResponse( + file_id=existing_file.id, + filename=existing_file.file_name, + message="File metadata updated successfully", + action="updated" + ) + else: + # Create new record + file_record = FileRecord( + file_name=file_metadata.file_name, + absolute_path=file_metadata.absolute_path, + device=file_metadata.device, + device_ip=file_metadata.device_ip, + device_user=file_metadata.device_user, + last_modified_time=file_metadata.last_modified_time, + size=file_metadata.size, + file_type=file_metadata.file_type + ) + + session.add(file_record) + session.commit() + session.refresh(file_record) + + return UploadResponse( + file_id=file_record.id, + filename=file_record.file_name, + message="File metadata registered successfully", + action="created" + ) + + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error registering file metadata: {str(e)}") + + +# ==================== File Upload with Embedding (For RAG) ==================== + +@app.post("/files/upload", response_model=UploadResponse) +async def upload_file_with_embedding( + file: UploadFile = File(...), + device: str = Query(..., description="Device identifier"), + device_ip: str = Query(..., description="Device IP address"), + device_user: str = Query(..., description="Device SSH username"), + absolute_path: str = Query(..., description="Absolute path on device") +): + """ + Upload a file, generate embedding, and store in both databases. + + Use this endpoint when you want to enable semantic search on the file content. + The file content is read to generate embeddings, then discarded. + """ + tmp_path = None + try: + # Save uploaded file temporarily + suffix = Path(file.filename or "file.txt").suffix + with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: + content = await file.read() + tmp.write(content) + tmp_path = tmp.name + + # Generate embedding from file content + embedding = generate_embedding(tmp_path, is_file=True) + + # Store file metadata in database + session = Session(engine) + try: + # Check if file already exists + statement = select(FileRecord).where( + FileRecord.absolute_path == absolute_path, + FileRecord.device == device + ) + existing_file = session.exec(statement).first() + + if existing_file: + # Update existing record + existing_file.file_name = file.filename or "unknown" + existing_file.device_ip = device_ip + existing_file.device_user = device_user + existing_file.last_modified_time = datetime.utcnow() + existing_file.size = len(content) + existing_file.file_type = suffix or ".txt" + existing_file.created_time = datetime.utcnow() + + session.add(existing_file) + session.commit() + session.refresh(existing_file) + + # Update embedding in vector database + vector_db.delete_embedding(existing_file.id) + vector_db.insert(embedding, existing_file.id) + + return UploadResponse( + file_id=existing_file.id, + filename=file.filename or "unknown", + message="File uploaded and indexed successfully", + action="updated" + ) + else: + # Create new record + file_record = FileRecord( + file_name=file.filename or "unknown", + absolute_path=absolute_path, + device=device, + device_ip=device_ip, + device_user=device_user, + last_modified_time=datetime.utcnow(), + size=len(content), + file_type=suffix or ".txt" + ) + session.add(file_record) + session.commit() + session.refresh(file_record) + + # Store embedding in vector database + success = vector_db.insert(embedding, file_record.id) + + if not success: + raise HTTPException(status_code=500, detail="Failed to store embedding") + + return UploadResponse( + file_id=file_record.id, + filename=file.filename or "unknown", + message="File uploaded and indexed successfully", + action="created" + ) + finally: + session.close() + + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + finally: + # Clean up temp file + if tmp_path and os.path.exists(tmp_path): + os.unlink(tmp_path) + + +# ==================== Colinear Query Expansion ==================== + +def colinear_query_expansion( + original_query: str, + expansion_count: int = 3 +) -> List[str]: + """ + Generate expanded queries using colinear query expansion technique. + + This creates related queries by: + 1. Adding synonyms and related terms + 2. Reformulating the query in different ways + 3. Adding context-specific expansions + + Args: + original_query: The original search query + expansion_count: Number of expanded queries to generate + + Returns: + List of expanded query strings (including original) + """ + expanded_queries = [original_query] + + # Simple expansion strategies (can be enhanced with LLM in future) + # Strategy 1: Add common document-related terms + if expansion_count >= 1: + expanded_queries.append(f"document about {original_query}") + + # Strategy 2: Add file content context + if expansion_count >= 2: + expanded_queries.append(f"file containing {original_query}") + + # Strategy 3: Add informational context + if expansion_count >= 3: + expanded_queries.append(f"information regarding {original_query}") + + # Strategy 4: Add data context + if expansion_count >= 4: + expanded_queries.append(f"data related to {original_query}") + + return expanded_queries[:expansion_count + 1] + + +# ==================== Semantic Search with Query Expansion ==================== + +@app.post("/search/semantic", response_model=List[SearchResult]) +async def semantic_search(request: QueryRequest): + """ + Semantic search using vector embeddings with Colinear Query Expansion. + + This endpoint: + 1. Takes your search query + 2. Optionally expands it into multiple related queries (Colinear Query Expansion) + 3. Generates embeddings for each query variant + 4. Searches the vector database + 5. Combines and ranks results by relevance + + Parameters: + - query: Search query string + - top_k: Maximum number of results to return + - use_query_expansion: Enable/disable query expansion (default: True) + - expansion_count: Number of query variants to generate (default: 3) + """ + try: + all_results = {} # file_id -> (distance, matched_via) + + # Generate query variants + if request.use_query_expansion: + queries = colinear_query_expansion(request.query, request.expansion_count) + print(f"šŸ” Query Expansion: {len(queries)} variants generated") + for i, q in enumerate(queries): + print(f" {i}: {q}") + else: + queries = [request.query] + + # Search with each query variant + for idx, query_variant in enumerate(queries): + # Generate embedding for this query variant + query_embedding = generate_embedding(query_variant, is_file=False) + + # Search vector database + results = vector_db.get_file(query_embedding, k=request.top_k) + + # Track which query matched which file + matched_via = "original_query" if idx == 0 else f"expanded_query_{idx}" + + # Combine results (keep best distance for each file) + for file_id, distance in results: + if file_id not in all_results or distance < all_results[file_id][0]: + all_results[file_id] = (distance, matched_via) + + # Get top_k unique results sorted by distance + sorted_results = sorted(all_results.items(), key=lambda x: x[1][0])[:request.top_k] + + # Retrieve file metadata and build response + search_results = [] + session = Session(engine) + try: + for file_id, (distance, matched_via) in sorted_results: + file_obj = session.get(FileRecord, file_id) + if file_obj: + # Convert L2 distance to similarity score (0-1, higher is better) + # Using negative exponential: similarity = e^(-distance) + similarity_score = float(np.exp(-distance)) + + search_results.append(SearchResult( + file_id=file_obj.id or 0, + filename=file_obj.file_name, + path=file_obj.absolute_path, + device=file_obj.device, + device_ip=file_obj.device_ip, + device_user=file_obj.device_user, + size=file_obj.size, + file_type=file_obj.file_type, + similarity_score=similarity_score, + last_modified_time=file_obj.last_modified_time, + matched_via=matched_via + )) + + print(f"āœ“ Found {len(search_results)} results") + return search_results + finally: + session.close() + + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +# ==================== Keyword/Wildcard Search (Original from main.py) ==================== + +@app.get("/search/keyword", response_model=List[FileSearchResponse]) +def keyword_search( + query: str = Query(..., description="Search query for file name (supports * wildcard)"), + session: Session = Depends(get_session) +): + """ + Traditional keyword search for files by name with wildcard support. + + This is faster than semantic search but only matches filenames, not content. + + Parameters: + - query: Search term (supports * wildcard for fuzzy matching) + Examples: "*.txt", "report*", "*2024*", "document.pdf" + + Returns: + - List of matching file metadata records + """ + # Convert wildcard pattern to SQL LIKE pattern + like_pattern = query.replace("*", "%") + + # Fuzzy search on file_name (case-insensitive) + statement = select(FileRecord).where(FileRecord.file_name.like(f"%{like_pattern}%")) + + results = session.exec(statement).all() + + if not results: + raise HTTPException(status_code=404, detail=f"No files found matching '{query}'") + + return results + + +# ==================== File Management ==================== + +@app.get("/files", response_model=List[dict]) +async def list_files(session: Session = Depends(get_session)): + """List all files in the database.""" + try: + files = session.exec(select(FileRecord)).all() + return [ + { + "id": f.id, + "file_name": f.file_name, + "absolute_path": f.absolute_path, + "device": f.device, + "device_ip": f.device_ip, + "device_user": f.device_user, + "last_modified_time": f.last_modified_time.isoformat(), + "created_time": f.created_time.isoformat(), + "size": f.size, + "file_type": f.file_type + } + for f in files + ] + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@app.get("/files/{file_id}", response_model=FileSearchResponse) +async def get_file_metadata(file_id: int, session: Session = Depends(get_session)): + """ + Get metadata for a specific file by its ID. + + Returns file location info so the CLI can initiate SCP transfer. + """ + file_obj = session.get(FileRecord, file_id) + + if not file_obj: + raise HTTPException(status_code=404, detail=f"File with ID {file_id} not found") + + return file_obj + + +@app.delete("/files/{file_id}") +def delete_file_metadata(file_id: int, session: Session = Depends(get_session)): + """ + Delete file metadata record and its embedding. + + This only removes the metadata from the Pi's database. + The actual file remains on the source device. + """ + file_record = session.get(FileRecord, file_id) + + if not file_record: + raise HTTPException(status_code=404, detail=f"File with ID {file_id} not found") + + # Delete embedding from vector database + vector_db.delete_embedding(file_id) + + # Delete database record + session.delete(file_record) + session.commit() + + return { + "message": "File metadata and embedding deleted successfully", + "file_id": file_id, + "file_name": file_record.file_name, + "device": file_record.device + } + + +# ==================== Statistics ==================== + +@app.get("/stats") +async def get_stats(session: Session = Depends(get_session)): + """Get statistics about the system.""" + try: + vector_stats = vector_db.get_stats() + files = session.exec(select(FileRecord)).all() + total_files = len(files) + total_size = sum(f.size for f in files) + + return { + "total_files": total_files, + "total_vectors": vector_stats['total_vectors'], + "total_size_bytes": total_size, + "vector_dimension": vector_stats['dimension'], + "index_type": vector_stats['index_type'], + "embedding_model": "all-MiniLM-L6-v2 (local, no API key)" + } + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/backend/vector_db.py b/backend/vector_db.py new file mode 100644 index 0000000..5d74e04 --- /dev/null +++ b/backend/vector_db.py @@ -0,0 +1,305 @@ +""" +Vector database module for distributed file transfer system. +Handles vector embeddings storage and similarity search using FAISS. +Uses SQLite for metadata storage. +""" + +import faiss +import numpy as np +import sqlite3 +import pickle +import os +from typing import List, Tuple, Optional + +# Vector database setup +VECTOR_DB_PATH = 'vectors.db' +FAISS_INDEX_PATH = 'faiss.index' +EMBEDDING_DIMENSION = 768 # Default for Gemini embeddings + + +def init_vector_db(db_path: str = VECTOR_DB_PATH, dimension: int = EMBEDDING_DIMENSION): + """ + Initialize the vector database, creating the table if it doesn't exist. + + Args: + db_path: Path to the SQLite database file + dimension: Dimension of vector embeddings + + Returns: + VectorDatabase instance + """ + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + + # Create table for vector embeddings + cursor.execute(''' + CREATE TABLE IF NOT EXISTS vector_embeddings ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + file_id INTEGER NOT NULL UNIQUE, + embedding BLOB NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + ''') + + # Create index on file_id for faster lookups + cursor.execute(''' + CREATE INDEX IF NOT EXISTS idx_file_id ON vector_embeddings(file_id) + ''') + + conn.commit() + conn.close() + print(f"Vector database initialized at {db_path}") + + return VectorDatabase(db_path=db_path, dimension=dimension) + + +class VectorDatabase: + """ + Vector database manager using FAISS for efficient similarity search + and SQLite for metadata storage. + """ + + def __init__(self, db_path: str = VECTOR_DB_PATH, dimension: int = EMBEDDING_DIMENSION): + """ + Initialize the vector database. + + Args: + db_path: Path to SQLite database + dimension: Dimension of the vector embeddings + """ + self.db_path = db_path + self.dimension = dimension + self.index = None + self.id_to_file_id = {} # Maps FAISS index position to file_id + self.file_id_to_index = {} # Maps file_id to FAISS index position + + # Load or create FAISS index + self._load_or_create_index() + + def _load_or_create_index(self): + """Load existing FAISS index or create a new one.""" + if os.path.exists(FAISS_INDEX_PATH): + try: + self.index = faiss.read_index(FAISS_INDEX_PATH) + self._load_mappings() + print(f"Loaded FAISS index from {FAISS_INDEX_PATH} with {self.index.ntotal} vectors") + except Exception as e: + print(f"Error loading FAISS index: {e}") + self._create_new_index() + else: + self._create_new_index() + + def _create_new_index(self): + """Create a new FAISS index.""" + # Using L2 distance + self.index = faiss.IndexFlatL2(self.dimension) + print(f"Created new FAISS index with dimension {self.dimension}") + + def _load_mappings(self): + """Load file_id mappings from database.""" + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + + cursor.execute('SELECT id, file_id FROM vector_embeddings ORDER BY id') + rows = cursor.fetchall() + + for idx, (db_id, file_id) in enumerate(rows): + self.id_to_file_id[idx] = file_id + self.file_id_to_index[file_id] = idx + + conn.close() + + def _save_index(self): + """Save FAISS index to disk.""" + if self.index is not None: + faiss.write_index(self.index, FAISS_INDEX_PATH) + + def insert(self, vector: np.ndarray, file_id: int) -> bool: + """ + Insert a vector embedding for a file. + + Args: + vector: Numpy array of shape (dimension,) + file_id: ID of the file in the main database + + Returns: + True if successful, False if file_id already exists + + Raises: + ValueError: If vector dimension doesn't match + """ + # Validate vector dimension + if vector.shape[0] != self.dimension: + raise ValueError(f"Vector dimension {vector.shape} doesn't match expected ({self.dimension},)") + + # Check if file_id already exists + if file_id in self.file_id_to_index: + print(f"Embedding for file_id {file_id} already exists.") + return False + + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + + try: + # Serialize vector + vector_blob = pickle.dumps(vector) + + # Store in SQLite database + cursor.execute( + 'INSERT INTO vector_embeddings (file_id, embedding) VALUES (?, ?)', + (file_id, vector_blob) + ) + conn.commit() + + # Add to FAISS index + vector_reshaped = vector.reshape(1, -1).astype('float32') + self.index.add(vector_reshaped) + + # Update mappings + idx = self.index.ntotal - 1 + self.id_to_file_id[idx] = file_id + self.file_id_to_index[file_id] = idx + + # Save index + self._save_index() + + print(f"Inserted embedding for file_id {file_id} at index {idx}") + return True + + except sqlite3.IntegrityError: + print(f"Embedding for file_id {file_id} already exists in database.") + return False + except Exception as e: + conn.rollback() + print(f"Error inserting embedding: {e}") + raise + finally: + conn.close() + + def get_file(self, vector: np.ndarray, k: int = 1) -> List[Tuple[int, float]]: + """ + Find the k most similar files for a given query vector. + + Args: + vector: Query vector as numpy array of shape (dimension,) + k: Number of nearest neighbors to return + + Returns: + List of tuples (file_id, distance) sorted by similarity + + Raises: + ValueError: If vector dimension doesn't match + """ + # Validate vector dimension + if vector.shape[0] != self.dimension: + raise ValueError(f"Vector dimension {vector.shape} doesn't match expected ({self.dimension},)") + + if self.index.ntotal == 0: + print("No vectors in the index.") + return [] + + # Ensure k doesn't exceed total vectors + k = min(k, self.index.ntotal) + + # Search FAISS index + vector_reshaped = vector.reshape(1, -1).astype('float32') + distances, indices = self.index.search(vector_reshaped, k) + + # Map indices to file_ids + results = [] + for idx, distance in zip(indices[0], distances[0]): + if idx in self.id_to_file_id: + file_id = self.id_to_file_id[idx] + results.append((file_id, float(distance))) + + return results + + def delete_embedding(self, file_id: int) -> bool: + """ + Delete an embedding by file_id. + Note: FAISS doesn't support deletion, so we rebuild the index. + + Args: + file_id: ID of the file to remove + + Returns: + True if successful, False if not found + """ + if file_id not in self.file_id_to_index: + print(f"No embedding found for file_id {file_id}") + return False + + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + + try: + # Remove from database + cursor.execute('DELETE FROM vector_embeddings WHERE file_id = ?', (file_id,)) + conn.commit() + + # Rebuild FAISS index + self._rebuild_index() + + print(f"Deleted embedding for file_id {file_id}") + return True + + except Exception as e: + conn.rollback() + print(f"Error deleting embedding: {e}") + raise + finally: + conn.close() + + def _rebuild_index(self): + """Rebuild FAISS index from database.""" + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + + try: + # Create new index + self._create_new_index() + self.id_to_file_id.clear() + self.file_id_to_index.clear() + + # Load all embeddings + cursor.execute('SELECT id, file_id, embedding FROM vector_embeddings ORDER BY id') + rows = cursor.fetchall() + + if rows: + vectors = [] + for idx, (db_id, file_id, embedding_blob) in enumerate(rows): + vector = pickle.loads(embedding_blob) + vectors.append(vector) + self.id_to_file_id[idx] = file_id + self.file_id_to_index[file_id] = idx + + # Add all vectors to index + vectors_array = np.array(vectors).astype('float32') + self.index.add(vectors_array) + + # Save index + self._save_index() + + print(f"Rebuilt index with {len(rows)} vectors") + finally: + conn.close() + + def get_stats(self) -> dict: + """ + Get statistics about the vector database. + + Returns: + Dictionary with stats + """ + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + cursor.execute('SELECT COUNT(*) FROM vector_embeddings') + db_count = cursor.fetchone()[0] + conn.close() + + return { + 'total_vectors': self.index.ntotal if self.index else 0, + 'db_records': db_count, + 'dimension': self.dimension, + 'index_type': type(self.index).__name__ if self.index else None + }