diff --git a/apps/skit/src/websocket_handlers.rs b/apps/skit/src/websocket_handlers.rs index ee7af00c..6c1b5f26 100644 --- a/apps/skit/src/websocket_handlers.rs +++ b/apps/skit/src/websocket_handlers.rs @@ -13,7 +13,8 @@ use crate::session::Session; use crate::state::{AppState, BroadcastEvent}; use opentelemetry::global; use streamkit_api::{ - Event as ApiEvent, EventPayload, MessageType, RequestPayload, ResponsePayload, + Event as ApiEvent, EventPayload, MessageType, RequestPayload, ResponsePayload, ValidationError, + ValidationErrorType, }; use streamkit_core::control::{EngineControlMessage, NodeControlMessage}; use streamkit_core::registry::NodeDefinition; @@ -34,6 +35,76 @@ fn can_access_session(session: &Session, role_name: &str, perms: &Permissions) - session.created_by.as_ref().is_none_or(|creator| creator == role_name) } +/// Validate a single AddNode operation against permission and security rules. +/// +/// Returns `Some(error_message)` if the operation is not allowed, `None` if it passes. +/// This is the single source of truth for AddNode validation, used by `handle_add_node`, +/// `handle_validate_batch`, and `handle_apply_batch`. +fn validate_add_node_op( + kind: &str, + params: Option<&serde_json::Value>, + perms: &Permissions, + security_config: &crate::config::SecurityConfig, +) -> Option { + // Reject oneshot-only marker nodes on the dynamic control plane. + if kind == "streamkit::http_input" || kind == "streamkit::http_output" { + return Some(format!( + "Node type '{kind}' is oneshot-only and cannot be used in dynamic sessions" + )); + } + + // Check if the node type is allowed. + if !perms.is_node_allowed(kind) { + return Some(format!("Permission denied: node type '{kind}' not allowed")); + } + + // If this is a plugin node, enforce the plugin allowlist too. + if kind.starts_with("plugin::") && !perms.is_plugin_allowed(kind) { + return Some(format!("Permission denied: plugin '{kind}' not allowed")); + } + + // Security: validate file_reader paths. + if kind == "core::file_reader" { + let Some(path) = params.and_then(|p| p.get("path")).and_then(serde_json::Value::as_str) + else { + return Some( + "Invalid file_reader params: expected params.path to be a string".to_string(), + ); + }; + if let Err(e) = file_security::validate_file_path(path, security_config) { + return Some(format!("Invalid file path: {e}")); + } + } + + // Security: validate file_writer paths. + if kind == "core::file_writer" { + let Some(path) = params.and_then(|p| p.get("path")).and_then(serde_json::Value::as_str) + else { + return Some( + "Invalid file_writer params: expected params.path to be a string".to_string(), + ); + }; + if let Err(e) = file_security::validate_write_path(path, security_config) { + return Some(format!("Invalid write path: {e}")); + } + } + + // Security: validate script_path (if present) for core::script nodes. + if kind == "core::script" { + if let Some(path) = + params.and_then(|p| p.get("script_path")).and_then(serde_json::Value::as_str) + { + if !path.trim().is_empty() { + if let Err(e) = file_security::validate_file_path(path, security_config) { + return Some(format!("Invalid script_path: {e}")); + } + } + } + } + + None +} + pub async fn handle_request_payload( payload: RequestPayload, app_state: &AppState, @@ -77,8 +148,8 @@ pub async fn handle_request_payload( RequestPayload::GetPipeline { session_id } => { handle_get_pipeline(session_id, app_state, perms, role_name).await }, - RequestPayload::ValidateBatch { session_id: _, operations } => { - Some(handle_validate_batch(&operations, app_state, perms)) + RequestPayload::ValidateBatch { session_id, operations } => { + Some(handle_validate_batch(session_id, &operations, app_state, perms, role_name).await) }, RequestPayload::ApplyBatch { session_id, operations } => { handle_apply_batch(session_id, operations, app_state, perms, role_name).await @@ -429,73 +500,10 @@ async fn handle_add_node( }); } - // Reject oneshot-only marker nodes on the dynamic control plane. - if kind == "streamkit::http_input" || kind == "streamkit::http_output" { - return Some(ResponsePayload::Error { - message: format!( - "Node type '{kind}' is oneshot-only and cannot be used in dynamic sessions" - ), - }); - } - - // Check if the node type is allowed - if !perms.is_node_allowed(&kind) { - return Some(ResponsePayload::Error { - message: format!("Permission denied: node type '{kind}' not allowed"), - }); - } - - // If this is a plugin node, enforce the plugin allowlist too. - if kind.starts_with("plugin::") && !perms.is_plugin_allowed(&kind) { - return Some(ResponsePayload::Error { - message: format!("Permission denied: plugin '{kind}' not allowed"), - }); - } - - // Security: validate file_reader paths on the control plane too (not just oneshot/HTTP). - if kind == "core::file_reader" { - let Some(path) = - params.as_ref().and_then(|p| p.get("path")).and_then(serde_json::Value::as_str) - else { - return Some(ResponsePayload::Error { - message: "Invalid file_reader params: expected params.path to be a string" - .to_string(), - }); - }; - if let Err(e) = file_security::validate_file_path(path, &app_state.config.security) { - return Some(ResponsePayload::Error { message: format!("Invalid file path: {e}") }); - } - } - - // Security: validate file_writer paths on the control plane too (avoid arbitrary file writes). - if kind == "core::file_writer" { - let Some(path) = - params.as_ref().and_then(|p| p.get("path")).and_then(serde_json::Value::as_str) - else { - return Some(ResponsePayload::Error { - message: "Invalid file_writer params: expected params.path to be a string" - .to_string(), - }); - }; - if let Err(e) = file_security::validate_write_path(path, &app_state.config.security) { - return Some(ResponsePayload::Error { message: format!("Invalid write path: {e}") }); - } - } - - // Security: validate script_path (if present) for core::script nodes. - if kind == "core::script" { - if let Some(path) = - params.as_ref().and_then(|p| p.get("script_path")).and_then(serde_json::Value::as_str) - { - if !path.trim().is_empty() { - if let Err(e) = file_security::validate_file_path(path, &app_state.config.security) - { - return Some(ResponsePayload::Error { - message: format!("Invalid script_path: {e}"), - }); - } - } - } + if let Some(message) = + validate_add_node_op(&kind, params.as_ref(), perms, &app_state.config.security) + { + return Some(ResponsePayload::Error { message }); } // Get session with SHORT lock hold to avoid blocking other operations @@ -1120,10 +1128,12 @@ async fn handle_get_pipeline( Some(ResponsePayload::Pipeline { pipeline: Box::new(api_pipeline) }) } -fn handle_validate_batch( +async fn handle_validate_batch( + session_id: String, operations: &[streamkit_api::BatchOperation], app_state: &AppState, perms: &Permissions, + role_name: &str, ) -> ResponsePayload { // Validate that user has permission for modify_sessions if !perms.modify_sessions { @@ -1132,67 +1142,73 @@ fn handle_validate_batch( }; } - // Basic validation: check that all referenced node types are allowed - for op in operations { - if let streamkit_api::BatchOperation::AddNode { kind, params, .. } = op { - if !perms.is_node_allowed(kind) { - return ResponsePayload::Error { - message: format!("Permission denied: node type '{kind}' not allowed"), - }; - } + // Verify session exists + let session = { + let session_manager = app_state.session_manager.lock().await; + session_manager.get_session_by_name_or_id(&session_id) + }; - if kind == "core::file_reader" { - let path = - params.as_ref().and_then(|p| p.get("path")).and_then(serde_json::Value::as_str); - let Some(path) = path else { - return ResponsePayload::Error { - message: "Invalid file_reader params: expected params.path to be a string" - .to_string(), - }; - }; - if let Err(e) = file_security::validate_file_path(path, &app_state.config.security) - { - return ResponsePayload::Error { message: format!("Invalid file path: {e}") }; - } - } + let Some(session) = session else { + return ResponsePayload::Error { message: format!("Session '{session_id}' not found") }; + }; - if kind == "core::file_writer" { - let path = - params.as_ref().and_then(|p| p.get("path")).and_then(serde_json::Value::as_str); - let Some(path) = path else { - return ResponsePayload::Error { - message: "Invalid file_writer params: expected params.path to be a string" - .to_string(), - }; - }; - if let Err(e) = file_security::validate_write_path(path, &app_state.config.security) - { - return ResponsePayload::Error { message: format!("Invalid write path: {e}") }; - } - } + // Check ownership + if !can_access_session(&session, role_name, perms) { + return ResponsePayload::Error { + message: "Permission denied: you do not own this session".to_string(), + }; + } - if kind == "core::script" { - if let Some(path) = params - .as_ref() - .and_then(|p| p.get("script_path")) - .and_then(serde_json::Value::as_str) - { - if !path.trim().is_empty() { - if let Err(e) = - file_security::validate_file_path(path, &app_state.config.security) - { - return ResponsePayload::Error { - message: format!("Invalid script_path: {e}"), - }; - } - } + // Collect all validation errors so the caller sees every problem at once. + let mut errors: Vec = Vec::new(); + + // Pre-validate duplicate node_ids against the pipeline model, mirroring + // the same simulation that handle_apply_batch performs. + let mut live_ids: std::collections::HashSet = + session.pipeline.lock().await.nodes.keys().cloned().collect(); + for op in operations { + match op { + streamkit_api::BatchOperation::AddNode { node_id, .. } => { + if !live_ids.insert(node_id.clone()) { + errors.push(ValidationError { + error_type: ValidationErrorType::Error, + message: format!( + "Batch rejected: node '{node_id}' already exists in the pipeline" + ), + node_id: Some(node_id.clone()), + connection_id: None, + }); } + }, + streamkit_api::BatchOperation::RemoveNode { node_id } => { + live_ids.remove(node_id.as_str()); + }, + _ => {}, + } + } + + // Validate all AddNode operations against permission and security rules. + for op in operations { + if let streamkit_api::BatchOperation::AddNode { node_id, kind, params, .. } = op { + if let Some(message) = + validate_add_node_op(kind, params.as_ref(), perms, &app_state.config.security) + { + errors.push(ValidationError { + error_type: ValidationErrorType::Error, + message, + node_id: Some(node_id.clone()), + connection_id: None, + }); } } } - info!(operation_count = operations.len(), "Validated batch operations"); - ResponsePayload::ValidationResult { errors: Vec::new() } + info!( + operation_count = operations.len(), + error_count = errors.len(), + "Validated batch operations" + ); + ResponsePayload::ValidationResult { errors } } #[allow(clippy::significant_drop_tightening)] @@ -1256,65 +1272,13 @@ async fn handle_apply_batch( } } // Pipeline lock released after pre-validation - // Validate permissions for all operations + // Validate permissions for all operations. for op in &operations { if let streamkit_api::BatchOperation::AddNode { kind, params, .. } = op { - if !perms.is_node_allowed(kind) { - return Some(ResponsePayload::Error { - message: format!("Permission denied: node type '{kind}' not allowed"), - }); - } - - if kind == "core::file_reader" { - let path = - params.as_ref().and_then(|p| p.get("path")).and_then(serde_json::Value::as_str); - let Some(path) = path else { - return Some(ResponsePayload::Error { - message: "Invalid file_reader params: expected params.path to be a string" - .to_string(), - }); - }; - if let Err(e) = file_security::validate_file_path(path, &app_state.config.security) - { - return Some(ResponsePayload::Error { - message: format!("Invalid file path: {e}"), - }); - } - } - - if kind == "core::file_writer" { - let path = - params.as_ref().and_then(|p| p.get("path")).and_then(serde_json::Value::as_str); - let Some(path) = path else { - return Some(ResponsePayload::Error { - message: "Invalid file_writer params: expected params.path to be a string" - .to_string(), - }); - }; - if let Err(e) = file_security::validate_write_path(path, &app_state.config.security) - { - return Some(ResponsePayload::Error { - message: format!("Invalid write path: {e}"), - }); - } - } - - if kind == "core::script" { - if let Some(path) = params - .as_ref() - .and_then(|p| p.get("script_path")) - .and_then(serde_json::Value::as_str) - { - if !path.trim().is_empty() { - if let Err(e) = - file_security::validate_file_path(path, &app_state.config.security) - { - return Some(ResponsePayload::Error { - message: format!("Invalid script_path: {e}"), - }); - } - } - } + if let Some(message) = + validate_add_node_op(kind, params.as_ref(), perms, &app_state.config.security) + { + return Some(ResponsePayload::Error { message }); } } } diff --git a/apps/skit/tests/batch_validation_test.rs b/apps/skit/tests/batch_validation_test.rs new file mode 100644 index 00000000..fcca337c --- /dev/null +++ b/apps/skit/tests/batch_validation_test.rs @@ -0,0 +1,855 @@ +// SPDX-FileCopyrightText: © 2025 StreamKit Contributors +// +// SPDX-License-Identifier: MPL-2.0 + +#![allow( + clippy::unwrap_used, + clippy::expect_used, + clippy::disallowed_macros, + clippy::uninlined_format_args +)] + +use futures_util::{SinkExt, StreamExt}; +use serde_json::json; +use std::collections::HashMap; +use std::net::SocketAddr; +use streamkit_api::{ + BatchOperation, MessageType, Request, RequestPayload, Response, ResponsePayload, +}; +use streamkit_server::Config; +use tokio::net::TcpListener; +use tokio::time::{timeout, Duration}; +use tokio_tungstenite::{ + connect_async, + tungstenite::{client::IntoClientRequest, Message as WsMessage}, +}; + +// Type aliases to reduce verbosity of the fully-expanded WebSocket stream types. +type WsStream = + tokio_tungstenite::WebSocketStream>; +type WsWriter = futures_util::stream::SplitSink; +type WsReader = futures_util::stream::SplitStream; + +/// Helper to read messages from WebSocket, skipping events until we get a response with matching correlation_id +async fn read_response(read: &mut WsReader, expected_correlation_id: &str) -> Response { + loop { + let message = timeout(Duration::from_secs(5), read.next()) + .await + .expect("Timeout waiting for response") + .expect("No message received") + .expect("Failed to read message"); + + let text = message.to_text().expect("Expected text message"); + + let value: serde_json::Value = serde_json::from_str(text).expect("Failed to parse message"); + let msg_type = value.get("type").and_then(|v| v.as_str()); + + if msg_type == Some("event") { + continue; + } + + let response: Response = serde_json::from_str(text).expect("Failed to parse response"); + + if response.correlation_id.as_deref() == Some(expected_correlation_id) { + return response; + } + } +} + +async fn start_test_server() -> Option<(SocketAddr, tokio::task::JoinHandle<()>)> { + start_test_server_with_config(Config::default()).await +} + +async fn start_test_server_with_config( + config: Config, +) -> Option<(SocketAddr, tokio::task::JoinHandle<()>)> { + let listener = match TcpListener::bind("127.0.0.1:0").await { + Ok(listener) => listener, + Err(e) if e.kind() == std::io::ErrorKind::PermissionDenied => return None, + Err(e) => panic!("Failed to bind test server listener: {e}"), + }; + let addr = listener.local_addr().unwrap(); + + let server_handle = tokio::spawn(async move { + let (app, _state) = streamkit_server::server::create_app(config, None); + axum::serve(listener, app.into_make_service()).await.unwrap(); + }); + + tokio::time::sleep(Duration::from_millis(100)).await; + + Some((addr, server_handle)) +} + +/// Helper: connect to WS, create a session, and return (write, read, session_id). +async fn setup_session(addr: SocketAddr) -> (WsWriter, WsReader, String) { + let ws_url = format!("ws://{}/api/v1/control", addr); + let (ws_stream, _) = connect_async(&ws_url).await.expect("Failed to connect to WebSocket"); + let (mut write, mut read) = ws_stream.split(); + + let create_request = Request { + message_type: MessageType::Request, + correlation_id: Some("setup-create".to_string()), + payload: RequestPayload::CreateSession { name: Some("batch-test".to_string()) }, + }; + + write + .send(WsMessage::Text(serde_json::to_string(&create_request).unwrap().into())) + .await + .unwrap(); + + let response = read_response(&mut read, "setup-create").await; + let session_id = match response.payload { + ResponsePayload::SessionCreated { session_id, .. } => session_id, + other => panic!("Expected SessionCreated, got: {:?}", other), + }; + + (write, read, session_id) +} + +/// Helper: send a ValidateBatch request and return the response payload. +async fn send_validate_batch( + write: &mut WsWriter, + read: &mut WsReader, + session_id: &str, + operations: Vec, + correlation_id: &str, +) -> ResponsePayload { + let request = Request { + message_type: MessageType::Request, + correlation_id: Some(correlation_id.to_string()), + payload: RequestPayload::ValidateBatch { session_id: session_id.to_string(), operations }, + }; + + write.send(WsMessage::Text(serde_json::to_string(&request).unwrap().into())).await.unwrap(); + + read_response(read, correlation_id).await.payload +} + +/// Helper: send an ApplyBatch request and return the response payload. +async fn send_apply_batch( + write: &mut WsWriter, + read: &mut WsReader, + session_id: &str, + operations: Vec, + correlation_id: &str, +) -> ResponsePayload { + let request = Request { + message_type: MessageType::Request, + correlation_id: Some(correlation_id.to_string()), + payload: RequestPayload::ApplyBatch { session_id: session_id.to_string(), operations }, + }; + + write.send(WsMessage::Text(serde_json::to_string(&request).unwrap().into())).await.unwrap(); + + read_response(read, correlation_id).await.payload +} + +/// Build a Config whose default role has an empty plugin allowlist, so +/// `plugin::*` nodes are rejected. +fn config_with_no_plugins_allowed() -> Config { + use streamkit_server::{Permissions, PermissionsConfig}; + + let mut restricted = Permissions::admin(); + restricted.allowed_plugins = Vec::new(); // deny all plugins + + let mut roles = HashMap::new(); + roles.insert("admin".to_string(), restricted); + + Config { + permissions: PermissionsConfig { roles, ..PermissionsConfig::default() }, + ..Config::default() + } +} + +/// Build a Config with a trusted role header so tests can select roles per connection. +fn config_with_role_header() -> Config { + use streamkit_server::PermissionsConfig; + + Config { + permissions: PermissionsConfig { + role_header: Some("x-role".to_string()), + ..PermissionsConfig::default() + }, + ..Config::default() + } +} + +/// Connect to the WS control endpoint with a custom role header. +async fn connect_with_role(addr: SocketAddr, role: &str) -> (WsWriter, WsReader) { + let mut request = format!("ws://{addr}/api/v1/control") + .into_client_request() + .expect("Failed to build WS request"); + request.headers_mut().insert("x-role", role.parse().unwrap()); + let (ws_stream, _) = connect_async(request).await.expect("Failed to connect to WebSocket"); + ws_stream.split() +} + +// --------------------------------------------------------------------------- +// ValidateBatch tests +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn test_validate_batch_rejects_http_input_node() { + let _ = tracing_subscriber::fmt::try_init(); + + let Some((addr, _server_handle)) = start_test_server().await else { + eprintln!("Skipping: local TCP bind not permitted"); + return; + }; + + let (mut write, mut read, session_id) = setup_session(addr).await; + + let payload = send_validate_batch( + &mut write, + &mut read, + &session_id, + vec![BatchOperation::AddNode { + node_id: "http_in".to_string(), + kind: "streamkit::http_input".to_string(), + params: None, + }], + "validate-http-input", + ) + .await; + + match payload { + ResponsePayload::ValidationResult { errors } => { + assert_eq!(errors.len(), 1, "Expected exactly one validation error"); + assert!( + errors[0].message.contains("oneshot-only"), + "Expected oneshot-only error, got: {}", + errors[0].message + ); + assert_eq!(errors[0].node_id.as_deref(), Some("http_in")); + }, + other => panic!("Expected ValidationResult for http_input, got: {:?}", other), + } +} + +#[tokio::test] +async fn test_validate_batch_rejects_http_output_node() { + let _ = tracing_subscriber::fmt::try_init(); + + let Some((addr, _server_handle)) = start_test_server().await else { + eprintln!("Skipping: local TCP bind not permitted"); + return; + }; + + let (mut write, mut read, session_id) = setup_session(addr).await; + + let payload = send_validate_batch( + &mut write, + &mut read, + &session_id, + vec![BatchOperation::AddNode { + node_id: "http_out".to_string(), + kind: "streamkit::http_output".to_string(), + params: None, + }], + "validate-http-output", + ) + .await; + + match payload { + ResponsePayload::ValidationResult { errors } => { + assert_eq!(errors.len(), 1, "Expected exactly one validation error"); + assert!( + errors[0].message.contains("oneshot-only"), + "Expected oneshot-only error, got: {}", + errors[0].message + ); + assert_eq!(errors[0].node_id.as_deref(), Some("http_out")); + }, + other => panic!("Expected ValidationResult for http_output, got: {:?}", other), + } +} + +#[tokio::test] +async fn test_validate_batch_rejects_disallowed_plugin() { + let _ = tracing_subscriber::fmt::try_init(); + + let Some((addr, _server_handle)) = + start_test_server_with_config(config_with_no_plugins_allowed()).await + else { + eprintln!("Skipping: local TCP bind not permitted"); + return; + }; + + let (mut write, mut read, session_id) = setup_session(addr).await; + + let payload = send_validate_batch( + &mut write, + &mut read, + &session_id, + vec![BatchOperation::AddNode { + node_id: "p1".to_string(), + kind: "plugin::native::whisper".to_string(), + params: None, + }], + "validate-disallowed-plugin", + ) + .await; + + match payload { + ResponsePayload::ValidationResult { errors } => { + assert_eq!(errors.len(), 1, "Expected exactly one validation error"); + assert!( + errors[0].message.contains("plugin") && errors[0].message.contains("not allowed"), + "Expected plugin not-allowed error, got: {}", + errors[0].message + ); + assert_eq!(errors[0].node_id.as_deref(), Some("p1")); + }, + other => panic!("Expected ValidationResult for disallowed plugin, got: {:?}", other), + } +} + +#[tokio::test] +async fn test_validate_batch_allows_valid_node() { + let _ = tracing_subscriber::fmt::try_init(); + + let Some((addr, _server_handle)) = start_test_server().await else { + eprintln!("Skipping: local TCP bind not permitted"); + return; + }; + + let (mut write, mut read, session_id) = setup_session(addr).await; + + let payload = send_validate_batch( + &mut write, + &mut read, + &session_id, + vec![BatchOperation::AddNode { + node_id: "gain1".to_string(), + kind: "audio::gain".to_string(), + params: Some(json!({"gain": 2.0})), + }], + "validate-valid-node", + ) + .await; + + match payload { + ResponsePayload::ValidationResult { errors } => { + assert!(errors.is_empty(), "Expected no validation errors for valid node"); + }, + other => panic!("Expected ValidationResult for valid node, got: {:?}", other), + } +} + +#[tokio::test] +async fn test_validate_batch_rejects_nonexistent_session() { + let _ = tracing_subscriber::fmt::try_init(); + + let Some((addr, _server_handle)) = start_test_server().await else { + eprintln!("Skipping: local TCP bind not permitted"); + return; + }; + + let ws_url = format!("ws://{}/api/v1/control", addr); + let (ws_stream, _) = connect_async(&ws_url).await.expect("Failed to connect to WebSocket"); + let (mut write, mut read) = ws_stream.split(); + + let payload = send_validate_batch( + &mut write, + &mut read, + "nonexistent-session-id", + vec![BatchOperation::AddNode { + node_id: "gain1".to_string(), + kind: "audio::gain".to_string(), + params: None, + }], + "validate-no-session", + ) + .await; + + match payload { + ResponsePayload::Error { message } => { + assert!( + message.contains("not found"), + "Expected session not-found error, got: {message}" + ); + }, + other => panic!("Expected Error for nonexistent session, got: {:?}", other), + } +} + +#[tokio::test] +async fn test_validate_batch_rejects_mixed_with_oneshot_node() { + let _ = tracing_subscriber::fmt::try_init(); + + let Some((addr, _server_handle)) = start_test_server().await else { + eprintln!("Skipping: local TCP bind not permitted"); + return; + }; + + let (mut write, mut read, session_id) = setup_session(addr).await; + + let payload = send_validate_batch( + &mut write, + &mut read, + &session_id, + vec![ + BatchOperation::AddNode { + node_id: "gain1".to_string(), + kind: "audio::gain".to_string(), + params: Some(json!({"gain": 1.0})), + }, + BatchOperation::AddNode { + node_id: "http_in".to_string(), + kind: "streamkit::http_input".to_string(), + params: None, + }, + ], + "validate-mixed", + ) + .await; + + match payload { + ResponsePayload::ValidationResult { errors } => { + assert_eq!( + errors.len(), + 1, + "Expected exactly one validation error for the oneshot node" + ); + assert!( + errors[0].message.contains("oneshot-only"), + "Expected oneshot-only error in mixed batch, got: {}", + errors[0].message + ); + assert_eq!(errors[0].node_id.as_deref(), Some("http_in")); + }, + other => { + panic!("Expected ValidationResult for mixed batch with oneshot node, got: {:?}", other) + }, + } +} + +#[tokio::test] +async fn test_validate_batch_rejects_duplicate_node_id() { + let _ = tracing_subscriber::fmt::try_init(); + + let Some((addr, _server_handle)) = start_test_server().await else { + eprintln!("Skipping: local TCP bind not permitted"); + return; + }; + + let (mut write, mut read, session_id) = setup_session(addr).await; + + // Two AddNode ops with the same node_id should trigger a duplicate error. + let payload = send_validate_batch( + &mut write, + &mut read, + &session_id, + vec![ + BatchOperation::AddNode { + node_id: "dup1".to_string(), + kind: "audio::gain".to_string(), + params: None, + }, + BatchOperation::AddNode { + node_id: "dup1".to_string(), + kind: "audio::gain".to_string(), + params: None, + }, + ], + "validate-dup-node-id", + ) + .await; + + match payload { + ResponsePayload::ValidationResult { errors } => { + assert!(!errors.is_empty(), "Expected at least one error for duplicate node_id"); + assert!( + errors.iter().any(|e| e.message.contains("already exists")), + "Expected duplicate node_id error, got: {:?}", + errors.iter().map(|e| &e.message).collect::>() + ); + }, + other => panic!("Expected ValidationResult for duplicate node_id, got: {:?}", other), + } +} + +#[tokio::test] +async fn test_validate_batch_reports_all_errors() { + let _ = tracing_subscriber::fmt::try_init(); + + let Some((addr, _server_handle)) = start_test_server().await else { + eprintln!("Skipping: local TCP bind not permitted"); + return; + }; + + let (mut write, mut read, session_id) = setup_session(addr).await; + + // Two invalid nodes — both errors should be reported, not just the first. + let payload = send_validate_batch( + &mut write, + &mut read, + &session_id, + vec![ + BatchOperation::AddNode { + node_id: "http_in".to_string(), + kind: "streamkit::http_input".to_string(), + params: None, + }, + BatchOperation::AddNode { + node_id: "http_out".to_string(), + kind: "streamkit::http_output".to_string(), + params: None, + }, + ], + "validate-all-errors", + ) + .await; + + match payload { + ResponsePayload::ValidationResult { errors } => { + assert_eq!(errors.len(), 2, "Expected two validation errors, got {}", errors.len()); + assert!( + errors.iter().all(|e| e.message.contains("oneshot-only")), + "Expected both errors to be oneshot-only, got: {:?}", + errors.iter().map(|e| &e.message).collect::>() + ); + }, + other => panic!("Expected ValidationResult with two errors, got: {:?}", other), + } +} + +#[tokio::test] +async fn test_validate_batch_rejects_cross_role_ownership() { + let _ = tracing_subscriber::fmt::try_init(); + + let Some((addr, _server_handle)) = + start_test_server_with_config(config_with_role_header()).await + else { + eprintln!("Skipping: local TCP bind not permitted"); + return; + }; + + // Connect as "admin" and create a session. + let (mut admin_write, mut admin_read) = connect_with_role(addr, "admin").await; + let create_request = Request { + message_type: MessageType::Request, + correlation_id: Some("admin-create".to_string()), + payload: RequestPayload::CreateSession { name: Some("admin-session".to_string()) }, + }; + admin_write + .send(WsMessage::Text(serde_json::to_string(&create_request).unwrap().into())) + .await + .unwrap(); + let response = read_response(&mut admin_read, "admin-create").await; + let session_id = match response.payload { + ResponsePayload::SessionCreated { session_id, .. } => session_id, + other => panic!("Expected SessionCreated, got: {:?}", other), + }; + + // Connect as "user" (access_all_sessions = false) and try to validate on + // the admin's session. + let (mut user_write, mut user_read) = connect_with_role(addr, "user").await; + let payload = send_validate_batch( + &mut user_write, + &mut user_read, + &session_id, + vec![BatchOperation::AddNode { + node_id: "gain1".to_string(), + kind: "audio::gain".to_string(), + params: None, + }], + "user-validate-admin-session", + ) + .await; + + match payload { + ResponsePayload::Error { message } => { + assert!( + message.contains("Permission denied") || message.contains("not found"), + "Expected ownership/permission error, got: {message}" + ); + }, + other => { + panic!("Expected Error for cross-role ownership in ValidateBatch, got: {:?}", other) + }, + } +} + +// --------------------------------------------------------------------------- +// ApplyBatch tests +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn test_apply_batch_rejects_http_input_node() { + let _ = tracing_subscriber::fmt::try_init(); + + let Some((addr, _server_handle)) = start_test_server().await else { + eprintln!("Skipping: local TCP bind not permitted"); + return; + }; + + let (mut write, mut read, session_id) = setup_session(addr).await; + + let payload = send_apply_batch( + &mut write, + &mut read, + &session_id, + vec![BatchOperation::AddNode { + node_id: "http_in".to_string(), + kind: "streamkit::http_input".to_string(), + params: None, + }], + "apply-http-input", + ) + .await; + + match payload { + ResponsePayload::Error { message } => { + assert!( + message.contains("oneshot-only"), + "Expected oneshot-only error, got: {message}" + ); + }, + other => panic!("Expected Error for http_input in apply, got: {:?}", other), + } +} + +#[tokio::test] +async fn test_apply_batch_rejects_http_output_node() { + let _ = tracing_subscriber::fmt::try_init(); + + let Some((addr, _server_handle)) = start_test_server().await else { + eprintln!("Skipping: local TCP bind not permitted"); + return; + }; + + let (mut write, mut read, session_id) = setup_session(addr).await; + + let payload = send_apply_batch( + &mut write, + &mut read, + &session_id, + vec![BatchOperation::AddNode { + node_id: "http_out".to_string(), + kind: "streamkit::http_output".to_string(), + params: None, + }], + "apply-http-output", + ) + .await; + + match payload { + ResponsePayload::Error { message } => { + assert!( + message.contains("oneshot-only"), + "Expected oneshot-only error, got: {message}" + ); + }, + other => panic!("Expected Error for http_output in apply, got: {:?}", other), + } +} + +#[tokio::test] +async fn test_apply_batch_rejects_disallowed_plugin() { + let _ = tracing_subscriber::fmt::try_init(); + + let Some((addr, _server_handle)) = + start_test_server_with_config(config_with_no_plugins_allowed()).await + else { + eprintln!("Skipping: local TCP bind not permitted"); + return; + }; + + let (mut write, mut read, session_id) = setup_session(addr).await; + + let payload = send_apply_batch( + &mut write, + &mut read, + &session_id, + vec![BatchOperation::AddNode { + node_id: "p1".to_string(), + kind: "plugin::native::whisper".to_string(), + params: None, + }], + "apply-disallowed-plugin", + ) + .await; + + match payload { + ResponsePayload::Error { message } => { + assert!( + message.contains("plugin") && message.contains("not allowed"), + "Expected plugin not-allowed error, got: {message}" + ); + }, + other => panic!("Expected Error for disallowed plugin in apply, got: {:?}", other), + } +} + +#[tokio::test] +async fn test_apply_batch_allows_valid_node() { + let _ = tracing_subscriber::fmt::try_init(); + + let Some((addr, _server_handle)) = start_test_server().await else { + eprintln!("Skipping: local TCP bind not permitted"); + return; + }; + + let (mut write, mut read, session_id) = setup_session(addr).await; + + let payload = send_apply_batch( + &mut write, + &mut read, + &session_id, + vec![BatchOperation::AddNode { + node_id: "gain1".to_string(), + kind: "audio::gain".to_string(), + params: Some(json!({"gain": 2.0})), + }], + "apply-valid-node", + ) + .await; + + match payload { + ResponsePayload::BatchApplied { success, errors } => { + assert!(success, "Expected batch apply to succeed"); + assert!(errors.is_empty(), "Expected no errors from batch apply"); + }, + ResponsePayload::Error { message } => { + panic!("Unexpected error for valid node in apply: {message}"); + }, + other => panic!("Expected BatchApplied for valid node, got: {:?}", other), + } +} + +#[tokio::test] +async fn test_apply_batch_rejects_mixed_with_oneshot_node() { + let _ = tracing_subscriber::fmt::try_init(); + + let Some((addr, _server_handle)) = start_test_server().await else { + eprintln!("Skipping: local TCP bind not permitted"); + return; + }; + + let (mut write, mut read, session_id) = setup_session(addr).await; + + let payload = send_apply_batch( + &mut write, + &mut read, + &session_id, + vec![ + BatchOperation::AddNode { + node_id: "gain1".to_string(), + kind: "audio::gain".to_string(), + params: Some(json!({"gain": 1.0})), + }, + BatchOperation::AddNode { + node_id: "http_in".to_string(), + kind: "streamkit::http_input".to_string(), + params: None, + }, + ], + "apply-mixed", + ) + .await; + + match payload { + ResponsePayload::Error { message } => { + assert!( + message.contains("oneshot-only"), + "Expected oneshot-only error in mixed batch, got: {message}" + ); + }, + other => { + panic!("Expected Error for mixed batch with oneshot node in apply, got: {:?}", other) + }, + } +} + +#[tokio::test] +async fn test_apply_batch_rejects_nonexistent_session() { + let _ = tracing_subscriber::fmt::try_init(); + + let Some((addr, _server_handle)) = start_test_server().await else { + eprintln!("Skipping: local TCP bind not permitted"); + return; + }; + + let ws_url = format!("ws://{}/api/v1/control", addr); + let (ws_stream, _) = connect_async(&ws_url).await.expect("Failed to connect to WebSocket"); + let (mut write, mut read) = ws_stream.split(); + + let payload = send_apply_batch( + &mut write, + &mut read, + "nonexistent-session-id", + vec![BatchOperation::AddNode { + node_id: "gain1".to_string(), + kind: "audio::gain".to_string(), + params: None, + }], + "apply-nonexistent-session", + ) + .await; + + match payload { + ResponsePayload::Error { message } => { + assert!( + message.contains("not found"), + "Expected session not-found error, got: {message}" + ); + }, + other => panic!("Expected Error for nonexistent session in ApplyBatch, got: {:?}", other), + } +} + +#[tokio::test] +async fn test_apply_batch_rejects_cross_role_ownership() { + let _ = tracing_subscriber::fmt::try_init(); + + let Some((addr, _server_handle)) = + start_test_server_with_config(config_with_role_header()).await + else { + eprintln!("Skipping: local TCP bind not permitted"); + return; + }; + + // Connect as "admin" and create a session. + let (mut admin_write, mut admin_read) = connect_with_role(addr, "admin").await; + let create_request = Request { + message_type: MessageType::Request, + correlation_id: Some("admin-create".to_string()), + payload: RequestPayload::CreateSession { name: Some("admin-session".to_string()) }, + }; + admin_write + .send(WsMessage::Text(serde_json::to_string(&create_request).unwrap().into())) + .await + .unwrap(); + let response = read_response(&mut admin_read, "admin-create").await; + let session_id = match response.payload { + ResponsePayload::SessionCreated { session_id, .. } => session_id, + other => panic!("Expected SessionCreated, got: {:?}", other), + }; + + // Connect as "user" (access_all_sessions = false) and try to apply on + // the admin's session. + let (mut user_write, mut user_read) = connect_with_role(addr, "user").await; + let payload = send_apply_batch( + &mut user_write, + &mut user_read, + &session_id, + vec![BatchOperation::AddNode { + node_id: "gain1".to_string(), + kind: "audio::gain".to_string(), + params: None, + }], + "user-apply-admin-session", + ) + .await; + + match payload { + ResponsePayload::Error { message } => { + assert!( + message.contains("Permission denied") || message.contains("not found"), + "Expected ownership/permission error, got: {message}" + ); + }, + other => panic!("Expected Error for cross-role ownership in ApplyBatch, got: {:?}", other), + } +}