From 4ed9128ac3995475a612f096f8546df300905ea8 Mon Sep 17 00:00:00 2001 From: Marlon Etheredge Date: Thu, 5 Feb 2026 17:07:09 +0100 Subject: [PATCH] feat: adding runtime barrier --- .../kotlin/at/ac/uibk/dps/cirrina/Cirrina.kt | 2 - .../uibk/dps/cirrina/EnvironmentVariables.kt | 5 +- .../kotlin/at/ac/uibk/dps/cirrina/Runtime.kt | 14 +++++ .../uibk/dps/cirrina/di/CirrinaComponent.kt | 4 +- .../ac/uibk/dps/cirrina/di/CirrinaModule.kt | 14 +++-- .../cirrina/execution/object/EventHandler.kt | 4 ++ .../provider/EventHandlerInMemory.kt | 4 ++ .../execution/provider/EventHandlerZenoh.kt | 52 ++++++++++++++++++- .../at/ac/uibk/dps/cirrina/di/TestModule.kt | 4 ++ .../provider/EventHandlerZenohTest.kt | 36 +++++++++++++ 10 files changed, 125 insertions(+), 14 deletions(-) diff --git a/src/main/kotlin/at/ac/uibk/dps/cirrina/Cirrina.kt b/src/main/kotlin/at/ac/uibk/dps/cirrina/Cirrina.kt index 8e9a038b..a5d57afa 100644 --- a/src/main/kotlin/at/ac/uibk/dps/cirrina/Cirrina.kt +++ b/src/main/kotlin/at/ac/uibk/dps/cirrina/Cirrina.kt @@ -26,8 +26,6 @@ class Cirrina { } companion object { - const val ETCD_CONNECTION_TIMEOUT = 1000L - init { configureLogging() } diff --git a/src/main/kotlin/at/ac/uibk/dps/cirrina/EnvironmentVariables.kt b/src/main/kotlin/at/ac/uibk/dps/cirrina/EnvironmentVariables.kt index e6703eb6..906580d8 100644 --- a/src/main/kotlin/at/ac/uibk/dps/cirrina/EnvironmentVariables.kt +++ b/src/main/kotlin/at/ac/uibk/dps/cirrina/EnvironmentVariables.kt @@ -33,12 +33,15 @@ object EnvironmentVariables { val influxMetricOrg = EnvironmentVariable("INFLUX_METRIC_ORG", "org") val influxMetricBucket = EnvironmentVariable("INFLUX_METRIC_BUCKET", "bucket") val influxMetricToken = EnvironmentVariable("INFLUX_METRIC_TOKEN", "bzO10KmR8x") - val influxMetricStep = EnvironmentVariable("INFLUX_METRIC_STEP", 5000L) + val influxMetricStep = EnvironmentVariable("INFLUX_METRIC_STEP", 5000L, { it.toLong() }) val zipkinTraceUrl = EnvironmentVariable("ZIPKIN_TRACE_URL", null) val csmMainUri = EnvironmentVariable("CSM_MAIN_URI", "file:///app/main.pkl") + val csmBarrier = EnvironmentVariable("CSM_BARRIER", null) + val csmParties = EnvironmentVariable("CSM_PARTIES", null, { it.toInt() }) + val eventProvider = EnvironmentVariable( "EVENT_PROVIDER", diff --git a/src/main/kotlin/at/ac/uibk/dps/cirrina/Runtime.kt b/src/main/kotlin/at/ac/uibk/dps/cirrina/Runtime.kt index 8d7ba20a..81caf4a6 100644 --- a/src/main/kotlin/at/ac/uibk/dps/cirrina/Runtime.kt +++ b/src/main/kotlin/at/ac/uibk/dps/cirrina/Runtime.kt @@ -1,6 +1,7 @@ package at.ac.uibk.dps.cirrina import at.ac.uibk.dps.cirrina.cirrina.di.CsmMain +import at.ac.uibk.dps.cirrina.cirrina.di.Identifier import at.ac.uibk.dps.cirrina.execution.`object`.Context import at.ac.uibk.dps.cirrina.execution.`object`.ContextVariable import at.ac.uibk.dps.cirrina.execution.`object`.Event @@ -34,6 +35,7 @@ private val logger = KotlinLogging.logger {} class Runtime @Inject constructor( + @Identifier private val identifier: String, private val eventHandler: EventHandler, private val stateMachineFactory: StateMachine.Factory, persistentContext: Context?, @@ -122,6 +124,18 @@ constructor( stateMachineInstances[stateMachineObjectName] fun run() = runBlocking { + val barrier = EnvironmentVariables.csmBarrier.get() + val parties = EnvironmentVariables.csmParties.get() + + if (barrier != null && parties != null) { + logger.info { "waiting for barrier '$barrier' with '$parties' parties" } + + eventHandler.register(barrier, identifier) + eventHandler.wait(barrier, parties) + + logger.info { "barrier reached" } + } + measureTime { stateMachineInstances.values.forEach { it.start() } diff --git a/src/main/kotlin/at/ac/uibk/dps/cirrina/di/CirrinaComponent.kt b/src/main/kotlin/at/ac/uibk/dps/cirrina/di/CirrinaComponent.kt index 529c0d6a..f697be11 100644 --- a/src/main/kotlin/at/ac/uibk/dps/cirrina/di/CirrinaComponent.kt +++ b/src/main/kotlin/at/ac/uibk/dps/cirrina/di/CirrinaComponent.kt @@ -2,11 +2,11 @@ package at.ac.uibk.dps.cirrina.di import at.ac.uibk.dps.cirrina.Runtime import at.ac.uibk.dps.cirrina.cirrina.di.CirrinaModule +import at.ac.uibk.dps.cirrina.cirrina.di.Identifier import at.ac.uibk.dps.cirrina.execution.`object`.Context import at.ac.uibk.dps.cirrina.execution.`object`.EventHandler import dagger.Component import io.micrometer.core.instrument.MeterRegistry -import jakarta.inject.Named import jakarta.inject.Singleton @Singleton @@ -21,5 +21,5 @@ interface CirrinaComponent { fun runtime(): Runtime - @Named("identifier") fun identifier(): String + @Identifier fun identifier(): String } diff --git a/src/main/kotlin/at/ac/uibk/dps/cirrina/di/CirrinaModule.kt b/src/main/kotlin/at/ac/uibk/dps/cirrina/di/CirrinaModule.kt index 41cb90fc..de2538c8 100644 --- a/src/main/kotlin/at/ac/uibk/dps/cirrina/di/CirrinaModule.kt +++ b/src/main/kotlin/at/ac/uibk/dps/cirrina/di/CirrinaModule.kt @@ -39,15 +39,13 @@ import io.opentelemetry.sdk.resources.Resource import io.opentelemetry.sdk.trace.SdkTracerProvider import io.opentelemetry.sdk.trace.export.BatchSpanProcessor import io.opentelemetry.semconv.ServiceAttributes -import jakarta.inject.Named import jakarta.inject.Qualifier import jakarta.inject.Singleton import java.net.URI import java.time.Duration import java.util.UUID -import mu.KotlinLogging -private val logger = KotlinLogging.logger {} +@Qualifier @Retention(AnnotationRetention.RUNTIME) annotation class Identifier @Qualifier @Retention(AnnotationRetention.RUNTIME) annotation class CsmMain @@ -97,7 +95,7 @@ class CirrinaModule { @Provides @Singleton fun provideObservationRegistry( - @Named("identifier") identifier: String, + @Identifier identifier: String, meterRegistry: MeterRegistry, ): ObservationRegistry { val observationRegistry = ObservationRegistry.create() @@ -141,12 +139,12 @@ class CirrinaModule { } } + @Provides @Singleton @Identifier fun provideIdentifier(): String = "cirrina.${UUID.randomUUID()}" + @Provides @Singleton - @Named("identifier") - fun provideIdentifier(): String = "cirrina.${UUID.randomUUID()}" - - @Provides @CsmMain fun provideCsmMain(): URI = URI(EnvironmentVariables.csmMainUri.get()) + @CsmMain + fun provideCsmMain(): URI = URI(EnvironmentVariables.csmMainUri.get()) @Provides @Singleton diff --git a/src/main/kotlin/at/ac/uibk/dps/cirrina/execution/object/EventHandler.kt b/src/main/kotlin/at/ac/uibk/dps/cirrina/execution/object/EventHandler.kt index fa384bcc..f2793bf4 100644 --- a/src/main/kotlin/at/ac/uibk/dps/cirrina/execution/object/EventHandler.kt +++ b/src/main/kotlin/at/ac/uibk/dps/cirrina/execution/object/EventHandler.kt @@ -20,6 +20,10 @@ abstract class EventHandler() : AutoCloseable { abstract fun unsubscribe(source: String) + abstract fun register(group: String, member: String) + + abstract fun wait(group: String, parties: Int) + fun registerHandler(handler: PropagationHandler) { handlers.add(handler) } diff --git a/src/main/kotlin/at/ac/uibk/dps/cirrina/execution/provider/EventHandlerInMemory.kt b/src/main/kotlin/at/ac/uibk/dps/cirrina/execution/provider/EventHandlerInMemory.kt index 7248d1d6..cc71c912 100644 --- a/src/main/kotlin/at/ac/uibk/dps/cirrina/execution/provider/EventHandlerInMemory.kt +++ b/src/main/kotlin/at/ac/uibk/dps/cirrina/execution/provider/EventHandlerInMemory.kt @@ -10,5 +10,9 @@ class EventHandlerInMemory : EventHandler() { override fun unsubscribe(source: String) {} + override fun register(barrier: String, member: String) {} + + override fun wait(barrier: String, n: Int) {} + override fun close() {} } diff --git a/src/main/kotlin/at/ac/uibk/dps/cirrina/execution/provider/EventHandlerZenoh.kt b/src/main/kotlin/at/ac/uibk/dps/cirrina/execution/provider/EventHandlerZenoh.kt index 810e8d4b..eb993573 100644 --- a/src/main/kotlin/at/ac/uibk/dps/cirrina/execution/provider/EventHandlerZenoh.kt +++ b/src/main/kotlin/at/ac/uibk/dps/cirrina/execution/provider/EventHandlerZenoh.kt @@ -14,6 +14,8 @@ import io.zenoh.pubsub.Subscriber import io.zenoh.sample.Sample import java.io.File import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.CountDownLatch +import java.util.concurrent.TimeUnit class EventHandlerZenoh() : EventHandler() { private val session: Session @@ -26,7 +28,7 @@ class EventHandlerZenoh() : EventHandler() { Config.fromFile(File(uri)).getOrThrow() } ?: Config.default() - this.session = Zenoh.open(config).getOrThrow() + session = Zenoh.open(config).getOrThrow() } override fun send(event: Event) { @@ -66,6 +68,54 @@ class EventHandlerZenoh() : EventHandler() { activeSubscriptions.remove(selectorString)?.close() } + override fun register(group: String, member: String) { + val key = "liveness/$group/$member" + val keyExpr = KeyExpr.tryFrom(key).getOrThrow() + + session.liveliness().declareToken(keyExpr).onFailure { + error("failed to register liveness '$group/$member'") + } + } + + override fun wait(group: String, parties: Int) { + val key = "liveness/$group/**" + val keyExpr = KeyExpr.tryFrom(key).getOrThrow() + val discoveredMembers = ConcurrentHashMap.newKeySet() + val latch = CountDownLatch(parties) + + val sub = + session + .liveliness() + .declareSubscriber( + KeyExpr.tryFrom(key).getOrThrow(), + callback = { sample -> + if (discoveredMembers.add(sample.keyExpr.toString())) { + latch.countDown() + } + }, + ) + .getOrElse({ error("failed to subscribe to liveness '$group'") }) + + session + .liveliness() + .get( + keyExpr, + callback = { reply -> + reply.result.onSuccess { sample -> + if (discoveredMembers.add(sample.keyExpr.toString())) { + latch.countDown() + } + } + }, + ) + .onFailure { error("failed to get liveness '$group'") } + + if (!latch.await(30, TimeUnit.SECONDS)) { + error("timeout: ${discoveredMembers.size}/$parties members.") + } + sub.close() + } + private fun getZenohPath(event: Event): String? { return when (event.channel) { Csml.EventChannel.EXTERNAL -> event.source.let { "$it/${event.topic}" } diff --git a/src/test/kotlin/at/ac/uibk/dps/cirrina/di/TestModule.kt b/src/test/kotlin/at/ac/uibk/dps/cirrina/di/TestModule.kt index 694a66f0..d5bbb878 100644 --- a/src/test/kotlin/at/ac/uibk/dps/cirrina/di/TestModule.kt +++ b/src/test/kotlin/at/ac/uibk/dps/cirrina/di/TestModule.kt @@ -1,6 +1,7 @@ package at.ac.uibk.dps.cirrina.di import at.ac.uibk.dps.cirrina.cirrina.di.CsmMain +import at.ac.uibk.dps.cirrina.cirrina.di.Identifier import at.ac.uibk.dps.cirrina.execution.`object`.ActionCommandFactory import at.ac.uibk.dps.cirrina.execution.`object`.ActionCommandFactoryImpl import at.ac.uibk.dps.cirrina.execution.`object`.Context @@ -12,6 +13,7 @@ import io.micrometer.core.instrument.simple.SimpleMeterRegistry import io.micrometer.observation.ObservationRegistry import jakarta.inject.Singleton import java.net.URI +import java.util.UUID @Module class TestModule( @@ -28,6 +30,8 @@ class TestModule( @Provides fun provideObservationRegistry(): ObservationRegistry = ObservationRegistry.create() + @Provides @Singleton @Identifier fun provideIdentifier(): String = "cirrina.${UUID.randomUUID()}" + @Provides @CsmMain fun provideCsmMain() = mainUri @Provides diff --git a/src/test/kotlin/at/ac/uibk/dps/cirrina/execution/provider/EventHandlerZenohTest.kt b/src/test/kotlin/at/ac/uibk/dps/cirrina/execution/provider/EventHandlerZenohTest.kt index 4b66734d..df7049df 100644 --- a/src/test/kotlin/at/ac/uibk/dps/cirrina/execution/provider/EventHandlerZenohTest.kt +++ b/src/test/kotlin/at/ac/uibk/dps/cirrina/execution/provider/EventHandlerZenohTest.kt @@ -2,7 +2,43 @@ package at.ac.uibk.dps.cirrina.execution.provider import at.ac.uibk.dps.cirrina.execution.`object`.EventHandler import at.ac.uibk.dps.cirrina.execution.`object`.EventHandlerTest +import java.util.concurrent.atomic.AtomicInteger +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.delay +import kotlinx.coroutines.joinAll +import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withTimeout +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Test class EventHandlerZenohTest : EventHandlerTest() { override fun createEventHandler(): EventHandler = EventHandlerZenoh() + + @Test + fun testRegisterWait() = runBlocking { + val parties = 10 + val group = "group" + val completedParties = AtomicInteger(0) + + val jobs = + (1..parties).map { i -> + launch(Dispatchers.IO) { + createEventHandler().use { handler -> + val memberName = "member-$i" + + if (i > parties / 2) delay(100) + + handler.register(group, memberName) + handler.wait(group, parties) + + completedParties.incrementAndGet() + } + } + } + + withTimeout(35000) { jobs.joinAll() } + + assertEquals(parties, completedParties.get()) + } }