Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 15 additions & 21 deletions database/src/main/scala/no/ndla/database/TableMigration.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,33 +19,27 @@ abstract class TableMigration[ROW_DATA] extends BaseJavaMigration {
def updateRow(rowData: ROW_DATA)(implicit session: DBSession): Int
lazy val tableNameSQL: SQLSyntax = SQLSyntax.createUnsafely(tableName)

private def countAllRows(implicit session: DBSession): Option[Long] = {
sql"select count(*) from $tableNameSQL where $whereClause".map(rs => rs.long("count")).single()
}

private def allRows(offset: Long)(implicit session: DBSession): Seq[ROW_DATA] = {
sql"select * from $tableNameSQL where $whereClause order by id limit $chunkSize offset $offset"
.map(rs => extractRowData(rs))
.list()
}

override def migrate(context: Context): Unit = DB(context.getConnection)
.autoClose(false)
.withinTx { session =>
migrateRows(using session)
}

protected def migrateRows(implicit session: DBSession): Unit = {
val count = countAllRows.get
var numPagesLeft = (count / chunkSize) + 1
var offset = 0L

while (numPagesLeft > 0) {
allRows(offset * chunkSize).map { rowData =>
updateRow(rowData)
}: Unit
numPagesLeft -= 1
offset += 1
protected def migrateRows(implicit session: DBSession): Unit = Iterator
.unfold(0L) { lastId =>
getRowChunk(lastId) match {
case Nil => None
case chunk => Some((chunk, chunk.last._1))
}
}
.takeWhile(_.nonEmpty)
.foreach { chunk =>
chunk.foreach((_, rowData) => updateRow(rowData))
}

private def getRowChunk(lastId: Long)(implicit session: DBSession): Seq[(Long, ROW_DATA)] = {
sql"select * from $tableNameSQL where $whereClause and id > $lastId order by id limit $chunkSize"
.map(rs => (rs.long("id"), extractRowData(rs)))
.list()
}
}
114 changes: 114 additions & 0 deletions database/src/test/scala/no/ndla/database/TableMigrationTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Part of NDLA database
* Copyright (C) 2026 NDLA
*
* See LICENSE
*
*/

package no.ndla.database

import no.ndla.scalatestsuite.{DatabaseIntegrationSuite, UnitTestSuite}
import org.flywaydb.core.Flyway
import scalikejdbc.*

class TableMigrationTest extends DatabaseIntegrationSuite, UnitTestSuite, TestEnvironment {
val dataSource: DataSource = testDataSource.get
val schema: String = "testschema"
val schemaSql = SQLSyntax.createUnsafely(schema)
val tableName: String = "test"
val tableNameSql = SQLSyntax.createUnsafely(tableName)

override def beforeAll(): Unit = {
super.beforeAll()

dataSource.connectToDatabase()
}

override def beforeEach(): Unit = {
super.beforeEach()

DB.autoCommit { implicit session =>
sql"""
drop schema if exists $schemaSql cascade;
create schema $schemaSql;
create table $tableNameSql (id int primary key, data text);""".execute()
}
}

private def insertIdsFromRange(range: Range): Unit = {
DB.autoCommit { implicit session =>
val sqlInsertParts = range.map(id => sqls"insert into $tableNameSql (id, data) values ($id, ${"row" + id})")
val joinedSqlInsert = SQLSyntax.join(sqlInsertParts, sqls";")
sql"$joinedSqlInsert".execute()
}
}

private def runMigration[A](migration: TableMigration[A]): Unit = {
val flyway = Flyway
.configure()
.javaMigrations(migration)
.dataSource(dataSource)
.schemas(schema)
.baselineVersion("00")
.baselineOnMigrate(true)
.load()

flyway.migrate()
}

test("that all rows are updated with no where clause") {
insertIdsFromRange(1 to 50)

class V01__Foo extends TableMigration[Long] {
override val tableName: String = TableMigrationTest.this.tableName
override lazy val whereClause: SQLSyntax = sqls"true"
override val chunkSize: Int = 10

override def extractRowData(rs: WrappedResultSet): Long = rs.long("id")

override def updateRow(rowData: Long)(implicit session: DBSession): Int = {
sql"update $tableNameSql set data = ${"updated_row" + rowData} where id = $rowData".update()
}
}

runMigration(V01__Foo())

DB.readOnly { implicit session =>
val updatedRowsCount = sql"select count(*) from $tableNameSql where data like 'updated_row%'"
.map(_.int(1))
.single()
.get
updatedRowsCount should be(50)
}
}

test("that keyset pagination works correctly") {
val step = 3
insertIdsFromRange(100 to 1 by -step)
val maxIdToUpdate = 50
val expectedUpdateCount = (maxIdToUpdate / step) + 1

class V01__Foo extends TableMigration[Long] {
override val tableName: String = TableMigrationTest.this.tableName
override lazy val whereClause: SQLSyntax = sqls"id < $maxIdToUpdate"
override val chunkSize: Int = 10

override def extractRowData(rs: WrappedResultSet): Long = rs.long("id")

override def updateRow(rowData: Long)(implicit session: DBSession): Int = {
sql"update $tableNameSql set data = ${"updated_row" + rowData} where id = $rowData".update()
}
}

runMigration(V01__Foo())

DB.readOnly { implicit session =>
val updatedIds = sql"select id from $tableNameSql where data like 'updated_row%' order by id"
.map(_.int("id"))
.list()
all(updatedIds) should be < maxIdToUpdate
updatedIds.length should be(expectedUpdateCount)
}
}
}
Loading