diff --git a/database/src/main/scala/no/ndla/database/TableMigration.scala b/database/src/main/scala/no/ndla/database/TableMigration.scala index 908afed4e0..1ad1c27ae4 100644 --- a/database/src/main/scala/no/ndla/database/TableMigration.scala +++ b/database/src/main/scala/no/ndla/database/TableMigration.scala @@ -19,33 +19,27 @@ abstract class TableMigration[ROW_DATA] extends BaseJavaMigration { def updateRow(rowData: ROW_DATA)(implicit session: DBSession): Int lazy val tableNameSQL: SQLSyntax = SQLSyntax.createUnsafely(tableName) - private def countAllRows(implicit session: DBSession): Option[Long] = { - sql"select count(*) from $tableNameSQL where $whereClause".map(rs => rs.long("count")).single() - } - - private def allRows(offset: Long)(implicit session: DBSession): Seq[ROW_DATA] = { - sql"select * from $tableNameSQL where $whereClause order by id limit $chunkSize offset $offset" - .map(rs => extractRowData(rs)) - .list() - } - override def migrate(context: Context): Unit = DB(context.getConnection) .autoClose(false) .withinTx { session => migrateRows(using session) } - protected def migrateRows(implicit session: DBSession): Unit = { - val count = countAllRows.get - var numPagesLeft = (count / chunkSize) + 1 - var offset = 0L - - while (numPagesLeft > 0) { - allRows(offset * chunkSize).map { rowData => - updateRow(rowData) - }: Unit - numPagesLeft -= 1 - offset += 1 + protected def migrateRows(implicit session: DBSession): Unit = Iterator + .unfold(0L) { lastId => + getRowChunk(lastId) match { + case Nil => None + case chunk => Some((chunk, chunk.last._1)) + } + } + .takeWhile(_.nonEmpty) + .foreach { chunk => + chunk.foreach((_, rowData) => updateRow(rowData)) } + + private def getRowChunk(lastId: Long)(implicit session: DBSession): Seq[(Long, ROW_DATA)] = { + sql"select * from $tableNameSQL where $whereClause and id > $lastId order by id limit $chunkSize" + .map(rs => (rs.long("id"), extractRowData(rs))) + .list() } } diff --git a/database/src/test/scala/no/ndla/database/TableMigrationTest.scala b/database/src/test/scala/no/ndla/database/TableMigrationTest.scala new file mode 100644 index 0000000000..b64f0db0b3 --- /dev/null +++ b/database/src/test/scala/no/ndla/database/TableMigrationTest.scala @@ -0,0 +1,114 @@ +/* + * Part of NDLA database + * Copyright (C) 2026 NDLA + * + * See LICENSE + * + */ + +package no.ndla.database + +import no.ndla.scalatestsuite.{DatabaseIntegrationSuite, UnitTestSuite} +import org.flywaydb.core.Flyway +import scalikejdbc.* + +class TableMigrationTest extends DatabaseIntegrationSuite, UnitTestSuite, TestEnvironment { + val dataSource: DataSource = testDataSource.get + val schema: String = "testschema" + val schemaSql = SQLSyntax.createUnsafely(schema) + val tableName: String = "test" + val tableNameSql = SQLSyntax.createUnsafely(tableName) + + override def beforeAll(): Unit = { + super.beforeAll() + + dataSource.connectToDatabase() + } + + override def beforeEach(): Unit = { + super.beforeEach() + + DB.autoCommit { implicit session => + sql""" + drop schema if exists $schemaSql cascade; + create schema $schemaSql; + create table $tableNameSql (id int primary key, data text);""".execute() + } + } + + private def insertIdsFromRange(range: Range): Unit = { + DB.autoCommit { implicit session => + val sqlInsertParts = range.map(id => sqls"insert into $tableNameSql (id, data) values ($id, ${"row" + id})") + val joinedSqlInsert = SQLSyntax.join(sqlInsertParts, sqls";") + sql"$joinedSqlInsert".execute() + } + } + + private def runMigration[A](migration: TableMigration[A]): Unit = { + val flyway = Flyway + .configure() + .javaMigrations(migration) + .dataSource(dataSource) + .schemas(schema) + .baselineVersion("00") + .baselineOnMigrate(true) + .load() + + flyway.migrate() + } + + test("that all rows are updated with no where clause") { + insertIdsFromRange(1 to 50) + + class V01__Foo extends TableMigration[Long] { + override val tableName: String = TableMigrationTest.this.tableName + override lazy val whereClause: SQLSyntax = sqls"true" + override val chunkSize: Int = 10 + + override def extractRowData(rs: WrappedResultSet): Long = rs.long("id") + + override def updateRow(rowData: Long)(implicit session: DBSession): Int = { + sql"update $tableNameSql set data = ${"updated_row" + rowData} where id = $rowData".update() + } + } + + runMigration(V01__Foo()) + + DB.readOnly { implicit session => + val updatedRowsCount = sql"select count(*) from $tableNameSql where data like 'updated_row%'" + .map(_.int(1)) + .single() + .get + updatedRowsCount should be(50) + } + } + + test("that keyset pagination works correctly") { + val step = 3 + insertIdsFromRange(100 to 1 by -step) + val maxIdToUpdate = 50 + val expectedUpdateCount = (maxIdToUpdate / step) + 1 + + class V01__Foo extends TableMigration[Long] { + override val tableName: String = TableMigrationTest.this.tableName + override lazy val whereClause: SQLSyntax = sqls"id < $maxIdToUpdate" + override val chunkSize: Int = 10 + + override def extractRowData(rs: WrappedResultSet): Long = rs.long("id") + + override def updateRow(rowData: Long)(implicit session: DBSession): Int = { + sql"update $tableNameSql set data = ${"updated_row" + rowData} where id = $rowData".update() + } + } + + runMigration(V01__Foo()) + + DB.readOnly { implicit session => + val updatedIds = sql"select id from $tableNameSql where data like 'updated_row%' order by id" + .map(_.int("id")) + .list() + all(updatedIds) should be < maxIdToUpdate + updatedIds.length should be(expectedUpdateCount) + } + } +}