diff --git a/core/src/main/kotlin/com/github/andrewoma/kwery/core/dialect/Dialect.kt b/core/src/main/kotlin/com/github/andrewoma/kwery/core/dialect/Dialect.kt index 273e89f..c7c4d41 100644 --- a/core/src/main/kotlin/com/github/andrewoma/kwery/core/dialect/Dialect.kt +++ b/core/src/main/kotlin/com/github/andrewoma/kwery/core/dialect/Dialect.kt @@ -55,6 +55,8 @@ interface Dialect { .joinToString(" ") return sql + "\n" + limitAndOffset } + + fun escapeName(columnName: String): String = columnName } internal fun String.truncate(limit: Int): String { diff --git a/core/src/main/kotlin/com/github/andrewoma/kwery/core/dialect/MysqlDialect.kt b/core/src/main/kotlin/com/github/andrewoma/kwery/core/dialect/MysqlDialect.kt index 54f68d5..b79a1a2 100644 --- a/core/src/main/kotlin/com/github/andrewoma/kwery/core/dialect/MysqlDialect.kt +++ b/core/src/main/kotlin/com/github/andrewoma/kwery/core/dialect/MysqlDialect.kt @@ -47,4 +47,8 @@ open class MysqlDialect : Dialect { override fun allocateIds(count: Int, sequence: String, columnName: String) = throw UnsupportedOperationException() override val supportsFetchingGeneratedKeysByName = false + + override fun escapeName(columnName: String): String = + '`' + columnName.replace("`", "``") + '`' + } diff --git a/core/src/main/kotlin/com/github/andrewoma/kwery/core/dialect/PostgresDialect.kt b/core/src/main/kotlin/com/github/andrewoma/kwery/core/dialect/PostgresDialect.kt index cf911c8..0f28201 100644 --- a/core/src/main/kotlin/com/github/andrewoma/kwery/core/dialect/PostgresDialect.kt +++ b/core/src/main/kotlin/com/github/andrewoma/kwery/core/dialect/PostgresDialect.kt @@ -49,4 +49,8 @@ open class PostgresDialect : Dialect { "select nextval('$sequence') as $columnName from generate_series(1, $count)" override val supportsFetchingGeneratedKeysByName = true + + override fun escapeName(columnName: String): String = + '"' + columnName.replace("\"", "\"\"") + '"' + } diff --git a/core/src/main/kotlin/com/github/andrewoma/kwery/core/dialect/SqliteDialect.kt b/core/src/main/kotlin/com/github/andrewoma/kwery/core/dialect/SqliteDialect.kt index 5b788d5..742ef5d 100644 --- a/core/src/main/kotlin/com/github/andrewoma/kwery/core/dialect/SqliteDialect.kt +++ b/core/src/main/kotlin/com/github/andrewoma/kwery/core/dialect/SqliteDialect.kt @@ -47,4 +47,8 @@ open class SqliteDialect : Dialect { override fun allocateIds(count: Int, sequence: String, columnName: String) = throw UnsupportedOperationException() override val supportsFetchingGeneratedKeysByName = false + + override fun escapeName(columnName: String): String = + '"' + columnName.replace("\"", "\"\"") + '"' + } diff --git a/mapper/src/main/kotlin/com/github/andrewoma/kwery/mapper/AbstractDao.kt b/mapper/src/main/kotlin/com/github/andrewoma/kwery/mapper/AbstractDao.kt index 4d63339..b4dce41 100644 --- a/mapper/src/main/kotlin/com/github/andrewoma/kwery/mapper/AbstractDao.kt +++ b/mapper/src/main/kotlin/com/github/andrewoma/kwery/mapper/AbstractDao.kt @@ -43,10 +43,12 @@ abstract class AbstractDao( override val defaultColumns = table.defaultColumns - protected val columns = table.defaultColumns.join() + protected val columns = table.defaultColumns.joinNames() private val listeners = linkedSetOf() + private val escapedTableName = session.dialect.escapeName(table.name) + fun addListener(listener: Listener) { listeners.add(listener) } @@ -75,12 +77,16 @@ abstract class AbstractDao( return this.groupBy { it.first }.map { apply(it.key, it.value.map { it.second }) } } - protected fun Iterable>.join(separator: String = ", ", f: (Column) -> String = nf): String { - return this.map { f(it) }.joinToString(separator) + protected fun Iterable>.joinNames(separator: String = ", ", f: (Column) -> String = nf): String { + return this.joinToString(separator) { session.dialect.escapeName(f(it)) } + } + + protected fun Iterable>.joinStrings(separator: String = ", ", f: (Column) -> String = nf): String { + return this.joinToString(separator) { f(it) } } protected fun Iterable>.equate(separator: String = ", ", f: (Column) -> String = nf): String { - return this.map { "${f(it)} = :${f(it)}" }.joinToString(separator) + return this.joinToString(separator) { "${session.dialect.escapeName(f(it))} = :${f(it)}" } } protected fun Collection.copyToSqlArray(): java.sql.Array { @@ -103,7 +109,7 @@ abstract class AbstractDao( override fun findById(id: ID, columns: Set>): T? = withTransaction { val name = "findById" val sql = sql(name to columns) { - "select ${columns.join()} \nfrom ${table.name} \nwhere ${table.idColumns.equate(" and ")}" + "select ${columns.joinNames()} \nfrom $escapedTableName \nwhere ${table.idColumns.equate(" and ")}" } session.select(sql, table.idMap(session, id, nf), options(name), table.rowMapper(columns)).firstOrNull() } @@ -111,14 +117,14 @@ abstract class AbstractDao( override fun findByIdForUpdate(id: ID, columns: Set>): T? = withTransaction { val name = "findByIdForUpdate" val sql = sql(name to columns) { - "select ${columns.join()} \nfrom ${table.name} \nwhere ${table.idColumns.equate(" and ")}\nfor update" + "select ${columns.joinNames()} \nfrom $escapedTableName \nwhere ${table.idColumns.equate(" and ")}\nfor update" } session.select(sql, table.idMap(session, id, nf), options(name), table.rowMapper(columns)).firstOrNull() } override fun findAll(columns: Set>): List = withTransaction { val name = "findAll" - val sql = sql(name to columns) { "select ${columns.join()} \nfrom ${table.name}" } + val sql = sql(name to columns) { "select ${columns.joinNames()} \nfrom $escapedTableName" } session.select(sql, mapOf(), options(name), table.rowMapper(columns)) } @@ -130,7 +136,7 @@ abstract class AbstractDao( val exampleMap = table.objectMap(session, example, exampleColumns, nf) val sql = sql(Triple(name, exampleColumns, columns)) { - "select ${columns.join()} \nfrom ${table.name}\nwhere ${exampleColumns.equate(" and ")}" + "select ${columns.joinNames()} \nfrom $escapedTableName\nwhere ${exampleColumns.equate(" and ")}" } session.select(sql, exampleMap, options(name), table.rowMapper(columns)) } @@ -165,8 +171,10 @@ abstract class AbstractDao( fun delta(): Pair> { val differences = difference(oldMap, newMap) val sql = sql(name to differences) { - val columns = differences.keys.map { "$it = :$it" }.joinToString(", ") - "update ${table.name}\nset $columns \nwhere ${table.idColumns.equate(" and ")} and $versionCol = :$oldVersionParam" + val columns = differences.keys.joinToString(", ") { "$it = :$it" } + "update $escapedTableName \n" + + "set $columns \n" + + "where ${table.idColumns.equate(" and ")} and $versionCol = :$oldVersionParam" } val parameters = hashMapOfExpectedSize(differences.size + table.idColumns.size + 1) parameters.putAll(differences) @@ -177,7 +185,9 @@ abstract class AbstractDao( fun full(): Pair> { val sql = sql(name) { - "update ${table.name}\nset ${table.dataColumns.equate()} \nwhere ${table.idColumns.equate(" and ")} and $versionCol = :$oldVersionParam" + "update $escapedTableName \n" + + "set ${table.dataColumns.equate()} \n" + + "where ${table.idColumns.equate(" and ")} and $versionCol = :$oldVersionParam" } val parameters = hashMapOfExpectedSize(newMap.size + table.idColumns.size + 1) parameters.putAll(newMap) @@ -208,7 +218,7 @@ abstract class AbstractDao( override fun delete(id: ID): Int = withTransaction { val name = "delete" - val sql = sql(name) { "delete from ${table.name} where ${table.idColumns.equate(" and ")}" } + val sql = sql(name) { "delete from $escapedTableName where ${table.idColumns.equate(" and ")}" } val count = session.update(sql, table.idMap(session, id, nf), options(name)) fireEvent { DeleteEvent(table, id, null) } @@ -221,7 +231,7 @@ abstract class AbstractDao( val new = fireTransformingEvent(newValue) { PreUpdateEvent(table, id(newValue), newValue, null) } val sql = sql(name) { - "update ${table.name}\nset ${table.dataColumns.equate()} \nwhere ${table.idColumns.equate(" and ")}" + "update $escapedTableName\nset ${table.dataColumns.equate()} \nwhere ${table.idColumns.equate(" and ")}" } val newMap = table.objectMap(session, new, table.allColumns) @@ -247,7 +257,8 @@ abstract class AbstractDao( } val columns = if (generateKeys) table.dataColumns else table.allColumns - val sql = sql(name) { "insert into ${table.name}(${columns.join()}) \nvalues (${columns.join { ":${it.name}" }})" } + val sql = sql(name) { "insert into $escapedTableName (${columns.joinNames()}) \n" + + "values (${columns.joinStrings { ":${it.name}" }})" } val inserted = if (generateKeys) { val list = session.batchInsert(sql, new.map { table.objectMap(session, it, columns, nf) }, options(name), @@ -281,7 +292,8 @@ abstract class AbstractDao( val generateKeys = isGeneratedKey(new, idStrategy) val columns = if (generateKeys) table.dataColumns else table.allColumns - val sql = sql(name to columns) { "insert into ${table.name}(${columns.join()}) \nvalues (${columns.join { ":${it.name}" }})" } + val sql = sql(name to columns) { "insert into $escapedTableName (${columns.joinNames()}) \n" + + "values (${columns.joinStrings { ":${it.name}" }})" } val parameters = table.objectMap(session, new, columns, nf) val (count, inserted) = if (generateKeys) { @@ -314,7 +326,7 @@ abstract class AbstractDao( val values = if (session.dialect.supportsArrayBasedIn) { val sql = sql(name to columns) { - "select ${columns.join()} \nfrom ${table.name} \nwhere ${table.idColumns.first().name} " + + "select ${columns.joinNames()} \nfrom $escapedTableName \nwhere ${session.dialect.escapeName(table.idColumns.first().name)} " + session.dialect.arrayBasedIn("ids") } val array = ids.copyToSqlArray() @@ -325,7 +337,7 @@ abstract class AbstractDao( } } else { val sql = sql(name to columns) { - "select ${columns.join()} \nfrom ${table.name} \nwhere ${table.idColumns.first().name} in (:ids)" + "select ${columns.joinNames()} \nfrom $escapedTableName \nwhere ${session.dialect.escapeName(table.idColumns.first().name)} in (:ids)" } session.select(sql, mapOf("ids" to ids), options(name), table.rowMapper(columns)) } @@ -353,7 +365,7 @@ abstract class AbstractDao( val updates = new.map { table.objectMap(session, it, table.allColumns) } val sql = sql(name) { - "update ${table.name}\nset ${table.dataColumns.equate()} \nwhere ${table.idColumns.equate(" and ")}" + "update $escapedTableName\nset ${table.dataColumns.equate()} \nwhere ${table.idColumns.equate(" and ")}" } val counts = session.batchUpdate(sql, updates, options(name)) @@ -402,7 +414,7 @@ abstract class AbstractDao( } val sql = sql(name) { - "update ${table.name}\nset ${table.dataColumns.equate()} \nwhere ${table.idColumns.equate(" and ")} and $versionCol = :$oldVersionParam" + "update $escapedTableName\nset ${table.dataColumns.equate()} \nwhere ${table.idColumns.equate(" and ")} and $versionCol = :$oldVersionParam" } val counts = session.batchUpdate(sql, updates.map { it.first }, options(name))