From 079087dd3cb3032f745368aaebc54b23be475654 Mon Sep 17 00:00:00 2001 From: Johan Berggren Date: Wed, 4 Mar 2026 13:57:17 +0100 Subject: [PATCH] Refactor DuckDB connection management and improve SQL generation reliability --- poetry.lock | 14 +- pyproject.toml | 1 + src/api/v1/files.py | 1 + src/datastores/sql/crud/authz.py | 45 +++--- src/lib/duckdb_utils.py | 244 ++++++++++++++++++------------- 5 files changed, 184 insertions(+), 121 deletions(-) diff --git a/poetry.lock b/poetry.lock index 3c7f749..a16ab97 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2944,6 +2944,18 @@ files = [ {file = "python_multipart-0.0.22.tar.gz", hash = "sha256:7340bef99a7e0032613f56dc36027b959fd3b30a787ed62d310e951f7c3a3a58"}, ] +[[package]] +name = "pytz" +version = "2026.1.post1" +description = "World timezone definitions, modern and historical" +optional = false +python-versions = "*" +groups = ["main"] +files = [ + {file = "pytz-2026.1.post1-py2.py3-none-any.whl", hash = "sha256:f2fd16142fda348286a75e1a524be810bb05d444e5a081f37f7affc635035f7a"}, + {file = "pytz-2026.1.post1.tar.gz", hash = "sha256:3378dde6a0c3d26719182142c56e60c7f9af7e968076f31aae569d72a0358ee1"}, +] + [[package]] name = "pywin32" version = "311" @@ -3675,4 +3687,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" python-versions = "^3.12" -content-hash = "96b08bcbcad6b41ba9f4010a90ecbff3c499089df6611c371dcc37730a881856" +content-hash = "ffbc3820146a9f22357ee32be0e4a1c07467ac25871c9e11036eaf18190b534d" diff --git a/pyproject.toml b/pyproject.toml index 19c96e7..bb117f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ openrelik-ai-common = "~0.5.0" openrelik-common = "^0.7.6" sse-starlette = "^2.3.6" duckdb = "^1.3.2" +pytz = "^2026.1.post1" [tool.poetry.group.dev.dependencies] pylint = "^3.1.0" diff --git a/src/api/v1/files.py b/src/api/v1/files.py index d484b3f..7916e13 100644 --- a/src/api/v1/files.py +++ b/src/api/v1/files.py @@ -574,6 +574,7 @@ def generate_query( llm_model=active_llm["config"]["model"], tables_schemas=json.dumps(tables_schemas), user_request=request.user_request, + file=file, ) return {"user_request": request.user_request, "generated_query": generated_query} diff --git a/src/datastores/sql/crud/authz.py b/src/datastores/sql/crud/authz.py index 331fd32..18c8593 100644 --- a/src/datastores/sql/crud/authz.py +++ b/src/datastores/sql/crud/authz.py @@ -21,9 +21,9 @@ from datastores.sql.models.file import File from datastores.sql.models.folder import Folder +from datastores.sql.models.group import GroupRole from datastores.sql.models.role import Role from datastores.sql.models.user import User, UserRole -from datastores.sql.models.group import GroupRole class AuthorizationError(Exception): @@ -113,9 +113,7 @@ def check_user_access( for group in user.groups: group_role = ( db.query(GroupRole) - .filter( - GroupRole.group_id == group.id, GroupRole.folder_id == folder.id - ) + .filter(GroupRole.group_id == group.id, GroupRole.folder_id == folder.id) .first() ) if group_role and group_role.role in allowed_roles: @@ -126,12 +124,11 @@ def check_user_access( return False # No access found -def require_access( - allowed_roles: list, http_exception: bool = True, error_message: str = None -): +def require_access(allowed_roles: list, http_exception: bool = True, error_message: str = None): def decorator(func: Callable): - @wraps(func) - async def wrapper(*args, **kwargs): + + def _check_access(kwargs): + """Shared access check logic for both sync and async wrappers.""" db = kwargs.get("db") folder_id = kwargs.get("folder_id") file_id = kwargs.get("file_id") @@ -141,9 +138,7 @@ async def wrapper(*args, **kwargs): folder = db.get(Folder, folder_id) if not folder: raise HTTPException(status_code=404, detail="Folder not found.") - if not check_user_access( - db, current_user, allowed_roles, folder=folder - ): + if not check_user_access(db, current_user, allowed_roles, folder=folder): raise_authorization_error( http_exception, error_message or "No access to folder" ) @@ -152,18 +147,24 @@ async def wrapper(*args, **kwargs): file = db.get(File, file_id) if not file: raise HTTPException(status_code=404, detail="File not found.") - if not check_user_access( - db, current_user, allowed_roles, file=file - ): - raise_authorization_error( - http_exception, error_message or "No access to file" - ) - # Await only if func is async - if asyncio.iscoroutinefunction(func): + if not check_user_access(db, current_user, allowed_roles, file=file): + raise_authorization_error(http_exception, error_message or "No access to file") + + if asyncio.iscoroutinefunction(func): + + @wraps(func) + async def async_wrapper(*args, **kwargs): + _check_access(kwargs) return await func(*args, **kwargs) - return func(*args, **kwargs) # Call directly if func is sync + return async_wrapper + else: + + @wraps(func) + def sync_wrapper(*args, **kwargs): + _check_access(kwargs) + return func(*args, **kwargs) - return wrapper + return sync_wrapper return decorator diff --git a/src/lib/duckdb_utils.py b/src/lib/duckdb_utils.py index 10a80e8..d1cb076 100644 --- a/src/lib/duckdb_utils.py +++ b/src/lib/duckdb_utils.py @@ -14,11 +14,39 @@ """Utility functions for working with DuckDB databases.""" import os +from contextlib import contextmanager import duckdb from openrelik_ai_common.providers import manager +@contextmanager +def _duckdb_sqlite_connection(file_path: str): + """Context manager that yields a DuckDB connection with a SQLite file attached. + + Args: + file_path: Path to the SQLite database file. + + Yields: + duckdb.DuckDBPyConnection: A configured DuckDB connection with the SQLite file attached. + """ + SQLITE_EXTENSION_PATH = "/app/openrelik/sqlite_scanner.duckdb_extension" + db_conn = duckdb.connect() + try: + if os.path.exists(SQLITE_EXTENSION_PATH): + db_conn.execute(f"INSTALL '{SQLITE_EXTENSION_PATH}'; LOAD '{SQLITE_EXTENSION_PATH}';") + else: + db_conn.execute("INSTALL sqlite; LOAD sqlite;") + + db_conn.execute("SET enable_external_access = false;") + db_conn.execute(f"ATTACH '{file_path}' AS sqlite_db (TYPE SQLITE, READ_ONLY TRUE);") + db_conn.execute("USE sqlite_db;") + + yield db_conn + finally: + db_conn.close() + + def is_sql_file(magic_text: str) -> bool: """Check if file is a valid SQL database file based on magic text. @@ -50,60 +78,28 @@ def get_tables_schemas(file: object) -> dict: file: A SQL database file. Returns: - dict: A dictionary where keys are table names and values are dictionaries of column names and their + dict: A dictionary where keys are table names and values are dictionaries of column names and their types. """ if not is_sql_file(file.magic_text): return {} - - # Path to the local DuckDB sqlite scanner extension - duckdb_extension_path = "/app/openrelik/sqlite_scanner.duckdb_extension" - - # Try to install the DuckDB sqlite scanner locally, required to query sqlite files through DuckDB. - # If the extension is not found, the extension is downloaded automatically from the DuckDB repository. - if os.path.exists(duckdb_extension_path): - try: - db_conn = duckdb.connect() - db_extensions_query = "INSTALL '{0:s}'; LOAD '{0:s}';".format(duckdb_extension_path) - db_conn.execute(db_extensions_query) - except Exception as e: - raise RuntimeError(e) - finally: - db_conn.close() - - # Set read_only to True to prevent any modifications to the database. - db_conn = duckdb.connect(file.path, read_only=True) - - # Disable external access to prevent that the query accesses files outside the database. - db_conn.execute("SET enable_external_access = false;") - - # Query to get all tables and their schemas using DuckDB's information schema. - tables_schemas_query = """ - SELECT t.table_name, c.column_name, c.data_type - FROM information_schema.tables t - JOIN information_schema.columns c - ON t.table_schema = c.table_schema - AND t.table_name = c.table_name - WHERE t.table_type = 'BASE TABLE' - ORDER BY t.table_schema, t.table_name, c.ordinal_position; - """ - try: - results = db_conn.execute(tables_schemas_query).fetchall() - tables_schemas = {} - for table_name, column_name, data_type in results: - if table_name not in tables_schemas: - tables_schemas[table_name] = {} - tables_schemas[table_name][column_name] = data_type + with _duckdb_sqlite_connection(file.path) as db_conn: + tables = db_conn.execute( + "SELECT name FROM main.sqlite_master WHERE type = 'table' ORDER BY name;" + ).fetchall() + return { + table_name: { + col[1]: col[2] + for col in db_conn.execute(f"PRAGMA table_info('{table_name}');").fetchall() + } + for (table_name,) in tables + } except Exception as e: raise RuntimeError(e) - finally: - db_conn.close() - - return tables_schemas def run_query(file: object, sql_query: str) -> list[dict]: - """Run a SQL query on a DuckDB database file and return a list of dictionaries. + """Run a SQL query on a SQLite database file and return a list of dictionaries. Args: file: A SQL database file. @@ -114,84 +110,136 @@ def run_query(file: object, sql_query: str) -> list[dict]: """ if not is_sql_file(file.magic_text): raise RuntimeError("File is not a supported SQL format.") - - # Set read_only to True to prevent any modifications to the database. - db_conn = duckdb.connect(file.path, read_only=True) - - # Disable external access to prevent that the query accesses files outside the database. - db_conn.execute("SET enable_external_access = false;") - try: - # Execute the query and get both the data and column names - cursor = db_conn.execute(sql_query) - columns = [desc[0] for desc in cursor.description] - results = cursor.fetchall() - - # Convert bytes to string in results more efficiently - results = [ - tuple(str(item) if isinstance(item, bytes) else item for item in row) for row in results - ] - - # Manually create a list of dictionaries - list_of_dicts = [dict(zip(columns, row)) for row in results] - + with _duckdb_sqlite_connection(file.path) as db_conn: + cursor = db_conn.execute(sql_query) + columns = [desc[0] for desc in cursor.description] + results = [ + tuple(str(item) if isinstance(item, bytes) else item for item in row) + for row in cursor.fetchall() + ] + return [dict(zip(columns, row)) for row in results] except Exception as e: raise RuntimeError(e) - finally: - db_conn.close() - - return list_of_dicts def generate_sql_query( - llm_provider: str, llm_model: str, tables_schemas: str, user_request: str + llm_provider: str, + llm_model: str, + tables_schemas: str, + user_request: str, + file: object, + max_retries: int = 3, ) -> str: - """Generate a SQL query based on the database schema and user request. + """Generate a valid SQL query from a natural language request, retrying on failure. + + It uses the LLM to generate a query and then validates it by running it against the database. + If the query fails, it uses the error message to generate a new query, letting the LLM correct + its mistakes. Args: llm_provider (str): The name of the LLM provider to use. llm_model (str): The name of the LLM model to use. tables_schemas (str): The database schema information. - user_request (str): The user's request or question. + user_request (str): The natural language request or question. + file (object): The SQL database file used to validate the query. + max_retries (int): Maximum number of generation attempts. Defaults to 3. Returns: - str: The generated SQL query. + str: A validated SQL query that executed without errors. + + Raises: + RuntimeError: If a valid query cannot be generated after max_retries attempts. """ SYSTEM_INSTRUCTION = """ - Given an input question, create a syntactically correct DuckDB SQL query to - run to help find the answer. Unless the user specifies in his question a - specific number of examples they wish to obtain, always limit your query to - at most 10 results. You can order the results by a relevant column to - return the most interesting examples in the database. - - Never query for all the columns from a specific table, only ask for a the - few relevant columns given the question. - - Pay attention to use only the column names that you can see in the schema - description. Be careful to not query for columns that do not exist. Also, - pay attention to which column is in which table. - - Output format: - * **ALWAYS** use the LIMIT clause to limit the number of results to at most 10. - * Return the query in a single line as a string, without any additional text or formatting. - * Do not include any comments or explanations in the output. - * Do not use any aliases in the query. - * Do not return markdown or any other formatting. + You are an expert SQL query generator for DuckDB querying attached SQLite databases. + Your sole job is to output a single, valid DuckDB SQL query — nothing else. + + STRICT OUTPUT RULES: + - Output ONLY the raw SQL query, on a single line + - No markdown, no backticks, no code fences, no comments, no explanations + - **ALWAYS** use the LIMIT clause to limit the number of results to at most 10 + + QUERY RULES: + - Only use table and column names that appear exactly in the schema provided — do not invent or guess names + - Always qualify ambiguous column names with their table name (e.g. table_name.column_name) + - Use LIMIT 10 unless the user explicitly requests a different number of results + - Never use SELECT * — only select columns relevant to the question + - For text matching, use LOWER() on both sides and LIKE for case-insensitive search (e.g. LOWER(col) LIKE LOWER('%value%')) + - When joining tables, identify the foreign key relationships from the schema before writing the JOIN + - Always format timestamp outputs as ISO 8601 strings with UTC timezone using strftime. When converting integer timestamps, wrap the conversion in AT TIME ZONE 'UTC': + strftime( + CASE + WHEN column > 1000000000000000 THEN to_timestamp(column / 1000000.0) AT TIME ZONE 'UTC' + WHEN column > 1000000000000 THEN to_timestamp(column / 1000.0) AT TIME ZONE 'UTC' + ELSE to_timestamp(column) AT TIME ZONE 'UTC' + END, + '%Y-%m-%dT%H:%M:%SZ' + ) + - If a column stores dates as text, cast with CAST(column AS DATE) + - If the question cannot be answered from the schema, output exactly: SELECT 'insufficient schema' AS error """ PROMPT = """ - ### Tables and Schemas: + ### Database Schema + The following tables and columns are available. Use ONLY these — do not reference any others. + {tables_schemas} - ### User Input: + ### Task + Write a DuckDB SQL query that answers the following: {user_request} + + Remember: output only the raw SQL query, nothing else. + """ + + RETRY_PROMPT = """ + ### Database Schema + {tables_schemas} + + ### User Input + {user_request} + + ### Previous Attempt + Query: {previous_query} + Error: {error} + + The query above failed with the error shown. Fix the query to resolve this error. + Remember: output only the raw SQL query, nothing else. """ provider = manager.LLMManager().get_provider(llm_provider) llm = provider(model_name=llm_model, system_instructions=SYSTEM_INSTRUCTION) - query = llm.generate( - prompt=PROMPT.format(tables_schemas=tables_schemas, user_request=user_request) - ) - return query + query = None + last_error = None + + for attempt in range(1, max_retries + 1): + if attempt == 1: + query = llm.generate( + prompt=PROMPT.format( + tables_schemas=tables_schemas, + user_request=user_request, + ) + ).strip() + else: + query = llm.generate( + prompt=RETRY_PROMPT.format( + tables_schemas=tables_schemas, + user_request=user_request, + previous_query=query, + error=last_error, + ) + ).strip() + + try: + run_query(file, query) + return query + except RuntimeError as e: + last_error = str(e) + if attempt == max_retries: + raise RuntimeError( + f"Query generation failed after {max_retries} attempts. " + f"Last error: {last_error}" + )