diff --git a/src/main/java/com/dnastack/ga4gh/dataconnect/adapter/trino/TrinoDataConnectAdapter.java b/src/main/java/com/dnastack/ga4gh/dataconnect/adapter/trino/TrinoDataConnectAdapter.java index 0253081..b01c2d7 100644 --- a/src/main/java/com/dnastack/ga4gh/dataconnect/adapter/trino/TrinoDataConnectAdapter.java +++ b/src/main/java/com/dnastack/ga4gh/dataconnect/adapter/trino/TrinoDataConnectAdapter.java @@ -19,7 +19,6 @@ import lombok.Getter; import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; -import org.apache.commons.lang3.StringUtils; import org.jdbi.v3.core.Jdbi; import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; @@ -62,6 +61,18 @@ public String toString() { private static final Pattern qualifiedNameMatcher = Pattern.compile("^\"?[^\"]+\"?\\.\"?[^\"]+\"?\\.\"?[^\"]+\"?$"); + /** + * Pattern to match fully-qualified table names (catalog.schema.table) in FROM and JOIN clauses. + * Matches unquoted identifiers that contain dots, capturing the three parts separately. + * Does not match already-quoted identifiers. + */ + private static final Pattern FROM_TABLE_PATTERN = Pattern.compile( + "(?i)(FROM|JOIN)\\s+(?!\")" + // FROM or JOIN keyword, not followed by quote + "([a-zA-Z_][a-zA-Z0-9_-]*)\\." + // catalog (unquoted identifier, may contain hyphens) + "([a-zA-Z_][a-zA-Z0-9_-]*)\\." + // schema (unquoted identifier, may contain hyphens) + "([a-zA-Z0-9_][a-zA-Z0-9_-]*)" // table (can start with digit, may contain hyphens) + ); + private final Map> trinoSchemaCache; private final Map> trinoCatalogCache; @@ -128,9 +139,16 @@ public SQLFunction(MatchResult matchResult) { } - //rewrites the query by replacing all instances of functionName(a_0, a_1) - //with a_argIndex - private String rewriteQuery(String query, String functionName, int argIndex) { + + /** + rewrites the query by replacing all instances of functionName(a_0, a_1) with a_argIndex + * + * @param query + * @param functionName + * @param argIndex + * @return + */ + private String rewriteFunctionNameIndex(String query, String functionName, int argIndex) { return biFunctionPattern.matcher(query) .replaceAll(matchResult -> { SQLFunction sf = new SQLFunction(matchResult); @@ -242,8 +260,7 @@ public TableData search( Map extraCredentials, DataModel dataModel ) { - - String rewrittenQuery = rewriteQuery(query, "ga4gh_type", 0); + String rewrittenQuery = applyQueryRewrites(query); TrinoDataPage response = client.query(rewrittenQuery, extraCredentials); QueryJob queryJob = createQueryJob(response.id(), query, dataModel, response.nextUri()); return toTableData(response, queryJob, request); @@ -505,10 +522,7 @@ private TablesList getTables(CatalogSchema current, CatalogSchema next, HttpServ public TableData getTableData(String tableName, HttpServletRequest request, Map extraCredentials) { // Get table JSON schema from tables registry if one exists for this table (for tables from trino-public) DataModel dataModel = getDataModelFromSupplier(tableName); - //Add quotes to tableName in the query. Table name can be of the format ..tableName - //So if the tableName has two dots in it, then everything after the third dot, should come within quotes. - String validTableName = getTableNameInCorrectFormat(tableName); - TableData tableData = search("SELECT * FROM " + validTableName, request, extraCredentials, dataModel); + TableData tableData = search("SELECT * FROM " + tableName, request, extraCredentials, dataModel); // Populate the dataModel only if there is tableData if (!tableData.getData().isEmpty()) { @@ -542,10 +556,7 @@ public TableInfo getTableInfo( log.info("Data model supplier returned null for table: '{}'. Falling back to trino query", tableName); // since the data model was not found in the supplier, perform a more expensive query to fallback to trino and fetch a single // row of data. - //Add quotes to tableName in the query. Table name can be of the format ..tableName - //So if the tableName has two dots in it, then everything after the third dot, should come within quotes. - String validTableName = getTableNameInCorrectFormat(tableName); - TableData tableData = searchAll("SELECT * FROM " + validTableName + " LIMIT 1", request, extraCredentials, dataModel); + TableData tableData = searchAll("SELECT * FROM " + tableName + " LIMIT 1", request, extraCredentials, dataModel); log.info("Data model is empty in tables registry for table {}.", tableName); dataModel = tableData.getDataModel(); dataModel.setId(getDataModelId(tableName, request)); @@ -554,26 +565,12 @@ public TableInfo getTableInfo( return new TableInfo(tableName, dataModel.getDescription(), dataModel, null); } - private String getTableNameInCorrectFormat(String tableName) { - String validTableName = tableName; - if (StringUtils.countMatches(tableName, ".") >= 2) { - - // If there are two or more dots, then quote the entire part after the second dot(assuming that this will be the table name). - int secondIndex = StringUtils.ordinalIndexOf(tableName, ".", 2); - - //Everything before second catalog name will be catalog(+schema) - String catalogAndSchema = tableName.substring(0, secondIndex + 1); - String table = tableName.substring(secondIndex + 1); - - //If the table name doesn't starts with or ends with quotes then add quotes - if (!table.startsWith("\"") || !table.endsWith("\"")) { - table = "\"" + table + "\""; - } - validTableName = catalogAndSchema + table; - } else { - log.warn("Table name {} has less than 2 dots in it.", tableName); - } - return validTableName; + /** + * Quotes the given string so it can be used in a SQL query as an identifier + * (for example, a catalog, schema, table, or column name). + */ + private static String quoteIdentifier(String identifier) { + return "\"" + identifier.replace("\"", "\"\"") + "\""; } private boolean isValidTrinoName(String tableName) { @@ -1059,4 +1056,44 @@ private QueryJob getQueryJob(String id) { .orElseThrow(() -> new InvalidQueryJobException(id)); } + /** + * Applies all query rewrites in sequence. + * Add new rewrite steps here to keep the transformation pipeline in one place + * @param query input query + */ + private String applyQueryRewrites(String query) { + String result = query; + result = rewriteFunctionNameIndex(result, "ga4gh_type", 0); + result = quoteTableNamesInQuery(result); + return result; + } + + /** + * Quotes fully-qualified table names in FROM and JOIN clauses. + * Converts "FROM catalog.schema.table" to "FROM \"catalog\".\"schema\".\"table\"". + * This prevents SQL parsing errors when table names start with numbers (e.g., "03_chris"). + * @param query input query + */ + private String quoteTableNamesInQuery(String query) { + Matcher matcher = FROM_TABLE_PATTERN.matcher(query); + StringBuilder result = new StringBuilder(); + while (matcher.find()) { + String keyword = matcher.group(1); + String catalog = matcher.group(2); + String schema = matcher.group(3); + String table = matcher.group(4); + String replacement = keyword + " " + + quoteIdentifier(catalog) + "." + + quoteIdentifier(schema) + "." + + quoteIdentifier(table); + matcher.appendReplacement(result, Matcher.quoteReplacement(replacement)); + } + matcher.appendTail(result); + String transformedQuery = result.toString(); + if (!transformedQuery.equals(query)) { + log.debug("Quoted table names in query: {} -> {}", query, transformedQuery); + } + return transformedQuery; + } + } diff --git a/src/test/java/com/dnastack/ga4gh/dataconnect/adapter/trino/TrinoDataConnectAdapterTest.java b/src/test/java/com/dnastack/ga4gh/dataconnect/adapter/trino/TrinoDataConnectAdapterTest.java index ab0fc9b..0deccac 100644 --- a/src/test/java/com/dnastack/ga4gh/dataconnect/adapter/trino/TrinoDataConnectAdapterTest.java +++ b/src/test/java/com/dnastack/ga4gh/dataconnect/adapter/trino/TrinoDataConnectAdapterTest.java @@ -1103,4 +1103,12 @@ public void getTablesByCatalogAndSchema_singleCatalog_lastSchema() { assertNull(tablesList.getErrors()); } + @Test + public void quoteTableNamesInQuery_should_quoteTableNameStartingWithDigit() { + // Table names starting with a digit (e.g., "03_test") cause Trino parsing errors if unquoted + String query = "SELECT * FROM bigquery_catalog.some_dataset.03_test_table WHERE id = 1"; + String result = ReflectionTestUtils.invokeMethod(dataConnectAdapter, "quoteTableNamesInQuery", query); + assertThat(result, equalTo("SELECT * FROM \"bigquery_catalog\".\"some_dataset\".\"03_test_table\" WHERE id = 1")); + } + }