From 30091ea067f16317f3234371fb51c4c8f998913b Mon Sep 17 00:00:00 2001 From: Mohamad Jaara Date: Tue, 3 Feb 2026 12:40:43 +0100 Subject: [PATCH] feat: implement bulk update for persistent WebSocket status via MDM --- .../com/wire/kalium/persistence/Accounts.sq | 3 + .../persistence/daokaliumdb/AccountsDAO.kt | 7 ++ .../persistence/globalDB/AccountsDAOTest.kt | 81 +++++++++++++++++-- .../wire/kalium/logic/GlobalKaliumScope.kt | 5 ++ .../logic/data/session/SessionRepository.kt | 4 + 5 files changed, 93 insertions(+), 7 deletions(-) diff --git a/data/persistence/src/commonMain/db_global/com/wire/kalium/persistence/Accounts.sq b/data/persistence/src/commonMain/db_global/com/wire/kalium/persistence/Accounts.sq index 40f52b18632..0f75d4c9c3e 100644 --- a/data/persistence/src/commonMain/db_global/com/wire/kalium/persistence/Accounts.sq +++ b/data/persistence/src/commonMain/db_global/com/wire/kalium/persistence/Accounts.sq @@ -57,6 +57,9 @@ SELECT isPersistentWebSocketEnabled FROM Accounts WHERE logout_reason IS NULL AN updatePersistentWebSocketStatus: UPDATE Accounts SET isPersistentWebSocketEnabled = :isPersistentWebSocketEnabled WHERE id = :userId; +updateAllPersistentWebSocketStatus: +UPDATE Accounts SET isPersistentWebSocketEnabled = :enabled WHERE logout_reason IS NULL; + updateSsoId: UPDATE Accounts SET scim_external_id = :scimExternalId, subject = :subject, tenant = :tenant WHERE id = :userId; diff --git a/data/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/daokaliumdb/AccountsDAO.kt b/data/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/daokaliumdb/AccountsDAO.kt index f02a2a2058d..44371de022b 100644 --- a/data/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/daokaliumdb/AccountsDAO.kt +++ b/data/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/daokaliumdb/AccountsDAO.kt @@ -174,6 +174,7 @@ interface AccountsDAO { suspend fun deleteAccount(userIDEntity: UserIDEntity) suspend fun markAccountAsInvalid(userIDEntity: UserIDEntity, logoutReason: LogoutReason) suspend fun updatePersistentWebSocketStatus(userIDEntity: UserIDEntity, isPersistentWebSocketEnabled: Boolean) + suspend fun setAllAccountsPersistentWebSocketEnabled(enabled: Boolean) suspend fun persistentWebSocketStatus(userIDEntity: UserIDEntity): Boolean suspend fun accountInfo(userIDEntity: UserIDEntity): AccountInfoEntity? fun fullAccountInfo(userIDEntity: UserIDEntity): FullAccountEntity? @@ -304,6 +305,12 @@ internal class AccountsDAOImpl internal constructor( } } + override suspend fun setAllAccountsPersistentWebSocketEnabled(enabled: Boolean) { + withContext(queriesContext) { + queries.updateAllPersistentWebSocketStatus(enabled) + } + } + override suspend fun persistentWebSocketStatus(userIDEntity: UserIDEntity): Boolean = withContext(queriesContext) { queries.persistentWebSocketStatus(userIDEntity).executeAsOne() } diff --git a/data/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/globalDB/AccountsDAOTest.kt b/data/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/globalDB/AccountsDAOTest.kt index f3af6f1e3cc..bc191ebe9f4 100644 --- a/data/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/globalDB/AccountsDAOTest.kt +++ b/data/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/globalDB/AccountsDAOTest.kt @@ -29,15 +29,9 @@ import com.wire.kalium.persistence.db.GlobalDatabaseBuilder import com.wire.kalium.persistence.model.LogoutReason import com.wire.kalium.persistence.model.ServerConfigEntity import com.wire.kalium.persistence.model.SsoIdEntity -import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.flow.first import kotlinx.coroutines.ExperimentalCoroutinesApi -import kotlinx.coroutines.test.StandardTestDispatcher -import kotlinx.coroutines.test.TestCoroutineScheduler -import kotlinx.coroutines.test.TestDispatcher -import kotlinx.coroutines.test.resetMain import kotlinx.coroutines.test.runTest -import kotlinx.coroutines.test.setMain -import kotlin.test.AfterTest import kotlin.test.BeforeTest import kotlin.test.Test import kotlin.test.assertEquals @@ -207,6 +201,79 @@ class AccountsDAOTest : GlobalDBBaseTest() { assertEquals(null, result) } + @Test + fun whenUpdatingPersistentWebSocketStatus_thenStatusIsUpdated() = runTest { + val account = VALID_ACCOUNT + globalDatabaseBuilder.accountsDAO.insertOrReplace(account.info.userIDEntity, account.ssoId, account.managedBy, account.serverConfigId, false) + + // initial status false + val initial = globalDatabaseBuilder.accountsDAO.persistentWebSocketStatus(account.info.userIDEntity) + assertEquals(false, initial) + + // update to true + globalDatabaseBuilder.accountsDAO.updatePersistentWebSocketStatus(account.info.userIDEntity, true) + val updated = globalDatabaseBuilder.accountsDAO.persistentWebSocketStatus(account.info.userIDEntity) + assertEquals(true, updated) + } + + @Test + fun whenSettingAllAccountsPersistentWebSocketEnabled_thenAllStatusesAreUpdated() = runTest { + val a1 = VALID_ACCOUNT + val a2 = VALID_ACCOUNT.copy(info = AccountInfoEntity(UserIDEntity("user2", "domain2"), null)) + val a3 = VALID_ACCOUNT.copy(info = AccountInfoEntity(UserIDEntity("user3", "domain3"), null)) + + listOf(a1, a2, a3).forEach { + globalDatabaseBuilder.accountsDAO.insertOrReplace(it.info.userIDEntity, it.ssoId, it.managedBy, it.serverConfigId, false) + } + + globalDatabaseBuilder.accountsDAO.setAllAccountsPersistentWebSocketEnabled(true) + + listOf(a1, a2, a3).forEach { + val status = globalDatabaseBuilder.accountsDAO.persistentWebSocketStatus(it.info.userIDEntity) + assertEquals(true, status) + } + } + + @Test + fun whenGettingAllValidAccountPersistentWebSocketStatus_thenOnlyValidAccountsIncluded() = runTest { + val valid1 = VALID_ACCOUNT + val valid2 = VALID_ACCOUNT.copy(info = AccountInfoEntity(UserIDEntity("userB", "domainB"), null)) + val invalid = INVALID_ACCOUNT + + // insert accounts with different initial statuses + globalDatabaseBuilder.accountsDAO.insertOrReplace(valid1.info.userIDEntity, valid1.ssoId, valid1.managedBy, valid1.serverConfigId, true) + globalDatabaseBuilder.accountsDAO.insertOrReplace(valid2.info.userIDEntity, valid2.ssoId, valid2.managedBy, valid2.serverConfigId, false) + globalDatabaseBuilder.accountsDAO.insertOrReplace(invalid.info.userIDEntity, invalid.ssoId, invalid.managedBy, invalid.serverConfigId, true) + globalDatabaseBuilder.accountsDAO.markAccountAsInvalid(invalid.info.userIDEntity, invalid.info.logoutReason!!) + + val list = globalDatabaseBuilder.accountsDAO.getAllValidAccountPersistentWebSocketStatus().first() + // Should contain only the two valid accounts in any order + val ids = list.map { it.userIDEntity }.toSet() + assertEquals(setOf(valid1.info.userIDEntity, valid2.info.userIDEntity), ids) + val map = list.associateBy({ it.userIDEntity }, { it.isPersistentWebSocketEnabled }) + assertEquals(true, map[valid1.info.userIDEntity]) + assertEquals(false, map[valid2.info.userIDEntity]) + } + + @Test + fun whenRequestingValidAccountWithServerConfigId_thenReturnMapForValidAccounts() = runTest { + val valid1 = VALID_ACCOUNT + val valid2 = VALID_ACCOUNT.copy(info = AccountInfoEntity(UserIDEntity("userC", "domainC"), null)) + val invalid = INVALID_ACCOUNT + + listOf(valid1, valid2, invalid).forEach { + globalDatabaseBuilder.accountsDAO.insertOrReplace(it.info.userIDEntity, it.ssoId, it.managedBy, it.serverConfigId, false) + } + globalDatabaseBuilder.accountsDAO.markAccountAsInvalid(invalid.info.userIDEntity, invalid.info.logoutReason!!) + + val map = globalDatabaseBuilder.accountsDAO.validAccountWithServerConfigId() + // only valid1 and valid2 should be present + assertEquals(setOf(valid1.info.userIDEntity, valid2.info.userIDEntity), map.keys) + map.values.forEach { serverConfig -> + assertEquals(SERVER_CONFIG, serverConfig) + } + } + private companion object { val VALID_ACCOUNT = FullAccountEntity( diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/GlobalKaliumScope.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/GlobalKaliumScope.kt index 7b778f14083..b7917a727e2 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/GlobalKaliumScope.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/GlobalKaliumScope.kt @@ -18,6 +18,8 @@ package com.wire.kalium.logic +import com.wire.kalium.common.error.StorageFailure +import com.wire.kalium.common.functional.Either import com.wire.kalium.logic.configuration.notification.NotificationTokenDataSource import com.wire.kalium.logic.configuration.notification.NotificationTokenRepository import com.wire.kalium.logic.configuration.server.CustomServerConfigDataSource @@ -122,6 +124,9 @@ public class GlobalKaliumScope internal constructor( public val observePersistentWebSocketConnectionStatus: ObservePersistentWebSocketConnectionStatusUseCase get() = ObservePersistentWebSocketConnectionStatusUseCaseImpl(sessionRepository) + public suspend fun setAllPersistentWebSocketEnabled(enabled: Boolean): Either = + sessionRepository.setAllPersistentWebSocketEnabled(enabled) + private val notificationTokenRepository: NotificationTokenRepository get() = NotificationTokenDataSource(globalPreferences.tokenStorage) diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/session/SessionRepository.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/session/SessionRepository.kt index f8c523e0239..8967c5bc49d 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/session/SessionRepository.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/session/SessionRepository.kt @@ -77,6 +77,7 @@ internal interface SessionRepository { suspend fun deleteSession(userId: UserId): Either suspend fun ssoId(userId: UserId): Either suspend fun updatePersistentWebSocketStatus(userId: UserId, isPersistentWebSocketEnabled: Boolean): Either + suspend fun setAllPersistentWebSocketEnabled(enabled: Boolean): Either suspend fun updateSsoIdAndScimInfo(userId: UserId, ssoId: SsoId?, managedBy: ManagedByDTO?): Either suspend fun isFederated(userId: UserId): Either suspend fun getAllValidAccountPersistentWebSocketStatus(): Either>> @@ -198,6 +199,9 @@ internal class SessionDataSource internal constructor( accountsDAO.updatePersistentWebSocketStatus(userId.toDao(), isPersistentWebSocketEnabled) } + override suspend fun setAllPersistentWebSocketEnabled(enabled: Boolean): Either = + wrapStorageRequest { accountsDAO.setAllAccountsPersistentWebSocketEnabled(enabled) } + override suspend fun updateSsoIdAndScimInfo( userId: UserId, ssoId: SsoId?,