Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ interface Dialect {
.joinToString(" ")
return sql + "\n" + limitAndOffset
}

fun escapeName(columnName: String): String = columnName
}

internal fun String.truncate(limit: Int): String {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,8 @@ open class MysqlDialect : Dialect {
override fun allocateIds(count: Int, sequence: String, columnName: String) = throw UnsupportedOperationException()

override val supportsFetchingGeneratedKeysByName = false

override fun escapeName(columnName: String): String =
'`' + columnName.replace("`", "``") + '`'

}
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,8 @@ open class PostgresDialect : Dialect {
"select nextval('$sequence') as $columnName from generate_series(1, $count)"

override val supportsFetchingGeneratedKeysByName = true

override fun escapeName(columnName: String): String =
'"' + columnName.replace("\"", "\"\"") + '"'

}
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,8 @@ open class SqliteDialect : Dialect {
override fun allocateIds(count: Int, sequence: String, columnName: String) = throw UnsupportedOperationException()

override val supportsFetchingGeneratedKeysByName = false

override fun escapeName(columnName: String): String =
'"' + columnName.replace("\"", "\"\"") + '"'

}
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,12 @@ abstract class AbstractDao<T : Any, ID : Any>(

override val defaultColumns = table.defaultColumns

protected val columns = table.defaultColumns.join()
protected val columns = table.defaultColumns.joinNames()

private val listeners = linkedSetOf<Listener>()

private val escapedTableName = session.dialect.escapeName(table.name)

fun addListener(listener: Listener) {
listeners.add(listener)
}
Expand Down Expand Up @@ -75,12 +77,16 @@ abstract class AbstractDao<T : Any, ID : Any>(
return this.groupBy { it.first }.map { apply(it.key, it.value.map { it.second }) }
}

protected fun Iterable<Column<T, *>>.join(separator: String = ", ", f: (Column<T, *>) -> String = nf): String {
return this.map { f(it) }.joinToString(separator)
protected fun Iterable<Column<T, *>>.joinNames(separator: String = ", ", f: (Column<T, *>) -> String = nf): String {
return this.joinToString(separator) { session.dialect.escapeName(f(it)) }
}

protected fun Iterable<Column<T, *>>.joinStrings(separator: String = ", ", f: (Column<T, *>) -> String = nf): String {
return this.joinToString(separator) { f(it) }
}

protected fun Iterable<Column<T, *>>.equate(separator: String = ", ", f: (Column<T, *>) -> String = nf): String {
return this.map { "${f(it)} = :${f(it)}" }.joinToString(separator)
return this.joinToString(separator) { "${session.dialect.escapeName(f(it))} = :${f(it)}" }
}

protected fun Collection<ID>.copyToSqlArray(): java.sql.Array {
Expand All @@ -103,22 +109,22 @@ abstract class AbstractDao<T : Any, ID : Any>(
override fun findById(id: ID, columns: Set<Column<T, *>>): T? = withTransaction {
val name = "findById"
val sql = sql(name to columns) {
"select ${columns.join()} \nfrom ${table.name} \nwhere ${table.idColumns.equate(" and ")}"
"select ${columns.joinNames()} \nfrom $escapedTableName \nwhere ${table.idColumns.equate(" and ")}"
}
session.select(sql, table.idMap(session, id, nf), options(name), table.rowMapper(columns)).firstOrNull()
}

override fun findByIdForUpdate(id: ID, columns: Set<Column<T, *>>): T? = withTransaction {
val name = "findByIdForUpdate"
val sql = sql(name to columns) {
"select ${columns.join()} \nfrom ${table.name} \nwhere ${table.idColumns.equate(" and ")}\nfor update"
"select ${columns.joinNames()} \nfrom $escapedTableName \nwhere ${table.idColumns.equate(" and ")}\nfor update"
}
session.select(sql, table.idMap(session, id, nf), options(name), table.rowMapper(columns)).firstOrNull()
}

override fun findAll(columns: Set<Column<T, *>>): List<T> = withTransaction {
val name = "findAll"
val sql = sql(name to columns) { "select ${columns.join()} \nfrom ${table.name}" }
val sql = sql(name to columns) { "select ${columns.joinNames()} \nfrom $escapedTableName" }
session.select(sql, mapOf(), options(name), table.rowMapper(columns))
}

Expand All @@ -130,7 +136,7 @@ abstract class AbstractDao<T : Any, ID : Any>(

val exampleMap = table.objectMap(session, example, exampleColumns, nf)
val sql = sql(Triple(name, exampleColumns, columns)) {
"select ${columns.join()} \nfrom ${table.name}\nwhere ${exampleColumns.equate(" and ")}"
"select ${columns.joinNames()} \nfrom $escapedTableName\nwhere ${exampleColumns.equate(" and ")}"
}
session.select(sql, exampleMap, options(name), table.rowMapper(columns))
}
Expand Down Expand Up @@ -165,8 +171,10 @@ abstract class AbstractDao<T : Any, ID : Any>(
fun delta(): Pair<String, Map<String, Any?>> {
val differences = difference(oldMap, newMap)
val sql = sql(name to differences) {
val columns = differences.keys.map { "$it = :$it" }.joinToString(", ")
"update ${table.name}\nset $columns \nwhere ${table.idColumns.equate(" and ")} and $versionCol = :$oldVersionParam"
val columns = differences.keys.joinToString(", ") { "$it = :$it" }
"update $escapedTableName \n" +
"set $columns \n" +
"where ${table.idColumns.equate(" and ")} and $versionCol = :$oldVersionParam"
}
val parameters = hashMapOfExpectedSize<String, Any?>(differences.size + table.idColumns.size + 1)
parameters.putAll(differences)
Expand All @@ -177,7 +185,9 @@ abstract class AbstractDao<T : Any, ID : Any>(

fun full(): Pair<String, HashMap<String, Any?>> {
val sql = sql(name) {
"update ${table.name}\nset ${table.dataColumns.equate()} \nwhere ${table.idColumns.equate(" and ")} and $versionCol = :$oldVersionParam"
"update $escapedTableName \n" +
"set ${table.dataColumns.equate()} \n" +
"where ${table.idColumns.equate(" and ")} and $versionCol = :$oldVersionParam"
}
val parameters = hashMapOfExpectedSize<String, Any?>(newMap.size + table.idColumns.size + 1)
parameters.putAll(newMap)
Expand Down Expand Up @@ -208,7 +218,7 @@ abstract class AbstractDao<T : Any, ID : Any>(

override fun delete(id: ID): Int = withTransaction {
val name = "delete"
val sql = sql(name) { "delete from ${table.name} where ${table.idColumns.equate(" and ")}" }
val sql = sql(name) { "delete from $escapedTableName where ${table.idColumns.equate(" and ")}" }
val count = session.update(sql, table.idMap(session, id, nf), options(name))

fireEvent { DeleteEvent(table, id, null) }
Expand All @@ -221,7 +231,7 @@ abstract class AbstractDao<T : Any, ID : Any>(
val new = fireTransformingEvent(newValue) { PreUpdateEvent(table, id(newValue), newValue, null) }

val sql = sql(name) {
"update ${table.name}\nset ${table.dataColumns.equate()} \nwhere ${table.idColumns.equate(" and ")}"
"update $escapedTableName\nset ${table.dataColumns.equate()} \nwhere ${table.idColumns.equate(" and ")}"
}
val newMap = table.objectMap(session, new, table.allColumns)

Expand All @@ -247,7 +257,8 @@ abstract class AbstractDao<T : Any, ID : Any>(
}

val columns = if (generateKeys) table.dataColumns else table.allColumns
val sql = sql(name) { "insert into ${table.name}(${columns.join()}) \nvalues (${columns.join { ":${it.name}" }})" }
val sql = sql(name) { "insert into $escapedTableName (${columns.joinNames()}) \n" +
"values (${columns.joinStrings { ":${it.name}" }})" }

val inserted = if (generateKeys) {
val list = session.batchInsert(sql, new.map { table.objectMap(session, it, columns, nf) }, options(name),
Expand Down Expand Up @@ -281,7 +292,8 @@ abstract class AbstractDao<T : Any, ID : Any>(
val generateKeys = isGeneratedKey(new, idStrategy)

val columns = if (generateKeys) table.dataColumns else table.allColumns
val sql = sql(name to columns) { "insert into ${table.name}(${columns.join()}) \nvalues (${columns.join { ":${it.name}" }})" }
val sql = sql(name to columns) { "insert into $escapedTableName (${columns.joinNames()}) \n" +
"values (${columns.joinStrings { ":${it.name}" }})" }
val parameters = table.objectMap(session, new, columns, nf)

val (count, inserted) = if (generateKeys) {
Expand Down Expand Up @@ -314,7 +326,7 @@ abstract class AbstractDao<T : Any, ID : Any>(

val values = if (session.dialect.supportsArrayBasedIn) {
val sql = sql(name to columns) {
"select ${columns.join()} \nfrom ${table.name} \nwhere ${table.idColumns.first().name} " +
"select ${columns.joinNames()} \nfrom $escapedTableName \nwhere ${session.dialect.escapeName(table.idColumns.first().name)} " +
session.dialect.arrayBasedIn("ids")
}
val array = ids.copyToSqlArray()
Expand All @@ -325,7 +337,7 @@ abstract class AbstractDao<T : Any, ID : Any>(
}
} else {
val sql = sql(name to columns) {
"select ${columns.join()} \nfrom ${table.name} \nwhere ${table.idColumns.first().name} in (:ids)"
"select ${columns.joinNames()} \nfrom $escapedTableName \nwhere ${session.dialect.escapeName(table.idColumns.first().name)} in (:ids)"
}
session.select(sql, mapOf("ids" to ids), options(name), table.rowMapper(columns))
}
Expand Down Expand Up @@ -353,7 +365,7 @@ abstract class AbstractDao<T : Any, ID : Any>(
val updates = new.map { table.objectMap(session, it, table.allColumns) }

val sql = sql(name) {
"update ${table.name}\nset ${table.dataColumns.equate()} \nwhere ${table.idColumns.equate(" and ")}"
"update $escapedTableName\nset ${table.dataColumns.equate()} \nwhere ${table.idColumns.equate(" and ")}"
}

val counts = session.batchUpdate(sql, updates, options(name))
Expand Down Expand Up @@ -402,7 +414,7 @@ abstract class AbstractDao<T : Any, ID : Any>(
}

val sql = sql(name) {
"update ${table.name}\nset ${table.dataColumns.equate()} \nwhere ${table.idColumns.equate(" and ")} and $versionCol = :$oldVersionParam"
"update $escapedTableName\nset ${table.dataColumns.equate()} \nwhere ${table.idColumns.equate(" and ")} and $versionCol = :$oldVersionParam"
}

val counts = session.batchUpdate(sql, updates.map { it.first }, options(name))
Expand Down