|
5 | 5 | package io.airbyte.integrations.destination.snowflake.client |
6 | 6 |
|
7 | 7 | import edu.umd.cs.findbugs.annotations.SuppressFBWarnings |
| 8 | +import io.airbyte.cdk.ConfigErrorException |
8 | 9 | import io.airbyte.cdk.load.command.DestinationStream |
9 | 10 | import io.airbyte.cdk.load.component.TableOperationsClient |
10 | 11 | import io.airbyte.cdk.load.component.TableSchemaEvolutionClient |
@@ -64,39 +65,43 @@ class SnowflakeAirbyteClient( |
64 | 65 | } |
65 | 66 |
|
66 | 67 | override suspend fun createNamespace(namespace: String) { |
67 | | - // Check if the schema exists first |
68 | | - val schemaExistsResult = |
69 | | - dataSource.connection.use { connection -> |
70 | | - val databaseName = snowflakeConfiguration.database.toSnowflakeCompatibleName() |
71 | | - val statement = |
72 | | - connection.prepareStatement( |
73 | | - """ |
74 | | - SELECT COUNT(*) > 0 AS SCHEMA_EXISTS |
75 | | - FROM "$databaseName".INFORMATION_SCHEMA.SCHEMATA |
76 | | - WHERE SCHEMA_NAME = ? |
77 | | - """.andLog() |
78 | | - ) |
79 | | - |
80 | | - // When querying information_schema, snowflake needs the "true" schema name, |
81 | | - // so we unescape it here. |
82 | | - val unescapedNamespace = namespace.replace("\"\"", "\"") |
83 | | - statement.setString(1, unescapedNamespace) |
84 | | - |
85 | | - statement.use { |
86 | | - val resultSet = it.executeQuery() |
87 | | - resultSet.use { rs -> |
88 | | - if (rs.next()) { |
89 | | - rs.getBoolean("SCHEMA_EXISTS") |
90 | | - } else { |
91 | | - false |
| 68 | + try { |
| 69 | + // Check if the schema exists first |
| 70 | + val schemaExistsResult = |
| 71 | + dataSource.connection.use { connection -> |
| 72 | + val databaseName = snowflakeConfiguration.database.toSnowflakeCompatibleName() |
| 73 | + val statement = |
| 74 | + connection.prepareStatement( |
| 75 | + """ |
| 76 | + SELECT COUNT(*) > 0 AS SCHEMA_EXISTS |
| 77 | + FROM "$databaseName".INFORMATION_SCHEMA.SCHEMATA |
| 78 | + WHERE SCHEMA_NAME = ? |
| 79 | + """.andLog() |
| 80 | + ) |
| 81 | + |
| 82 | + // When querying information_schema, snowflake needs the "true" schema name, |
| 83 | + // so we unescape it here. |
| 84 | + val unescapedNamespace = namespace.replace("\"\"", "\"") |
| 85 | + statement.setString(1, unescapedNamespace) |
| 86 | + |
| 87 | + statement.use { |
| 88 | + val resultSet = it.executeQuery() |
| 89 | + resultSet.use { rs -> |
| 90 | + if (rs.next()) { |
| 91 | + rs.getBoolean("SCHEMA_EXISTS") |
| 92 | + } else { |
| 93 | + false |
| 94 | + } |
92 | 95 | } |
93 | 96 | } |
94 | 97 | } |
95 | | - } |
96 | 98 |
|
97 | | - if (!schemaExistsResult) { |
98 | | - // Create the schema only if it doesn't exist |
99 | | - execute(sqlGenerator.createNamespace(namespace)) |
| 99 | + if (!schemaExistsResult) { |
| 100 | + // Create the schema only if it doesn't exist |
| 101 | + execute(sqlGenerator.createNamespace(namespace)) |
| 102 | + } |
| 103 | + } catch (e: SnowflakeSQLException) { |
| 104 | + handleSnowflakePermissionError(e) |
100 | 105 | } |
101 | 106 | } |
102 | 107 |
|
@@ -194,28 +199,35 @@ class SnowflakeAirbyteClient( |
194 | 199 | } |
195 | 200 |
|
196 | 201 | internal fun getColumnsFromDb(tableName: TableName): Set<ColumnDefinition> { |
197 | | - val sql = |
198 | | - sqlGenerator.describeTable(schemaName = tableName.namespace, tableName = tableName.name) |
199 | | - dataSource.connection.use { connection -> |
200 | | - val statement = connection.createStatement() |
201 | | - return statement.use { |
202 | | - val rs: ResultSet = it.executeQuery(sql) |
203 | | - val columnsInDb: MutableSet<ColumnDefinition> = mutableSetOf() |
204 | | - |
205 | | - while (rs.next()) { |
206 | | - val columnName = escapeJsonIdentifier(rs.getString("name")) |
207 | | - |
208 | | - // Filter out airbyte columns |
209 | | - if (airbyteColumnNames.contains(columnName)) { |
210 | | - continue |
| 202 | + try { |
| 203 | + val sql = |
| 204 | + sqlGenerator.describeTable( |
| 205 | + schemaName = tableName.namespace, |
| 206 | + tableName = tableName.name |
| 207 | + ) |
| 208 | + dataSource.connection.use { connection -> |
| 209 | + val statement = connection.createStatement() |
| 210 | + return statement.use { |
| 211 | + val rs: ResultSet = it.executeQuery(sql) |
| 212 | + val columnsInDb: MutableSet<ColumnDefinition> = mutableSetOf() |
| 213 | + |
| 214 | + while (rs.next()) { |
| 215 | + val columnName = escapeJsonIdentifier(rs.getString("name")) |
| 216 | + |
| 217 | + // Filter out airbyte columns |
| 218 | + if (airbyteColumnNames.contains(columnName)) { |
| 219 | + continue |
| 220 | + } |
| 221 | + val dataType = rs.getString("type").takeWhile { char -> char != '(' } |
| 222 | + |
| 223 | + columnsInDb.add(ColumnDefinition(columnName, dataType, false)) |
211 | 224 | } |
212 | | - val dataType = rs.getString("type").takeWhile { char -> char != '(' } |
213 | 225 |
|
214 | | - columnsInDb.add(ColumnDefinition(columnName, dataType, false)) |
| 226 | + columnsInDb |
215 | 227 | } |
216 | | - |
217 | | - columnsInDb |
218 | 228 | } |
| 229 | + } catch (e: SnowflakeSQLException) { |
| 230 | + handleSnowflakePermissionError(e) |
219 | 231 | } |
220 | 232 | } |
221 | 233 |
|
@@ -302,32 +314,60 @@ class SnowflakeAirbyteClient( |
302 | 314 | } |
303 | 315 |
|
304 | 316 | fun describeTable(tableName: TableName): LinkedHashMap<String, String> = |
305 | | - dataSource.connection.use { connection -> |
306 | | - val statement = connection.createStatement() |
307 | | - return statement.use { |
308 | | - val resultSet = it.executeQuery(sqlGenerator.showColumns(tableName)) |
309 | | - val columns = linkedMapOf<String, String>() |
310 | | - while (resultSet.next()) { |
311 | | - val columnName = resultSet.getString(DESCRIBE_TABLE_COLUMN_NAME_FIELD) |
312 | | - // this is... incredibly annoying. The resultset will give us a string like |
313 | | - // `{"type":"VARIANT","nullable":true}`. |
314 | | - // So we need to parse that JSON, and then fetch the actual thing we care about. |
315 | | - // Also, some of the type names aren't the ones we're familiar with (e.g. |
316 | | - // `FIXED` for numeric columns), |
317 | | - // so the output here is not particularly ergonomic. |
318 | | - val columnType = |
319 | | - resultSet |
320 | | - .getString(DESCRIBE_TABLE_COLUMN_TYPE_FIELD) |
321 | | - .deserializeToNode()["type"] |
322 | | - .asText() |
323 | | - columns[columnName] = columnType |
| 317 | + try { |
| 318 | + dataSource.connection.use { connection -> |
| 319 | + val statement = connection.createStatement() |
| 320 | + return statement.use { |
| 321 | + val resultSet = it.executeQuery(sqlGenerator.showColumns(tableName)) |
| 322 | + val columns = linkedMapOf<String, String>() |
| 323 | + while (resultSet.next()) { |
| 324 | + val columnName = resultSet.getString(DESCRIBE_TABLE_COLUMN_NAME_FIELD) |
| 325 | + // this is... incredibly annoying. The resultset will give us a string like |
| 326 | + // `{"type":"VARIANT","nullable":true}`. |
| 327 | + // So we need to parse that JSON, and then fetch the actual thing we care |
| 328 | + // about. |
| 329 | + // Also, some of the type names aren't the ones we're familiar with (e.g. |
| 330 | + // `FIXED` for numeric columns), |
| 331 | + // so the output here is not particularly ergonomic. |
| 332 | + val columnType = |
| 333 | + resultSet |
| 334 | + .getString(DESCRIBE_TABLE_COLUMN_TYPE_FIELD) |
| 335 | + .deserializeToNode()["type"] |
| 336 | + .asText() |
| 337 | + columns[columnName] = columnType |
| 338 | + } |
| 339 | + columns |
324 | 340 | } |
325 | | - columns |
326 | 341 | } |
| 342 | + } catch (e: SnowflakeSQLException) { |
| 343 | + handleSnowflakePermissionError(e) |
327 | 344 | } |
328 | 345 |
|
329 | 346 | internal fun execute(query: String) = |
330 | | - dataSource.connection.use { connection -> |
331 | | - connection.createStatement().use { it.executeQuery(query) } |
| 347 | + try { |
| 348 | + dataSource.connection.use { connection -> |
| 349 | + connection.createStatement().use { it.executeQuery(query) } |
| 350 | + } |
| 351 | + } catch (e: SnowflakeSQLException) { |
| 352 | + handleSnowflakePermissionError(e) |
332 | 353 | } |
| 354 | + |
| 355 | + /** |
| 356 | + * Checks if a SnowflakeSQLException is related to permissions and wraps it as a |
| 357 | + * ConfigErrorException. Otherwise, rethrows the original exception. |
| 358 | + */ |
| 359 | + private fun handleSnowflakePermissionError(e: SnowflakeSQLException): Nothing { |
| 360 | + val errorMessage = e.message?.lowercase() ?: "" |
| 361 | + |
| 362 | + // Check for known permission-related error patterns |
| 363 | + when { |
| 364 | + errorMessage.contains("current role has no privileges on it") -> { |
| 365 | + throw ConfigErrorException(e.message ?: "Permission error", e) |
| 366 | + } |
| 367 | + else -> { |
| 368 | + // Not a known permission error, rethrow as-is |
| 369 | + throw e |
| 370 | + } |
| 371 | + } |
| 372 | + } |
333 | 373 | } |
0 commit comments