diff --git a/sqlit/domains/connections/providers/teradata/adapter.py b/sqlit/domains/connections/providers/teradata/adapter.py index 29fe886..abb03fc 100644 --- a/sqlit/domains/connections/providers/teradata/adapter.py +++ b/sqlit/domains/connections/providers/teradata/adapter.py @@ -8,7 +8,6 @@ ColumnInfo, CursorBasedAdapter, IndexInfo, - SequenceInfo, TableInfo, TriggerInfo, ) @@ -49,10 +48,6 @@ def supports_cross_database_queries(self) -> bool: def supports_stored_procedures(self) -> bool: return True - @property - def supports_sequences(self) -> bool: - return True - def apply_database_override(self, config: ConnectionConfig, database: str) -> ConnectionConfig: """Apply a default database for unqualified queries.""" if not database: @@ -91,8 +86,9 @@ def connect(self, config: ConnectionConfig) -> Any: def get_databases(self, conn: Any) -> list[str]: cursor = conn.cursor() cursor.execute( + "lock row for access " "SELECT DatabaseName FROM DBC.DatabasesV " - "WHERE DatabaseKind IN ('D', 'U') " + "WHERE dbkind IN ('D', 'U') " "ORDER BY DatabaseName" ) return [row[0] for row in cursor.fetchall()] @@ -101,6 +97,7 @@ def get_tables(self, conn: Any, database: str | None = None) -> list[TableInfo]: cursor = conn.cursor() if database: cursor.execute( + "lock row for access " "SELECT DatabaseName, TableName FROM DBC.TablesV " "WHERE TableKind = 'T' AND DatabaseName = ? " "ORDER BY TableName", @@ -108,6 +105,7 @@ def get_tables(self, conn: Any, database: str | None = None) -> list[TableInfo]: ) else: cursor.execute( + "lock row for access " "SELECT DatabaseName, TableName FROM DBC.TablesV " "WHERE TableKind = 'T' " "ORDER BY DatabaseName, TableName" @@ -118,6 +116,7 @@ def get_views(self, conn: Any, database: str | None = None) -> list[TableInfo]: cursor = conn.cursor() if database: cursor.execute( + "lock row for access " "SELECT DatabaseName, TableName FROM DBC.TablesV " "WHERE TableKind = 'V' AND DatabaseName = ? " "ORDER BY TableName", @@ -125,6 +124,7 @@ def get_views(self, conn: Any, database: str | None = None) -> list[TableInfo]: ) else: cursor.execute( + "lock row for access " "SELECT DatabaseName, TableName FROM DBC.TablesV " "WHERE TableKind = 'V' " "ORDER BY DatabaseName, TableName" @@ -142,14 +142,13 @@ def get_columns( pk_columns: set[str] = set() try: cursor.execute( - "SELECT ic.ColumnName " - "FROM DBC.IndexConstraintsV c " - "JOIN DBC.IndexColumnsV ic " - " ON c.DatabaseName = ic.DatabaseName " - " AND c.TableName = ic.TableName " - " AND c.IndexNumber = ic.IndexNumber " - "WHERE c.ConstraintType = 'P' " - "AND c.DatabaseName = ? AND c.TableName = ?", + "lock row for access " + "select " + "COLUMNNAME " + "from DBC.INDICESV " + "where DATABASENAME = ? " + "and TABLENAME = ? " + "and INDEXTYPE = 'P' ", (schema_name, table), ) pk_columns = {row[0] for row in cursor.fetchall()} @@ -157,6 +156,7 @@ def get_columns( pk_columns = set() cursor.execute( + "lock row for access " "SELECT ColumnName, ColumnType FROM DBC.ColumnsV " "WHERE DatabaseName = ? AND TableName = ? " "ORDER BY ColumnId", @@ -171,6 +171,7 @@ def get_procedures(self, conn: Any, database: str | None = None) -> list[str]: cursor = conn.cursor() if database: cursor.execute( + "lock row for access " "SELECT TableName FROM DBC.TablesV " "WHERE TableKind = 'P' AND DatabaseName = ? " "ORDER BY TableName", @@ -178,6 +179,7 @@ def get_procedures(self, conn: Any, database: str | None = None) -> list[str]: ) else: cursor.execute( + "lock row for access " "SELECT TableName FROM DBC.TablesV " "WHERE TableKind = 'P' " "ORDER BY TableName" @@ -188,6 +190,7 @@ def get_indexes(self, conn: Any, database: str | None = None) -> list[IndexInfo] cursor = conn.cursor() if database: cursor.execute( + "lock row for access " "SELECT IndexName, TableName, UniqueFlag FROM DBC.IndicesV " "WHERE DatabaseName = ? " "ORDER BY TableName, IndexName", @@ -195,6 +198,7 @@ def get_indexes(self, conn: Any, database: str | None = None) -> list[IndexInfo] ) else: cursor.execute( + "lock row for access " "SELECT IndexName, TableName, UniqueFlag FROM DBC.IndicesV " "ORDER BY DatabaseName, TableName, IndexName" ) @@ -207,6 +211,7 @@ def get_triggers(self, conn: Any, database: str | None = None) -> list[TriggerIn cursor = conn.cursor() if database: cursor.execute( + "lock row for access " "SELECT TriggerName, TableName FROM DBC.TriggersV " "WHERE DatabaseName = ? " "ORDER BY TableName, TriggerName", @@ -214,26 +219,18 @@ def get_triggers(self, conn: Any, database: str | None = None) -> list[TriggerIn ) else: cursor.execute( + "lock row for access " "SELECT TriggerName, TableName FROM DBC.TriggersV " "ORDER BY DatabaseName, TableName, TriggerName" ) return [TriggerInfo(name=row[0], table_name=row[1]) for row in cursor.fetchall()] - def get_sequences(self, conn: Any, database: str | None = None) -> list[SequenceInfo]: - cursor = conn.cursor() - if database: - cursor.execute( - "SELECT SequenceName FROM DBC.SequencesV " - "WHERE DatabaseName = ? " - "ORDER BY SequenceName", - (database,), - ) - else: - cursor.execute( - "SELECT SequenceName FROM DBC.SequencesV " - "ORDER BY DatabaseName, SequenceName" - ) - return [SequenceInfo(name=row[0]) for row in cursor.fetchall()] + def get_sequences(self, conn: Any, database: str | None = None) -> list[str]: + """Teradata does not support standalone sequences. + + Auto-increment behaviour is provided by IDENTITY columns instead. + """ + return [] def quote_identifier(self, name: str) -> str: escaped = name.replace('"', '""') @@ -242,5 +239,5 @@ def quote_identifier(self, name: str) -> str: def build_select_query(self, table: str, limit: int, database: str | None = None, schema: str | None = None) -> str: schema_name = schema or database if schema_name: - return f'SELECT TOP {limit} * FROM "{schema_name}"."{table}"' - return f'SELECT TOP {limit} * FROM "{table}"' + return f'lock row for access select top {limit} * from "{schema_name}"."{table}"' + return f'lock row for access select top {limit} * from "{table}"' diff --git a/sqlit/domains/query/app/query_service.py b/sqlit/domains/query/app/query_service.py index b458a5b..19131b3 100644 --- a/sqlit/domains/query/app/query_service.py +++ b/sqlit/domains/query/app/query_service.py @@ -68,9 +68,10 @@ class KeywordQueryAnalyzer: def classify(self, query: str) -> QueryKind: """Classify query based on keyword of the last statement. - For multi-statement queries like 'BEGIN; INSERT...; SELECT * FROM t;', - we check the last statement to determine if results should be returned. - Uses the same splitting logic as multi_statement.split_statements. + Enhanced for Teradata: + - Supports SEL (Teradata abbreviation for SELECT) + - Supports HELP statements + - Handles LOCKING ... SELECT patterns (common in Teradata) """ from sqlit.domains.query.editing.comments import ( is_comment_line, @@ -87,17 +88,22 @@ def classify(self, query: str) -> QueryKind: for stmt in reversed(statements): if is_comment_only_statement(stmt): continue - # Found a statement with actual SQL - get first non-comment line + # Get first non-comment line lines = [line.strip() for line in stmt.split("\n") if line.strip()] non_comment_lines = [line for line in lines if not is_comment_line(line)] if non_comment_lines: - first_line = non_comment_lines[0].upper() - first_word = first_line.split()[0] if first_line else "" + first_line_upper = non_comment_lines[0].upper() + + # Teradata-specific patterns (word-boundary aware) + if re.search(r"\b(SELECT|WITH|SHOW|DESCRIBE|EXPLAIN|PRAGMA|SEL|HELP)\b", first_line_upper): + return QueryKind.RETURNS_ROWS + + # Fallback to original first-word check + first_word = first_line_upper.split()[0] if first_line_upper else "" return QueryKind.RETURNS_ROWS if first_word in SELECT_KEYWORDS else QueryKind.NON_QUERY return QueryKind.NON_QUERY - class DialectQueryAnalyzer: def __init__(self, dialect: Any, fallback: QueryAnalyzer | None = None) -> None: self._dialect = dialect