From 25f50bb7f1e9b872facd1c961338eed810eedc4a Mon Sep 17 00:00:00 2001 From: yuqiannemo Date: Sun, 19 Oct 2025 16:12:11 +0800 Subject: [PATCH 1/4] Add embedding function --- backend/embedding.py | 110 ++++++++++++++++++++++++++++++++++++++ backend/requirements.txt | 3 ++ backend/test_embedding.py | 101 ++++++++++++++++++++++++++++++++++ 3 files changed, 214 insertions(+) create mode 100644 backend/embedding.py create mode 100644 backend/test_embedding.py diff --git a/backend/embedding.py b/backend/embedding.py new file mode 100644 index 0000000..348bc68 --- /dev/null +++ b/backend/embedding.py @@ -0,0 +1,110 @@ +""" +Embedding generation module for distributed file transfer system. +Generates vector embeddings from file content or raw text using Google's Gemini API. +""" + +import numpy as np +import os +from typing import Union +from pathlib import Path + +# Gemini client (lazy loaded) +_genai = None + + +def get_genai_client(): + """ + Get or initialize the Google Generative AI client. + Requires GOOGLE_API_KEY environment variable to be set. + """ + global _genai + if _genai is None: + try: + import google.generativeai as genai + api_key = os.getenv('GOOGLE_API_KEY') + if not api_key: + raise ValueError( + "GOOGLE_API_KEY environment variable not set. " + "Set it with: export GOOGLE_API_KEY='your-key-here'" + ) + genai.configure(api_key=api_key) + _genai = genai + print("Google Generative AI client initialized successfully") + except ImportError: + raise ImportError( + "google-generativeai package is required. Install with: pip install google-generativeai" + ) + return _genai + + +def generate_embedding(input_data: Union[str, Path], is_file: bool = None) -> np.ndarray: + """ + Generate embedding vector from either a file or raw text content. + If the content is the same, the embedding will be the same. + + Args: + input_data: Either a file path (str or Path) or raw text content (str) + is_file: If True, treat input_data as file path. If False, treat as raw text. + If None, auto-detect based on whether input_data is a valid file path. + + Returns: + numpy array of shape (384,) containing the embedding vector + + Raises: + FileNotFoundError: If is_file=True but file doesn't exist + ValueError: If input_data is empty + + Examples: + >>> # From file + >>> embedding1 = generate_embedding('/path/to/file.txt', is_file=True) + >>> + >>> # From raw text + >>> embedding2 = generate_embedding('This is some text content', is_file=False) + >>> + >>> # Auto-detect + >>> embedding3 = generate_embedding('/path/to/file.txt') # Treats as file if exists + >>> embedding4 = generate_embedding('Some text') # Treats as text if not a file + """ + # Auto-detect if is_file is not specified + if is_file is None: + input_path = Path(input_data) + is_file = input_path.exists() and input_path.is_file() + + # Read content + if is_file: + file_path = Path(input_data) + if not file_path.exists(): + raise FileNotFoundError(f"File not found: {input_data}") + + # Read file content + try: + with open(file_path, 'r', encoding='utf-8') as f: + content = f.read() + except UnicodeDecodeError: + # Try reading as binary and decode with errors='ignore' + with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: + content = f.read() + + print(f"Read {len(content)} characters from file: {file_path}") + else: + content = str(input_data) + + # Validate content + if not content or not content.strip(): + raise ValueError("Content is empty. Cannot generate embedding for empty content.") + + # Generate embedding using Gemini API + genai = get_genai_client() + + # Use Gemini's embedding model + result = genai.embed_content( + model="models/embedding-001", + content=content, + task_type="retrieval_document" + ) + + # Extract embedding from response + embedding = np.array(result['embedding'], dtype='float32') + + print(f"Generated embedding with shape: {embedding.shape}") + return embedding diff --git a/backend/requirements.txt b/backend/requirements.txt index 7d8b841..676e828 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -14,3 +14,6 @@ typing-inspection==0.4.2 typing_extensions==4.15.0 uvicorn==0.38.0 SQLAlchemy>=2.0.36 +numpy>=1.24.2 +google-generativeai>=0.8.0 +python-dotenv>=1.0.0 diff --git a/backend/test_embedding.py b/backend/test_embedding.py new file mode 100644 index 0000000..6cdab12 --- /dev/null +++ b/backend/test_embedding.py @@ -0,0 +1,101 @@ +""" +Test script for the embedding generation function. +Demonstrates usage with both files and raw text using Google's Gemini API. +""" + +import os +import numpy as np +from embedding import generate_embedding +from dotenv import load_dotenv +load_dotenv() + + +def test_embedding(): + """Test the embedding generation function.""" + + # Check if API key is set + if not os.getenv('GOOGLE_API_KEY'): + print("ERROR: GOOGLE_API_KEY environment variable not set!") + print("Set it with: export GOOGLE_API_KEY='your-key-here'") + print("Get your free key from: https://makersuite.google.com/app/apikey") + return + + print("=" * 60) + print("Testing Embedding Generation with Google Gemini API") + print("=" * 60) + print() + + # Test 1: Generate embedding from raw text + print("1. Testing with raw text...") + text1 = "This is a sample document about machine learning." + embedding1 = generate_embedding(text1, is_file=False) + print(f"Generated embedding: shape={embedding1.shape}, dtype={embedding1.dtype}") + print() + + # Test 2: Generate embedding from the same text (should be identical) + print("2. Testing with identical text...") + text2 = "This is a sample document about machine learning." + embedding2 = generate_embedding(text2, is_file=False) + + # Check if embeddings are the same + are_same = np.allclose(embedding1, embedding2) + print(f"Embeddings are identical: {are_same}") + print(f"Distance between embeddings: {np.linalg.norm(embedding1 - embedding2):.10f}") + print() + + # Test 3: Generate embedding from different text + print("3. Testing with different text...") + text3 = "A completely different topic about cooking recipes." + embedding3 = generate_embedding(text3, is_file=False) + + distance_similar = np.linalg.norm(embedding1 - embedding2) + distance_different = np.linalg.norm(embedding1 - embedding3) + print(f"Distance between same texts: {distance_similar:.6f}") + print(f"Distance between different texts: {distance_different:.6f}") + print() + + # Test 4: Create a test file and generate embedding from it + print("4. Testing with file...") + test_file = "test_content.txt" + test_content = "This is a sample document about machine learning." + + with open(test_file, 'w', encoding='utf-8') as f: + f.write(test_content) + + embedding4 = generate_embedding(test_file, is_file=True) + + # Check if file embedding matches text embedding + file_matches_text = np.allclose(embedding1, embedding4) + print(f"File embedding matches text embedding: {file_matches_text}") + print(f"Distance: {np.linalg.norm(embedding1 - embedding4):.10f}") + + # Clean up + os.remove(test_file) + print() + + # Test 5: Auto-detect mode + print("5. Testing auto-detect mode...") + + # Create test file + with open(test_file, 'w', encoding='utf-8') as f: + f.write("Auto-detect test content") + + # Should detect as file + embedding5 = generate_embedding(test_file) # Auto-detect + print(f"Auto-detected as file: shape={embedding5.shape}") + + # Should detect as text + embedding6 = generate_embedding("This is obviously text, not a file path") + print(f"Auto-detected as text: shape={embedding6.shape}") + + # Clean up + os.remove(test_file) + print() + + print("=" * 60) + print("All tests completed successfully!") + print("=" * 60) + + +if __name__ == "__main__": + test_embedding() From d310fbd4aed920dfbd91b9e9604bd5c925d0ff0f Mon Sep 17 00:00:00 2001 From: yuqiannemo Date: Sun, 19 Oct 2025 16:14:48 +0800 Subject: [PATCH 2/4] Add example env file --- backend/.env.example | 1 + 1 file changed, 1 insertion(+) create mode 100644 backend/.env.example diff --git a/backend/.env.example b/backend/.env.example new file mode 100644 index 0000000..2474827 --- /dev/null +++ b/backend/.env.example @@ -0,0 +1 @@ +GOOGLE_API_KEY = your_google_api_key_here \ No newline at end of file From 1acae26aff947fa57e2717a643da2d72e0298588 Mon Sep 17 00:00:00 2001 From: yuqiannemo Date: Sun, 19 Oct 2025 19:03:47 +0800 Subject: [PATCH 3/4] Add rag server --- backend/.env.example | 3 +- backend/embedding.py | 118 +++++++++------- backend/rag_server.py | 267 ++++++++++++++++++++++++++++++++++++ backend/vector_db.py | 305 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 642 insertions(+), 51 deletions(-) create mode 100644 backend/rag_server.py create mode 100644 backend/vector_db.py diff --git a/backend/.env.example b/backend/.env.example index 2474827..1904eaf 100644 --- a/backend/.env.example +++ b/backend/.env.example @@ -1 +1,2 @@ -GOOGLE_API_KEY = your_google_api_key_here \ No newline at end of file +OPENAI_API_KEY=your-openai-api-key-here +GOOGLE_API_KEY=your-google-api-key-here diff --git a/backend/embedding.py b/backend/embedding.py index 348bc68..e7b5596 100644 --- a/backend/embedding.py +++ b/backend/embedding.py @@ -1,6 +1,7 @@ """ Embedding generation module for distributed file transfer system. -Generates vector embeddings from file content or raw text using Google's Gemini API. +Generates vector embeddings from file content or raw text using sentence-transformers. +NO API KEY REQUIRED - runs locally! """ import numpy as np @@ -8,6 +9,33 @@ from typing import Union from pathlib import Path +# Sentence transformer model (lazy loaded) +_embedding_model = None + + +def get_embeddings_model(): + """ + Get or initialize the sentence transformer model. + No API key needed - runs locally! + """ + global _embedding_model + if _embedding_model is None: + try: + from sentence_transformers import SentenceTransformer + # Use a small, fast model that runs locally + _embedding_model = SentenceTransformer('all-MiniLM-L6-v2') + print("Loaded local sentence transformer model: all-MiniLM-L6-v2 (no API key needed)") + except ImportError: + raise ImportError( + "sentence-transformers package is required. Install with: pip install sentence-transformers" + ) + return _embedding_model + +import numpy as np +import os +from typing import Union +from pathlib import Path + # Gemini client (lazy loaded) _genai = None @@ -39,72 +67,62 @@ def get_genai_client(): def generate_embedding(input_data: Union[str, Path], is_file: bool = None) -> np.ndarray: """ - Generate embedding vector from either a file or raw text content. - If the content is the same, the embedding will be the same. + Generate embedding vector from file content or text string. + Uses local sentence transformer model - no API key needed! Args: - input_data: Either a file path (str or Path) or raw text content (str) - is_file: If True, treat input_data as file path. If False, treat as raw text. - If None, auto-detect based on whether input_data is a valid file path. + input_data: Either a file path (str/Path) or text content (str) + is_file: If True, treat input_data as file path. If False, treat as text. + If None, auto-detect based on input type and file existence. Returns: numpy array of shape (384,) containing the embedding vector - Raises: - FileNotFoundError: If is_file=True but file doesn't exist - ValueError: If input_data is empty - - Examples: - >>> # From file - >>> embedding1 = generate_embedding('/path/to/file.txt', is_file=True) - >>> - >>> # From raw text - >>> embedding2 = generate_embedding('This is some text content', is_file=False) - >>> - >>> # Auto-detect - >>> embedding3 = generate_embedding('/path/to/file.txt') # Treats as file if exists - >>> embedding4 = generate_embedding('Some text') # Treats as text if not a file + Example: + # From file + embedding = generate_embedding("document.txt", is_file=True) + + # From text + embedding = generate_embedding("This is my text content", is_file=False) + + # Auto-detect + embedding = generate_embedding(Path("document.txt")) """ - # Auto-detect if is_file is not specified + model = get_embeddings_model() + + # Auto-detect if is_file not specified if is_file is None: - input_path = Path(input_data) - is_file = input_path.exists() and input_path.is_file() + if isinstance(input_data, Path): + is_file = True + elif isinstance(input_data, str): + # Check if it's a valid file path + is_file = os.path.isfile(input_data) + else: + raise ValueError("input_data must be a string or Path object") - # Read content + # Read file content if needed if is_file: file_path = Path(input_data) if not file_path.exists(): - raise FileNotFoundError(f"File not found: {input_data}") + raise FileNotFoundError(f"File not found: {file_path}") - # Read file content try: + # Try reading as text with open(file_path, 'r', encoding='utf-8') as f: - content = f.read() + text_content = f.read() except UnicodeDecodeError: - # Try reading as binary and decode with errors='ignore' - with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: - content = f.read() - - print(f"Read {len(content)} characters from file: {file_path}") + # If binary file, convert to string representation + with open(file_path, 'rb') as f: + # For binary files, use filename + size as content + text_content = f"{file_path.name} (binary file, size: {file_path.stat().st_size} bytes)" else: - content = str(input_data) - - # Validate content - if not content or not content.strip(): - raise ValueError("Content is empty. Cannot generate embedding for empty content.") - - # Generate embedding using Gemini API - genai = get_genai_client() + text_content = str(input_data) - # Use Gemini's embedding model - result = genai.embed_content( - model="models/embedding-001", - content=content, - task_type="retrieval_document" - ) + if not text_content or not text_content.strip(): + raise ValueError("Cannot generate embedding from empty content") - # Extract embedding from response - embedding = np.array(result['embedding'], dtype='float32') + # Generate embedding using local model + embedding = model.encode(text_content, convert_to_numpy=True) - print(f"Generated embedding with shape: {embedding.shape}") - return embedding + # Ensure it's a 1D numpy array of float32 + return embedding.astype(np.float32) diff --git a/backend/rag_server.py b/backend/rag_server.py new file mode 100644 index 0000000..50eb603 --- /dev/null +++ b/backend/rag_server.py @@ -0,0 +1,267 @@ +""" +RAG Server for distributed file transfer system. +Provides API endpoints for file upload, search, and retrieval using vector embeddings. +""" + +from fastapi import FastAPI, UploadFile, File, HTTPException, Query +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel +from typing import List, Optional +import os +import tempfile +from pathlib import Path +from dotenv import load_dotenv + +# Load environment variables from .env file +load_dotenv() + +from embedding import generate_embedding +from vector_db import VectorDatabase, init_vector_db +from database import create_db_and_tables, get_session, engine +from models import FileRecord +from sqlmodel import Session, select +from datetime import datetime + +# Initialize FastAPI app +app = FastAPI(title="RAG File Server", version="1.0.0") + +# CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Global instances +vector_db: Optional[VectorDatabase] = None +EMBEDDING_DIM = 384 # all-MiniLM-L6-v2 dimension (local model, no API key needed!) + +# Request/Response models +class QueryRequest(BaseModel): + query: str + top_k: int = 5 + +class SearchResult(BaseModel): + file_id: int + filename: str + path: str + device: str + alias: str + size: int + distance: float + content_preview: Optional[str] = None + +class UploadResponse(BaseModel): + file_id: int + filename: str + message: str + + +@app.on_event("startup") +async def startup_event(): + """Initialize databases on startup.""" + global vector_db + + # Initialize file database + create_db_and_tables() + + # Initialize vector database + vector_db = init_vector_db(dimension=EMBEDDING_DIM) + + print("āœ“ RAG Server initialized successfully") + + +@app.get("/") +async def root(): + """Health check endpoint.""" + return { + "status": "running", + "service": "RAG File Server", + "version": "1.0.0" + } + + +@app.post("/upload", response_model=UploadResponse) +async def upload_file( + file: UploadFile = File(...), + device: str = Query(..., description="Device identifier"), + device_ip: str = Query(..., description="Device IP address"), + device_user: str = Query(..., description="Device SSH username"), + absolute_path: str = Query(..., description="Absolute path on device") +): + """ + Upload a file, generate embedding, and store in both databases. + """ + tmp_path = None + try: + # Save uploaded file temporarily + suffix = Path(file.filename or "file.txt").suffix + with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: + content = await file.read() + tmp.write(content) + tmp_path = tmp.name + + # Generate embedding from file content + embedding = generate_embedding(tmp_path, is_file=True) + + # Store file metadata in database + session = Session(engine) + try: + file_record = FileRecord( + file_name=file.filename or "unknown", + absolute_path=absolute_path, + device=device, + device_ip=device_ip, + device_user=device_user, + last_modified_time=datetime.utcnow(), + size=len(content), + file_type=suffix or ".txt" + ) + session.add(file_record) + session.commit() + session.refresh(file_record) + + # Store embedding in vector database + success = vector_db.insert(embedding, file_record.id) + + if not success: + raise HTTPException(status_code=500, detail="Failed to store embedding") + + return UploadResponse( + file_id=file_record.id, + filename=file.filename or "unknown", + message="File uploaded and indexed successfully" + ) + finally: + session.close() + + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + finally: + # Clean up temp file + if tmp_path and os.path.exists(tmp_path): + os.unlink(tmp_path) + + +@app.post("/search", response_model=List[SearchResult]) +async def search_files(request: QueryRequest): + """ + Search for similar files using semantic search. + """ + try: + # Generate embedding for query + query_embedding = generate_embedding(request.query, is_file=False) + + # Search vector database + results = vector_db.get_file(query_embedding, k=request.top_k) + + # Retrieve file metadata + search_results = [] + session = Session(engine) + try: + for file_id, distance in results: + file_obj = session.get(FileRecord, file_id) + if file_obj: + search_results.append(SearchResult( + file_id=file_obj.id or 0, + filename=file_obj.file_name, + path=file_obj.absolute_path, + device=file_obj.device, + alias=file_obj.device_user, + size=file_obj.size, + distance=distance, + content_preview=None + )) + return search_results + finally: + session.close() + + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@app.get("/files", response_model=List[dict]) +async def list_files(): + """ + List all files in the database. + """ + try: + session = Session(engine) + try: + files = session.exec(select(FileRecord)).all() + return [ + { + "id": f.id, + "file_name": f.file_name, + "absolute_path": f.absolute_path, + "device": f.device, + "device_ip": f.device_ip, + "device_user": f.device_user, + "last_modified_time": f.last_modified_time.isoformat(), + "created_time": f.created_time.isoformat(), + "size": f.size, + "file_type": f.file_type + } + for f in files + ] + finally: + session.close() + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@app.get("/files/{file_id}") +async def get_file(file_id: int): + """ + Get file metadata by ID. + """ + session = Session(engine) + try: + file_obj = session.get(FileRecord, file_id) + if not file_obj: + raise HTTPException(status_code=404, detail="File not found") + return { + "id": file_obj.id, + "file_name": file_obj.file_name, + "absolute_path": file_obj.absolute_path, + "device": file_obj.device, + "device_ip": file_obj.device_ip, + "device_user": file_obj.device_user, + "last_modified_time": file_obj.last_modified_time.isoformat(), + "created_time": file_obj.created_time.isoformat(), + "size": file_obj.size, + "file_type": file_obj.file_type + } + finally: + session.close() + + +@app.get("/stats") +async def get_stats(): + """ + Get statistics about the RAG system. + """ + try: + vector_stats = vector_db.get_stats() + session = Session(engine) + try: + files = session.exec(select(FileRecord)).all() + total_files = len(files) + finally: + session.close() + + return { + "total_files": total_files, + "total_vectors": vector_stats['total_vectors'], + "vector_dimension": vector_stats['dimension'], + "index_type": vector_stats['index_type'] + } + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/backend/vector_db.py b/backend/vector_db.py new file mode 100644 index 0000000..5d74e04 --- /dev/null +++ b/backend/vector_db.py @@ -0,0 +1,305 @@ +""" +Vector database module for distributed file transfer system. +Handles vector embeddings storage and similarity search using FAISS. +Uses SQLite for metadata storage. +""" + +import faiss +import numpy as np +import sqlite3 +import pickle +import os +from typing import List, Tuple, Optional + +# Vector database setup +VECTOR_DB_PATH = 'vectors.db' +FAISS_INDEX_PATH = 'faiss.index' +EMBEDDING_DIMENSION = 768 # Default for Gemini embeddings + + +def init_vector_db(db_path: str = VECTOR_DB_PATH, dimension: int = EMBEDDING_DIMENSION): + """ + Initialize the vector database, creating the table if it doesn't exist. + + Args: + db_path: Path to the SQLite database file + dimension: Dimension of vector embeddings + + Returns: + VectorDatabase instance + """ + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + + # Create table for vector embeddings + cursor.execute(''' + CREATE TABLE IF NOT EXISTS vector_embeddings ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + file_id INTEGER NOT NULL UNIQUE, + embedding BLOB NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + ''') + + # Create index on file_id for faster lookups + cursor.execute(''' + CREATE INDEX IF NOT EXISTS idx_file_id ON vector_embeddings(file_id) + ''') + + conn.commit() + conn.close() + print(f"Vector database initialized at {db_path}") + + return VectorDatabase(db_path=db_path, dimension=dimension) + + +class VectorDatabase: + """ + Vector database manager using FAISS for efficient similarity search + and SQLite for metadata storage. + """ + + def __init__(self, db_path: str = VECTOR_DB_PATH, dimension: int = EMBEDDING_DIMENSION): + """ + Initialize the vector database. + + Args: + db_path: Path to SQLite database + dimension: Dimension of the vector embeddings + """ + self.db_path = db_path + self.dimension = dimension + self.index = None + self.id_to_file_id = {} # Maps FAISS index position to file_id + self.file_id_to_index = {} # Maps file_id to FAISS index position + + # Load or create FAISS index + self._load_or_create_index() + + def _load_or_create_index(self): + """Load existing FAISS index or create a new one.""" + if os.path.exists(FAISS_INDEX_PATH): + try: + self.index = faiss.read_index(FAISS_INDEX_PATH) + self._load_mappings() + print(f"Loaded FAISS index from {FAISS_INDEX_PATH} with {self.index.ntotal} vectors") + except Exception as e: + print(f"Error loading FAISS index: {e}") + self._create_new_index() + else: + self._create_new_index() + + def _create_new_index(self): + """Create a new FAISS index.""" + # Using L2 distance + self.index = faiss.IndexFlatL2(self.dimension) + print(f"Created new FAISS index with dimension {self.dimension}") + + def _load_mappings(self): + """Load file_id mappings from database.""" + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + + cursor.execute('SELECT id, file_id FROM vector_embeddings ORDER BY id') + rows = cursor.fetchall() + + for idx, (db_id, file_id) in enumerate(rows): + self.id_to_file_id[idx] = file_id + self.file_id_to_index[file_id] = idx + + conn.close() + + def _save_index(self): + """Save FAISS index to disk.""" + if self.index is not None: + faiss.write_index(self.index, FAISS_INDEX_PATH) + + def insert(self, vector: np.ndarray, file_id: int) -> bool: + """ + Insert a vector embedding for a file. + + Args: + vector: Numpy array of shape (dimension,) + file_id: ID of the file in the main database + + Returns: + True if successful, False if file_id already exists + + Raises: + ValueError: If vector dimension doesn't match + """ + # Validate vector dimension + if vector.shape[0] != self.dimension: + raise ValueError(f"Vector dimension {vector.shape} doesn't match expected ({self.dimension},)") + + # Check if file_id already exists + if file_id in self.file_id_to_index: + print(f"Embedding for file_id {file_id} already exists.") + return False + + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + + try: + # Serialize vector + vector_blob = pickle.dumps(vector) + + # Store in SQLite database + cursor.execute( + 'INSERT INTO vector_embeddings (file_id, embedding) VALUES (?, ?)', + (file_id, vector_blob) + ) + conn.commit() + + # Add to FAISS index + vector_reshaped = vector.reshape(1, -1).astype('float32') + self.index.add(vector_reshaped) + + # Update mappings + idx = self.index.ntotal - 1 + self.id_to_file_id[idx] = file_id + self.file_id_to_index[file_id] = idx + + # Save index + self._save_index() + + print(f"Inserted embedding for file_id {file_id} at index {idx}") + return True + + except sqlite3.IntegrityError: + print(f"Embedding for file_id {file_id} already exists in database.") + return False + except Exception as e: + conn.rollback() + print(f"Error inserting embedding: {e}") + raise + finally: + conn.close() + + def get_file(self, vector: np.ndarray, k: int = 1) -> List[Tuple[int, float]]: + """ + Find the k most similar files for a given query vector. + + Args: + vector: Query vector as numpy array of shape (dimension,) + k: Number of nearest neighbors to return + + Returns: + List of tuples (file_id, distance) sorted by similarity + + Raises: + ValueError: If vector dimension doesn't match + """ + # Validate vector dimension + if vector.shape[0] != self.dimension: + raise ValueError(f"Vector dimension {vector.shape} doesn't match expected ({self.dimension},)") + + if self.index.ntotal == 0: + print("No vectors in the index.") + return [] + + # Ensure k doesn't exceed total vectors + k = min(k, self.index.ntotal) + + # Search FAISS index + vector_reshaped = vector.reshape(1, -1).astype('float32') + distances, indices = self.index.search(vector_reshaped, k) + + # Map indices to file_ids + results = [] + for idx, distance in zip(indices[0], distances[0]): + if idx in self.id_to_file_id: + file_id = self.id_to_file_id[idx] + results.append((file_id, float(distance))) + + return results + + def delete_embedding(self, file_id: int) -> bool: + """ + Delete an embedding by file_id. + Note: FAISS doesn't support deletion, so we rebuild the index. + + Args: + file_id: ID of the file to remove + + Returns: + True if successful, False if not found + """ + if file_id not in self.file_id_to_index: + print(f"No embedding found for file_id {file_id}") + return False + + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + + try: + # Remove from database + cursor.execute('DELETE FROM vector_embeddings WHERE file_id = ?', (file_id,)) + conn.commit() + + # Rebuild FAISS index + self._rebuild_index() + + print(f"Deleted embedding for file_id {file_id}") + return True + + except Exception as e: + conn.rollback() + print(f"Error deleting embedding: {e}") + raise + finally: + conn.close() + + def _rebuild_index(self): + """Rebuild FAISS index from database.""" + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + + try: + # Create new index + self._create_new_index() + self.id_to_file_id.clear() + self.file_id_to_index.clear() + + # Load all embeddings + cursor.execute('SELECT id, file_id, embedding FROM vector_embeddings ORDER BY id') + rows = cursor.fetchall() + + if rows: + vectors = [] + for idx, (db_id, file_id, embedding_blob) in enumerate(rows): + vector = pickle.loads(embedding_blob) + vectors.append(vector) + self.id_to_file_id[idx] = file_id + self.file_id_to_index[file_id] = idx + + # Add all vectors to index + vectors_array = np.array(vectors).astype('float32') + self.index.add(vectors_array) + + # Save index + self._save_index() + + print(f"Rebuilt index with {len(rows)} vectors") + finally: + conn.close() + + def get_stats(self) -> dict: + """ + Get statistics about the vector database. + + Returns: + Dictionary with stats + """ + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + cursor.execute('SELECT COUNT(*) FROM vector_embeddings') + db_count = cursor.fetchone()[0] + conn.close() + + return { + 'total_vectors': self.index.ntotal if self.index else 0, + 'db_records': db_count, + 'dimension': self.dimension, + 'index_type': type(self.index).__name__ if self.index else None + } From 35294a9e6de7c0a0c4d2d5393b77e5829c498bca Mon Sep 17 00:00:00 2001 From: yuqiannemo Date: Sun, 19 Oct 2025 21:38:39 +0800 Subject: [PATCH 4/4] Refactor to unified server --- backend/main.py | 203 -------------- backend/rag_server.py | 267 ------------------ backend/test_api.py | 189 +++++++++++++ backend/unified_server.py | 572 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 761 insertions(+), 470 deletions(-) delete mode 100644 backend/main.py delete mode 100644 backend/rag_server.py create mode 100644 backend/test_api.py create mode 100644 backend/unified_server.py diff --git a/backend/main.py b/backend/main.py deleted file mode 100644 index aa63a42..0000000 --- a/backend/main.py +++ /dev/null @@ -1,203 +0,0 @@ -from fastapi import FastAPI, Depends, HTTPException, Query, Body -from sqlmodel import Session, select, or_ -from typing import List, Optional -from pathlib import Path -from datetime import datetime -from pydantic import BaseModel - -from models import FileRecord, FileSearchResponse -from database import create_db_and_tables, get_session - -app = FastAPI(title="File Sync API", version="1.0.0") - - -@app.on_event("startup") -def on_startup(): - """ - Initialize database on startup - """ - create_db_and_tables() - - -@app.get("/") -def read_root(): - """ - Health check endpoint - """ - return {"status": "File Sync API is running"} - - -@app.get("/files/search", response_model=List[FileSearchResponse]) -def search_files( - query: str = Query(..., description="Search query for file name (supports * wildcard)"), - session: Session = Depends(get_session) -): - """ - GET endpoint: Fuzzy search for files by name with wildcard support - - Returns ONLY metadata - no actual files are stored on the Pi. - Files will be retrieved via SCP from source devices when user confirms download. - - Parameters: - - query: The search term (supports * wildcard for fuzzy matching) - Examples: "*.txt", "report*", "*2024*", "document.pdf" - - Returns: - - List of matching file metadata records - """ - # Convert wildcard pattern to SQL LIKE pattern - # * becomes % in SQL LIKE syntax - like_pattern = query.replace("*", "%") - - # Fuzzy search on file_name (case-insensitive with LIKE) - statement = select(FileRecord).where(FileRecord.file_name.like(f"%{like_pattern}%")) - - results = session.exec(statement).all() - - if not results: - raise HTTPException(status_code=404, detail=f"No files found matching '{query}'") - - return results - - -@app.get("/files/{file_id}", response_model=FileSearchResponse) -def get_file_metadata(file_id: int, session: Session = Depends(get_session)): - """ - Get metadata for a specific file by its ID - - This returns the file location info so the CLI can initiate SCP transfer. - - Parameters: - - file_id: The database ID of the file - - Returns: - - File metadata including device IP and path for SCP retrieval - """ - statement = select(FileRecord).where(FileRecord.id == file_id) - file_record = session.exec(statement).first() - - if not file_record: - raise HTTPException(status_code=404, detail=f"File with ID {file_id} not found") - - return file_record - - -class FileMetadata(BaseModel): - """Request body for registering file metadata""" - file_name: str - absolute_path: str - device: str - device_ip: str - device_user: str - last_modified_time: datetime - size: int - file_type: str - - -@app.post("/files/register") -def register_file( - file_metadata: FileMetadata, - session: Session = Depends(get_session) -): - """ - POST endpoint: Register file metadata (NO FILE UPLOAD) - - The Pi does NOT receive the actual file. It only stores metadata about where - the file exists on the network. When a user wants to download, the CLI will - use SCP to retrieve it from the source device. - - If a file with the same absolute_path and device already exists, it will be - updated (only latest version is kept). - - Parameters: - - file_metadata: JSON body containing all file metadata - - Returns: - - Success message with file ID - """ - try: - # Check if file already exists (same absolute_path + device) - statement = select(FileRecord).where( - FileRecord.absolute_path == file_metadata.absolute_path, - FileRecord.device == file_metadata.device - ) - existing_file = session.exec(statement).first() - - if existing_file: - # Update existing record (keep only latest version) - existing_file.file_name = file_metadata.file_name - existing_file.device_ip = file_metadata.device_ip - existing_file.device_user = file_metadata.device_user - existing_file.last_modified_time = file_metadata.last_modified_time - existing_file.size = file_metadata.size - existing_file.file_type = file_metadata.file_type - existing_file.created_time = datetime.utcnow() # Update timestamp - - session.add(existing_file) - session.commit() - session.refresh(existing_file) - - return { - "message": "File metadata updated successfully", - "file_id": existing_file.id, - "file_name": existing_file.file_name, - "action": "updated" - } - else: - # Create new record - file_record = FileRecord( - file_name=file_metadata.file_name, - absolute_path=file_metadata.absolute_path, - device=file_metadata.device, - device_ip=file_metadata.device_ip, - device_user=file_metadata.device_user, - last_modified_time=file_metadata.last_modified_time, - size=file_metadata.size, - file_type=file_metadata.file_type - ) - - session.add(file_record) - session.commit() - session.refresh(file_record) - - return { - "message": "File metadata registered successfully", - "file_id": file_record.id, - "file_name": file_record.file_name, - "action": "created" - } - - except Exception as e: - raise HTTPException(status_code=500, detail=f"Error registering file metadata: {str(e)}") - - -@app.delete("/files/{file_id}") -def delete_file_metadata(file_id: int, session: Session = Depends(get_session)): - """ - DELETE endpoint: Delete file metadata record - - This only removes the metadata from the Pi's database. - The actual file remains on the source device. - - Parameters: - - file_id: The database ID of the file metadata to delete - - Returns: - - Success message - """ - statement = select(FileRecord).where(FileRecord.id == file_id) - file_record = session.exec(statement).first() - - if not file_record: - raise HTTPException(status_code=404, detail=f"File with ID {file_id} not found") - - # Delete only the database record (no actual file to delete on Pi) - session.delete(file_record) - session.commit() - - return { - "message": "File metadata deleted successfully", - "file_id": file_id, - "file_name": file_record.file_name, - "device": file_record.device - } diff --git a/backend/rag_server.py b/backend/rag_server.py deleted file mode 100644 index 50eb603..0000000 --- a/backend/rag_server.py +++ /dev/null @@ -1,267 +0,0 @@ -""" -RAG Server for distributed file transfer system. -Provides API endpoints for file upload, search, and retrieval using vector embeddings. -""" - -from fastapi import FastAPI, UploadFile, File, HTTPException, Query -from fastapi.middleware.cors import CORSMiddleware -from pydantic import BaseModel -from typing import List, Optional -import os -import tempfile -from pathlib import Path -from dotenv import load_dotenv - -# Load environment variables from .env file -load_dotenv() - -from embedding import generate_embedding -from vector_db import VectorDatabase, init_vector_db -from database import create_db_and_tables, get_session, engine -from models import FileRecord -from sqlmodel import Session, select -from datetime import datetime - -# Initialize FastAPI app -app = FastAPI(title="RAG File Server", version="1.0.0") - -# CORS middleware -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -# Global instances -vector_db: Optional[VectorDatabase] = None -EMBEDDING_DIM = 384 # all-MiniLM-L6-v2 dimension (local model, no API key needed!) - -# Request/Response models -class QueryRequest(BaseModel): - query: str - top_k: int = 5 - -class SearchResult(BaseModel): - file_id: int - filename: str - path: str - device: str - alias: str - size: int - distance: float - content_preview: Optional[str] = None - -class UploadResponse(BaseModel): - file_id: int - filename: str - message: str - - -@app.on_event("startup") -async def startup_event(): - """Initialize databases on startup.""" - global vector_db - - # Initialize file database - create_db_and_tables() - - # Initialize vector database - vector_db = init_vector_db(dimension=EMBEDDING_DIM) - - print("āœ“ RAG Server initialized successfully") - - -@app.get("/") -async def root(): - """Health check endpoint.""" - return { - "status": "running", - "service": "RAG File Server", - "version": "1.0.0" - } - - -@app.post("/upload", response_model=UploadResponse) -async def upload_file( - file: UploadFile = File(...), - device: str = Query(..., description="Device identifier"), - device_ip: str = Query(..., description="Device IP address"), - device_user: str = Query(..., description="Device SSH username"), - absolute_path: str = Query(..., description="Absolute path on device") -): - """ - Upload a file, generate embedding, and store in both databases. - """ - tmp_path = None - try: - # Save uploaded file temporarily - suffix = Path(file.filename or "file.txt").suffix - with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: - content = await file.read() - tmp.write(content) - tmp_path = tmp.name - - # Generate embedding from file content - embedding = generate_embedding(tmp_path, is_file=True) - - # Store file metadata in database - session = Session(engine) - try: - file_record = FileRecord( - file_name=file.filename or "unknown", - absolute_path=absolute_path, - device=device, - device_ip=device_ip, - device_user=device_user, - last_modified_time=datetime.utcnow(), - size=len(content), - file_type=suffix or ".txt" - ) - session.add(file_record) - session.commit() - session.refresh(file_record) - - # Store embedding in vector database - success = vector_db.insert(embedding, file_record.id) - - if not success: - raise HTTPException(status_code=500, detail="Failed to store embedding") - - return UploadResponse( - file_id=file_record.id, - filename=file.filename or "unknown", - message="File uploaded and indexed successfully" - ) - finally: - session.close() - - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - finally: - # Clean up temp file - if tmp_path and os.path.exists(tmp_path): - os.unlink(tmp_path) - - -@app.post("/search", response_model=List[SearchResult]) -async def search_files(request: QueryRequest): - """ - Search for similar files using semantic search. - """ - try: - # Generate embedding for query - query_embedding = generate_embedding(request.query, is_file=False) - - # Search vector database - results = vector_db.get_file(query_embedding, k=request.top_k) - - # Retrieve file metadata - search_results = [] - session = Session(engine) - try: - for file_id, distance in results: - file_obj = session.get(FileRecord, file_id) - if file_obj: - search_results.append(SearchResult( - file_id=file_obj.id or 0, - filename=file_obj.file_name, - path=file_obj.absolute_path, - device=file_obj.device, - alias=file_obj.device_user, - size=file_obj.size, - distance=distance, - content_preview=None - )) - return search_results - finally: - session.close() - - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -@app.get("/files", response_model=List[dict]) -async def list_files(): - """ - List all files in the database. - """ - try: - session = Session(engine) - try: - files = session.exec(select(FileRecord)).all() - return [ - { - "id": f.id, - "file_name": f.file_name, - "absolute_path": f.absolute_path, - "device": f.device, - "device_ip": f.device_ip, - "device_user": f.device_user, - "last_modified_time": f.last_modified_time.isoformat(), - "created_time": f.created_time.isoformat(), - "size": f.size, - "file_type": f.file_type - } - for f in files - ] - finally: - session.close() - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -@app.get("/files/{file_id}") -async def get_file(file_id: int): - """ - Get file metadata by ID. - """ - session = Session(engine) - try: - file_obj = session.get(FileRecord, file_id) - if not file_obj: - raise HTTPException(status_code=404, detail="File not found") - return { - "id": file_obj.id, - "file_name": file_obj.file_name, - "absolute_path": file_obj.absolute_path, - "device": file_obj.device, - "device_ip": file_obj.device_ip, - "device_user": file_obj.device_user, - "last_modified_time": file_obj.last_modified_time.isoformat(), - "created_time": file_obj.created_time.isoformat(), - "size": file_obj.size, - "file_type": file_obj.file_type - } - finally: - session.close() - - -@app.get("/stats") -async def get_stats(): - """ - Get statistics about the RAG system. - """ - try: - vector_stats = vector_db.get_stats() - session = Session(engine) - try: - files = session.exec(select(FileRecord)).all() - total_files = len(files) - finally: - session.close() - - return { - "total_files": total_files, - "total_vectors": vector_stats['total_vectors'], - "vector_dimension": vector_stats['dimension'], - "index_type": vector_stats['index_type'] - } - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -if __name__ == "__main__": - import uvicorn - uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/backend/test_api.py b/backend/test_api.py new file mode 100644 index 0000000..e83f6ed --- /dev/null +++ b/backend/test_api.py @@ -0,0 +1,189 @@ +#!/usr/bin/env python3 +""" +Test script for the Unified Server with Colinear Query Expansion +""" + +import requests +import json +from pathlib import Path + +BASE_URL = "http://localhost:8000" + +def print_section(title): + print("\n" + "="*70) + print(f" {title}") + print("="*70 + "\n") + +def test_health_check(): + print_section("1. Health Check") + response = requests.get(f"{BASE_URL}/") + print(json.dumps(response.json(), indent=2)) + +def test_upload_file(): + print_section("2. Upload Test File with Embedding") + + # Create a test file + test_content = """ + Machine Learning Tutorial + + This document covers the basics of machine learning, including: + - Supervised learning algorithms + - Neural networks and deep learning + - Data preprocessing techniques + - Model evaluation metrics + """ + + test_file = Path("/tmp/ml_tutorial.txt") + test_file.write_text(test_content) + + # Upload the file + with open(test_file, 'rb') as f: + files = {'file': ('ml_tutorial.txt', f, 'text/plain')} + params = { + 'device': 'laptop', + 'device_ip': '192.168.1.100', + 'device_user': 'testuser', + 'absolute_path': '/home/testuser/docs/ml_tutorial.txt' + } + response = requests.post(f"{BASE_URL}/files/upload", files=files, params=params) + + print(f"Status: {response.status_code}") + print(json.dumps(response.json(), indent=2)) + + # Upload another file + test_content2 = """ + Python Programming Guide + + Learn Python programming from scratch: + - Variables and data types + - Functions and classes + - File I/O operations + - Error handling + """ + + test_file2 = Path("/tmp/python_guide.txt") + test_file2.write_text(test_content2) + + with open(test_file2, 'rb') as f: + files = {'file': ('python_guide.txt', f, 'text/plain')} + params = { + 'device': 'desktop', + 'device_ip': '192.168.1.101', + 'device_user': 'testuser', + 'absolute_path': '/home/testuser/docs/python_guide.txt' + } + response = requests.post(f"{BASE_URL}/files/upload", files=files, params=params) + + print(f"\nSecond file - Status: {response.status_code}") + print(json.dumps(response.json(), indent=2)) + +def test_semantic_search_without_expansion(): + print_section("3. Semantic Search WITHOUT Query Expansion") + + query_data = { + "query": "learning tutorial", + "top_k": 5, + "use_query_expansion": False + } + + response = requests.post(f"{BASE_URL}/search/semantic", json=query_data) + print(f"Status: {response.status_code}") + + if response.status_code == 200: + results = response.json() + print(f"Found {len(results)} results:\n") + for i, result in enumerate(results, 1): + print(f"{i}. {result['filename']}") + print(f" Similarity: {result['similarity_score']:.4f}") + print(f" Matched via: {result['matched_via']}") + print(f" Device: {result['device']} ({result['device_ip']})") + print() + else: + print(response.text) + +def test_semantic_search_with_expansion(): + print_section("4. Semantic Search WITH Query Expansion (Colinear)") + + query_data = { + "query": "learning tutorial", + "top_k": 5, + "use_query_expansion": True, + "expansion_count": 3 + } + + print("Query: 'learning tutorial'") + print("Expansion enabled: True") + print("Expansion count: 3") + print("\nExpected query variants:") + print(" 0. learning tutorial (original)") + print(" 1. document about learning tutorial") + print(" 2. file containing learning tutorial") + print(" 3. information regarding learning tutorial") + print("\n" + "-"*70 + "\n") + + response = requests.post(f"{BASE_URL}/search/semantic", json=query_data) + print(f"Status: {response.status_code}") + + if response.status_code == 200: + results = response.json() + print(f"Found {len(results)} results:\n") + for i, result in enumerate(results, 1): + print(f"{i}. {result['filename']}") + print(f" Similarity: {result['similarity_score']:.4f}") + print(f" Matched via: {result['matched_via']} ⭐") + print(f" Device: {result['device']} ({result['device_ip']})") + print(f" Path: {result['path']}") + print() + else: + print(response.text) + +def test_keyword_search(): + print_section("5. Traditional Keyword Search") + + response = requests.get(f"{BASE_URL}/search/keyword?query=*.txt") + print(f"Status: {response.status_code}") + + if response.status_code == 200: + results = response.json() + print(f"Found {len(results)} .txt files:\n") + for i, result in enumerate(results, 1): + print(f"{i}. {result['file_name']}") + print(f" Device: {result['device']}") + print() + else: + print(response.text) + +def test_stats(): + print_section("6. System Statistics") + + response = requests.get(f"{BASE_URL}/stats") + print(json.dumps(response.json(), indent=2)) + +def main(): + print("\n" + "šŸš€ " * 25) + print(" Testing Unified Server with Colinear Query Expansion") + print("šŸš€ " * 25) + + try: + test_health_check() + test_upload_file() + test_semantic_search_without_expansion() + test_semantic_search_with_expansion() + test_keyword_search() + test_stats() + + print_section("āœ… All Tests Complete!") + print("Key Observations:") + print("- Semantic search WITH expansion may return more/better results") + print("- 'matched_via' shows which query variant found the match") + print("- Similarity scores range from 0-1 (higher = better)") + print("- Query expansion helps find relevant files with different wording") + + except requests.exceptions.ConnectionError: + print("\nāŒ Error: Could not connect to server at", BASE_URL) + print("Make sure the server is running: python unified_server.py") + except Exception as e: + print(f"\nāŒ Error: {e}") + +if __name__ == "__main__": + main() diff --git a/backend/unified_server.py b/backend/unified_server.py new file mode 100644 index 0000000..cb11f17 --- /dev/null +++ b/backend/unified_server.py @@ -0,0 +1,572 @@ +""" +Unified File Sync & RAG Server +Combines file metadata management with semantic search using vector embeddings. +Implements Colinear Query Expansion for improved search results. +""" + +from fastapi import FastAPI, UploadFile, File, HTTPException, Query, Depends, Body +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel +from sqlmodel import Session, select, or_ +from typing import List, Optional +from contextlib import asynccontextmanager +import os +import tempfile +from pathlib import Path +from datetime import datetime +import numpy as np +from dotenv import load_dotenv + +# Load environment variables +load_dotenv() + +from embedding import generate_embedding +from vector_db import VectorDatabase, init_vector_db +from database import create_db_and_tables, get_session, engine +from models import FileRecord, FileSearchResponse + +# Global instances +vector_db: Optional[VectorDatabase] = None +EMBEDDING_DIM = 384 # all-MiniLM-L6-v2 dimension (local model, no API key needed!) + + +# ==================== Lifespan Context Manager ==================== + +@asynccontextmanager +async def lifespan(app: FastAPI): + """ + Lifespan context manager for startup and shutdown events. + Replaces deprecated @app.on_event decorators. + """ + global vector_db + + # Startup + print("šŸš€ Starting Unified File Sync & RAG Server...") + create_db_and_tables() + vector_db = init_vector_db(dimension=EMBEDDING_DIM) + print("āœ“ Unified File Sync & RAG Server initialized successfully") + + yield # Server runs here + + # Shutdown + print("šŸ›‘ Shutting down server...") + + +# Initialize FastAPI app with lifespan +app = FastAPI( + title="Unified File Sync & RAG Server", + version="2.0.0", + description="File metadata management with semantic search and query expansion", + lifespan=lifespan +) + +# CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Global instances +vector_db: Optional[VectorDatabase] = None +EMBEDDING_DIM = 384 # all-MiniLM-L6-v2 dimension (local model, no API key needed!) + +# ==================== Request/Response Models ==================== + +class FileMetadata(BaseModel): + """Request body for registering file metadata""" + file_name: str + absolute_path: str + device: str + device_ip: str + device_user: str + last_modified_time: datetime + size: int + file_type: str + + +class QueryRequest(BaseModel): + """Semantic search query with query expansion support""" + query: str + top_k: int = 5 + use_query_expansion: bool = True # Enable Colinear Query Expansion by default + expansion_count: int = 3 # Number of expanded queries to generate + + +class SearchResult(BaseModel): + """Enhanced search result with similarity score""" + file_id: int + filename: str + path: str + device: str + device_ip: str + device_user: str + size: int + file_type: str + similarity_score: float # 0-1, higher is better (converted from distance) + last_modified_time: datetime + matched_via: str # "original_query" or "expanded_query_N" + + +class UploadResponse(BaseModel): + file_id: int + filename: str + message: str + action: str # "created" or "updated" + + +# ==================== CORS Middleware ==================== + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # Configure appropriately for production + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +# ==================== Health Check ==================== + +@app.get("/") +async def root(): + """Health check endpoint.""" + return { + "status": "running", + "service": "Unified File Sync & RAG Server", + "version": "2.0.0", + "features": ["metadata_search", "semantic_search", "query_expansion"] + } + + +# ==================== File Registration (Metadata Only) ==================== + +@app.post("/files/register", response_model=UploadResponse) +def register_file( + file_metadata: FileMetadata, + session: Session = Depends(get_session) +): + """ + Register file metadata WITHOUT uploading the actual file. + + The Pi does NOT receive the actual file. It only stores metadata about where + the file exists on the network. When a user wants to download, the CLI will + use SCP to retrieve it from the source device. + + If a file with the same absolute_path and device already exists, it will be updated. + """ + try: + # Check if file already exists (same absolute_path + device) + statement = select(FileRecord).where( + FileRecord.absolute_path == file_metadata.absolute_path, + FileRecord.device == file_metadata.device + ) + existing_file = session.exec(statement).first() + + if existing_file: + # Update existing record (keep only latest version) + existing_file.file_name = file_metadata.file_name + existing_file.device_ip = file_metadata.device_ip + existing_file.device_user = file_metadata.device_user + existing_file.last_modified_time = file_metadata.last_modified_time + existing_file.size = file_metadata.size + existing_file.file_type = file_metadata.file_type + existing_file.created_time = datetime.utcnow() + + session.add(existing_file) + session.commit() + session.refresh(existing_file) + + return UploadResponse( + file_id=existing_file.id, + filename=existing_file.file_name, + message="File metadata updated successfully", + action="updated" + ) + else: + # Create new record + file_record = FileRecord( + file_name=file_metadata.file_name, + absolute_path=file_metadata.absolute_path, + device=file_metadata.device, + device_ip=file_metadata.device_ip, + device_user=file_metadata.device_user, + last_modified_time=file_metadata.last_modified_time, + size=file_metadata.size, + file_type=file_metadata.file_type + ) + + session.add(file_record) + session.commit() + session.refresh(file_record) + + return UploadResponse( + file_id=file_record.id, + filename=file_record.file_name, + message="File metadata registered successfully", + action="created" + ) + + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error registering file metadata: {str(e)}") + + +# ==================== File Upload with Embedding (For RAG) ==================== + +@app.post("/files/upload", response_model=UploadResponse) +async def upload_file_with_embedding( + file: UploadFile = File(...), + device: str = Query(..., description="Device identifier"), + device_ip: str = Query(..., description="Device IP address"), + device_user: str = Query(..., description="Device SSH username"), + absolute_path: str = Query(..., description="Absolute path on device") +): + """ + Upload a file, generate embedding, and store in both databases. + + Use this endpoint when you want to enable semantic search on the file content. + The file content is read to generate embeddings, then discarded. + """ + tmp_path = None + try: + # Save uploaded file temporarily + suffix = Path(file.filename or "file.txt").suffix + with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: + content = await file.read() + tmp.write(content) + tmp_path = tmp.name + + # Generate embedding from file content + embedding = generate_embedding(tmp_path, is_file=True) + + # Store file metadata in database + session = Session(engine) + try: + # Check if file already exists + statement = select(FileRecord).where( + FileRecord.absolute_path == absolute_path, + FileRecord.device == device + ) + existing_file = session.exec(statement).first() + + if existing_file: + # Update existing record + existing_file.file_name = file.filename or "unknown" + existing_file.device_ip = device_ip + existing_file.device_user = device_user + existing_file.last_modified_time = datetime.utcnow() + existing_file.size = len(content) + existing_file.file_type = suffix or ".txt" + existing_file.created_time = datetime.utcnow() + + session.add(existing_file) + session.commit() + session.refresh(existing_file) + + # Update embedding in vector database + vector_db.delete_embedding(existing_file.id) + vector_db.insert(embedding, existing_file.id) + + return UploadResponse( + file_id=existing_file.id, + filename=file.filename or "unknown", + message="File uploaded and indexed successfully", + action="updated" + ) + else: + # Create new record + file_record = FileRecord( + file_name=file.filename or "unknown", + absolute_path=absolute_path, + device=device, + device_ip=device_ip, + device_user=device_user, + last_modified_time=datetime.utcnow(), + size=len(content), + file_type=suffix or ".txt" + ) + session.add(file_record) + session.commit() + session.refresh(file_record) + + # Store embedding in vector database + success = vector_db.insert(embedding, file_record.id) + + if not success: + raise HTTPException(status_code=500, detail="Failed to store embedding") + + return UploadResponse( + file_id=file_record.id, + filename=file.filename or "unknown", + message="File uploaded and indexed successfully", + action="created" + ) + finally: + session.close() + + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + finally: + # Clean up temp file + if tmp_path and os.path.exists(tmp_path): + os.unlink(tmp_path) + + +# ==================== Colinear Query Expansion ==================== + +def colinear_query_expansion( + original_query: str, + expansion_count: int = 3 +) -> List[str]: + """ + Generate expanded queries using colinear query expansion technique. + + This creates related queries by: + 1. Adding synonyms and related terms + 2. Reformulating the query in different ways + 3. Adding context-specific expansions + + Args: + original_query: The original search query + expansion_count: Number of expanded queries to generate + + Returns: + List of expanded query strings (including original) + """ + expanded_queries = [original_query] + + # Simple expansion strategies (can be enhanced with LLM in future) + # Strategy 1: Add common document-related terms + if expansion_count >= 1: + expanded_queries.append(f"document about {original_query}") + + # Strategy 2: Add file content context + if expansion_count >= 2: + expanded_queries.append(f"file containing {original_query}") + + # Strategy 3: Add informational context + if expansion_count >= 3: + expanded_queries.append(f"information regarding {original_query}") + + # Strategy 4: Add data context + if expansion_count >= 4: + expanded_queries.append(f"data related to {original_query}") + + return expanded_queries[:expansion_count + 1] + + +# ==================== Semantic Search with Query Expansion ==================== + +@app.post("/search/semantic", response_model=List[SearchResult]) +async def semantic_search(request: QueryRequest): + """ + Semantic search using vector embeddings with Colinear Query Expansion. + + This endpoint: + 1. Takes your search query + 2. Optionally expands it into multiple related queries (Colinear Query Expansion) + 3. Generates embeddings for each query variant + 4. Searches the vector database + 5. Combines and ranks results by relevance + + Parameters: + - query: Search query string + - top_k: Maximum number of results to return + - use_query_expansion: Enable/disable query expansion (default: True) + - expansion_count: Number of query variants to generate (default: 3) + """ + try: + all_results = {} # file_id -> (distance, matched_via) + + # Generate query variants + if request.use_query_expansion: + queries = colinear_query_expansion(request.query, request.expansion_count) + print(f"šŸ” Query Expansion: {len(queries)} variants generated") + for i, q in enumerate(queries): + print(f" {i}: {q}") + else: + queries = [request.query] + + # Search with each query variant + for idx, query_variant in enumerate(queries): + # Generate embedding for this query variant + query_embedding = generate_embedding(query_variant, is_file=False) + + # Search vector database + results = vector_db.get_file(query_embedding, k=request.top_k) + + # Track which query matched which file + matched_via = "original_query" if idx == 0 else f"expanded_query_{idx}" + + # Combine results (keep best distance for each file) + for file_id, distance in results: + if file_id not in all_results or distance < all_results[file_id][0]: + all_results[file_id] = (distance, matched_via) + + # Get top_k unique results sorted by distance + sorted_results = sorted(all_results.items(), key=lambda x: x[1][0])[:request.top_k] + + # Retrieve file metadata and build response + search_results = [] + session = Session(engine) + try: + for file_id, (distance, matched_via) in sorted_results: + file_obj = session.get(FileRecord, file_id) + if file_obj: + # Convert L2 distance to similarity score (0-1, higher is better) + # Using negative exponential: similarity = e^(-distance) + similarity_score = float(np.exp(-distance)) + + search_results.append(SearchResult( + file_id=file_obj.id or 0, + filename=file_obj.file_name, + path=file_obj.absolute_path, + device=file_obj.device, + device_ip=file_obj.device_ip, + device_user=file_obj.device_user, + size=file_obj.size, + file_type=file_obj.file_type, + similarity_score=similarity_score, + last_modified_time=file_obj.last_modified_time, + matched_via=matched_via + )) + + print(f"āœ“ Found {len(search_results)} results") + return search_results + finally: + session.close() + + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +# ==================== Keyword/Wildcard Search (Original from main.py) ==================== + +@app.get("/search/keyword", response_model=List[FileSearchResponse]) +def keyword_search( + query: str = Query(..., description="Search query for file name (supports * wildcard)"), + session: Session = Depends(get_session) +): + """ + Traditional keyword search for files by name with wildcard support. + + This is faster than semantic search but only matches filenames, not content. + + Parameters: + - query: Search term (supports * wildcard for fuzzy matching) + Examples: "*.txt", "report*", "*2024*", "document.pdf" + + Returns: + - List of matching file metadata records + """ + # Convert wildcard pattern to SQL LIKE pattern + like_pattern = query.replace("*", "%") + + # Fuzzy search on file_name (case-insensitive) + statement = select(FileRecord).where(FileRecord.file_name.like(f"%{like_pattern}%")) + + results = session.exec(statement).all() + + if not results: + raise HTTPException(status_code=404, detail=f"No files found matching '{query}'") + + return results + + +# ==================== File Management ==================== + +@app.get("/files", response_model=List[dict]) +async def list_files(session: Session = Depends(get_session)): + """List all files in the database.""" + try: + files = session.exec(select(FileRecord)).all() + return [ + { + "id": f.id, + "file_name": f.file_name, + "absolute_path": f.absolute_path, + "device": f.device, + "device_ip": f.device_ip, + "device_user": f.device_user, + "last_modified_time": f.last_modified_time.isoformat(), + "created_time": f.created_time.isoformat(), + "size": f.size, + "file_type": f.file_type + } + for f in files + ] + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@app.get("/files/{file_id}", response_model=FileSearchResponse) +async def get_file_metadata(file_id: int, session: Session = Depends(get_session)): + """ + Get metadata for a specific file by its ID. + + Returns file location info so the CLI can initiate SCP transfer. + """ + file_obj = session.get(FileRecord, file_id) + + if not file_obj: + raise HTTPException(status_code=404, detail=f"File with ID {file_id} not found") + + return file_obj + + +@app.delete("/files/{file_id}") +def delete_file_metadata(file_id: int, session: Session = Depends(get_session)): + """ + Delete file metadata record and its embedding. + + This only removes the metadata from the Pi's database. + The actual file remains on the source device. + """ + file_record = session.get(FileRecord, file_id) + + if not file_record: + raise HTTPException(status_code=404, detail=f"File with ID {file_id} not found") + + # Delete embedding from vector database + vector_db.delete_embedding(file_id) + + # Delete database record + session.delete(file_record) + session.commit() + + return { + "message": "File metadata and embedding deleted successfully", + "file_id": file_id, + "file_name": file_record.file_name, + "device": file_record.device + } + + +# ==================== Statistics ==================== + +@app.get("/stats") +async def get_stats(session: Session = Depends(get_session)): + """Get statistics about the system.""" + try: + vector_stats = vector_db.get_stats() + files = session.exec(select(FileRecord)).all() + total_files = len(files) + total_size = sum(f.size for f in files) + + return { + "total_files": total_files, + "total_vectors": vector_stats['total_vectors'], + "total_size_bytes": total_size, + "vector_dimension": vector_stats['dimension'], + "index_type": vector_stats['index_type'], + "embedding_model": "all-MiniLM-L6-v2 (local, no API key)" + } + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8000)