diff --git a/Cargo.lock b/Cargo.lock index 9ba9822b..a0b10d8a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4410,9 +4410,9 @@ dependencies = [ [[package]] name = "rand" -version = "0.10.0" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc266eb313df6c5c09c1c7b1fbe2510961e5bcd3add930c1e31f7ed9da0feff8" +checksum = "d2e8e8bcc7961af1fdac401278c6a831614941f6164ee3bf4ce61b7edb162207" dependencies = [ "chacha20", "getrandom 0.4.2", @@ -5636,7 +5636,7 @@ dependencies = [ "clap", "futures", "futures-util", - "rand 0.10.0", + "rand 0.10.1", "reqwest 0.13.2", "rustyline", "serde", @@ -6180,7 +6180,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32497e9a4c7b38532efcdebeef879707aa9f794296a4f0244f6f69e9bc8574bd" dependencies = [ "fastrand", - "getrandom 0.4.2", + "getrandom 0.3.4", "once_cell", "rustix 1.1.4", "windows-sys 0.61.2", diff --git a/apps/skit/src/websocket_handlers.rs b/apps/skit/src/websocket_handlers.rs index 7da72e5b..ee7af00c 100644 --- a/apps/skit/src/websocket_handlers.rs +++ b/apps/skit/src/websocket_handlers.rs @@ -519,6 +519,11 @@ async fn handle_add_node( { let mut pipeline = session.pipeline.lock().await; + if pipeline.nodes.contains_key(&node_id) { + return Some(ResponsePayload::Error { + message: format!("Node '{node_id}' already exists in the pipeline"), + }); + } pipeline.nodes.insert( node_id.clone(), streamkit_api::Node { kind: kind.clone(), params: params.clone(), state: None }, @@ -1224,6 +1229,33 @@ async fn handle_apply_batch( }); } + // Pre-validate duplicate node_ids against the pipeline model. + // Simulate the batch's Add/Remove sequence so that Remove→Add for + // the same ID within the batch is allowed, but duplicate Adds + // (without intervening Remove) are rejected before any mutation. + { + let pipeline = session.pipeline.lock().await; + let mut live_ids: std::collections::HashSet<&str> = + pipeline.nodes.keys().map(String::as_str).collect(); + for op in &operations { + match op { + streamkit_api::BatchOperation::AddNode { node_id, .. } => { + if !live_ids.insert(node_id.as_str()) { + return Some(ResponsePayload::Error { + message: format!( + "Batch rejected: node '{node_id}' already exists in the pipeline" + ), + }); + } + }, + streamkit_api::BatchOperation::RemoveNode { node_id } => { + live_ids.remove(node_id.as_str()); + }, + _ => {}, + } + } + } // Pipeline lock released after pre-validation + // Validate permissions for all operations for op in &operations { if let streamkit_api::BatchOperation::AddNode { kind, params, .. } = op { diff --git a/apps/skit/tests/session_lifecycle_test.rs b/apps/skit/tests/session_lifecycle_test.rs index 37cb572d..fbc5d9a3 100644 --- a/apps/skit/tests/session_lifecycle_test.rs +++ b/apps/skit/tests/session_lifecycle_test.rs @@ -458,19 +458,15 @@ async fn test_session_destroy_shuts_down_pipeline() { println!("✅ Session created: {}", session_id); - // Add a source node (silence generator) + // Add a source node (audio::gain is a registered core node with in/out pins) let add_source_request = Request { message_type: MessageType::Request, correlation_id: Some("add-source".to_string()), payload: RequestPayload::AddNode { session_id: session_id.clone(), node_id: "source".to_string(), - kind: "silence".to_string(), - params: Some(json!({ - "duration_ms": 10000, // 10 seconds - "sample_rate": 48000, - "channels": 2 - })), + kind: "audio::gain".to_string(), + params: Some(json!({"gain": 1.0})), }, }; @@ -488,15 +484,15 @@ async fn test_session_destroy_shuts_down_pipeline() { println!("✅ Added source node"); - // Add a gain node + // Add a second gain node let add_gain_request = Request { message_type: MessageType::Request, correlation_id: Some("add-gain".to_string()), payload: RequestPayload::AddNode { session_id: session_id.clone(), node_id: "gain".to_string(), - kind: "gain".to_string(), - params: Some(json!({"gain": 1.0})), + kind: "audio::gain".to_string(), + params: Some(json!({"gain": 0.5})), }, }; @@ -562,18 +558,21 @@ async fn test_session_destroy_shuts_down_pipeline() { ResponsePayload::Pipeline { pipeline } => { assert_eq!(pipeline.nodes.len(), 2); - // Check that nodes are in Running state (not Failed or Stopped) + // Both nodes use "audio::gain" which is a registered built-in + // node type. With async creation they may still be in + // Creating state, or have progressed to Initializing/Ready. for (node_id, node) in &pipeline.nodes { if let Some(state) = &node.state { println!("Node '{}' state: {:?}", node_id, state); assert!( matches!( state, - streamkit_core::NodeState::Initializing + streamkit_core::NodeState::Creating + | streamkit_core::NodeState::Initializing | streamkit_core::NodeState::Ready | streamkit_core::NodeState::Running ), - "Node '{}' should be initializing/ready/running, got: {:?}", + "Node '{}' should be creating/initializing/ready/running, got: {:?}", node_id, state ); diff --git a/crates/core/src/state.rs b/crates/core/src/state.rs index ebab9a6a..d69279ba 100644 --- a/crates/core/src/state.rs +++ b/crates/core/src/state.rs @@ -12,6 +12,8 @@ //! Nodes transition through these states during their lifecycle: //! //! ```text +//! Creating +//! ↓ //! Initializing //! ↓ //! Ready ──────────┐ @@ -89,6 +91,8 @@ impl From for StopReason { /// Nodes transition through these states during their lifecycle: /// /// ```text +/// Creating +/// ↓ /// Initializing /// ↓ /// Ready ──────────┐ @@ -105,6 +109,8 @@ impl From for StopReason { /// ``` /// /// ### Valid Transitions: +/// - `Creating` → `Initializing` (node factory completed successfully) +/// - `Creating` → `Failed` (node factory returned an error) /// - `Initializing` → `Ready` (source nodes) or `Running` (processing nodes) /// - `Ready` → `Running` (when pipeline is ready) /// - `Running` → `Recovering` (temporary issues, will retry) @@ -120,6 +126,11 @@ impl From for StopReason { #[derive(Debug, Clone, Serialize, Deserialize, TS)] #[ts(export)] pub enum NodeState { + /// Node is being created by the factory (e.g., loading ONNX models via FFI). + /// This state is set immediately when `AddNode` is received, before the + /// (potentially slow) constructor runs in a background task. + Creating, + /// Node is starting up and performing initialization. /// Examples: Opening connections, loading resources, validating configuration. Initializing, diff --git a/crates/engine/src/dynamic_actor.rs b/crates/engine/src/dynamic_actor.rs index c0777001..97f076e0 100644 --- a/crates/engine/src/dynamic_actor.rs +++ b/crates/engine/src/dynamic_actor.rs @@ -50,6 +50,34 @@ struct NodeChannels { view_data: mpsc::Sender, } +/// Result of a background node creation task, sent back to the actor loop. +pub struct NodeCreatedEvent { + node_id: String, + kind: String, + creation_id: u64, + result: Result, StreamKitError>, +} + +/// A connection request deferred because one or both endpoints are still in +/// `Creating` state and not yet present in `live_nodes`. +#[derive(Debug)] +pub struct PendingConnection { + from_node: String, + from_pin: String, + to_node: String, + to_pin: String, + mode: crate::dynamic_messages::ConnectionMode, +} + +/// A `TuneNode` message deferred because the target node is still in +/// `Creating` state. Replayed once the node finishes initialization and +/// enters `live_nodes`. +#[derive(Debug)] +pub struct PendingTune { + node_id: String, + message: NodeControlMessage, +} + /// The state for the long-running, dynamic engine actor (Control Plane). pub struct DynamicEngine { pub(super) registry: Arc>, @@ -124,10 +152,28 @@ pub struct DynamicEngine { pub(super) node_packets_errored_counter: opentelemetry::metrics::Counter, // Node state metric (1=running, 0=not running) pub(super) node_state_gauge: opentelemetry::metrics::Gauge, + /// Sender half of the internal channel for background node creation results. + /// Cloned into each spawned creation task. + pub(super) node_created_tx: mpsc::Sender, + /// Receiver half — polled in the actor `select!` loop. + pub(super) node_created_rx: mpsc::Receiver, + /// Connections deferred because one or both endpoints are still `Creating`. + pub(super) pending_connections: Vec, + /// TuneNode messages deferred because the target node is still `Creating`. + pub(super) pending_tunes: Vec, + /// Monotonically increasing counter used to tag each spawned creation task. + /// Lets `handle_node_created` distinguish stale results (from a previous + /// Remove → re-Add cycle) from the current active creation. + pub(super) next_creation_id: u64, + /// Maps node_id → creation_id for nodes currently in `Creating` state. + /// When `NodeCreated` arrives, its `creation_id` must match the active + /// entry; otherwise the result is stale and discarded. + pub(super) active_creations: HashMap, } impl DynamicEngine { const fn node_state_name(state: &NodeState) -> &'static str { match state { + NodeState::Creating => "creating", NodeState::Initializing => "initializing", NodeState::Ready => "ready", NodeState::Running => "running", @@ -156,10 +202,13 @@ impl DynamicEngine { loop { tokio::select! { Some(control_msg) = self.control_rx.recv() => { - if !self.handle_engine_control(control_msg, &channels).await { + if !self.handle_engine_control(control_msg).await { break; // Shutdown requested } }, + Some(created) = self.node_created_rx.recv() => { + self.handle_node_created(created, &channels).await; + }, Some(query_msg) = self.query_rx.recv() => { self.handle_query(query_msg).await; }, @@ -603,7 +652,10 @@ impl DynamicEngine { } // 3. Initialize State and Stats - self.node_states.insert(node_id.to_string(), NodeState::Initializing); + // Use broadcast_state_update so the gauge transition (e.g. + // Creating → Initializing) zeroes the previous gauge and sets + // the new one atomically — no window where no gauge reads 1. + self.broadcast_state_update(node_id, NodeState::Initializing); self.node_stats.insert(node_id.to_string(), NodeStats::default()); // 4. Setup pin management channel. @@ -1472,62 +1524,374 @@ impl DynamicEngine { self.nodes_active_gauge.record(self.live_nodes.len() as u64, &[]); } + /// Handles a completed background node creation. + /// + /// On success: initializes the node, then flushes any pending connections + /// whose endpoints are now both realized. + /// On failure: transitions the node to `Failed`, drains pending connections + /// referencing the failed node. + async fn handle_node_created(&mut self, event: NodeCreatedEvent, channels: &NodeChannels) { + let NodeCreatedEvent { node_id, kind, creation_id, result } = event; + + // Check whether this creation result is still the active one. + // A mismatch means either: + // - RemoveNode was called while Creating (entry removed), or + // - Remove → re-Add happened and a newer creation superseded this one. + // In both cases, discard the stale result. + match self.active_creations.get(&node_id) { + Some(&active_id) if active_id == creation_id => { + // This is the current active creation — remove the tracking + // entry and proceed with initialization. + self.active_creations.remove(&node_id); + }, + _ => { + tracing::info!( + node = %node_id, + creation_id, + "Discarding stale/cancelled creation result" + ); + return; + }, + } + + match result { + Ok(node) => { + tracing::info!(node = %node_id, kind = %kind, "Node created successfully, initializing"); + + // initialize_node calls broadcast_state_update(Initializing) + // which reads Creating as the previous state and zeroes its + // gauge before setting Initializing to 1 — no gap. + if let Err(e) = self.initialize_node(node, &node_id, &kind, channels).await { + tracing::error!( + node_id = %node_id, + kind = %kind, + error = %e, + "Failed to initialize node after async creation" + ); + + // Broadcast Failed (reads prev state before inserting). + self.broadcast_state_update( + &node_id, + NodeState::Failed { reason: e.to_string() }, + ); + + // Clean up node_kinds (mirrors RemoveNode-while-Creating). + self.node_kinds.remove(&node_id); + + // Drain pending connections and tunes referencing this node. + self.pending_connections + .retain(|pc| pc.from_node != node_id && pc.to_node != node_id); + self.pending_tunes.retain(|pt| pt.node_id != node_id); + return; + } + + // Flush pending connections where both endpoints are now realized. + self.flush_pending_connections().await; + + // Replay any TuneNode messages that arrived while Creating. + self.flush_pending_tunes(&node_id).await; + }, + Err(e) => { + tracing::error!( + node_id = %node_id, + kind = %kind, + error = %e, + "Background node creation failed" + ); + + // Broadcast Failed (reads prev state before inserting). + self.broadcast_state_update(&node_id, NodeState::Failed { reason: e.to_string() }); + + // Clean up node_kinds (mirrors RemoveNode-while-Creating). + self.node_kinds.remove(&node_id); + + // Drain pending connections and tunes referencing this node. + self.pending_connections + .retain(|pc| pc.from_node != node_id && pc.to_node != node_id); + self.pending_tunes.retain(|pt| pt.node_id != node_id); + }, + } + } + + /// Zero-out the gauge for a specific state (one-hot pattern helper). + fn zero_state_gauge(&self, node_id: &str, state: &NodeState) { + let state_name = Self::node_state_name(state); + self.node_state_gauge.record( + 0, + &[KeyValue::new("node_id", node_id.to_owned()), KeyValue::new("state", state_name)], + ); + } + + /// Broadcast a state update to all subscribers (used when the actor itself + /// synthesizes a state transition, e.g. `Creating → Failed`). + /// + /// Reads the previous state from `node_states` **before** inserting the + /// new one, so the one-hot gauge zeroing is correct. + fn broadcast_state_update(&mut self, node_id: &str, new_state: NodeState) { + let state_name = Self::node_state_name(&new_state); + self.node_state_transitions_counter.add( + 1, + &[KeyValue::new("node_id", node_id.to_owned()), KeyValue::new("state", state_name)], + ); + + // Zero-out the previous state's gauge series (one-hot pattern), + // mirroring the logic in `handle_state_update`. + if let Some(prev_state) = self.node_states.get(node_id) { + let prev_state_name = Self::node_state_name(prev_state); + if prev_state_name != state_name { + self.node_state_gauge.record( + 0, + &[ + KeyValue::new("node_id", node_id.to_owned()), + KeyValue::new("state", prev_state_name), + ], + ); + } + } + + // Insert the new state AFTER reading the previous one. + self.node_states.insert(node_id.to_owned(), new_state.clone()); + + self.node_state_gauge.record( + 1, + &[KeyValue::new("node_id", node_id.to_owned()), KeyValue::new("state", state_name)], + ); + + let update = NodeStateUpdate::new(node_id.to_owned(), new_state); + self.state_subscribers.retain(|subscriber| match subscriber.try_send(update.clone()) { + Ok(()) => true, + Err(mpsc::error::TrySendError::Full(_)) => { + let subscriber = subscriber.clone(); + let update = update.clone(); + tokio::spawn(async move { + let _ = subscriber.send(update).await; + }); + true + }, + Err(mpsc::error::TrySendError::Closed(_)) => false, + }); + } + + /// Execute any pending connections whose both endpoints are now realized + /// (i.e., present in `live_nodes`). + async fn flush_pending_connections(&mut self) { + // Drain the vec, keeping connections that still have unrealized endpoints. + let pending = std::mem::take(&mut self.pending_connections); + let mut still_pending = Vec::new(); + + for pc in pending { + let from_realized = self.live_nodes.contains_key(&pc.from_node); + let to_realized = self.live_nodes.contains_key(&pc.to_node); + + if from_realized && to_realized { + tracing::info!( + "Replaying deferred connection {}.{} -> {}.{}", + pc.from_node, + pc.from_pin, + pc.to_node, + pc.to_pin + ); + self.connect_nodes(pc.from_node, pc.from_pin, pc.to_node, pc.to_pin, pc.mode).await; + self.check_and_activate_pipeline(); + } else { + still_pending.push(pc); + } + } + + self.pending_connections = still_pending; + } + + /// Replay any deferred `TuneNode` messages for a node that has just been + /// initialized and is now present in `live_nodes`. + async fn flush_pending_tunes(&mut self, node_id: &str) { + let (for_node, rest): (Vec<_>, Vec<_>) = std::mem::take(&mut self.pending_tunes) + .into_iter() + .partition(|pt| pt.node_id == node_id); + + self.pending_tunes = rest; + + for pt in for_node { + if let Some(node) = self.live_nodes.get(&pt.node_id) { + tracing::info!( + node_id = %pt.node_id, + "Replaying deferred TuneNode message" + ); + if node.control_tx.send(pt.message).await.is_err() { + tracing::warn!( + "Could not replay TuneNode for '{}': node may have shut down", + pt.node_id + ); + } + } + } + } + + /// Returns `true` if the node is in `Creating` state (not yet in `live_nodes`). + fn is_node_creating(&self, node_id: &str) -> bool { + matches!(self.node_states.get(node_id), Some(NodeState::Creating)) + } + /// Handles a single control message sent to the engine. /// Returns true if the engine should continue running, false if it should shut down. #[allow(clippy::cognitive_complexity)] - async fn handle_engine_control( - &mut self, - msg: EngineControlMessage, - channels: &NodeChannels, - ) -> bool { + async fn handle_engine_control(&mut self, msg: EngineControlMessage) -> bool { match msg { EngineControlMessage::AddNode { node_id, kind, params } => { self.engine_operations_counter.add(1, &[KeyValue::new("operation", "add_node")]); - tracing::info!(name = %node_id, kind = %kind, "Adding node to graph"); - let node_result = { - let registry = match self.registry.read() { - Ok(guard) => guard, - Err(err) => { - tracing::error!(error = %err, "Registry lock poisoned while adding node"); - return true; - }, - }; - registry.create_node(&kind, params.as_ref()) - }; + tracing::info!(name = %node_id, kind = %kind, "Adding node to graph (async)"); - match node_result { - Ok(node) => { - self.node_kinds.insert(node_id.clone(), kind.clone()); - if let Err(e) = self.initialize_node(node, &node_id, &kind, channels).await - { - tracing::error!( - node_id = %node_id, - kind = %kind, - error = %e, - "Failed to initialize node" - ); - } - }, - Err(e) => tracing::error!("Failed to create node '{}': {}", node_id, e), + // Reject duplicate node IDs — the node already exists in + // node_states (either Creating or fully initialized). + if self.node_states.contains_key(&node_id) { + tracing::error!( + node_id = %node_id, + kind = %kind, + "Cannot add node: a node with this ID already exists" + ); + return true; } + + // Assign a unique creation ID so handle_node_created can + // distinguish stale results from a previous Remove → re-Add + // cycle. + let creation_id = self.next_creation_id; + self.next_creation_id += 1; + self.active_creations.insert(node_id.clone(), creation_id); + + // Record kind immediately so the actor loop continues + // processing the next message without blocking. + self.node_kinds.insert(node_id.clone(), kind.clone()); + + // Insert Creating state and broadcast to subscribers. + // broadcast_state_update handles gauge + node_states insert. + self.broadcast_state_update(&node_id, NodeState::Creating); + + // Spawn background creation: `create_node` may invoke FFI + // that blocks for 10-20+ seconds (ONNX model loading). + let registry = Arc::clone(&self.registry); + let tx = self.node_created_tx.clone(); + let spawn_node_id = node_id; + let spawn_kind = kind.clone(); + tokio::spawn(async move { + let result = tokio::task::spawn_blocking(move || { + let guard = match registry.read() { + Ok(g) => g, + Err(err) => { + return Err(StreamKitError::Runtime(format!( + "Registry lock poisoned: {err}" + ))); + }, + }; + guard.create_node(&spawn_kind, params.as_ref()) + }) + .await; + + let result = match result { + Ok(inner) => inner, + Err(join_err) => Err(StreamKitError::Runtime(format!( + "Node creation task panicked: {join_err}" + ))), + }; + + let _ = tx + .send(NodeCreatedEvent { + node_id: spawn_node_id, + kind, + creation_id, + result, + }) + .await; + }); }, EngineControlMessage::RemoveNode { node_id } => { self.engine_operations_counter.add(1, &[KeyValue::new("operation", "remove_node")]); tracing::info!(name = %node_id, "Removing node from graph"); - // Delegate shutdown to helper function - self.shutdown_node(&node_id).await; + + if self.is_node_creating(&node_id) { + // Node is still being created in the background. + // Remove the active_creations entry so that when the + // background task completes, handle_node_created finds + // no matching entry and discards the result. + tracing::info!( + node_id = %node_id, + "Node is still Creating — cancelling" + ); + self.active_creations.remove(&node_id); + // Zero the gauge before removing state (mirrors shutdown_node). + self.zero_state_gauge(&node_id, &NodeState::Creating); + self.node_states.remove(&node_id); + self.node_kinds.remove(&node_id); + // Drain pending connections and tunes referencing this node. + self.pending_connections + .retain(|pc| pc.from_node != node_id && pc.to_node != node_id); + self.pending_tunes.retain(|pt| pt.node_id != node_id); + } else { + // Normal shutdown for a fully initialized node. + self.shutdown_node(&node_id).await; + } }, EngineControlMessage::Connect { from_node, from_pin, to_node, to_pin, mode } => { self.engine_operations_counter.add(1, &[KeyValue::new("operation", "connect")]); - // Delegate connection logic - self.connect_nodes(from_node, from_pin, to_node, to_pin, mode).await; - // Check if pipeline is ready to activate after connection is established - self.check_and_activate_pipeline(); + // Both endpoints must at least exist in node_states + // (Creating or fully initialized). If either is completely + // unknown, the connection would be deferred forever. + let from_exists = self.node_states.contains_key(&from_node); + let to_exists = self.node_states.contains_key(&to_node); + if !from_exists || !to_exists { + tracing::error!( + "Cannot connect {}.{} -> {}.{}: endpoint(s) not found \ + (from_exists={}, to_exists={})", + from_node, + from_pin, + to_node, + to_pin, + from_exists, + to_exists + ); + return true; + } + + // If either endpoint is still Creating, defer the connection. + let from_creating = self.is_node_creating(&from_node); + let to_creating = self.is_node_creating(&to_node); + + if from_creating || to_creating { + tracing::info!( + "Deferring connection {}.{} -> {}.{} (from_creating={}, to_creating={})", + from_node, + from_pin, + to_node, + to_pin, + from_creating, + to_creating + ); + self.pending_connections.push(PendingConnection { + from_node, + from_pin, + to_node, + to_pin, + mode, + }); + } else { + // Both endpoints are realized — connect immediately. + self.connect_nodes(from_node, from_pin, to_node, to_pin, mode).await; + self.check_and_activate_pipeline(); + } }, EngineControlMessage::Disconnect { from_node, from_pin, to_node, to_pin } => { self.engine_operations_counter.add(1, &[KeyValue::new("operation", "disconnect")]); - // Delegate disconnection logic + + // Also remove any matching deferred connection so it isn't + // replayed later by `flush_pending_connections`. + self.pending_connections.retain(|pc| { + !(pc.from_node == from_node + && pc.from_pin == from_pin + && pc.to_node == to_node + && pc.to_pin == to_pin) + }); + + // Delegate disconnection logic for realized connections. self.disconnect_nodes(from_node, from_pin, to_node, to_pin).await; }, EngineControlMessage::TuneNode { node_id, message } => { @@ -1538,6 +1902,9 @@ impl DynamicEngine { node_id ); } + } else if self.is_node_creating(&node_id) { + tracing::info!("Deferring TuneNode for '{}': still in Creating state", node_id); + self.pending_tunes.push(PendingTune { node_id, message }); } else { tracing::warn!("Could not tune non-existent node '{}'", node_id); } @@ -1545,6 +1912,13 @@ impl DynamicEngine { EngineControlMessage::Shutdown => { tracing::info!("Received shutdown signal, stopping all nodes"); + // Step 0: Clean up nodes still in Creating state. + // Clear all active_creations so any background results + // that arrive after shutdown are discarded. + self.active_creations.clear(); + self.pending_connections.clear(); + self.pending_tunes.clear(); + // Step 1: Close all input channels so nodes blocked on recv() will exit // This ensures nodes that don't check control_rx will still shut down self.node_inputs.clear(); diff --git a/crates/engine/src/lib.rs b/crates/engine/src/lib.rs index 83fba63e..f975c421 100644 --- a/crates/engine/src/lib.rs +++ b/crates/engine/src/lib.rs @@ -176,6 +176,9 @@ impl Engine { "Starting Dynamic Engine actor" ); + // Internal channel for background node creation results. + let (nc_tx, nc_rx) = mpsc::channel(64); + let meter = global::meter("skit_engine"); let dynamic_engine = DynamicEngine { registry: Arc::clone(&self.registry), @@ -236,6 +239,12 @@ impl Engine { .u64_gauge("node.state") .with_description("Node state (1=running, 0=stopped/failed)") .build(), + node_created_tx: nc_tx, + node_created_rx: nc_rx, + pending_connections: Vec::new(), + pending_tunes: Vec::new(), + next_creation_id: 0, + active_creations: std::collections::HashMap::new(), }; let engine_task = tokio::spawn(dynamic_engine.run()); diff --git a/crates/engine/src/tests/async_node_creation.rs b/crates/engine/src/tests/async_node_creation.rs new file mode 100644 index 00000000..1e3a5411 --- /dev/null +++ b/crates/engine/src/tests/async_node_creation.rs @@ -0,0 +1,1054 @@ +// SPDX-FileCopyrightText: © 2025 StreamKit Contributors +// +// SPDX-License-Identifier: MPL-2.0 + +//! Tests for async (non-blocking) node creation in the dynamic engine. +//! +//! Validates that `AddNode` no longer blocks the actor loop: node constructors +//! run inside `spawn_blocking`, connections are deferred while endpoints are +//! `Creating`, and edge cases (cancellation, failure, shutdown) are handled. + +use super::super::*; +use std::sync::{ + atomic::{AtomicBool, AtomicU32, Ordering}, + Arc, +}; +use std::time::{Duration, Instant}; +use streamkit_core::control::EngineControlMessage; +use streamkit_core::state::NodeState; +use streamkit_core::{NodeRegistry, ProcessorNode, StreamKitError}; + +// --------------------------------------------------------------------------- +// Test node implementations +// --------------------------------------------------------------------------- + +/// A node whose constructor sleeps for a configurable duration, simulating +/// heavy FFI work (e.g., ONNX model loading). Uses `std::thread::sleep` +/// because the constructor runs inside `spawn_blocking`. +struct SlowTestNode; + +impl SlowTestNode { + fn factory( + delay: Duration, + created: Arc, + ) -> impl Fn(Option<&serde_json::Value>) -> Result, StreamKitError> + + Send + + Sync + + 'static { + move |_params| { + std::thread::sleep(delay); + created.store(true, Ordering::SeqCst); + Ok(Box::new(Self) as Box) + } + } +} + +#[streamkit_core::async_trait] +impl ProcessorNode for SlowTestNode { + fn input_pins(&self) -> Vec { + vec![streamkit_core::InputPin { + name: "in".to_string(), + accepts_types: vec![streamkit_core::types::PacketType::Any], + cardinality: streamkit_core::PinCardinality::One, + }] + } + + fn output_pins(&self) -> Vec { + vec![streamkit_core::OutputPin { + name: "out".to_string(), + produces_type: streamkit_core::types::PacketType::Binary, + cardinality: streamkit_core::PinCardinality::Broadcast, + }] + } + + async fn run( + self: Box, + mut context: streamkit_core::NodeContext, + ) -> Result<(), StreamKitError> { + loop { + match context.control_rx.recv().await { + Some(streamkit_core::control::NodeControlMessage::Shutdown) | None => return Ok(()), + Some( + streamkit_core::control::NodeControlMessage::Start + | streamkit_core::control::NodeControlMessage::UpdateParams(_), + ) => {}, + } + } + } +} + +/// A slow node that records every `UpdateParams` message it receives. +/// Used to verify that TuneNode messages sent while the node is Creating +/// are queued and replayed after initialization. +struct TuneTrackingSlowNode { + tune_count: Arc, +} + +impl TuneTrackingSlowNode { + fn factory( + delay: Duration, + created: Arc, + tune_count: Arc, + ) -> impl Fn(Option<&serde_json::Value>) -> Result, StreamKitError> + + Send + + Sync + + 'static { + move |_params| { + std::thread::sleep(delay); + created.store(true, Ordering::SeqCst); + Ok(Box::new(Self { tune_count: tune_count.clone() }) as Box) + } + } +} + +#[streamkit_core::async_trait] +impl ProcessorNode for TuneTrackingSlowNode { + fn input_pins(&self) -> Vec { + vec![streamkit_core::InputPin { + name: "in".to_string(), + accepts_types: vec![streamkit_core::types::PacketType::Any], + cardinality: streamkit_core::PinCardinality::One, + }] + } + + fn output_pins(&self) -> Vec { + vec![streamkit_core::OutputPin { + name: "out".to_string(), + produces_type: streamkit_core::types::PacketType::Binary, + cardinality: streamkit_core::PinCardinality::Broadcast, + }] + } + + async fn run( + self: Box, + mut context: streamkit_core::NodeContext, + ) -> Result<(), StreamKitError> { + loop { + match context.control_rx.recv().await { + Some(streamkit_core::control::NodeControlMessage::Shutdown) | None => return Ok(()), + Some(streamkit_core::control::NodeControlMessage::UpdateParams(_)) => { + self.tune_count.fetch_add(1, Ordering::SeqCst); + }, + Some(streamkit_core::control::NodeControlMessage::Start) => {}, + } + } + } +} + +/// A simple source node (no inputs) that stays alive until shutdown. +struct SimpleSourceNode; + +#[streamkit_core::async_trait] +impl ProcessorNode for SimpleSourceNode { + fn input_pins(&self) -> Vec { + Vec::new() + } + + fn output_pins(&self) -> Vec { + vec![streamkit_core::OutputPin { + name: "out".to_string(), + produces_type: streamkit_core::types::PacketType::Binary, + cardinality: streamkit_core::PinCardinality::Broadcast, + }] + } + + async fn run( + self: Box, + mut context: streamkit_core::NodeContext, + ) -> Result<(), StreamKitError> { + loop { + match context.control_rx.recv().await { + Some(streamkit_core::control::NodeControlMessage::Shutdown) | None => return Ok(()), + Some( + streamkit_core::control::NodeControlMessage::Start + | streamkit_core::control::NodeControlMessage::UpdateParams(_), + ) => {}, + } + } + } +} + +/// A node whose constructor always fails with an error. +struct FailingConstructorNode; + +impl FailingConstructorNode { + fn factory( + ) -> impl Fn(Option<&serde_json::Value>) -> Result, StreamKitError> + + Send + + Sync + + 'static { + |_params| { + std::thread::sleep(Duration::from_millis(100)); + Err(StreamKitError::Runtime("Model loading failed: out of memory".to_string())) + } + } +} + +/// A fast node whose constructor records the creation count (for concurrency tests). +struct FastTestNode; + +impl FastTestNode { + fn factory( + counter: Arc, + ) -> impl Fn(Option<&serde_json::Value>) -> Result, StreamKitError> + + Send + + Sync + + 'static { + move |_params| { + counter.fetch_add(1, Ordering::SeqCst); + Ok(Box::new(Self) as Box) + } + } +} + +#[streamkit_core::async_trait] +impl ProcessorNode for FastTestNode { + fn input_pins(&self) -> Vec { + vec![streamkit_core::InputPin { + name: "in".to_string(), + accepts_types: vec![streamkit_core::types::PacketType::Any], + cardinality: streamkit_core::PinCardinality::One, + }] + } + + fn output_pins(&self) -> Vec { + vec![streamkit_core::OutputPin { + name: "out".to_string(), + produces_type: streamkit_core::types::PacketType::Binary, + cardinality: streamkit_core::PinCardinality::Broadcast, + }] + } + + async fn run( + self: Box, + mut context: streamkit_core::NodeContext, + ) -> Result<(), StreamKitError> { + loop { + match context.control_rx.recv().await { + Some(streamkit_core::control::NodeControlMessage::Shutdown) | None => return Ok(()), + Some( + streamkit_core::control::NodeControlMessage::Start + | streamkit_core::control::NodeControlMessage::UpdateParams(_), + ) => {}, + } + } + } +} + +// --------------------------------------------------------------------------- +// Helper: build an engine with a pre-populated registry +// --------------------------------------------------------------------------- + +fn build_engine(registry: NodeRegistry) -> (Engine, DynamicEngineHandle) { + let engine = Engine { + registry: Arc::new(std::sync::RwLock::new(registry)), + audio_pool: Arc::new(streamkit_core::AudioFramePool::audio_default()), + video_pool: Arc::new(streamkit_core::VideoFramePool::video_default()), + }; + let handle = engine.start_dynamic_actor(DynamicEngineConfig::default()); + (engine, handle) +} + +/// Poll `handle.get_node_states()` until a predicate holds, with timeout. +async fn wait_for_states(handle: &DynamicEngineHandle, timeout_dur: Duration, pred: F) -> bool +where + F: Fn(&std::collections::HashMap) -> bool, +{ + let deadline = Instant::now() + timeout_dur; + while Instant::now() < deadline { + if let Ok(states) = handle.get_node_states().await { + if pred(&states) { + return true; + } + } + tokio::time::sleep(Duration::from_millis(20)).await; + } + false +} + +// --------------------------------------------------------------------------- +// Test 1: Basic async creation — actor can process other AddNode messages +// while a slow node is being created. +// --------------------------------------------------------------------------- + +#[tokio::test] +#[allow(clippy::expect_used)] +async fn test_basic_async_creation() { + let slow_created = Arc::new(AtomicBool::new(false)); + + let mut registry = NodeRegistry::new(); + registry.register_dynamic( + "test::slow", + SlowTestNode::factory(Duration::from_secs(1), slow_created.clone()), + serde_json::json!({}), + vec!["test".to_string()], + false, + ); + + let fast_counter = Arc::new(AtomicU32::new(0)); + registry.register_dynamic( + "test::fast", + FastTestNode::factory(fast_counter.clone()), + serde_json::json!({}), + vec!["test".to_string()], + false, + ); + + let (_engine, handle) = build_engine(registry); + + // Add slow node first, then fast node immediately after. + handle + .send_control(EngineControlMessage::AddNode { + node_id: "slow".to_string(), + kind: "test::slow".to_string(), + params: None, + }) + .await + .expect("send AddNode slow"); + + handle + .send_control(EngineControlMessage::AddNode { + node_id: "fast".to_string(), + kind: "test::fast".to_string(), + params: None, + }) + .await + .expect("send AddNode fast"); + + // The fast node should become available (past Creating) well before the + // slow node finishes its 1-second sleep. + let fast_ready = wait_for_states(&handle, Duration::from_secs(3), |states| { + states.get("fast").is_some_and(|s| !matches!(s, NodeState::Creating)) + }) + .await; + assert!(fast_ready, "fast node should leave Creating before slow node finishes"); + + // At this point the slow node should still be Creating (or just finishing). + // Wait for it to also complete. + let slow_ready = wait_for_states(&handle, Duration::from_secs(5), |states| { + states.get("slow").is_some_and(|s| !matches!(s, NodeState::Creating)) + }) + .await; + assert!(slow_ready, "slow node should eventually leave Creating"); + assert!(slow_created.load(Ordering::SeqCst), "slow constructor should have run"); + + handle.shutdown_and_wait().await.expect("shutdown"); +} + +// --------------------------------------------------------------------------- +// Test 2: Deferred connections — Connect sent before slow node finishes +// is replayed after creation completes. +// --------------------------------------------------------------------------- + +#[tokio::test] +#[allow(clippy::expect_used)] +async fn test_deferred_connections() { + let slow_created = Arc::new(AtomicBool::new(false)); + + let mut registry = NodeRegistry::new(); + registry.register_dynamic( + "test::source", + |_params| Ok(Box::new(SimpleSourceNode) as Box), + serde_json::json!({}), + vec!["test".to_string()], + false, + ); + registry.register_dynamic( + "test::slow", + SlowTestNode::factory(Duration::from_millis(500), slow_created.clone()), + serde_json::json!({}), + vec!["test".to_string()], + false, + ); + + let (_engine, handle) = build_engine(registry); + + // Add both nodes. + handle + .send_control(EngineControlMessage::AddNode { + node_id: "src".to_string(), + kind: "test::source".to_string(), + params: None, + }) + .await + .expect("add source"); + + handle + .send_control(EngineControlMessage::AddNode { + node_id: "slow".to_string(), + kind: "test::slow".to_string(), + params: None, + }) + .await + .expect("add slow"); + + // Connect immediately — slow node is still Creating. + handle + .send_control(EngineControlMessage::Connect { + from_node: "src".to_string(), + from_pin: "out".to_string(), + to_node: "slow".to_string(), + to_pin: "in".to_string(), + mode: streamkit_core::control::ConnectionMode::Reliable, + }) + .await + .expect("connect"); + + // After slow node finishes, both should be initialized and the deferred + // connection should have been replayed. + let both_ready = wait_for_states(&handle, Duration::from_secs(5), |states| { + let src_ok = states.get("src").is_some_and(|s| { + matches!(s, NodeState::Ready | NodeState::Running | NodeState::Initializing) + }); + let slow_ok = states.get("slow").is_some_and(|s| { + matches!(s, NodeState::Ready | NodeState::Running | NodeState::Initializing) + }); + src_ok && slow_ok + }) + .await; + assert!(both_ready, "both nodes should be initialized after deferred connection replay"); + + handle.shutdown_and_wait().await.expect("shutdown"); +} + +// --------------------------------------------------------------------------- +// Test 3: Multiple slow nodes — they should be created concurrently +// (total time ≈ max, not sum). +// --------------------------------------------------------------------------- + +#[tokio::test] +#[allow(clippy::expect_used)] +async fn test_multiple_slow_nodes_concurrent() { + let created_a = Arc::new(AtomicBool::new(false)); + let created_b = Arc::new(AtomicBool::new(false)); + let created_c = Arc::new(AtomicBool::new(false)); + + let mut registry = NodeRegistry::new(); + registry.register_dynamic( + "test::slow_a", + SlowTestNode::factory(Duration::from_millis(500), created_a.clone()), + serde_json::json!({}), + vec!["test".to_string()], + false, + ); + registry.register_dynamic( + "test::slow_b", + SlowTestNode::factory(Duration::from_millis(500), created_b.clone()), + serde_json::json!({}), + vec!["test".to_string()], + false, + ); + registry.register_dynamic( + "test::slow_c", + SlowTestNode::factory(Duration::from_millis(500), created_c.clone()), + serde_json::json!({}), + vec!["test".to_string()], + false, + ); + + let (_engine, handle) = build_engine(registry); + + let start = Instant::now(); + + for (id, kind) in [("a", "test::slow_a"), ("b", "test::slow_b"), ("c", "test::slow_c")] { + handle + .send_control(EngineControlMessage::AddNode { + node_id: id.to_string(), + kind: kind.to_string(), + params: None, + }) + .await + .expect("add node"); + } + + // Wait for all three to leave Creating. + let all_done = wait_for_states(&handle, Duration::from_secs(5), |states| { + ["a", "b", "c"] + .iter() + .all(|id| states.get(*id).is_some_and(|s| !matches!(s, NodeState::Creating))) + }) + .await; + assert!(all_done, "all three slow nodes should finish creation"); + + let elapsed = start.elapsed(); + // If sequential, ~1.5s; if concurrent, ~0.5s + overhead. + assert!( + elapsed < Duration::from_millis(1200), + "3 x 500ms nodes should complete in ~500ms (concurrent), but took {elapsed:?}", + ); + + assert!(created_a.load(Ordering::SeqCst)); + assert!(created_b.load(Ordering::SeqCst)); + assert!(created_c.load(Ordering::SeqCst)); + + handle.shutdown_and_wait().await.expect("shutdown"); +} + +// --------------------------------------------------------------------------- +// Test 4: Creation failure — node transitions to Failed, pending connections +// referencing it are drained. +// --------------------------------------------------------------------------- + +#[tokio::test] +#[allow(clippy::expect_used)] +async fn test_creation_failure() { + let mut registry = NodeRegistry::new(); + registry.register_dynamic( + "test::source", + |_params| Ok(Box::new(SimpleSourceNode) as Box), + serde_json::json!({}), + vec!["test".to_string()], + false, + ); + registry.register_dynamic( + "test::failing", + FailingConstructorNode::factory(), + serde_json::json!({}), + vec!["test".to_string()], + false, + ); + + let (_engine, handle) = build_engine(registry); + + handle + .send_control(EngineControlMessage::AddNode { + node_id: "src".to_string(), + kind: "test::source".to_string(), + params: None, + }) + .await + .expect("add source"); + + handle + .send_control(EngineControlMessage::AddNode { + node_id: "bad".to_string(), + kind: "test::failing".to_string(), + params: None, + }) + .await + .expect("add failing"); + + // Queue a connection to the failing node. + handle + .send_control(EngineControlMessage::Connect { + from_node: "src".to_string(), + from_pin: "out".to_string(), + to_node: "bad".to_string(), + to_pin: "in".to_string(), + mode: streamkit_core::control::ConnectionMode::Reliable, + }) + .await + .expect("connect"); + + // The failing node should transition to Failed. + let failed = wait_for_states(&handle, Duration::from_secs(3), |states| { + matches!(states.get("bad"), Some(NodeState::Failed { .. })) + }) + .await; + assert!(failed, "failing node should transition to Failed"); + + // Source node should still be fine. + let states = handle.get_node_states().await.expect("get states"); + assert!( + states.get("src").is_some_and(|s| !matches!(s, NodeState::Failed { .. })), + "source node should be unaffected" + ); + + handle.shutdown_and_wait().await.expect("shutdown"); +} + +// --------------------------------------------------------------------------- +// Test 5: RemoveNode while Creating — background result is discarded. +// --------------------------------------------------------------------------- + +#[tokio::test] +#[allow(clippy::expect_used)] +async fn test_remove_node_while_creating() { + let slow_created = Arc::new(AtomicBool::new(false)); + + let mut registry = NodeRegistry::new(); + registry.register_dynamic( + "test::slow", + SlowTestNode::factory(Duration::from_secs(1), slow_created.clone()), + serde_json::json!({}), + vec!["test".to_string()], + false, + ); + + let (_engine, handle) = build_engine(registry); + + // Add slow node. + handle + .send_control(EngineControlMessage::AddNode { + node_id: "doomed".to_string(), + kind: "test::slow".to_string(), + params: None, + }) + .await + .expect("add doomed"); + + // Give the actor a moment to process AddNode and set Creating state. + tokio::time::sleep(Duration::from_millis(50)).await; + + // Remove it while still Creating. + handle + .send_control(EngineControlMessage::RemoveNode { node_id: "doomed".to_string() }) + .await + .expect("remove doomed"); + + // Wait for the background creation to complete (1s), then verify + // the node was NOT added to the engine. + tokio::time::sleep(Duration::from_millis(1500)).await; + + let states = handle.get_node_states().await.expect("get states"); + assert!( + !states.contains_key("doomed"), + "removed-while-Creating node should not appear in states" + ); + + // The constructor did run (it was already spawned), but the result + // should have been discarded. + assert!( + slow_created.load(Ordering::SeqCst), + "constructor runs to completion even if cancelled" + ); + + handle.shutdown_and_wait().await.expect("shutdown"); +} + +// --------------------------------------------------------------------------- +// Test 6: Pipeline activation timing — source nodes do NOT activate until +// all nodes leave Creating state. +// --------------------------------------------------------------------------- + +#[tokio::test] +#[allow(clippy::expect_used)] +async fn test_pipeline_activation_timing() { + let slow_created = Arc::new(AtomicBool::new(false)); + + let mut registry = NodeRegistry::new(); + registry.register_dynamic( + "test::source", + |_params| Ok(Box::new(SimpleSourceNode) as Box), + serde_json::json!({}), + vec!["test".to_string()], + false, + ); + registry.register_dynamic( + "test::slow", + SlowTestNode::factory(Duration::from_millis(800), slow_created.clone()), + serde_json::json!({}), + vec!["test".to_string()], + false, + ); + + let (_engine, handle) = build_engine(registry); + + // Subscribe to state updates to observe activation. + let mut state_rx = handle.subscribe_state().await.expect("subscribe"); + + handle + .send_control(EngineControlMessage::AddNode { + node_id: "src".to_string(), + kind: "test::source".to_string(), + params: None, + }) + .await + .expect("add source"); + + handle + .send_control(EngineControlMessage::AddNode { + node_id: "proc".to_string(), + kind: "test::slow".to_string(), + params: None, + }) + .await + .expect("add slow processor"); + + handle + .send_control(EngineControlMessage::Connect { + from_node: "src".to_string(), + from_pin: "out".to_string(), + to_node: "proc".to_string(), + to_pin: "in".to_string(), + mode: streamkit_core::control::ConnectionMode::Reliable, + }) + .await + .expect("connect"); + + // Drain state updates; verify source doesn't go to Running before slow node + // leaves Creating. + let mut slow_left_creating = false; + let mut src_ran_before_slow_ready = false; + + let drain_deadline = Instant::now() + Duration::from_secs(5); + while Instant::now() < drain_deadline { + match tokio::time::timeout(Duration::from_millis(100), state_rx.recv()).await { + Ok(Some(update)) => { + if update.node_id == "proc" && !matches!(update.state, NodeState::Creating) { + slow_left_creating = true; + } + if update.node_id == "src" && matches!(update.state, NodeState::Running) { + if !slow_left_creating { + src_ran_before_slow_ready = true; + } + break; + } + }, + _ => { + if slow_left_creating { + break; + } + }, + } + } + + assert!( + !src_ran_before_slow_ready, + "source should NOT start Running before slow node leaves Creating" + ); + + handle.shutdown_and_wait().await.expect("shutdown"); +} + +// --------------------------------------------------------------------------- +// Test 7: Duplicate AddNode — second call is rejected. +// --------------------------------------------------------------------------- + +#[tokio::test] +#[allow(clippy::expect_used)] +async fn test_duplicate_add_node() { + let slow_created = Arc::new(AtomicBool::new(false)); + + let mut registry = NodeRegistry::new(); + registry.register_dynamic( + "test::slow", + SlowTestNode::factory(Duration::from_millis(500), slow_created.clone()), + serde_json::json!({}), + vec!["test".to_string()], + false, + ); + + let (_engine, handle) = build_engine(registry); + + // Add node. + handle + .send_control(EngineControlMessage::AddNode { + node_id: "dup".to_string(), + kind: "test::slow".to_string(), + params: None, + }) + .await + .expect("add first"); + + // Give actor time to process. + tokio::time::sleep(Duration::from_millis(50)).await; + + // Try adding the same node_id again. + handle + .send_control(EngineControlMessage::AddNode { + node_id: "dup".to_string(), + kind: "test::slow".to_string(), + params: None, + }) + .await + .expect("add duplicate"); + + // Wait for the original to finish. + let done = wait_for_states(&handle, Duration::from_secs(3), |states| { + states.get("dup").is_some_and(|s| !matches!(s, NodeState::Creating)) + }) + .await; + assert!(done, "original node should finish creating"); + + // The engine should still be responsive (no double-init crash). + let states = handle.get_node_states().await.expect("get states"); + assert!(states.contains_key("dup"), "node should exist"); + + handle.shutdown_and_wait().await.expect("shutdown"); +} + +// --------------------------------------------------------------------------- +// Test 8: RemoveNode then re-AddNode with same ID — new node is created +// correctly and old creation result is discarded. +// --------------------------------------------------------------------------- + +#[tokio::test] +#[allow(clippy::expect_used)] +async fn test_remove_then_readd_same_id() { + let created_v1 = Arc::new(AtomicBool::new(false)); + let created_v2 = Arc::new(AtomicBool::new(false)); + + let mut registry = NodeRegistry::new(); + registry.register_dynamic( + "test::slow_v1", + SlowTestNode::factory(Duration::from_secs(1), created_v1.clone()), + serde_json::json!({}), + vec!["test".to_string()], + false, + ); + registry.register_dynamic( + "test::fast_v2", + SlowTestNode::factory(Duration::from_millis(50), created_v2.clone()), + serde_json::json!({}), + vec!["test".to_string()], + false, + ); + + let (_engine, handle) = build_engine(registry); + + // Add slow v1. + handle + .send_control(EngineControlMessage::AddNode { + node_id: "node".to_string(), + kind: "test::slow_v1".to_string(), + params: None, + }) + .await + .expect("add v1"); + + tokio::time::sleep(Duration::from_millis(50)).await; + + // Remove while Creating. + handle + .send_control(EngineControlMessage::RemoveNode { node_id: "node".to_string() }) + .await + .expect("remove"); + + tokio::time::sleep(Duration::from_millis(50)).await; + + // Re-add with a different (fast) kind. + handle + .send_control(EngineControlMessage::AddNode { + node_id: "node".to_string(), + kind: "test::fast_v2".to_string(), + params: None, + }) + .await + .expect("add v2"); + + // Wait for v2 to finish. + let v2_done = wait_for_states(&handle, Duration::from_secs(3), |states| { + states.get("node").is_some_and(|s| !matches!(s, NodeState::Creating)) + }) + .await; + assert!(v2_done, "v2 node should finish creating"); + + // v2 should have been created. + assert!(created_v2.load(Ordering::SeqCst), "v2 constructor should have run"); + + // Wait for v1 background task to also complete (it was already spawned). + tokio::time::sleep(Duration::from_millis(1200)).await; + assert!(created_v1.load(Ordering::SeqCst), "v1 constructor runs to completion"); + + // The node should be the v2 version — verify it's not in Creating or + // Failed state (it should be fully initialized). + let states = handle.get_node_states().await.expect("get states"); + assert!(states.contains_key("node"), "node should exist"); + let state = states.get("node").expect("node state"); + assert!( + !matches!(state, NodeState::Creating | NodeState::Failed { .. }), + "v2 node should be initialized, not Creating/Failed, got: {state:?}" + ); + + handle.shutdown_and_wait().await.expect("shutdown"); +} + +// --------------------------------------------------------------------------- +// Test 9: Shutdown while Creating — clean shutdown, no panics. +// --------------------------------------------------------------------------- + +#[tokio::test] +#[allow(clippy::expect_used)] +async fn test_shutdown_while_creating() { + let slow_created = Arc::new(AtomicBool::new(false)); + + let mut registry = NodeRegistry::new(); + registry.register_dynamic( + "test::slow", + SlowTestNode::factory(Duration::from_secs(2), slow_created.clone()), + serde_json::json!({}), + vec!["test".to_string()], + false, + ); + + let (_engine, handle) = build_engine(registry); + + handle + .send_control(EngineControlMessage::AddNode { + node_id: "slow".to_string(), + kind: "test::slow".to_string(), + params: None, + }) + .await + .expect("add slow"); + + // Give actor time to set Creating state. + tokio::time::sleep(Duration::from_millis(50)).await; + + // Shutdown while slow node is still Creating. + let result = handle.shutdown_and_wait().await; + assert!(result.is_ok(), "shutdown should complete cleanly: {result:?}"); +} + +// --------------------------------------------------------------------------- +// Test 10: Connect with one realized, one creating — connection is deferred +// and replayed correctly. +// --------------------------------------------------------------------------- + +#[tokio::test] +#[allow(clippy::expect_used)] +async fn test_connect_one_realized_one_creating() { + let slow_created = Arc::new(AtomicBool::new(false)); + + let mut registry = NodeRegistry::new(); + registry.register_dynamic( + "test::source", + |_params| Ok(Box::new(SimpleSourceNode) as Box), + serde_json::json!({}), + vec!["test".to_string()], + false, + ); + registry.register_dynamic( + "test::slow", + SlowTestNode::factory(Duration::from_millis(500), slow_created.clone()), + serde_json::json!({}), + vec!["test".to_string()], + false, + ); + + let (_engine, handle) = build_engine(registry); + + // Add source (fast) — it will be realized quickly. + handle + .send_control(EngineControlMessage::AddNode { + node_id: "source".to_string(), + kind: "test::source".to_string(), + params: None, + }) + .await + .expect("add source"); + + // Wait for source to leave Creating. + let source_ready = wait_for_states(&handle, Duration::from_secs(2), |states| { + states.get("source").is_some_and(|s| !matches!(s, NodeState::Creating)) + }) + .await; + assert!(source_ready, "source should be realized quickly"); + + // Now add slow node. + handle + .send_control(EngineControlMessage::AddNode { + node_id: "slow".to_string(), + kind: "test::slow".to_string(), + params: None, + }) + .await + .expect("add slow"); + + // Connect while source is realized but slow is still Creating. + handle + .send_control(EngineControlMessage::Connect { + from_node: "source".to_string(), + from_pin: "out".to_string(), + to_node: "slow".to_string(), + to_pin: "in".to_string(), + mode: streamkit_core::control::ConnectionMode::Reliable, + }) + .await + .expect("connect"); + + // Wait for slow node to finish and verify both are initialized. + let both_done = wait_for_states(&handle, Duration::from_secs(5), |states| { + let source_ok = states.get("source").is_some_and(|s| { + matches!(s, NodeState::Ready | NodeState::Running | NodeState::Initializing) + }); + let slow_ok = states.get("slow").is_some_and(|s| { + matches!(s, NodeState::Ready | NodeState::Running | NodeState::Initializing) + }); + source_ok && slow_ok + }) + .await; + assert!(both_done, "both nodes should be initialized after deferred connection is replayed"); + + handle.shutdown_and_wait().await.expect("shutdown"); +} + +// --------------------------------------------------------------------------- +// Test 11: TuneNode messages sent while a node is Creating are queued and +// replayed after initialization completes. +// --------------------------------------------------------------------------- + +#[tokio::test] +#[allow(clippy::expect_used)] +async fn test_tune_node_queued_while_creating() { + let created = Arc::new(AtomicBool::new(false)); + let tune_count = Arc::new(AtomicU32::new(0)); + + let mut registry = NodeRegistry::new(); + registry.register_dynamic( + "test::tune_tracking_slow", + TuneTrackingSlowNode::factory(Duration::from_secs(1), created.clone(), tune_count.clone()), + serde_json::json!({}), + vec!["test".to_string()], + false, + ); + + let (_engine, handle) = build_engine(registry); + + // Add the slow node. + handle + .send_control(EngineControlMessage::AddNode { + node_id: "tracked".to_string(), + kind: "test::tune_tracking_slow".to_string(), + params: None, + }) + .await + .expect("add tracked"); + + // Verify it's still Creating (constructor sleeps 1s). + tokio::time::sleep(Duration::from_millis(50)).await; + assert!(!created.load(Ordering::SeqCst), "node should still be creating"); + + // Send two TuneNode messages while the node is Creating. + handle + .send_control(EngineControlMessage::TuneNode { + node_id: "tracked".to_string(), + message: streamkit_core::control::NodeControlMessage::UpdateParams( + serde_json::json!({"gain": 0.5}), + ), + }) + .await + .expect("tune 1"); + + handle + .send_control(EngineControlMessage::TuneNode { + node_id: "tracked".to_string(), + message: streamkit_core::control::NodeControlMessage::UpdateParams( + serde_json::json!({"gain": 0.8}), + ), + }) + .await + .expect("tune 2"); + + // Wait for the node to finish creation and initialization. + let initialized = wait_for_states(&handle, Duration::from_secs(5), |states| { + states.get("tracked").is_some_and(|s| { + matches!(s, NodeState::Ready | NodeState::Running | NodeState::Initializing) + }) + }) + .await; + assert!(initialized, "node should be initialized"); + + // Give a moment for the queued tunes to be replayed and processed. + tokio::time::sleep(Duration::from_millis(200)).await; + + // Both UpdateParams messages should have been delivered. + assert_eq!( + tune_count.load(Ordering::SeqCst), + 2, + "node should have received both queued TuneNode messages" + ); + + handle.shutdown_and_wait().await.expect("shutdown"); +} diff --git a/crates/engine/src/tests/connection_types.rs b/crates/engine/src/tests/connection_types.rs index 45f62e94..ddb962bf 100644 --- a/crates/engine/src/tests/connection_types.rs +++ b/crates/engine/src/tests/connection_types.rs @@ -21,6 +21,8 @@ fn create_test_engine() -> DynamicEngine { drop(control_tx); drop(query_tx); + let (node_created_tx, node_created_rx) = mpsc::channel(32); + let meter = opentelemetry::global::meter("test"); DynamicEngine { registry: std::sync::Arc::new(std::sync::RwLock::new(NodeRegistry::new())), @@ -57,6 +59,12 @@ fn create_test_engine() -> DynamicEngine { node_state_gauge: meter.u64_gauge("test.state").build(), runtime_schemas: HashMap::new(), runtime_schema_subscribers: Vec::new(), + node_created_tx, + node_created_rx, + pending_connections: Vec::new(), + pending_tunes: Vec::new(), + next_creation_id: 0, + active_creations: std::collections::HashMap::new(), } } diff --git a/crates/engine/src/tests/mod.rs b/crates/engine/src/tests/mod.rs index 0a9b9bf5..c35ece7f 100644 --- a/crates/engine/src/tests/mod.rs +++ b/crates/engine/src/tests/mod.rs @@ -4,6 +4,8 @@ //! Unit tests for the engine crate. +#[cfg(feature = "dynamic")] +mod async_node_creation; #[cfg(feature = "dynamic")] mod connection_types; #[cfg(feature = "dynamic")] diff --git a/crates/engine/src/tests/pipeline_activation.rs b/crates/engine/src/tests/pipeline_activation.rs index 7792d7e1..c7e78e3a 100644 --- a/crates/engine/src/tests/pipeline_activation.rs +++ b/crates/engine/src/tests/pipeline_activation.rs @@ -22,6 +22,8 @@ fn create_test_engine() -> DynamicEngine { drop(control_tx); drop(query_tx); + let (nc_tx, nc_rx) = mpsc::channel(32); + let meter = opentelemetry::global::meter("test"); DynamicEngine { registry: std::sync::Arc::new(std::sync::RwLock::new(NodeRegistry::new())), @@ -58,6 +60,12 @@ fn create_test_engine() -> DynamicEngine { node_state_gauge: meter.u64_gauge("test.state").build(), runtime_schemas: HashMap::new(), runtime_schema_subscribers: Vec::new(), + node_created_tx: nc_tx, + node_created_rx: nc_rx, + pending_connections: Vec::new(), + pending_tunes: Vec::new(), + next_creation_id: 0, + active_creations: std::collections::HashMap::new(), } } @@ -270,6 +278,19 @@ async fn test_activation_blocked_by_failed_node() { ); } +/// Source node should NOT receive Start while any node is still Creating. +#[tokio::test] +async fn test_activation_blocked_by_creating_node() { + let mut engine = create_test_engine(); + let mut source_rx = add_source_node(&mut engine, "source", NodeState::Ready); + add_processor_node(&mut engine, "slow_node", NodeState::Creating); + + engine.check_and_activate_pipeline(); + + let msg = source_rx.try_recv(); + assert!(msg.is_err(), "source node should NOT receive Start while a node is still Creating"); +} + /// Source node should NOT receive Start when a downstream node has Stopped. #[tokio::test] async fn test_activation_blocked_by_stopped_node() { diff --git a/ui/src/components/NodeStateIndicator.tsx b/ui/src/components/NodeStateIndicator.tsx index 97d6c80f..725b2d1a 100644 --- a/ui/src/components/NodeStateIndicator.tsx +++ b/ui/src/components/NodeStateIndicator.tsx @@ -152,6 +152,7 @@ function renderSlowPinsSummary( function getStateColor(state: NodeState): string { if (typeof state === 'string') { switch (state) { + case 'Creating': case 'Initializing': return 'var(--sk-status-initializing)'; case 'Running': @@ -201,6 +202,8 @@ function getStateLabel(state: NodeState): string { function getStateDescription(state: NodeState): string { if (typeof state === 'string') { switch (state) { + case 'Creating': + return 'Node is being created (loading resources)'; case 'Initializing': return 'Node is starting up and performing initialization'; case 'Running': diff --git a/ui/src/types/generated/api-types.ts b/ui/src/types/generated/api-types.ts index dba2147f..61aed408 100644 --- a/ui/src/types/generated/api-types.ts +++ b/ui/src/types/generated/api-types.ts @@ -101,7 +101,7 @@ bidirectional: boolean, }; export type StopReason = "completed" | "input_closed" | "output_closed" | "shutdown" | "no_inputs" | "unknown"; -export type NodeState = "Initializing" | "Ready" | "Running" | { "Recovering": { reason: string, details: JsonValue, } } | { "Degraded": { reason: string, details: JsonValue, } } | { "Failed": { reason: string, } } | { "Stopped": { reason: StopReason, } }; +export type NodeState = "Creating" | "Initializing" | "Ready" | "Running" | { "Recovering": { reason: string, details: JsonValue, } } | { "Degraded": { reason: string, details: JsonValue, } } | { "Failed": { reason: string, } } | { "Stopped": { reason: StopReason, } }; export type NodeStats = { /** diff --git a/ui/src/utils/sessionStatus.ts b/ui/src/utils/sessionStatus.ts index cd5c41d4..3548edaf 100644 --- a/ui/src/utils/sessionStatus.ts +++ b/ui/src/utils/sessionStatus.ts @@ -54,8 +54,8 @@ export function computeSessionStatus(nodeStates: Record): Ses return 'recovering'; } - // Check for initializing - if (states.some((state) => state === 'Initializing')) { + // Check for creating or initializing + if (states.some((state) => state === 'Creating' || state === 'Initializing')) { return 'initializing'; }