Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import lombok.Getter;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.jdbi.v3.core.Jdbi;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
Expand Down Expand Up @@ -62,6 +61,18 @@ public String toString() {
private static final Pattern qualifiedNameMatcher =
Pattern.compile("^\"?[^\"]+\"?\\.\"?[^\"]+\"?\\.\"?[^\"]+\"?$");

/**
* Pattern to match fully-qualified table names (catalog.schema.table) in FROM and JOIN clauses.
* Matches unquoted identifiers that contain dots, capturing the three parts separately.
* Does not match already-quoted identifiers.
*/
private static final Pattern FROM_TABLE_PATTERN = Pattern.compile(
"(?i)(FROM|JOIN)\\s+(?!\")" + // FROM or JOIN keyword, not followed by quote
"([a-zA-Z_][a-zA-Z0-9_-]*)\\." + // catalog (unquoted identifier, may contain hyphens)
"([a-zA-Z_][a-zA-Z0-9_-]*)\\." + // schema (unquoted identifier, may contain hyphens)
"([a-zA-Z0-9_][a-zA-Z0-9_-]*)" // table (can start with digit, may contain hyphens)
);

private final Map<String, Set<String>> trinoSchemaCache;
private final Map<String, Set<String>> trinoCatalogCache;

Expand Down Expand Up @@ -128,9 +139,16 @@ public SQLFunction(MatchResult matchResult) {

}

//rewrites the query by replacing all instances of functionName(a_0, a_1)
//with a_argIndex
private String rewriteQuery(String query, String functionName, int argIndex) {

/**
rewrites the query by replacing all instances of functionName(a_0, a_1) with a_argIndex
*
* @param query
* @param functionName
* @param argIndex
* @return
*/
private String rewriteFunctionNameIndex(String query, String functionName, int argIndex) {
return biFunctionPattern.matcher(query)
.replaceAll(matchResult -> {
SQLFunction sf = new SQLFunction(matchResult);
Expand Down Expand Up @@ -242,8 +260,7 @@ public TableData search(
Map<String, String> extraCredentials,
DataModel dataModel
) {

String rewrittenQuery = rewriteQuery(query, "ga4gh_type", 0);
String rewrittenQuery = applyQueryRewrites(query);
TrinoDataPage response = client.query(rewrittenQuery, extraCredentials);
QueryJob queryJob = createQueryJob(response.id(), query, dataModel, response.nextUri());
return toTableData(response, queryJob, request);
Expand Down Expand Up @@ -505,10 +522,7 @@ private TablesList getTables(CatalogSchema current, CatalogSchema next, HttpServ
public TableData getTableData(String tableName, HttpServletRequest request, Map<String, String> extraCredentials) {
// Get table JSON schema from tables registry if one exists for this table (for tables from trino-public)
DataModel dataModel = getDataModelFromSupplier(tableName);
//Add quotes to tableName in the query. Table name can be of the format <catalog_name>.<datasource_name>.tableName
//So if the tableName has two dots in it, then everything after the third dot, should come within quotes.
String validTableName = getTableNameInCorrectFormat(tableName);
TableData tableData = search("SELECT * FROM " + validTableName, request, extraCredentials, dataModel);
TableData tableData = search("SELECT * FROM " + tableName, request, extraCredentials, dataModel);

// Populate the dataModel only if there is tableData
if (!tableData.getData().isEmpty()) {
Expand Down Expand Up @@ -542,10 +556,7 @@ public TableInfo getTableInfo(
log.info("Data model supplier returned null for table: '{}'. Falling back to trino query", tableName);
// since the data model was not found in the supplier, perform a more expensive query to fallback to trino and fetch a single
// row of data.
//Add quotes to tableName in the query. Table name can be of the format <catalog_name>.<datasource_name>.tableName
//So if the tableName has two dots in it, then everything after the third dot, should come within quotes.
String validTableName = getTableNameInCorrectFormat(tableName);
TableData tableData = searchAll("SELECT * FROM " + validTableName + " LIMIT 1", request, extraCredentials, dataModel);
TableData tableData = searchAll("SELECT * FROM " + tableName + " LIMIT 1", request, extraCredentials, dataModel);
log.info("Data model is empty in tables registry for table {}.", tableName);
dataModel = tableData.getDataModel();
dataModel.setId(getDataModelId(tableName, request));
Expand All @@ -554,26 +565,12 @@ public TableInfo getTableInfo(
return new TableInfo(tableName, dataModel.getDescription(), dataModel, null);
}

private String getTableNameInCorrectFormat(String tableName) {
String validTableName = tableName;
if (StringUtils.countMatches(tableName, ".") >= 2) {

// If there are two or more dots, then quote the entire part after the second dot(assuming that this will be the table name).
int secondIndex = StringUtils.ordinalIndexOf(tableName, ".", 2);

//Everything before second catalog name will be catalog(+schema)
String catalogAndSchema = tableName.substring(0, secondIndex + 1);
String table = tableName.substring(secondIndex + 1);

//If the table name doesn't starts with or ends with quotes then add quotes
if (!table.startsWith("\"") || !table.endsWith("\"")) {
table = "\"" + table + "\"";
}
validTableName = catalogAndSchema + table;
} else {
log.warn("Table name {} has less than 2 dots in it.", tableName);
}
return validTableName;
/**
* Quotes the given string so it can be used in a SQL query as an identifier
* (for example, a catalog, schema, table, or column name).
*/
private static String quoteIdentifier(String identifier) {
return "\"" + identifier.replace("\"", "\"\"") + "\"";
}

private boolean isValidTrinoName(String tableName) {
Expand Down Expand Up @@ -1059,4 +1056,44 @@ private QueryJob getQueryJob(String id) {
.orElseThrow(() -> new InvalidQueryJobException(id));
}

/**
* Applies all query rewrites in sequence.
* Add new rewrite steps here to keep the transformation pipeline in one place
* @param query input query
*/
private String applyQueryRewrites(String query) {
String result = query;
result = rewriteFunctionNameIndex(result, "ga4gh_type", 0);
result = quoteTableNamesInQuery(result);
return result;
}

/**
* Quotes fully-qualified table names in FROM and JOIN clauses.
* Converts "FROM catalog.schema.table" to "FROM \"catalog\".\"schema\".\"table\"".
* This prevents SQL parsing errors when table names start with numbers (e.g., "03_chris").
* @param query input query
*/
private String quoteTableNamesInQuery(String query) {
Matcher matcher = FROM_TABLE_PATTERN.matcher(query);
StringBuilder result = new StringBuilder();
while (matcher.find()) {
String keyword = matcher.group(1);
String catalog = matcher.group(2);
String schema = matcher.group(3);
String table = matcher.group(4);
String replacement = keyword + " " +
quoteIdentifier(catalog) + "." +
quoteIdentifier(schema) + "." +
quoteIdentifier(table);
matcher.appendReplacement(result, Matcher.quoteReplacement(replacement));
}
matcher.appendTail(result);
String transformedQuery = result.toString();
if (!transformedQuery.equals(query)) {
log.debug("Quoted table names in query: {} -> {}", query, transformedQuery);
}
return transformedQuery;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -1103,4 +1103,12 @@ public void getTablesByCatalogAndSchema_singleCatalog_lastSchema() {
assertNull(tablesList.getErrors());
}

@Test
public void quoteTableNamesInQuery_should_quoteTableNameStartingWithDigit() {
// Table names starting with a digit (e.g., "03_test") cause Trino parsing errors if unquoted
String query = "SELECT * FROM bigquery_catalog.some_dataset.03_test_table WHERE id = 1";
String result = ReflectionTestUtils.invokeMethod(dataConnectAdapter, "quoteTableNamesInQuery", query);
assertThat(result, equalTo("SELECT * FROM \"bigquery_catalog\".\"some_dataset\".\"03_test_table\" WHERE id = 1"));
}

}