Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ SELECT isPersistentWebSocketEnabled FROM Accounts WHERE logout_reason IS NULL AN
updatePersistentWebSocketStatus:
UPDATE Accounts SET isPersistentWebSocketEnabled = :isPersistentWebSocketEnabled WHERE id = :userId;

updateAllPersistentWebSocketStatus:
UPDATE Accounts SET isPersistentWebSocketEnabled = :enabled WHERE logout_reason IS NULL;

updateSsoId:
UPDATE Accounts SET scim_external_id = :scimExternalId, subject = :subject, tenant = :tenant WHERE id = :userId;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ interface AccountsDAO {
suspend fun deleteAccount(userIDEntity: UserIDEntity)
suspend fun markAccountAsInvalid(userIDEntity: UserIDEntity, logoutReason: LogoutReason)
suspend fun updatePersistentWebSocketStatus(userIDEntity: UserIDEntity, isPersistentWebSocketEnabled: Boolean)
suspend fun setAllAccountsPersistentWebSocketEnabled(enabled: Boolean)
suspend fun persistentWebSocketStatus(userIDEntity: UserIDEntity): Boolean
suspend fun accountInfo(userIDEntity: UserIDEntity): AccountInfoEntity?
fun fullAccountInfo(userIDEntity: UserIDEntity): FullAccountEntity?
Expand Down Expand Up @@ -304,6 +305,12 @@ internal class AccountsDAOImpl internal constructor(
}
}

override suspend fun setAllAccountsPersistentWebSocketEnabled(enabled: Boolean) {
withContext(queriesContext) {
queries.updateAllPersistentWebSocketStatus(enabled)
}
}

override suspend fun persistentWebSocketStatus(userIDEntity: UserIDEntity): Boolean = withContext(queriesContext) {
queries.persistentWebSocketStatus(userIDEntity).executeAsOne()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,9 @@ import com.wire.kalium.persistence.db.GlobalDatabaseBuilder
import com.wire.kalium.persistence.model.LogoutReason
import com.wire.kalium.persistence.model.ServerConfigEntity
import com.wire.kalium.persistence.model.SsoIdEntity
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.test.StandardTestDispatcher
import kotlinx.coroutines.test.TestCoroutineScheduler
import kotlinx.coroutines.test.TestDispatcher
import kotlinx.coroutines.test.resetMain
import kotlinx.coroutines.test.runTest
import kotlinx.coroutines.test.setMain
import kotlin.test.AfterTest
import kotlin.test.BeforeTest
import kotlin.test.Test
import kotlin.test.assertEquals
Expand Down Expand Up @@ -207,6 +201,79 @@ class AccountsDAOTest : GlobalDBBaseTest() {
assertEquals(null, result)
}

@Test
fun whenUpdatingPersistentWebSocketStatus_thenStatusIsUpdated() = runTest {
val account = VALID_ACCOUNT
globalDatabaseBuilder.accountsDAO.insertOrReplace(account.info.userIDEntity, account.ssoId, account.managedBy, account.serverConfigId, false)

// initial status false
val initial = globalDatabaseBuilder.accountsDAO.persistentWebSocketStatus(account.info.userIDEntity)
assertEquals(false, initial)

// update to true
globalDatabaseBuilder.accountsDAO.updatePersistentWebSocketStatus(account.info.userIDEntity, true)
val updated = globalDatabaseBuilder.accountsDAO.persistentWebSocketStatus(account.info.userIDEntity)
assertEquals(true, updated)
}

@Test
fun whenSettingAllAccountsPersistentWebSocketEnabled_thenAllStatusesAreUpdated() = runTest {
val a1 = VALID_ACCOUNT
val a2 = VALID_ACCOUNT.copy(info = AccountInfoEntity(UserIDEntity("user2", "domain2"), null))
val a3 = VALID_ACCOUNT.copy(info = AccountInfoEntity(UserIDEntity("user3", "domain3"), null))

listOf(a1, a2, a3).forEach {
globalDatabaseBuilder.accountsDAO.insertOrReplace(it.info.userIDEntity, it.ssoId, it.managedBy, it.serverConfigId, false)
}

globalDatabaseBuilder.accountsDAO.setAllAccountsPersistentWebSocketEnabled(true)

listOf(a1, a2, a3).forEach {
val status = globalDatabaseBuilder.accountsDAO.persistentWebSocketStatus(it.info.userIDEntity)
assertEquals(true, status)
}
}

@Test
fun whenGettingAllValidAccountPersistentWebSocketStatus_thenOnlyValidAccountsIncluded() = runTest {
val valid1 = VALID_ACCOUNT
val valid2 = VALID_ACCOUNT.copy(info = AccountInfoEntity(UserIDEntity("userB", "domainB"), null))
val invalid = INVALID_ACCOUNT

// insert accounts with different initial statuses
globalDatabaseBuilder.accountsDAO.insertOrReplace(valid1.info.userIDEntity, valid1.ssoId, valid1.managedBy, valid1.serverConfigId, true)
globalDatabaseBuilder.accountsDAO.insertOrReplace(valid2.info.userIDEntity, valid2.ssoId, valid2.managedBy, valid2.serverConfigId, false)
globalDatabaseBuilder.accountsDAO.insertOrReplace(invalid.info.userIDEntity, invalid.ssoId, invalid.managedBy, invalid.serverConfigId, true)
globalDatabaseBuilder.accountsDAO.markAccountAsInvalid(invalid.info.userIDEntity, invalid.info.logoutReason!!)

val list = globalDatabaseBuilder.accountsDAO.getAllValidAccountPersistentWebSocketStatus().first()
// Should contain only the two valid accounts in any order
val ids = list.map { it.userIDEntity }.toSet()
assertEquals(setOf(valid1.info.userIDEntity, valid2.info.userIDEntity), ids)
val map = list.associateBy({ it.userIDEntity }, { it.isPersistentWebSocketEnabled })
assertEquals(true, map[valid1.info.userIDEntity])
assertEquals(false, map[valid2.info.userIDEntity])
}

@Test
fun whenRequestingValidAccountWithServerConfigId_thenReturnMapForValidAccounts() = runTest {
val valid1 = VALID_ACCOUNT
val valid2 = VALID_ACCOUNT.copy(info = AccountInfoEntity(UserIDEntity("userC", "domainC"), null))
val invalid = INVALID_ACCOUNT

listOf(valid1, valid2, invalid).forEach {
globalDatabaseBuilder.accountsDAO.insertOrReplace(it.info.userIDEntity, it.ssoId, it.managedBy, it.serverConfigId, false)
}
globalDatabaseBuilder.accountsDAO.markAccountAsInvalid(invalid.info.userIDEntity, invalid.info.logoutReason!!)

val map = globalDatabaseBuilder.accountsDAO.validAccountWithServerConfigId()
// only valid1 and valid2 should be present
assertEquals(setOf(valid1.info.userIDEntity, valid2.info.userIDEntity), map.keys)
map.values.forEach { serverConfig ->
assertEquals(SERVER_CONFIG, serverConfig)
}
}

private companion object {

val VALID_ACCOUNT = FullAccountEntity(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

package com.wire.kalium.logic

import com.wire.kalium.common.error.StorageFailure
import com.wire.kalium.common.functional.Either
import com.wire.kalium.logic.configuration.notification.NotificationTokenDataSource
import com.wire.kalium.logic.configuration.notification.NotificationTokenRepository
import com.wire.kalium.logic.configuration.server.CustomServerConfigDataSource
Expand Down Expand Up @@ -122,6 +124,9 @@ public class GlobalKaliumScope internal constructor(
public val observePersistentWebSocketConnectionStatus: ObservePersistentWebSocketConnectionStatusUseCase
get() = ObservePersistentWebSocketConnectionStatusUseCaseImpl(sessionRepository)

public suspend fun setAllPersistentWebSocketEnabled(enabled: Boolean): Either<StorageFailure, Unit> =
sessionRepository.setAllPersistentWebSocketEnabled(enabled)

private val notificationTokenRepository: NotificationTokenRepository
get() = NotificationTokenDataSource(globalPreferences.tokenStorage)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ internal interface SessionRepository {
suspend fun deleteSession(userId: UserId): Either<StorageFailure, Unit>
suspend fun ssoId(userId: UserId): Either<StorageFailure, SsoIdEntity?>
suspend fun updatePersistentWebSocketStatus(userId: UserId, isPersistentWebSocketEnabled: Boolean): Either<StorageFailure, Unit>
suspend fun setAllPersistentWebSocketEnabled(enabled: Boolean): Either<StorageFailure, Unit>
suspend fun updateSsoIdAndScimInfo(userId: UserId, ssoId: SsoId?, managedBy: ManagedByDTO?): Either<StorageFailure, Unit>
suspend fun isFederated(userId: UserId): Either<StorageFailure, Boolean>
suspend fun getAllValidAccountPersistentWebSocketStatus(): Either<StorageFailure, Flow<List<PersistentWebSocketStatus>>>
Expand Down Expand Up @@ -198,6 +199,9 @@ internal class SessionDataSource internal constructor(
accountsDAO.updatePersistentWebSocketStatus(userId.toDao(), isPersistentWebSocketEnabled)
}

override suspend fun setAllPersistentWebSocketEnabled(enabled: Boolean): Either<StorageFailure, Unit> =
wrapStorageRequest { accountsDAO.setAllAccountsPersistentWebSocketEnabled(enabled) }

override suspend fun updateSsoIdAndScimInfo(
userId: UserId,
ssoId: SsoId?,
Expand Down
Loading