diff --git a/Cargo.lock b/Cargo.lock index 52eb86b6..78e246af 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -131,7 +131,7 @@ version = "1.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc" dependencies = [ - "windows-sys 0.61.2", + "windows-sys 0.60.2", ] [[package]] @@ -142,7 +142,7 @@ checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d" dependencies = [ "anstyle", "once_cell_polyfill", - "windows-sys 0.61.2", + "windows-sys 0.60.2", ] [[package]] @@ -1600,7 +1600,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ "libc", - "windows-sys 0.61.2", + "windows-sys 0.59.0", ] [[package]] @@ -2603,7 +2603,7 @@ checksum = "3640c1c38b8e4e43584d8df18be5fc6b0aa314ce6ebf51b53313d4306cca8e46" dependencies = [ "hermit-abi", "libc", - "windows-sys 0.61.2", + "windows-sys 0.59.0", ] [[package]] @@ -3221,7 +3221,7 @@ version = "0.50.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" dependencies = [ - "windows-sys 0.61.2", + "windows-sys 0.59.0", ] [[package]] @@ -4780,7 +4780,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys 0.12.1", - "windows-sys 0.61.2", + "windows-sys 0.59.0", ] [[package]] @@ -4858,7 +4858,7 @@ dependencies = [ "security-framework", "security-framework-sys", "webpki-root-certs", - "windows-sys 0.61.2", + "windows-sys 0.59.0", ] [[package]] @@ -5321,7 +5321,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" dependencies = [ "libc", - "windows-sys 0.61.2", + "windows-sys 0.60.2", ] [[package]] @@ -5477,6 +5477,8 @@ dependencies = [ "resvg", "rquickjs", "rubato", + "rustls", + "rustls-platform-verifier", "schemars 1.2.1", "serde", "serde-saphyr", @@ -5488,6 +5490,7 @@ dependencies = [ "tempfile", "tiny-skia", "tokio", + "tokio-rustls", "tokio-util 0.7.18", "tower", "tracing", @@ -5930,7 +5933,7 @@ dependencies = [ "getrandom 0.4.2", "once_cell", "rustix 1.1.4", - "windows-sys 0.61.2", + "windows-sys 0.59.0", ] [[package]] @@ -7714,7 +7717,7 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" dependencies = [ - "windows-sys 0.61.2", + "windows-sys 0.59.0", ] [[package]] diff --git a/crates/nodes/Cargo.toml b/crates/nodes/Cargo.toml index 3dcfc7de..e0c277f8 100644 --- a/crates/nodes/Cargo.toml +++ b/crates/nodes/Cargo.toml @@ -87,6 +87,11 @@ ts-rs = { version = "12.0.1", optional = true } # H.264 codec via Cisco's OpenH264 (optional, behind `openh264` feature) openh264 = { version = "0.9", optional = true } +# RTMP publishing (optional, behind `rtmp` feature) +tokio-rustls = { version = "0.26", optional = true } +rustls = { version = "0.23", optional = true, default-features = false, features = ["std"] } +rustls-platform-verifier = { version = "0.6", optional = true } + # AV1 codec (optional, behind `av1` feature) rav1e = { version = "0.8", optional = true, default-features = false, features = ["threading", "asm"] } rav1d = { version = "1.1", optional = true, default-features = false, features = ["bitdepth_8", "bitdepth_16", "asm"] } @@ -117,6 +122,7 @@ default = [ "video", "mp4", "openh264", + "rtmp", ] # Individual features for each node. @@ -148,6 +154,7 @@ symphonia = ["dep:symphonia", "dep:schemars"] vp9 = ["dep:env-libvpx-sys", "dep:schemars"] av1 = ["dep:rav1e", "dep:rav1d", "dep:schemars"] openh264 = ["dep:openh264", "dep:schemars", "dep:serde_json"] +rtmp = ["dep:tokio-rustls", "dep:rustls", "dep:rustls-platform-verifier", "dep:schemars", "dep:serde_json"] svt_av1 = ["dep:schemars", "dep:serde_json", "dep:pkg-config", "dep:cc"] # svt_av1_static downloads + builds SVT-AV1 at compile time (no system install). # Not in `default` to keep dev builds fast; enabled explicitly in Dockerfiles and diff --git a/crates/nodes/src/transport/mod.rs b/crates/nodes/src/transport/mod.rs index 3be9d2d6..d5d8393e 100644 --- a/crates/nodes/src/transport/mod.rs +++ b/crates/nodes/src/transport/mod.rs @@ -14,6 +14,12 @@ pub mod http; #[cfg(feature = "http")] pub mod http_mse; +#[cfg(feature = "rtmp")] +mod rtmp_client; + +#[cfg(feature = "rtmp")] +pub mod rtmp; + /// Registers all available transport nodes with the engine's registry. pub fn register_transport_nodes(registry: &mut NodeRegistry) { // Call the registration function from each submodule. @@ -24,4 +30,7 @@ pub fn register_transport_nodes(registry: &mut NodeRegistry) { #[cfg(feature = "http")] http_mse::register_http_mse_nodes(registry); + + #[cfg(feature = "rtmp")] + rtmp::register_rtmp_nodes(registry); } diff --git a/crates/nodes/src/transport/rtmp.rs b/crates/nodes/src/transport/rtmp.rs new file mode 100644 index 00000000..c5e417ab --- /dev/null +++ b/crates/nodes/src/transport/rtmp.rs @@ -0,0 +1,1611 @@ +// SPDX-FileCopyrightText: © 2025 StreamKit Contributors +// +// SPDX-License-Identifier: MPL-2.0 + +//! RTMP publisher (sink) node. +//! +//! Uses an internal sans-I/O RTMP client (`rtmp_client`) to publish encoded +//! H.264 video and AAC audio to an arbitrary RTMP or RTMPS endpoint +//! (e.g. YouTube Live, Twitch). +//! +//! The node manages the TCP (or TLS) socket itself, feeding bytes between +//! tokio I/O and the client's `feed_recv_buf()` / `send_buf()` interface. + +use super::rtmp_client::{ + AudioFormat as RtmpAudioFormat, AudioFrame as RtmpAudioFrame, AvcPacketType, AvcSequenceHeader, + RtmpConnectionState, RtmpPublishClientConnection, RtmpTimestamp, RtmpTimestampDelta, RtmpUrl, + VideoCodec as RtmpVideoCodec, VideoFrame as RtmpVideoFrame, VideoFrameType, +}; +use async_trait::async_trait; +use opentelemetry::KeyValue; +use schemars::schema_for; +use schemars::JsonSchema; +use serde::Deserialize; +use streamkit_core::stats::NodeStatsTracker; +use streamkit_core::types::{ + AudioCodec, EncodedAudioFormat, EncodedVideoFormat, Packet, PacketType, VideoCodec, +}; +use streamkit_core::{ + config_helpers, registry::StaticPins, state_helpers, InputPin, NodeContext, NodeRegistry, + OutputPin, PinCardinality, ProcessorNode, StreamKitError, +}; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpStream; + +// --------------------------------------------------------------------------- +// Configuration +// --------------------------------------------------------------------------- + +/// Configuration for the RTMP publisher node. +#[derive(Debug, Clone, Deserialize, JsonSchema)] +#[serde(deny_unknown_fields)] +pub struct RtmpPublishConfig { + /// RTMP server URL. + /// + /// Supports `rtmp://` and `rtmps://` (TLS) schemes. + /// Can include the stream key in the path, or use the separate + /// `stream_key` / `stream_key_env` fields. + /// + /// Examples: + /// - `rtmp://a.rtmp.youtube.com/live2` (key via `stream_key` or `stream_key_env`) + /// - `rtmp://a.rtmp.youtube.com/live2/xxxx-xxxx-xxxx-xxxx` (key inline) + /// - `rtmps://live.twitch.tv/app/live_xxxx` + pub url: String, + + /// Stream key appended to the URL path. + /// + /// Optional — if omitted, the URL is used as-is (for URLs that + /// already include the key). Ignored when `stream_key_env` is set. + #[serde(default)] + pub stream_key: Option, + + /// Environment variable name containing the stream key. + /// + /// Read at node startup. Takes precedence over `stream_key`. + /// The name is fully user-controlled, so multiple RTMP output nodes + /// can each reference different variables. + /// + /// Example: `"SKIT_RTMP_STREAM_KEY"` → reads `$SKIT_RTMP_STREAM_KEY`. + #[serde(default)] + pub stream_key_env: Option, + + /// Audio sample rate in Hz for the AAC sequence header. + /// + /// Must match the sample rate produced by the upstream AAC encoder. + /// Common values: 48000, 44100, 32000. + /// Defaults to 48000. + #[serde(default = "default_sample_rate")] + pub sample_rate: u32, + + /// Number of audio channels for the AAC sequence header. + /// + /// Must match the channel count produced by the upstream AAC encoder. + /// 1 = mono, 2 = stereo. + /// Defaults to 2 (stereo). + #[serde(default = "default_channels")] + pub channels: u8, +} + +const fn default_sample_rate() -> u32 { + 48_000 +} + +const fn default_channels() -> u8 { + 2 +} + +// --------------------------------------------------------------------------- +// Node +// --------------------------------------------------------------------------- + +/// RTMP publisher sink node. +/// +/// Accepts encoded H.264 video and AAC audio on separate input pins and +/// publishes them to an RTMP endpoint using the FLV/RTMP wire format. +pub struct RtmpPublishNode { + config: RtmpPublishConfig, +} + +impl RtmpPublishNode { + pub const fn new(config: RtmpPublishConfig) -> Self { + Self { config } + } +} + +// --------------------------------------------------------------------------- +// ProcessorNode implementation +// --------------------------------------------------------------------------- + +#[async_trait] +impl ProcessorNode for RtmpPublishNode { + fn input_pins(&self) -> Vec { + vec![ + InputPin { + name: "video".to_string(), + accepts_types: vec![PacketType::EncodedVideo(EncodedVideoFormat { + codec: VideoCodec::H264, + bitstream_format: None, + codec_private: None, + profile: None, + level: None, + })], + cardinality: PinCardinality::One, + }, + InputPin { + name: "audio".to_string(), + accepts_types: vec![PacketType::EncodedAudio(EncodedAudioFormat { + codec: AudioCodec::Aac, + codec_private: None, + })], + cardinality: PinCardinality::One, + }, + ] + } + + fn output_pins(&self) -> Vec { + // Sink node — no outputs. + vec![] + } + + async fn run(self: Box, mut context: NodeContext) -> Result<(), StreamKitError> { + let node_name = context.output_sender.node_name().to_string(); + + state_helpers::emit_initializing(&context.state_tx, &node_name); + + // ── Validate AAC config ────────────────────────────────────────── + validate_aac_config(&self.config).map_err(|e| { + let msg = format!("Invalid AAC config: {e}"); + state_helpers::emit_failed(&context.state_tx, &node_name, &msg); + StreamKitError::Configuration(msg) + })?; + + // ── Resolve stream key (env var takes precedence) ─────────────── + let full_url = resolve_rtmp_url(&self.config).map_err(|e| { + let msg = format!("RTMP URL resolution failed: {e}"); + state_helpers::emit_failed(&context.state_tx, &node_name, &msg); + StreamKitError::Configuration(msg) + })?; + + // Log without the stream key (it's effectively a bearer token). + let masked_url = mask_stream_key(&full_url); + tracing::info!(%node_name, url = %masked_url, "RtmpPublishNode starting"); + + // ── Parse RTMP URL ────────────────────────────────────────────── + let rtmp_url: RtmpUrl = full_url.parse().map_err(|e| { + StreamKitError::Configuration(format!( + "Invalid RTMP URL '{}': {e}", + mask_stream_key(&full_url) + )) + })?; + + tracing::info!( + %node_name, + host = %rtmp_url.host, port = rtmp_url.port, + app = %rtmp_url.app, tls = rtmp_url.tls, + "Parsed RTMP URL" + ); + + // ── Connect TCP (+ optional TLS) ──────────────────────────────── + let mut stream = connect(&rtmp_url).await.map_err(|e| { + let msg = format!("Failed to connect to RTMP server: {e}"); + state_helpers::emit_failed(&context.state_tx, &node_name, &msg); + StreamKitError::Runtime(msg) + })?; + + tracing::info!(%node_name, "TCP connection established"); + + // ── Create RTMP connection and drive handshake ─────────────────── + let mut connection = RtmpPublishClientConnection::new(rtmp_url); + + drive_handshake(&mut connection, &mut stream, &node_name).await.map_err(|e| { + let msg = format!("RTMP handshake failed: {e}"); + state_helpers::emit_failed(&context.state_tx, &node_name, &msg); + StreamKitError::Runtime(msg) + })?; + + tracing::info!(%node_name, "RTMP connection in Publishing state"); + + state_helpers::emit_running(&context.state_tx, &node_name); + + // ── Obtain input receivers ────────────────────────────────────── + let mut video_rx = context.take_input("video")?; + let mut audio_rx = context.take_input("audio")?; + + // ── Stats / metrics ───────────────────────────────────────────── + let meter = opentelemetry::global::meter("streamkit"); + let packet_counter = meter.u64_counter("rtmp_publish.packets").build(); + let video_labels = + [KeyValue::new("node", node_name.clone()), KeyValue::new("track", "video")]; + let audio_labels = + [KeyValue::new("node", node_name.clone()), KeyValue::new("track", "audio")]; + let mut stats = NodeStatsTracker::new(node_name.clone(), context.stats_tx.clone()); + + // ── Publishing state ──────────────────────────────────────────── + let mut audio_seq_header_sent = false; + let mut video_packet_count: u64 = 0; + let mut audio_packet_count: u64 = 0; + let mut tcp_read_buf = vec![0u8; 8192]; + + // Per-track timestamp rebase state. Source timestamps from + // mic + camera are synchronized (same browser epoch), but audio + // and video arrive through different pipeline paths that may + // start at different wall-clock times (e.g. compositor generates + // early frames before MoQ video arrives, while audio waits for + // the opus→AAC chain). + // + // To align the tracks in the RTMP stream we follow the same + // pattern as the WebM muxer: each track's first frame computes + // a rebase offset so its RTMP timestamp starts at the current + // global position. Subsequent frames preserve the source- + // timestamp cadence (which is correct because mic/camera are + // synchronized). Large backward jumps (compositor calibration) + // trigger an offset reset. + let mut ts_state = RtmpTimestampState::new(); + + // ── Main publishing loop ──────────────────────────────────────── + tracing::info!(%node_name, "Entering RTMP publishing loop"); + + let result: Result<(), StreamKitError> = async { + loop { + // Biased select: TCP read is checked FIRST every + // iteration so server ACKs / pings are always drained + // before we send more media. Without this, the + // video/audio arms can starve the read arm and cause + // an ACK window overflow (`unacked > window * 2`). + tokio::select! { + biased; + + // TCP read (server responses / keepalive) — highest priority. + // 30s timeout prevents hanging if the server becomes + // unresponsive while input channels are idle. + read_result = tokio::time::timeout( + std::time::Duration::from_secs(30), + stream.read(&mut tcp_read_buf), + ) => { + let Ok(read_result) = read_result else { + // Timeout — server hasn't sent anything in 30s. + // This is normal during idle periods; just loop + // back to check the other select arms. + continue; + }; + match read_result { + Ok(0) => { + tracing::warn!(%node_name, "RTMP server closed connection"); + break; + } + Ok(n) => { + if let Err(e) = connection.feed_recv_buf(&tcp_read_buf[..n]) { + tracing::warn!(%node_name, error = %e, "Error feeding RTMP recv buffer"); + } + // Drain events (acks, pings, etc.) + if drain_events(&mut connection, &node_name) { + tracing::info!(%node_name, "Breaking loop: peer disconnected"); + break; + } + flush_send_buf(&mut connection, &mut stream, &mut tcp_read_buf, &node_name).await?; + } + Err(e) => { + tracing::warn!(%node_name, error = %e, "TCP read error"); + break; + } + } + } + + // Video input + maybe_pkt = video_rx.recv() => { + let Some(pkt) = maybe_pkt else { + tracing::info!(%node_name, "Video input channel closed"); + break; + }; + // Stop sending if the server has disconnected. + if connection.state() != RtmpConnectionState::Publishing { + tracing::warn!(%node_name, state = %connection.state(), "Connection no longer publishing, exiting"); + break; + } + let timestamp_ms = ts_state.stamp(&pkt, Track::Video, &node_name); + if let Err(e) = process_video_packet( + &pkt, &mut connection, timestamp_ms, + &packet_counter, &video_labels, + &mut stats, &mut video_packet_count, &node_name, + ) { + tracing::warn!(%node_name, error = %e, "Error processing video packet"); + stats.errored(); + } + flush_send_buf(&mut connection, &mut stream, &mut tcp_read_buf, &node_name).await?; + } + + // Audio input + maybe_pkt = audio_rx.recv() => { + let Some(pkt) = maybe_pkt else { + tracing::info!(%node_name, "Audio input channel closed"); + break; + }; + if connection.state() != RtmpConnectionState::Publishing { + tracing::warn!(%node_name, state = %connection.state(), "Connection no longer publishing, exiting"); + break; + } + let timestamp_ms = ts_state.stamp(&pkt, Track::Audio, &node_name); + if let Err(e) = process_audio_packet( + &pkt, &mut connection, &mut audio_seq_header_sent, + timestamp_ms, + self.config.sample_rate, self.config.channels, + &packet_counter, &audio_labels, + &mut stats, &mut audio_packet_count, &node_name, + ) { + tracing::warn!(%node_name, error = %e, "Error processing audio packet"); + stats.errored(); + } + flush_send_buf(&mut connection, &mut stream, &mut tcp_read_buf, &node_name).await?; + } + + // Shutdown signal + Some(control_msg) = context.control_rx.recv() => { + if matches!(control_msg, streamkit_core::control::NodeControlMessage::Shutdown) { + tracing::info!(%node_name, "Received shutdown signal"); + break; + } + } + } + + stats.maybe_send(); + } + Ok(()) + } + .await; + + tracing::info!(%node_name, video_packets = video_packet_count, audio_packets = audio_packet_count, "RTMP publishing finished"); + + // Best-effort graceful TCP shutdown so the server sees a FIN + // rather than an abrupt RST. The rtmp_client module does not + // expose deleteStream/FCUnpublish on the publish client, so we + // cannot send a clean RTMP-level teardown; the TCP close is the + // next best signal. + let _ = stream.shutdown().await; + + match result { + Ok(()) => { + state_helpers::emit_stopped(&context.state_tx, &node_name, "completed"); + Ok(()) + }, + Err(e) => { + state_helpers::emit_failed(&context.state_tx, &node_name, e.to_string()); + Err(e) + }, + } + } +} + +// --------------------------------------------------------------------------- +// TCP / TLS connection helpers +// --------------------------------------------------------------------------- + +/// Unified async stream over plain TCP or TLS. +enum RtmpStream { + Plain(TcpStream), + Tls(Box>), +} + +impl RtmpStream { + async fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + match self { + Self::Plain(s) => s.read(buf).await, + Self::Tls(s) => s.read(buf).await, + } + } + + /// Non-blocking read that returns `WouldBlock` when no data is available. + /// + /// For plain TCP this calls `TcpStream::try_read`, a direct syscall that + /// bypasses the tokio reactor. For TLS there is no synchronous decrypt + /// path, so this always returns `WouldBlock` — the biased main select + /// loop handles TLS ACK draining instead. + fn try_read(&mut self, buf: &mut [u8]) -> std::io::Result { + match self { + Self::Plain(s) => s.try_read(buf), + Self::Tls(_) => Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)), + } + } + + async fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> { + match self { + Self::Plain(s) => s.write_all(buf).await, + Self::Tls(s) => s.write_all(buf).await, + } + } + + async fn flush(&mut self) -> std::io::Result<()> { + match self { + Self::Plain(s) => tokio::io::AsyncWriteExt::flush(s).await, + Self::Tls(s) => tokio::io::AsyncWriteExt::flush(s).await, + } + } + + async fn shutdown(&mut self) -> std::io::Result<()> { + match self { + Self::Plain(s) => tokio::io::AsyncWriteExt::shutdown(s).await, + Self::Tls(s) => tokio::io::AsyncWriteExt::shutdown(s).await, + } + } +} + +/// Mask the stream-key portion of an RTMP URL for safe logging. +/// +/// If the URL path has two or more segments (e.g. `/app/stream_key`), +/// the last segment is replaced with ``. If the path has +/// only one segment (e.g. `/app` — no key embedded), the URL is +/// returned as-is so the app name remains visible in logs. +fn mask_stream_key(url: &str) -> String { + // Find the start of the path portion (after ://host[:port]). + let path_start = url + .find("://") + .and_then(|scheme_end| url[scheme_end + 3..].find('/').map(|p| scheme_end + 3 + p)); + + path_start.map_or_else( + || "".to_string(), + |start| { + let path = &url[start..]; + // rfind('/') always succeeds (at least the leading `/`). + // If > 0 there is a second segment to redact. + match path.rfind('/') { + Some(last) if last > 0 => format!("{}/", &url[..start + last]), + _ => url.to_string(), + } + }, + ) +} + +/// Valid AAC sampling frequencies (ISO 14496-3 Table 1.18). +const AAC_SAMPLE_RATES: [u32; 13] = [ + 96_000, 88_200, 64_000, 48_000, 44_100, 32_000, 24_000, 22_050, 16_000, 12_000, 11_025, 8_000, + 7_350, +]; + +/// Validate AAC-related config fields at startup so we fail fast with a +/// clear error instead of producing a corrupt AudioSpecificConfig at runtime. +fn validate_aac_config(config: &RtmpPublishConfig) -> Result<(), String> { + if config.channels == 0 || config.channels > 7 { + return Err(format!( + "channels must be 1..=7 (AAC channelConfiguration is 4 bits), got {}", + config.channels + )); + } + if !AAC_SAMPLE_RATES.contains(&config.sample_rate) { + return Err(format!( + "sample_rate {} is not a standard AAC sampling frequency; \ + valid values: {:?}", + config.sample_rate, AAC_SAMPLE_RATES + )); + } + Ok(()) +} + +/// Resolve the final RTMP URL from config fields. +/// +/// Priority: +/// 1. `stream_key_env` — read the key from the named environment variable. +/// 2. `stream_key` — use the literal value. +/// 3. Neither set — use `url` as-is (key already embedded). +/// +/// The resolved key is appended to the base URL separated by `/`. +fn resolve_rtmp_url(config: &RtmpPublishConfig) -> Result { + resolve_rtmp_url_with_env(config, |name| std::env::var(name)) +} + +/// Inner implementation that accepts an env-var resolver, allowing tests to +/// avoid `std::env::set_var` (which is unsound in multi-threaded processes +/// since Rust 1.83). +fn resolve_rtmp_url_with_env(config: &RtmpPublishConfig, env_var: F) -> Result +where + F: Fn(&str) -> Result, +{ + let key = if let Some(ref env_name) = config.stream_key_env { + let val = env_var(env_name).map_err(|e| { + format!("stream_key_env references '{env_name}' but the variable is not set: {e}") + })?; + if val.is_empty() { + return Err(format!( + "stream_key_env references '{env_name}' but the variable is empty" + )); + } + Some(val) + } else { + config.stream_key.clone() + }; + + match key { + Some(k) if !k.is_empty() => Ok(format!("{}/{}", config.url.trim_end_matches('/'), k)), + _ => Ok(config.url.clone()), + } +} + +/// Connect to the RTMP server, using TLS if the URL scheme is `rtmps://`. +async fn connect(url: &RtmpUrl) -> Result { + let addr = format!("{}:{}", url.host, url.port); + let tcp = tokio::time::timeout(std::time::Duration::from_secs(10), TcpStream::connect(&addr)) + .await + .map_err(|_| format!("TCP connect to {addr} timed out after 10s"))? + .map_err(|e| format!("TCP connect to {addr} failed: {e}"))?; + tcp.set_nodelay(true).map_err(|e| format!("Failed to set TCP_NODELAY: {e}"))?; + + if url.tls { + use rustls_platform_verifier::BuilderVerifierExt; + + let config = rustls::ClientConfig::builder() + .with_platform_verifier() + .map_err(|e| format!("Failed to build TLS config with platform verifier: {e}"))? + .with_no_client_auth(); + let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config)); + let server_name = rustls::pki_types::ServerName::try_from(url.host.clone()) + .map_err(|e| format!("Invalid TLS server name '{}': {e}", url.host))?; + let tls_stream = connector + .connect(server_name, tcp) + .await + .map_err(|e| format!("TLS handshake with {} failed: {e}", url.host))?; + Ok(RtmpStream::Tls(Box::new(tls_stream))) + } else { + Ok(RtmpStream::Plain(tcp)) + } +} + +// --------------------------------------------------------------------------- +// RTMP protocol helpers +// --------------------------------------------------------------------------- + +/// Drive the RTMP handshake until the connection reaches [`RtmpConnectionState::Publishing`]. +async fn drive_handshake( + connection: &mut RtmpPublishClientConnection, + stream: &mut RtmpStream, + node_name: &str, +) -> Result<(), String> { + let mut recv_buf = vec![0u8; 8192]; + + loop { + // Flush outgoing data first. + flush_send_buf_raw(connection, stream) + .await + .map_err(|e| format!("Handshake write failed: {e}"))?; + + if connection.state() == RtmpConnectionState::Publishing { + return Ok(()); + } + if connection.state() == RtmpConnectionState::Disconnecting { + return Err("RTMP server rejected the connection".to_string()); + } + + // Wait for data from the server (with timeout). + let read_result = + tokio::time::timeout(std::time::Duration::from_secs(10), stream.read(&mut recv_buf)) + .await; + + match read_result { + Ok(Ok(0)) => return Err("Server closed connection during handshake".to_string()), + Ok(Ok(n)) => { + connection + .feed_recv_buf(&recv_buf[..n]) + .map_err(|e| format!("Handshake feed error: {e}"))?; + }, + Ok(Err(e)) => return Err(format!("Handshake read error: {e}")), + Err(_) => return Err("Handshake timed out after 10s".to_string()), + } + + // Process events emitted by the handshake. + while let Some(event) = connection.next_event() { + tracing::debug!(%node_name, ?event, "RTMP handshake event"); + } + } +} + +/// Flush the RTMP connection's send buffer to the TCP stream (no ACK drain). +/// +/// Used during the handshake phase where ACK window overflow is not a concern +/// because the handshake loop already reads server data between flushes. +async fn flush_send_buf_raw( + connection: &mut RtmpPublishClientConnection, + stream: &mut RtmpStream, +) -> std::io::Result<()> { + while !connection.send_buf().is_empty() { + let buf = connection.send_buf(); + stream.write_all(buf).await?; + let len = buf.len(); + connection.advance_send_buf(len); + } + stream.flush().await?; + Ok(()) +} + +/// Flush the RTMP connection's send buffer to the TCP stream. +/// +/// After flushing, performs a non-blocking drain of any pending server data +/// (ACK messages, pings, etc.) via `try_read` (a direct non-blocking +/// syscall that works for plain TCP). For TLS streams `try_read` returns +/// `WouldBlock` immediately because there is no synchronous decryption +/// path — the biased main `select!` loop handles TLS ACK draining instead +/// by always checking the TCP read arm first. +async fn flush_send_buf( + connection: &mut RtmpPublishClientConnection, + stream: &mut RtmpStream, + tcp_read_buf: &mut [u8], + node_name: &str, +) -> Result<(), StreamKitError> { + // Write all pending outbound data. + while !connection.send_buf().is_empty() { + let buf = connection.send_buf(); + stream + .write_all(buf) + .await + .map_err(|e| StreamKitError::Runtime(format!("RTMP send failed: {e}")))?; + let len = buf.len(); + connection.advance_send_buf(len); + } + // Explicit flush to ensure TLS buffered data is sent immediately. + stream.flush().await.map_err(|e| StreamKitError::Runtime(format!("RTMP flush failed: {e}")))?; + + // Non-blocking drain: `try_read` does a direct non-blocking syscall + // (bypasses the tokio reactor) so it returns data that is already + // sitting in the OS receive buffer. This catches ACKs that arrived + // while we were writing. For TLS, `try_read` returns `WouldBlock` + // and the biased main loop handles draining instead. + loop { + match stream.try_read(tcp_read_buf) { + Ok(0) => { + return Err(StreamKitError::Runtime("RTMP server closed connection".to_string())); + }, + Ok(n) => { + if let Err(e) = connection.feed_recv_buf(&tcp_read_buf[..n]) { + tracing::warn!(%node_name, error = %e, "Error feeding RTMP recv buffer (flush drain)"); + } + if drain_events(connection, node_name) { + return Err(StreamKitError::Runtime( + "RTMP server disconnected during flush drain".to_string(), + )); + } + }, + Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => { + // No data available right now — done draining. + break; + }, + Err(e) => { + return Err(StreamKitError::Runtime(format!( + "RTMP read failed during flush drain: {e}" + ))); + }, + } + } + + Ok(()) +} + +/// Drain and log any pending RTMP events (acks, pings, ignored commands). +/// +/// Returns `true` if the peer signalled a disconnect, indicating that the +/// publishing loop should exit. +fn drain_events(connection: &mut RtmpPublishClientConnection, node_name: &str) -> bool { + let mut disconnected = false; + while let Some(event) = connection.next_event() { + match &event { + super::rtmp_client::RtmpConnectionEvent::DisconnectedByPeer { reason } => { + tracing::warn!(%node_name, %reason, "RTMP server disconnected"); + disconnected = true; + }, + super::rtmp_client::RtmpConnectionEvent::StateChanged(state) => { + tracing::info!(%node_name, %state, "RTMP state changed"); + }, + } + } + disconnected +} + +// --------------------------------------------------------------------------- +// Per-track timestamp rebase (mirrors WebM muxer `stage_frame` logic) +// --------------------------------------------------------------------------- + +/// Backward timestamp jump threshold (ms). Jumps larger than this trigger +/// a rebase offset reset. Typically caused by the compositor calibrating +/// its running clock to a newly-arrived remote MoQ input. +const BACKWARD_JUMP_THRESHOLD_MS: u32 = 500; + +/// Identifies the media track for timestamp rebase bookkeeping. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum Track { + Video, + Audio, +} + +/// Per-track rebase state for a single media track. +struct TrackTimestamp { + /// Offset (in ms) added to source timestamps so the track starts at + /// the current global RTMP position when it first produces output. + rebase_offset_ms: Option, + /// Last RTMP timestamp emitted for this track (for monotonicity). + last_ms: Option, +} + +impl TrackTimestamp { + const fn new() -> Self { + Self { rebase_offset_ms: None, last_ms: None } + } +} + +/// Manages RTMP timestamps for audio and video tracks. +/// +/// Source timestamps (from `PacketMetadata::timestamp_us`) are synchronized +/// because mic and camera are captured in the same browser epoch. However, +/// audio and video arrive through different pipeline paths that may start at +/// different wall-clock times (e.g. the compositor generates early video +/// frames before MoQ input arrives, while audio waits for the opus→AAC +/// chain). +/// +/// To align the tracks we apply the same per-track rebase pattern used by +/// the WebM muxer: each track's first frame computes an offset so its RTMP +/// timestamp starts at the current global position. Subsequent frames +/// preserve the source-timestamp cadence. Large backward jumps (compositor +/// calibration) trigger an offset reset so the track re-aligns. +struct RtmpTimestampState { + video: TrackTimestamp, + audio: TrackTimestamp, + /// The highest RTMP timestamp written across both tracks (ms). + global_last_ms: u32, +} + +impl RtmpTimestampState { + const fn new() -> Self { + Self { video: TrackTimestamp::new(), audio: TrackTimestamp::new(), global_last_ms: 0 } + } + + /// Compute the RTMP timestamp (u32 ms) for a packet, applying per-track + /// rebase and monotonicity enforcement. + #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] + // RTMP timestamps are u32 ms; wrapping after ~49 days is acceptable. + // Sign loss is guarded by `.max(0)` before each cast. + fn stamp(&mut self, packet: &Packet, track: Track, node_name: &str) -> u32 { + let timestamp_us = match packet { + // In practice this node only receives Binary packets (encoded + // H.264 / AAC), but the Video variant is included for + // completeness since the type system allows it. + Packet::Binary { metadata, .. } + | Packet::Video(streamkit_core::types::VideoFrame { metadata, .. }) => { + metadata.as_ref().and_then(|m| m.timestamp_us) + }, + _ => None, + }; + + let pkt_ms = timestamp_us.map_or(0i64, |us| i64::try_from(us / 1_000).unwrap_or(i64::MAX)); + + let ts = match track { + Track::Video => &mut self.video, + Track::Audio => &mut self.audio, + }; + + // First frame for this track: compute rebase offset so the track + // starts at the current global position. + let is_new_offset = ts.rebase_offset_ms.is_none(); + let offset = + *ts.rebase_offset_ms.get_or_insert_with(|| i64::from(self.global_last_ms) - pkt_ms); + if is_new_offset { + tracing::info!( + %node_name, + track = ?track, + offset, + pkt_ms, + global_last_ms = self.global_last_ms, + "RTMP timestamp rebase initialized" + ); + } + + let mut rtmp_ms = pkt_ms.saturating_add(offset).max(0) as u32; + + // Handle large backward jumps — typically caused by the compositor + // calibrating its running clock to a remote MoQ input. Reset the + // rebase offset so the track re-aligns with the global position + // (same strategy as the WebM muxer). + if let Some(last) = ts.last_ms { + if rtmp_ms < last { + let gap_ms = last - rtmp_ms; + if gap_ms > BACKWARD_JUMP_THRESHOLD_MS { + let new_offset = i64::from(self.global_last_ms) - pkt_ms; + tracing::info!( + %node_name, + track = ?track, + gap_ms, + old_offset = offset, + new_offset, + "RTMP timestamp rebase reset (backward jump)" + ); + ts.rebase_offset_ms = Some(new_offset); + rtmp_ms = pkt_ms.saturating_add(new_offset).max(0) as u32; + } + // Enforce monotonicity for remaining small gaps / jitter. + if rtmp_ms <= last { + rtmp_ms = last.saturating_add(1); + } + } + } + + ts.last_ms = Some(rtmp_ms); + if rtmp_ms > self.global_last_ms { + self.global_last_ms = rtmp_ms; + } + + rtmp_ms + } +} + +// --------------------------------------------------------------------------- + +/// Process one encoded video packet and send it via RTMP. +/// +/// Converts H.264 Annex B to AVCC format, extracts SPS/PPS on keyframes +/// to send as an AVC sequence header, then sends the video frame. +/// +/// `timestamp_ms` is the rebased RTMP timestamp computed by the caller +/// via `RtmpTimestampState::stamp`, ensuring audio and video share a +/// common time base derived from source timestamps. +#[allow(clippy::too_many_arguments)] // Packet-processing context (connection, counters, stats) is passed individually; bundling into a struct is a future cleanup. +fn process_video_packet( + packet: &Packet, + connection: &mut RtmpPublishClientConnection, + timestamp_ms: u32, + counter: &opentelemetry::metrics::Counter, + labels: &[KeyValue], + stats: &mut NodeStatsTracker, + packet_count: &mut u64, + node_name: &str, +) -> Result<(), StreamKitError> { + let Packet::Binary { data, metadata, .. } = packet else { + tracing::debug!(%node_name, "Ignoring non-binary video packet"); + stats.discarded(); + return Ok(()); + }; + + stats.received(); + + let keyframe = metadata.as_ref().and_then(|m| m.keyframe).unwrap_or(false); + + // Convert H.264 Annex B → AVCC + let conv = convert_annexb_to_avcc(data); + + // On keyframes, send the AVC sequence header (SPS/PPS) first. + if keyframe && !conv.sps_list.is_empty() && !conv.pps_list.is_empty() { + let sps = &conv.sps_list[0]; + let (profile, compat, level) = if sps.len() >= 4 { + (sps[1], sps[2], sps[3]) + } else { + // Fallback: Constrained Baseline Level 3.1 + (0x42, 0xC0, 0x1F) + }; + + let seq_header = AvcSequenceHeader { + avc_profile_indication: profile, + profile_compatibility: compat, + avc_level_indication: level, + length_size_minus_one: 3, // 4-byte NAL unit lengths + sps_list: conv.sps_list.clone(), + pps_list: conv.pps_list.clone(), + }; + + let seq_data = seq_header.to_bytes().map_err(|e| { + StreamKitError::Runtime(format!("Failed to serialize AVC sequence header: {e}")) + })?; + + let seq_frame = RtmpVideoFrame { + timestamp: RtmpTimestamp::from_millis(timestamp_ms), + composition_timestamp_offset: RtmpTimestampDelta::ZERO, + frame_type: VideoFrameType::KeyFrame, + codec: RtmpVideoCodec::Avc, + avc_packet_type: Some(AvcPacketType::SequenceHeader), + data: seq_data, + }; + + connection.send_video(&seq_frame).map_err(|e| { + StreamKitError::Runtime(format!("Failed to send AVC sequence header: {e}")) + })?; + + tracing::debug!(%node_name, %timestamp_ms, "Sent AVC sequence header"); + } + + // Send the actual video data (AVCC-formatted), excluding SPS/PPS NALUs + // which are already conveyed in the sequence header above. + // Guard: if an access unit contained only SPS/PPS (no slice NALUs), + // video_data will be empty — skip the NalUnit frame to avoid sending + // a zero-length payload that some RTMP servers reject. + if !conv.video_data.is_empty() { + let frame = RtmpVideoFrame { + timestamp: RtmpTimestamp::from_millis(timestamp_ms), + composition_timestamp_offset: RtmpTimestampDelta::ZERO, + frame_type: if keyframe { + VideoFrameType::KeyFrame + } else { + VideoFrameType::InterFrame + }, + codec: RtmpVideoCodec::Avc, + avc_packet_type: Some(AvcPacketType::NalUnit), + data: conv.video_data, + }; + + connection + .send_video(&frame) + .map_err(|e| StreamKitError::Runtime(format!("Failed to send video frame: {e}")))?; + } + + *packet_count += 1; + counter.add(1, labels); + stats.sent(); + + // `% N == 0` instead of `.is_multiple_of(N)` for MSRV < 1.85 compat. + #[allow(clippy::manual_is_multiple_of)] + if *packet_count <= 5 || *packet_count % 100 == 0 { + tracing::debug!(%node_name, packet = *packet_count, %timestamp_ms, %keyframe, "Sent video"); + } + + Ok(()) +} + +// --------------------------------------------------------------------------- +// Audio packet processing +// --------------------------------------------------------------------------- + +/// Process one encoded audio packet and send it via RTMP. +/// +/// On the first audio packet, sends an AAC `AudioSpecificConfig` as the +/// RTMP sequence header. Subsequent packets are sent as raw AAC frames. +/// +/// `timestamp_ms` is the rebased RTMP timestamp computed by the caller +/// via `RtmpTimestampState::stamp`, ensuring audio and video share a +/// common time base derived from source timestamps. +#[allow(clippy::too_many_arguments)] // Packet-processing context (connection, counters, stats) is passed individually; bundling into a struct is a future cleanup. +fn process_audio_packet( + packet: &Packet, + connection: &mut RtmpPublishClientConnection, + seq_header_sent: &mut bool, + timestamp_ms: u32, + sample_rate: u32, + channels: u8, + counter: &opentelemetry::metrics::Counter, + labels: &[KeyValue], + stats: &mut NodeStatsTracker, + packet_count: &mut u64, + node_name: &str, +) -> Result<(), StreamKitError> { + let Packet::Binary { data, .. } = packet else { + tracing::debug!(%node_name, "Ignoring non-binary audio packet"); + stats.discarded(); + return Ok(()); + }; + + stats.received(); + + // Send AAC sequence header (AudioSpecificConfig) on first audio packet. + if !*seq_header_sent { + let asc = build_aac_audio_specific_config(sample_rate, channels); + + let seq_frame = RtmpAudioFrame { + timestamp: RtmpTimestamp::from_millis(timestamp_ms), + format: RtmpAudioFormat::Aac, + sample_rate: RtmpAudioFrame::AAC_SAMPLE_RATE, + is_stereo: RtmpAudioFrame::AAC_STEREO, + is_8bit_sample: false, + is_aac_sequence_header: true, + data: asc, + }; + + connection.send_audio(&seq_frame).map_err(|e| { + StreamKitError::Runtime(format!("Failed to send AAC sequence header: {e}")) + })?; + + tracing::info!(%node_name, "Sent AAC sequence header (AudioSpecificConfig)"); + *seq_header_sent = true; + } + + // Send the raw AAC frame. + let frame = RtmpAudioFrame { + timestamp: RtmpTimestamp::from_millis(timestamp_ms), + format: RtmpAudioFormat::Aac, + sample_rate: RtmpAudioFrame::AAC_SAMPLE_RATE, + is_stereo: RtmpAudioFrame::AAC_STEREO, + is_8bit_sample: false, + is_aac_sequence_header: false, + data: data.to_vec(), + }; + + connection + .send_audio(&frame) + .map_err(|e| StreamKitError::Runtime(format!("Failed to send audio frame: {e}")))?; + + *packet_count += 1; + counter.add(1, labels); + stats.sent(); + + // `% N == 0` instead of `.is_multiple_of(N)` for MSRV < 1.85 compat. + #[allow(clippy::manual_is_multiple_of)] + if *packet_count <= 5 || *packet_count % 200 == 0 { + tracing::debug!(%node_name, packet = *packet_count, %timestamp_ms, "Sent audio"); + } + + Ok(()) +} + +// --------------------------------------------------------------------------- +// H.264 Annex B → AVCC conversion +// --------------------------------------------------------------------------- +// +// These helpers mirror the logic in `containers/mp4.rs`. A shared +// `h264_utils` module could deduplicate them in a follow-up refactor. + +/// NAL unit type bitmask (lower 5 bits of NAL header byte). +const H264_NAL_TYPE_MASK: u8 = 0x1F; +/// NAL unit type: Sequence Parameter Set. +const H264_NAL_SPS: u8 = 7; +/// NAL unit type: Picture Parameter Set. +const H264_NAL_PPS: u8 = 8; + +/// Result of converting an H.264 Annex B access unit to AVCC format. +struct AvccConversion { + /// AVCC-formatted video data (4-byte length-prefixed NAL units), + /// excluding SPS/PPS parameter sets (those go in the sequence header). + video_data: Vec, + /// SPS NAL units found in this access unit. + sps_list: Vec>, + /// PPS NAL units found in this access unit. + pps_list: Vec>, +} + +/// Parse an H.264 Annex B bitstream into individual NAL unit payloads. +/// +/// NAL units are delimited by 3-byte (`00 00 01`) or 4-byte (`00 00 00 01`) +/// start codes. The returned slices exclude the start-code prefix. +/// +/// **Known limitation**: the 3-byte `00 00 01` pattern can theoretically +/// appear inside NAL payload data (e.g. quantized coefficient blocks). +/// Spec-compliant encoders insert emulation-prevention bytes (`00 00 03`) +/// to avoid this ambiguity, so OpenH264 output is safe. If this node +/// ever receives data from an external encoder that omits prevention +/// bytes, frames could be mis-split. This mirrors the Annex B parser in +/// `containers/mp4.rs` — extracting a shared `h264_utils` module is +/// tracked as follow-up work. +fn parse_annexb_nal_units(data: &[u8]) -> Vec<&[u8]> { + let mut nals = Vec::new(); + let mut nal_start: Option = None; + let len = data.len(); + let mut i = 0; + + while i < len { + let sc_len = if i + 2 < len && data[i] == 0 && data[i + 1] == 0 && data[i + 2] == 1 { + 3 + } else if i + 3 < len + && data[i] == 0 + && data[i + 1] == 0 + && data[i + 2] == 0 + && data[i + 3] == 1 + { + 4 + } else { + 0 + }; + + if sc_len > 0 { + if let Some(start) = nal_start { + if start < i { + nals.push(&data[start..i]); + } + } + i += sc_len; + nal_start = Some(i); + } else { + i += 1; + } + } + + if let Some(start) = nal_start { + if start < len { + nals.push(&data[start..len]); + } + } + + nals +} + +/// Convert an H.264 Annex B bitstream to AVCC format. +/// +/// Each NAL unit's start code is replaced with a 4-byte big-endian length +/// prefix. SPS and PPS NAL units are extracted separately so the caller +/// can build the RTMP `AvcSequenceHeader`. +fn convert_annexb_to_avcc(data: &[u8]) -> AvccConversion { + let nals = parse_annexb_nal_units(data); + let mut video_data = Vec::with_capacity(data.len()); + let mut sps_list = Vec::new(); + let mut pps_list = Vec::new(); + + for nal in nals { + if nal.is_empty() { + continue; + } + + // Classify and extract parameter sets. + let nal_type = nal[0] & H264_NAL_TYPE_MASK; + if nal_type == H264_NAL_SPS { + sps_list.push(nal.to_vec()); + continue; // SPS goes in the sequence header, not the NalUnit data. + } else if nal_type == H264_NAL_PPS { + pps_list.push(nal.to_vec()); + continue; // PPS goes in the sequence header, not the NalUnit data. + } + + // 4-byte big-endian length prefix. + let len = u32::try_from(nal.len()).unwrap_or(u32::MAX); + video_data.extend_from_slice(&len.to_be_bytes()); + video_data.extend_from_slice(nal); + } + + AvccConversion { video_data, sps_list, pps_list } +} + +// --------------------------------------------------------------------------- +// AAC AudioSpecificConfig builder +// --------------------------------------------------------------------------- + +/// Build a 2-byte AAC-LC `AudioSpecificConfig` for the RTMP sequence header. +/// +/// Layout (ISO 14496-3 §1.6.2.1): +/// +/// ```text +/// 5 bits audioObjectType (2 = AAC-LC) +/// 4 bits samplingFrequencyIndex +/// 4 bits channelConfiguration +/// 3 bits GASpecificConfig (frameLengthFlag=0, dependsOnCoreCoder=0, extensionFlag=0) +/// ``` +/// +/// # Panics +/// +/// Never — `sample_rate` and `channels` are validated at node startup by +/// [`validate_aac_config`]. If an unrecognized rate somehow reaches here +/// the index defaults to 3 (48 kHz) and a warning is logged. +fn build_aac_audio_specific_config(sample_rate: u32, channels: u8) -> Vec { + // The array has 13 entries (indices 0..=12), so the position always + // fits in a u8. Fallback to index 3 (48 kHz) if not found — callers + // are expected to validate beforehand, but we avoid panicking here. + let freq_index: u8 = AAC_SAMPLE_RATES.iter().position(|&r| r == sample_rate).map_or_else( + || { + tracing::warn!(sample_rate, "Unrecognized AAC sample rate, defaulting to 48 kHz index"); + 3 + }, + |i| { + // Safe: AAC_SAMPLE_RATES has 13 entries, index ≤ 12. + // unwrap_or(3) is unreachable but avoids clippy::expect_used. + u8::try_from(i).unwrap_or(3) + }, + ); + + // AAC-LC object type = 2 + let object_type: u8 = 2; + + // Pack: 5 bits objectType | 4 bits freqIndex | 4 bits channels | 3 bits zeros + let byte0 = (object_type << 3) | (freq_index >> 1); + let byte1 = (freq_index << 7) | (channels << 3); + + vec![byte0, byte1] +} + +// --------------------------------------------------------------------------- +// Node registration +// --------------------------------------------------------------------------- + +/// Registers all RTMP transport nodes with the engine's registry. +/// +/// # Panics +/// +/// Panics if `RtmpPublishConfig`'s JSON schema fails to serialize, which +/// should never happen for a valid `schemars`-derived type. +#[allow(clippy::expect_used)] // Schema serialization should never fail for valid types +pub fn register_rtmp_nodes(registry: &mut NodeRegistry) { + let default_node = RtmpPublishNode::new(RtmpPublishConfig { + url: String::new(), + stream_key: None, + stream_key_env: None, + sample_rate: default_sample_rate(), + channels: default_channels(), + }); + + registry.register_static_with_description( + "transport::rtmp::publish", + |params| { + let config = config_helpers::parse_config_required(params)?; + Ok(Box::new(RtmpPublishNode::new(config))) + }, + serde_json::to_value(schema_for!(RtmpPublishConfig)) + .expect("RtmpPublishConfig schema should serialize to JSON"), + StaticPins { inputs: default_node.input_pins(), outputs: default_node.output_pins() }, + vec!["transport".to_string(), "rtmp".to_string()], + false, + "Publishes encoded H.264 video and AAC audio to an RTMP endpoint. \ + Accepts Annex B H.264 on the 'video' pin and raw AAC frames on the 'audio' pin, \ + converting to the RTMP/FLV wire format. Supports both RTMP and RTMPS (TLS).", + ); +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + use streamkit_core::types::PacketMetadata; + + // Note: env-var tests use unique variable names per test (prefixed + // `_SK_TEST_RTMP_*`) so they are safe to run in parallel without + // `#[serial]`. If a test is added that shares a variable name, + // add the `serial_test` crate. + + #[test] + fn parse_annexb_single_nal_4byte_sc() { + let data = [0x00, 0x00, 0x00, 0x01, 0x67, 0xAA, 0xBB]; + let nals = parse_annexb_nal_units(&data); + assert_eq!(nals.len(), 1); + assert_eq!(nals[0], &[0x67, 0xAA, 0xBB]); + } + + #[test] + fn parse_annexb_single_nal_3byte_sc() { + let data = [0x00, 0x00, 0x01, 0x68, 0xCC, 0xDD]; + let nals = parse_annexb_nal_units(&data); + assert_eq!(nals.len(), 1); + assert_eq!(nals[0], &[0x68, 0xCC, 0xDD]); + } + + #[test] + fn parse_annexb_multiple_nals() { + let mut data = Vec::new(); + data.extend_from_slice(&[0x00, 0x00, 0x00, 0x01]); // SPS start code + data.extend_from_slice(&[0x67, 0x42, 0xC0, 0x1F]); // SPS NAL + data.extend_from_slice(&[0x00, 0x00, 0x01]); // PPS start code + data.extend_from_slice(&[0x68, 0xCE, 0x38, 0x80]); // PPS NAL + data.extend_from_slice(&[0x00, 0x00, 0x00, 0x01]); // IDR start code + data.extend_from_slice(&[0x65, 0x88, 0x84]); // IDR NAL + + let nals = parse_annexb_nal_units(&data); + assert_eq!(nals.len(), 3); + assert_eq!(nals[0], &[0x67, 0x42, 0xC0, 0x1F]); // SPS + assert_eq!(nals[1], &[0x68, 0xCE, 0x38, 0x80]); // PPS + assert_eq!(nals[2], &[0x65, 0x88, 0x84]); // IDR + } + + #[test] + fn parse_annexb_empty_input() { + let nals = parse_annexb_nal_units(&[]); + assert!(nals.is_empty()); + } + + #[test] + fn convert_annexb_extracts_sps_pps() { + let mut annexb = Vec::new(); + annexb.extend_from_slice(&[0x00, 0x00, 0x00, 0x01]); + let sps = [0x67, 0x42, 0xC0, 0x1F]; // SPS NAL (type 7) + annexb.extend_from_slice(&sps); + annexb.extend_from_slice(&[0x00, 0x00, 0x00, 0x01]); + let pps = [0x68, 0xCE, 0x38, 0x80]; // PPS NAL (type 8) + annexb.extend_from_slice(&pps); + annexb.extend_from_slice(&[0x00, 0x00, 0x00, 0x01]); + let idr = [0x65, 0x88, 0x84]; // IDR NAL (type 5) + annexb.extend_from_slice(&idr); + + let result = convert_annexb_to_avcc(&annexb); + + assert_eq!(result.sps_list.len(), 1); + assert_eq!(result.pps_list.len(), 1); + assert_eq!(result.sps_list[0], sps.to_vec()); + assert_eq!(result.pps_list[0], pps.to_vec()); + + // Verify AVCC video_data contains only the IDR NAL (SPS/PPS excluded). + let avcc = &result.video_data; + let len = u32::from_be_bytes([avcc[0], avcc[1], avcc[2], avcc[3]]) as usize; + assert_eq!(len, idr.len()); + assert_eq!(&avcc[4..4 + len], &idr[..]); + assert_eq!(avcc.len(), 4 + idr.len()); + } + + #[test] + fn aac_audio_specific_config_48khz_stereo() { + let asc = build_aac_audio_specific_config(48_000, 2); + assert_eq!(asc.len(), 2); + // AAC-LC=2 (00010), freqIdx=3 (0011), channels=2 (0010), GASpec=000 + // 00010 0011 0010 000 = 0x11 0x90 + assert_eq!(asc[0], 0x11); + assert_eq!(asc[1], 0x90); + } + + #[test] + fn aac_audio_specific_config_44100_mono() { + let asc = build_aac_audio_specific_config(44_100, 1); + assert_eq!(asc.len(), 2); + // AAC-LC=2 (00010), freqIdx=4 (0100), channels=1 (0001), GASpec=000 + // 00010 0100 0001 000 = 0x12 0x08 + assert_eq!(asc[0], 0x12); + assert_eq!(asc[1], 0x08); + } + + // ── AAC config validation tests ───────────────────────────────────── + + #[test] + fn validate_aac_config_valid() { + let cfg = RtmpPublishConfig { + url: String::new(), + stream_key: None, + stream_key_env: None, + sample_rate: 48_000, + channels: 2, + }; + assert!(validate_aac_config(&cfg).is_ok()); + } + + #[test] + fn validate_aac_config_channels_zero_rejected() { + let cfg = RtmpPublishConfig { + url: String::new(), + stream_key: None, + stream_key_env: None, + sample_rate: 48_000, + channels: 0, + }; + let err = validate_aac_config(&cfg).unwrap_err(); + assert!(err.contains("channels"), "{err}"); + } + + #[test] + fn validate_aac_config_channels_overflow_rejected() { + let cfg = RtmpPublishConfig { + url: String::new(), + stream_key: None, + stream_key_env: None, + sample_rate: 48_000, + channels: 8, + }; + let err = validate_aac_config(&cfg).unwrap_err(); + assert!(err.contains("channels"), "{err}"); + } + + #[test] + fn validate_aac_config_invalid_sample_rate_rejected() { + let cfg = RtmpPublishConfig { + url: String::new(), + stream_key: None, + stream_key_env: None, + sample_rate: 22_000, + channels: 2, + }; + let err = validate_aac_config(&cfg).unwrap_err(); + assert!(err.contains("sample_rate"), "{err}"); + } + + #[test] + fn mask_stream_key_hides_key() { + let url = "rtmp://a.rtmp.youtube.com/live2/xxxx-xxxx-xxxx-xxxx"; + let masked = mask_stream_key(url); + assert_eq!(masked, "rtmp://a.rtmp.youtube.com/live2/"); + assert!(!masked.contains("xxxx")); + } + + #[test] + fn mask_stream_key_bare_url_not_over_redacted() { + // When no stream key is embedded, the app name should remain visible. + let url = "rtmp://a.rtmp.youtube.com/live2"; + let masked = mask_stream_key(url); + assert_eq!(masked, url, "bare URL without key should not be redacted"); + } + + #[test] + fn mask_stream_key_no_scheme() { + let masked = mask_stream_key("no-scheme-at-all"); + assert_eq!(masked, ""); + } + + #[test] + fn convert_annexb_sps_pps_not_in_video_data() { + // Regression test: SPS/PPS NALUs must NOT appear in the AVCC video_data + // field — they belong only in the AVC sequence header. + let mut annexb = Vec::new(); + // SPS + annexb.extend_from_slice(&[0x00, 0x00, 0x00, 0x01, 0x67, 0x42, 0xC0, 0x1F]); + // PPS + annexb.extend_from_slice(&[0x00, 0x00, 0x01, 0x68, 0xCE, 0x38, 0x80]); + // IDR slice + annexb.extend_from_slice(&[0x00, 0x00, 0x00, 0x01, 0x65, 0x11, 0x22]); + + let result = convert_annexb_to_avcc(&annexb); + + // SPS/PPS should be extracted. + assert_eq!(result.sps_list.len(), 1); + assert_eq!(result.pps_list.len(), 1); + + // video_data should contain only the IDR NAL, not SPS/PPS. + // Verify no NAL in video_data has type 7 (SPS) or 8 (PPS). + let avcc = &result.video_data; + let mut offset = 0; + while offset + 4 <= avcc.len() { + let len = u32::from_be_bytes([ + avcc[offset], + avcc[offset + 1], + avcc[offset + 2], + avcc[offset + 3], + ]) as usize; + offset += 4; + assert!(offset + len <= avcc.len(), "AVCC data truncated"); + let nal_type = avcc[offset] & H264_NAL_TYPE_MASK; + assert_ne!(nal_type, H264_NAL_SPS, "SPS should not be in video_data"); + assert_ne!(nal_type, H264_NAL_PPS, "PPS should not be in video_data"); + offset += len; + } + } + + // ── resolve_rtmp_url tests ────────────────────────────────────────── + + fn make_config(url: &str, key: Option<&str>, key_env: Option<&str>) -> RtmpPublishConfig { + RtmpPublishConfig { + url: url.to_string(), + stream_key: key.map(String::from), + stream_key_env: key_env.map(String::from), + sample_rate: default_sample_rate(), + channels: default_channels(), + } + } + + #[test] + fn resolve_url_no_key_uses_url_as_is() { + let cfg = make_config("rtmp://host/app/inline_key", None, None); + assert_eq!(resolve_rtmp_url(&cfg).unwrap(), "rtmp://host/app/inline_key"); + } + + #[test] + fn resolve_url_with_stream_key() { + let cfg = make_config("rtmp://a.rtmp.youtube.com/live2", Some("my-key"), None); + assert_eq!(resolve_rtmp_url(&cfg).unwrap(), "rtmp://a.rtmp.youtube.com/live2/my-key"); + } + + #[test] + fn resolve_url_strips_trailing_slash() { + let cfg = make_config("rtmp://host/app/", Some("key"), None); + assert_eq!(resolve_rtmp_url(&cfg).unwrap(), "rtmp://host/app/key"); + } + + #[test] + fn resolve_url_env_takes_precedence() { + let cfg = make_config("rtmp://host/app", Some("literal-key"), Some("MY_KEY")); + let result = resolve_rtmp_url_with_env(&cfg, |name| { + assert_eq!(name, "MY_KEY"); + Ok("env-key".to_string()) + }) + .unwrap(); + assert_eq!(result, "rtmp://host/app/env-key"); + } + + #[test] + fn resolve_url_env_var_set() { + let cfg = make_config("rtmp://host/app", None, Some("MY_KEY")); + let result = resolve_rtmp_url_with_env(&cfg, |_| Ok("secret123".to_string())).unwrap(); + assert_eq!(result, "rtmp://host/app/secret123"); + } + + #[test] + fn resolve_url_env_var_not_set() { + let cfg = make_config("rtmp://host/app", None, Some("MISSING")); + let err = + resolve_rtmp_url_with_env(&cfg, |_| Err(std::env::VarError::NotPresent)).unwrap_err(); + assert!(err.contains("not set"), "error should mention 'not set': {err}"); + } + + #[test] + fn resolve_url_env_var_empty() { + let cfg = make_config("rtmp://host/app", None, Some("MY_KEY")); + let err = resolve_rtmp_url_with_env(&cfg, |_| Ok(String::new())).unwrap_err(); + assert!(err.contains("empty"), "error should mention 'empty': {err}"); + } + + // ── RtmpTimestampState rebase tests ─────────────────────────────── + + /// Helper: build a `Packet::Binary` with a given `timestamp_us`. + fn make_packet(timestamp_us: Option) -> Packet { + Packet::Binary { + data: bytes::Bytes::from_static(&[0]), + metadata: timestamp_us.map(|ts| PacketMetadata { + timestamp_us: Some(ts), + duration_us: None, + sequence: None, + keyframe: None, + }), + content_type: None, + } + } + + #[test] + fn rebase_first_video_starts_at_zero() { + let mut state = RtmpTimestampState::new(); + let pkt = make_packet(Some(0)); + let ts = state.stamp(&pkt, Track::Video, "test"); + assert_eq!(ts, 0); + } + + #[test] + fn rebase_video_preserves_cadence() { + let mut state = RtmpTimestampState::new(); + let ts0 = state.stamp(&make_packet(Some(0)), Track::Video, "test"); + let ts1 = state.stamp(&make_packet(Some(33_000)), Track::Video, "test"); + let ts2 = state.stamp(&make_packet(Some(66_000)), Track::Video, "test"); + assert_eq!(ts0, 0); + assert_eq!(ts1, 33); + assert_eq!(ts2, 66); + } + + #[test] + fn rebase_late_audio_aligns_to_video() { + // Video has been running for 3 seconds. + let mut state = RtmpTimestampState::new(); + for i in 0..90 { + // 30fps video for 3 seconds (90 frames). + state.stamp(&make_packet(Some(i * 33_333)), Track::Video, "test"); + } + // 89 * 33_333us = 2_966_637us → global_last_ms ≈ 2966. + // Audio arrives with source_ts=0 (MoQ normalized). It should + // start at the current global position. + let audio_ts0 = state.stamp(&make_packet(Some(0)), Track::Audio, "test"); + let audio_ts1 = state.stamp(&make_packet(Some(20_000)), Track::Audio, "test"); + // Audio should start near video's current position (~2966ms). + assert!( + (2900..=3100).contains(&audio_ts0), + "audio should start near video position, got {audio_ts0}" + ); + // Cadence preserved: 20ms between audio frames. + assert_eq!(audio_ts1 - audio_ts0, 20); + } + + #[test] + fn rebase_backward_jump_resets_offset() { + // Simulate compositor calibration: video starts at running clock + // ts=0, then after calibration jumps backward to MoQ origin. + let mut state = RtmpTimestampState::new(); + + // Pre-calibration: compositor running clock 0..~4000ms. + for i in 0..120 { + state.stamp(&make_packet(Some(i * 33_333)), Track::Video, "test"); + } + // 119 * 33_333us = 3_966_627us → global_last_ms ≈ 3966. + let global_before = state.global_last_ms; + + // Post-calibration: compositor jumps to MoQ timestamp ~100ms + // (a large backward jump). + let ts = state.stamp(&make_packet(Some(100_000)), Track::Video, "test"); + // Should have reset and re-aligned near the global position. + assert!( + ts >= global_before, + "after rebase reset, ts ({ts}) should be >= global_before ({global_before})" + ); + } + + #[test] + fn rebase_monotonicity_enforced() { + let mut state = RtmpTimestampState::new(); + // First packet at 0ms to establish the offset. + let _ = state.stamp(&make_packet(Some(0)), Track::Video, "test"); + let ts0 = state.stamp(&make_packet(Some(100_000)), Track::Video, "test"); + // Small backward jitter (< 500ms threshold). + let ts1 = state.stamp(&make_packet(Some(99_000)), Track::Video, "test"); + assert!(ts1 > ts0, "timestamps must be monotonically increasing: ts0={ts0}, ts1={ts1}"); + } + + #[test] + fn convert_annexb_sps_pps_only_yields_empty_video_data() { + // An access unit containing only SPS+PPS (no slice NALUs) should + // produce empty video_data so the caller can skip the NalUnit frame. + let mut annexb = Vec::new(); + annexb.extend_from_slice(&[0x00, 0x00, 0x00, 0x01, 0x67, 0x42, 0xC0, 0x1F]); // SPS + annexb.extend_from_slice(&[0x00, 0x00, 0x01, 0x68, 0xCE, 0x38, 0x80]); // PPS + + let result = convert_annexb_to_avcc(&annexb); + + assert_eq!(result.sps_list.len(), 1); + assert_eq!(result.pps_list.len(), 1); + assert!( + result.video_data.is_empty(), + "video_data should be empty for SPS/PPS-only access units" + ); + } +} diff --git a/crates/nodes/src/transport/rtmp_client.rs b/crates/nodes/src/transport/rtmp_client.rs new file mode 100644 index 00000000..da1de8d3 --- /dev/null +++ b/crates/nodes/src/transport/rtmp_client.rs @@ -0,0 +1,2464 @@ +// SPDX-FileCopyrightText: © 2025 StreamKit Contributors +// +// SPDX-License-Identifier: MPL-2.0 + +//! Minimal sans-I/O RTMP publish client. +//! +//! Implements just enough of the RTMP protocol to connect to an RTMP/RTMPS +//! server and publish H.264 video + AAC audio. No server-side handling, +//! play/subscribe, or AMF3 support. +//! +//! This module replaces the external `shiguredo_rtmp` crate, fixing two +//! spec-compliance issues: +//! +//! 1. **Chunk stream ID assignment** — protocol control on csid 2, commands +//! and media on csid 3+ (the old library used csid 2 for everything, +//! which Twitch rejects). +//! 2. **Server-assigned stream ID** — the `createStream` response's stream +//! ID is stored and used for publish/media (the old library hardcoded 2, +//! but Twitch assigns 1). +//! +//! Additionally, the client does **not** enforce ACK windows on the send +//! side (matching OBS/FFmpeg behaviour), eliminating the need for the +//! `override_ack_window` hack. + +use std::collections::{HashMap, VecDeque}; +use std::fmt; + +// --------------------------------------------------------------------------- +// Error +// --------------------------------------------------------------------------- + +/// Error type for the RTMP client module. +#[derive(Debug)] +pub(super) struct Error { + message: String, +} + +impl Error { + fn new(msg: impl Into) -> Self { + Self { message: msg.into() } + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(&self.message) + } +} + +impl std::error::Error for Error {} + +// --------------------------------------------------------------------------- +// RtmpUrl +// --------------------------------------------------------------------------- + +/// Parsed RTMP URL. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(super) struct RtmpUrl { + pub host: String, + pub port: u16, + pub app: String, + pub stream_name: String, + pub tls: bool, +} + +impl RtmpUrl { + /// Parse `rtmp[s]://host[:port]/app[/extra_segments]/stream_name`. + /// + /// The path is split on the **last** `/` into `app` and `stream_name`. + /// Default ports: 1935 (rtmp), 443 (rtmps). + pub fn parse(s: &str) -> Result { + let (tls, rest) = if let Some(r) = s.strip_prefix("rtmps://") { + (true, r) + } else if let Some(r) = s.strip_prefix("rtmp://") { + (false, r) + } else { + return Err(Error::new("URL must start with rtmp:// or rtmps://")); + }; + + let default_port: u16 = if tls { 443 } else { 1935 }; + + // Split host[:port] from /path. + let (authority, path) = rest.find('/').map_or((rest, ""), |i| (&rest[..i], &rest[i + 1..])); + + let (host, port) = if let Some(colon) = authority.rfind(':') { + let port_str = &authority[colon + 1..]; + let port = port_str + .parse::() + .map_err(|_| Error::new(format!("Invalid port: {port_str}")))?; + (authority[..colon].to_string(), port) + } else { + (authority.to_string(), default_port) + }; + + if host.is_empty() { + return Err(Error::new("Empty host")); + } + + // Split path on last `/` into app and stream_name. + let (app, stream_name) = + path.rfind('/').map_or(("", path), |i| (&path[..i], &path[i + 1..])); + + // The "app" is everything before the last segment; if there's only + // one segment it becomes the stream_name and app is the whole path + // portion before the stream_name (which would be empty). But RTMP + // requires both, so we handle the single-segment case: the single + // segment is the app with an empty stream_name. + if app.is_empty() && !stream_name.is_empty() { + // Single path segment: treat it as app, stream_name empty. + // The caller (rtmp.rs) appends the stream key separately. + return Ok(Self { + host, + port, + app: stream_name.to_string(), + stream_name: String::new(), + tls, + }); + } + + if app.is_empty() { + return Err(Error::new("Empty app name in RTMP URL")); + } + + Ok(Self { host, port, app: app.to_string(), stream_name: stream_name.to_string(), tls }) + } + + /// Build the `tcUrl` for the RTMP connect command. + /// + /// Format: `rtmp[s]://host/app` — deliberately omits the default port + /// because Twitch returns a degraded response when the port is included. + fn tc_url(&self) -> String { + let scheme = if self.tls { "rtmps" } else { "rtmp" }; + let default_port = if self.tls { 443 } else { 1935 }; + if self.port == default_port { + format!("{scheme}://{}/{}", self.host, self.app) + } else { + format!("{scheme}://{}:{}/{}", self.host, self.port, self.app) + } + } +} + +impl std::str::FromStr for RtmpUrl { + type Err = Error; + fn from_str(s: &str) -> Result { + Self::parse(s) + } +} + +impl fmt::Display for RtmpUrl { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let scheme = if self.tls { "rtmps" } else { "rtmp" }; + write!(f, "{scheme}://{}:{}/{}/{}", self.host, self.port, self.app, self.stream_name) + } +} + +// --------------------------------------------------------------------------- +// RtmpTimestamp / RtmpTimestampDelta +// --------------------------------------------------------------------------- + +/// RTMP timestamp (milliseconds, u32). +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(super) struct RtmpTimestamp(u32); + +impl RtmpTimestamp { + pub const fn from_millis(ms: u32) -> Self { + Self(ms) + } + pub const fn millis(self) -> u32 { + self.0 + } +} + +/// RTMP timestamp delta (milliseconds, i32). +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(super) struct RtmpTimestampDelta(i32); + +impl RtmpTimestampDelta { + pub const ZERO: Self = Self(0); + pub const fn millis(self) -> i32 { + self.0 + } +} + +// --------------------------------------------------------------------------- +// Media types (public API for the rtmp.rs node) +// --------------------------------------------------------------------------- + +/// Video frame type. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(super) enum VideoFrameType { + KeyFrame, + InterFrame, +} + +/// Video codec identifier. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(super) enum VideoCodec { + Avc, +} + +/// AVC packet type (H.264). +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(super) enum AvcPacketType { + SequenceHeader, + NalUnit, +} + +/// Encoded video frame for RTMP publishing. +pub(super) struct VideoFrame { + pub timestamp: RtmpTimestamp, + pub composition_timestamp_offset: RtmpTimestampDelta, + pub frame_type: VideoFrameType, + pub codec: VideoCodec, + pub avc_packet_type: Option, + pub data: Vec, +} + +/// Audio format identifier. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(super) enum AudioFormat { + Aac, +} + +/// Audio sample rate (FLV header field). +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(super) enum AudioSampleRate { + Khz44, +} + +/// Encoded audio frame for RTMP publishing. +pub(super) struct AudioFrame { + pub timestamp: RtmpTimestamp, + pub format: AudioFormat, + pub sample_rate: AudioSampleRate, + pub is_8bit_sample: bool, + pub is_stereo: bool, + pub is_aac_sequence_header: bool, + pub data: Vec, +} + +impl AudioFrame { + /// FLV-spec fixed sample rate for AAC (value ignored by decoder). + pub const AAC_SAMPLE_RATE: AudioSampleRate = AudioSampleRate::Khz44; + /// FLV-spec fixed stereo flag for AAC (value ignored by decoder). + pub const AAC_STEREO: bool = true; +} + +/// AVC Sequence Header (`AVCDecoderConfigurationRecord`). +pub(super) struct AvcSequenceHeader { + pub avc_profile_indication: u8, + pub profile_compatibility: u8, + pub avc_level_indication: u8, + pub length_size_minus_one: u8, + pub sps_list: Vec>, + pub pps_list: Vec>, +} + +impl AvcSequenceHeader { + /// Serialize to `AVCDecoderConfigurationRecord` bytes. + #[allow(clippy::cast_possible_truncation)] + pub fn to_bytes(&self) -> Result, Error> { + if self.sps_list.is_empty() { + return Err(Error::new("AvcSequenceHeader: no SPS")); + } + if self.pps_list.is_empty() { + return Err(Error::new("AvcSequenceHeader: no PPS")); + } + if self.sps_list.len() > 31 { + return Err(Error::new("AvcSequenceHeader: too many SPS (max 31)")); + } + if self.pps_list.len() > 255 { + return Err(Error::new("AvcSequenceHeader: too many PPS (max 255)")); + } + + let mut buf = Vec::with_capacity(64); + // configurationVersion = 1 + buf.push(1); + buf.push(self.avc_profile_indication); + buf.push(self.profile_compatibility); + buf.push(self.avc_level_indication); + // lengthSizeMinusOne (6 bits reserved=0b111111 | 2 bits) + buf.push(0xFC | (self.length_size_minus_one & 0x03)); + // numOfSequenceParameterSets (3 bits reserved=0b111 | 5 bits count) + buf.push(0xE0 | (self.sps_list.len() as u8 & 0x1F)); + for sps in &self.sps_list { + let len = + u16::try_from(sps.len()).map_err(|_| Error::new("SPS too large for u16 length"))?; + buf.extend_from_slice(&len.to_be_bytes()); + buf.extend_from_slice(sps); + } + // numOfPictureParameterSets + buf.push(self.pps_list.len() as u8); + for pps in &self.pps_list { + let len = + u16::try_from(pps.len()).map_err(|_| Error::new("PPS too large for u16 length"))?; + buf.extend_from_slice(&len.to_be_bytes()); + buf.extend_from_slice(pps); + } + Ok(buf) + } +} + +// --------------------------------------------------------------------------- +// AMF0 codec (subset) +// --------------------------------------------------------------------------- + +/// AMF0 value — only the types needed for RTMP publish commands. +#[derive(Debug, Clone, PartialEq)] +enum Amf0Value { + Number(f64), + Boolean(bool), + String(String), + Object(Vec<(String, Self)>), + Null, +} + +// AMF0 type markers. +const AMF0_NUMBER: u8 = 0x00; +const AMF0_BOOLEAN: u8 = 0x01; +const AMF0_STRING: u8 = 0x02; +const AMF0_OBJECT: u8 = 0x03; +const AMF0_NULL: u8 = 0x05; +const AMF0_OBJECT_END: [u8; 3] = [0x00, 0x00, 0x09]; + +/// Encode an AMF0 value, appending bytes to `buf`. +fn amf0_encode(val: &Amf0Value, buf: &mut Vec) -> Result<(), Error> { + match val { + Amf0Value::Number(n) => { + buf.push(AMF0_NUMBER); + buf.extend_from_slice(&n.to_be_bytes()); + }, + Amf0Value::Boolean(b) => { + buf.push(AMF0_BOOLEAN); + buf.push(u8::from(*b)); + }, + Amf0Value::String(s) => { + buf.push(AMF0_STRING); + amf0_encode_string_payload(s, buf)?; + }, + Amf0Value::Object(props) => { + buf.push(AMF0_OBJECT); + for (key, val) in props { + amf0_encode_string_payload(key, buf)?; + amf0_encode(val, buf)?; + } + buf.extend_from_slice(&AMF0_OBJECT_END); + }, + Amf0Value::Null => { + buf.push(AMF0_NULL); + }, + } + Ok(()) +} + +/// Encode an AMF0 string payload (u16 length + UTF-8, no type marker). +fn amf0_encode_string_payload(s: &str, buf: &mut Vec) -> Result<(), Error> { + let len = u16::try_from(s.len()) + .map_err(|_| Error::new(format!("AMF0 string too long ({} bytes, max 65535)", s.len())))?; + buf.extend_from_slice(&len.to_be_bytes()); + buf.extend_from_slice(s.as_bytes()); + Ok(()) +} + +/// Decode one AMF0 value from a byte slice. +/// +/// Returns `(value, bytes_consumed)` or an error. +fn amf0_decode(data: &[u8]) -> Result<(Amf0Value, usize), Error> { + if data.is_empty() { + return Err(Error::new("AMF0: unexpected end of data")); + } + + let marker = data[0]; + let rest = &data[1..]; + + match marker { + AMF0_NUMBER => { + if rest.len() < 8 { + return Err(Error::new("AMF0 Number: need 8 bytes")); + } + let n = f64::from_be_bytes( + rest[..8] + .try_into() + .map_err(|_| Error::new("AMF0 Number: slice conversion failed"))?, + ); + Ok((Amf0Value::Number(n), 9)) + }, + AMF0_BOOLEAN => { + if rest.is_empty() { + return Err(Error::new("AMF0 Boolean: need 1 byte")); + } + Ok((Amf0Value::Boolean(rest[0] != 0), 2)) + }, + AMF0_STRING => { + let (s, consumed) = amf0_decode_string_payload(rest)?; + Ok((Amf0Value::String(s), 1 + consumed)) + }, + AMF0_OBJECT => { + let mut props = Vec::new(); + let mut offset = 1; // past the marker + loop { + if data.len() < offset + 3 { + return Err(Error::new("AMF0 Object: unexpected end")); + } + // Check for object-end marker (00 00 09). + if data[offset] == 0 && data[offset + 1] == 0 && data[offset + 2] == 0x09 { + offset += 3; + break; + } + let (key, key_consumed) = amf0_decode_string_payload(&data[offset..])?; + offset += key_consumed; + let (val, val_consumed) = amf0_decode(&data[offset..])?; + offset += val_consumed; + props.push((key, val)); + } + Ok((Amf0Value::Object(props), offset)) + }, + AMF0_NULL => Ok((Amf0Value::Null, 1)), + _ => Err(Error::new(format!("AMF0: unsupported type marker 0x{marker:02X}"))), + } +} + +/// Decode an AMF0 string payload (u16 length + UTF-8, no type marker). +fn amf0_decode_string_payload(data: &[u8]) -> Result<(String, usize), Error> { + if data.len() < 2 { + return Err(Error::new("AMF0 string: need 2 bytes for length")); + } + let len = u16::from_be_bytes([data[0], data[1]]) as usize; + if data.len() < 2 + len { + return Err(Error::new("AMF0 string: truncated")); + } + let s = std::str::from_utf8(&data[2..2 + len]) + .map_err(|e| Error::new(format!("AMF0 string: invalid UTF-8: {e}")))? + .to_string(); + Ok((s, 2 + len)) +} + +// --------------------------------------------------------------------------- +// RTMP Messages +// --------------------------------------------------------------------------- + +/// A fully decoded inbound RTMP message. +struct InboundMessage { + #[cfg(test)] + timestamp: u32, + msg_type_id: u8, + #[cfg(test)] + stream_id: u32, + payload: Vec, +} + +/// An outbound RTMP message to be chunk-encoded. +struct OutboundMessage { + csid: u16, + timestamp: u32, + msg_type_id: u8, + stream_id: u32, + payload: Vec, +} + +// RTMP message type IDs. +const MSG_SET_CHUNK_SIZE: u8 = 1; +const MSG_ABORT: u8 = 2; +const MSG_ACK: u8 = 3; +const MSG_USER_CONTROL: u8 = 4; +const MSG_WIN_ACK_SIZE: u8 = 5; +const MSG_SET_PEER_BANDWIDTH: u8 = 6; +const MSG_AUDIO: u8 = 8; +const MSG_VIDEO: u8 = 9; +const MSG_COMMAND_AMF0: u8 = 20; + +// User control event types. +const UC_STREAM_BEGIN: u16 = 0; +const UC_STREAM_EOF: u16 = 1; +const UC_PING_REQUEST: u16 = 6; + +// Chunk stream IDs (RTMP spec-compliant assignment). +const CSID_PROTOCOL_CONTROL: u16 = 2; +const CSID_COMMAND: u16 = 3; + +/// Chunk stream ID for commands/media on a given message stream. +/// Stream 0 uses csid=3, stream N uses csid=3+N. +#[allow(clippy::cast_possible_truncation)] +fn csid_for_stream(stream_id: u32) -> u16 { + // Clamp to avoid overflow — in practice stream IDs are small. + CSID_COMMAND + (stream_id.min(u32::from(u16::MAX) - u32::from(CSID_COMMAND)) as u16) +} + +// --------------------------------------------------------------------------- +// Chunk Encoder +// --------------------------------------------------------------------------- + +/// Per-csid state for outbound header compression. +#[derive(Default)] +struct ChunkEncoderCsidState { + prev_timestamp: u32, + prev_msg_length: u32, + prev_msg_type_id: u8, + prev_stream_id: u32, + prev_timestamp_delta: u32, + initialized: bool, +} + +/// Encodes RTMP messages into chunked wire format. +struct ChunkEncoder { + chunk_size: u32, + csid_states: HashMap, +} + +impl ChunkEncoder { + fn new() -> Self { + Self { chunk_size: 128, csid_states: HashMap::new() } + } + + const fn set_chunk_size(&mut self, size: u32) { + self.chunk_size = size; + } + + /// Encode a complete RTMP message into chunks, appending to `out`. + #[allow(clippy::cast_possible_truncation)] + fn encode_message(&mut self, msg: &OutboundMessage, out: &mut Vec) { + let payload_len = msg.payload.len() as u32; + let state = self.csid_states.entry(msg.csid).or_default(); + + // Determine fmt and compute the timestamp / delta. + let (fmt, timestamp_field) = if !state.initialized || msg.stream_id != state.prev_stream_id + { + // fmt=0: full header. + (0u8, msg.timestamp) + } else { + let delta = msg.timestamp.wrapping_sub(state.prev_timestamp); + if payload_len == state.prev_msg_length && msg.msg_type_id == state.prev_msg_type_id { + if delta == state.prev_timestamp_delta { + // fmt=3: all fields match including delta. + (3u8, delta) + } else { + // fmt=2: only timestamp delta differs. + (2u8, delta) + } + } else { + // fmt=1: stream_id matches, but length/type differ. + (1u8, delta) + } + }; + + // Update state. + if fmt == 0 || fmt == 1 { + state.prev_timestamp_delta = if fmt == 0 { msg.timestamp } else { timestamp_field }; + } else if fmt == 2 { + state.prev_timestamp_delta = timestamp_field; + } + state.prev_timestamp = msg.timestamp; + state.prev_msg_length = payload_len; + state.prev_msg_type_id = msg.msg_type_id; + state.prev_stream_id = msg.stream_id; + state.initialized = true; + + let extended = timestamp_field >= 0x00FF_FFFF; + let ts_wire = if extended { 0x00FF_FFFFu32 } else { timestamp_field }; + + // Write the first chunk header. + encode_basic_header(fmt, msg.csid, out); + encode_message_header(fmt, ts_wire, payload_len, msg.msg_type_id, msg.stream_id, out); + if extended { + out.extend_from_slice(×tamp_field.to_be_bytes()); + } + + // Write payload, splitting at chunk_size boundaries. + let chunk_size = self.chunk_size as usize; + let payload = &msg.payload; + let first_chunk = payload.len().min(chunk_size); + out.extend_from_slice(&payload[..first_chunk]); + + let mut offset = first_chunk; + while offset < payload.len() { + // Continuation chunk: fmt=3 header. + encode_basic_header(3, msg.csid, out); + if extended { + out.extend_from_slice(×tamp_field.to_be_bytes()); + } + let end = (offset + chunk_size).min(payload.len()); + out.extend_from_slice(&payload[offset..end]); + offset = end; + } + } +} + +/// Encode the basic header (fmt + csid). +fn encode_basic_header(fmt: u8, csid: u16, out: &mut Vec) { + let fmt_bits = fmt << 6; + if csid < 64 { + #[allow(clippy::cast_possible_truncation)] + out.push(fmt_bits | (csid as u8)); + } else if csid < 320 { + out.push(fmt_bits); // csid field = 0 → 2-byte form + #[allow(clippy::cast_possible_truncation)] + out.push((csid - 64) as u8); + } else { + out.push(fmt_bits | 1); // csid field = 1 → 3-byte form + let val = csid - 64; + #[allow(clippy::cast_possible_truncation)] + { + out.push(val as u8); + out.push((val >> 8) as u8); + } + } +} + +/// Encode the message header portion based on fmt. +fn encode_message_header( + fmt: u8, + ts_wire: u32, + msg_length: u32, + msg_type_id: u8, + stream_id: u32, + out: &mut Vec, +) { + match fmt { + 0 => { + // 11 bytes: timestamp(3) + msg_length(3) + msg_type_id(1) + stream_id(4 LE) + out.extend_from_slice(&ts_wire.to_be_bytes()[1..4]); // 3 bytes + out.extend_from_slice(&msg_length.to_be_bytes()[1..4]); // 3 bytes + out.push(msg_type_id); + out.extend_from_slice(&stream_id.to_le_bytes()); // 4 bytes LE + }, + 1 => { + // 7 bytes: timestamp_delta(3) + msg_length(3) + msg_type_id(1) + out.extend_from_slice(&ts_wire.to_be_bytes()[1..4]); + out.extend_from_slice(&msg_length.to_be_bytes()[1..4]); + out.push(msg_type_id); + }, + 2 => { + // 3 bytes: timestamp_delta(3) + out.extend_from_slice(&ts_wire.to_be_bytes()[1..4]); + }, + // fmt=3 and any other value: 0 bytes. + _ => {}, + } +} + +// --------------------------------------------------------------------------- +// Chunk Decoder +// --------------------------------------------------------------------------- + +/// Per-csid state for inbound chunk reassembly. +#[derive(Default, Clone)] +struct ChunkDecoderCsidState { + timestamp: u32, + msg_length: u32, + msg_type_id: u8, + stream_id: u32, + timestamp_delta: u32, + payload: Vec, + bytes_remaining: u32, + has_prev: bool, +} + +/// Decodes the chunked wire format into complete RTMP messages. +struct ChunkDecoder { + chunk_size: u32, + csid_states: HashMap, + buf: Vec, +} + +impl ChunkDecoder { + fn new() -> Self { + Self { chunk_size: 128, csid_states: HashMap::new(), buf: Vec::with_capacity(8192) } + } + + const fn set_chunk_size(&mut self, size: u32) { + self.chunk_size = size; + } + + fn push(&mut self, data: &[u8]) { + self.buf.extend_from_slice(data); + } + + /// Try to decode the next complete message from the buffer. + /// + /// Internally loops over continuation chunks so that a multi-chunk + /// message whose chunks are all present in the buffer is fully + /// reassembled in a single call. Returns `Ok(None)` only when the + /// buffer is empty or contains an incomplete chunk. + #[allow(clippy::cast_possible_truncation)] + fn decode_message(&mut self) -> Result, Error> { + loop { + if self.buf.is_empty() { + return Ok(None); + } + + let mut pos = 0; + + // ── Basic header ──────────────────────────────────────── + if pos >= self.buf.len() { + return Ok(None); + } + let first_byte = self.buf[pos]; + pos += 1; + let fmt = first_byte >> 6; + let csid_low = first_byte & 0x3F; + + let csid: u16 = match csid_low { + 0 => { + // 2-byte form. + if pos >= self.buf.len() { + return Ok(None); + } + let c = u16::from(self.buf[pos]) + 64; + pos += 1; + c + }, + 1 => { + // 3-byte form. + if pos + 1 >= self.buf.len() { + return Ok(None); + } + let c = u16::from(self.buf[pos]) + u16::from(self.buf[pos + 1]) * 256 + 64; + pos += 2; + c + }, + _ => u16::from(csid_low), + }; + + // ── Message header ────────────────────────────────────── + let header_len: usize = match fmt { + 0 => 11, + 1 => 7, + 2 => 3, + 3 => 0, + _ => return Err(Error::new(format!("Invalid chunk fmt: {fmt}"))), + }; + + if pos + header_len > self.buf.len() { + return Ok(None); // need more data + } + + let state = self.csid_states.entry(csid).or_default(); + + match fmt { + 0 => { + let ts = u32::from(self.buf[pos]) << 16 + | u32::from(self.buf[pos + 1]) << 8 + | u32::from(self.buf[pos + 2]); + let ml = u32::from(self.buf[pos + 3]) << 16 + | u32::from(self.buf[pos + 4]) << 8 + | u32::from(self.buf[pos + 5]); + let mt = self.buf[pos + 6]; + let si = u32::from(self.buf[pos + 7]) + | u32::from(self.buf[pos + 8]) << 8 + | u32::from(self.buf[pos + 9]) << 16 + | u32::from(self.buf[pos + 10]) << 24; + pos += 11; + state.timestamp = ts; + state.msg_length = ml; + state.msg_type_id = mt; + state.stream_id = si; + state.timestamp_delta = ts; // for fmt=0, delta equals timestamp + }, + 1 => { + let td = u32::from(self.buf[pos]) << 16 + | u32::from(self.buf[pos + 1]) << 8 + | u32::from(self.buf[pos + 2]); + let ml = u32::from(self.buf[pos + 3]) << 16 + | u32::from(self.buf[pos + 4]) << 8 + | u32::from(self.buf[pos + 5]); + let mt = self.buf[pos + 6]; + pos += 7; + state.timestamp_delta = td; + if state.has_prev { + state.timestamp = state.timestamp.wrapping_add(td); + } else { + state.timestamp = td; + } + state.msg_length = ml; + state.msg_type_id = mt; + // stream_id inherited + }, + 2 => { + let td = u32::from(self.buf[pos]) << 16 + | u32::from(self.buf[pos + 1]) << 8 + | u32::from(self.buf[pos + 2]); + pos += 3; + state.timestamp_delta = td; + if state.has_prev { + state.timestamp = state.timestamp.wrapping_add(td); + } else { + state.timestamp = td; + } + // msg_length, msg_type_id, stream_id inherited + }, + 3 => { + // All inherited. Apply delta for continuation of a new message + // (not a continuation chunk of the same message). + if state.bytes_remaining == 0 && state.has_prev { + state.timestamp = state.timestamp.wrapping_add(state.timestamp_delta); + } + }, + _ => unreachable!(), + } + + // Extended timestamp. + let is_extended = if fmt == 0 { + state.timestamp == 0x00FF_FFFF + } else { + state.timestamp_delta == 0x00FF_FFFF + }; + + if is_extended { + if pos + 4 > self.buf.len() { + return Ok(None); + } + let ext = u32::from_be_bytes([ + self.buf[pos], + self.buf[pos + 1], + self.buf[pos + 2], + self.buf[pos + 3], + ]); + pos += 4; + state.timestamp = if fmt == 0 { + ext + } else { + // For fmt 1/2/3 with extended timestamp, the ext field + // replaces the delta. + state.timestamp.wrapping_sub(state.timestamp_delta).wrapping_add(ext) + }; + state.timestamp_delta = ext; + } + + // ── Payload ───────────────────────────────────────────── + // If bytes_remaining == 0, this is the first chunk of a new message. + if state.bytes_remaining == 0 { + state.payload.clear(); + state.bytes_remaining = state.msg_length; + } + + let chunk_data_len = (state.bytes_remaining).min(self.chunk_size) as usize; + if pos + chunk_data_len > self.buf.len() { + return Ok(None); // need more data + } + + state.payload.extend_from_slice(&self.buf[pos..pos + chunk_data_len]); + state.bytes_remaining -= chunk_data_len as u32; + pos += chunk_data_len; + + // Consume the bytes we've processed. + self.buf.drain(..pos); + + // Check if the message is complete. + if state.bytes_remaining == 0 { + state.has_prev = true; + let msg = InboundMessage { + #[cfg(test)] + timestamp: state.timestamp, + msg_type_id: state.msg_type_id, + #[cfg(test)] + stream_id: state.stream_id, + payload: std::mem::take(&mut state.payload), + }; + return Ok(Some(msg)); + } + // Message not yet fully assembled — loop back to try the + // next continuation chunk from the buffer. + } + } +} + +// --------------------------------------------------------------------------- +// Handshake +// --------------------------------------------------------------------------- + +/// Client-side RTMP handshake state machine. +struct Handshake { + state: HandshakeState, + recv_buf: Vec, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum HandshakeState { + WaitingForS0S1, + WaitingForS2, + Complete, +} + +const HANDSHAKE_SIZE: usize = 1536; + +impl Handshake { + /// Create a new handshake and return `(self, c0c1_bytes)`. + /// + /// C0 = version byte (0x03). + /// C1 = 1536 bytes: timestamp(4) + zero(4) + random(1528). + fn new() -> (Self, Vec) { + let mut c1 = vec![0u8; HANDSHAKE_SIZE]; + // Timestamp = 0 (first 4 bytes already zero). + // Version = 0 (next 4 bytes already zero). + // Random data for bytes 8..1536. + fill_random(&mut c1[8..]); + + let mut c0c1 = Vec::with_capacity(1 + HANDSHAKE_SIZE); + c0c1.push(0x03); // RTMP version + c0c1.extend_from_slice(&c1); + + ( + Self { + state: HandshakeState::WaitingForS0S1, + recv_buf: Vec::with_capacity(1 + HANDSHAKE_SIZE * 2), + }, + c0c1, + ) + } + + /// Feed received bytes from the server. + /// + /// Returns: + /// - `None` — need more data. + /// - `Some((c2, leftover))` — handshake complete, send C2. `leftover` + /// contains any post-S2 bytes that arrived in the same TCP segment + /// and must be forwarded to the chunk decoder. + fn feed(&mut self, data: &[u8]) -> Option<(Vec, Vec)> { + self.recv_buf.extend_from_slice(data); + + match self.state { + HandshakeState::WaitingForS0S1 => { + // Need S0 (1 byte) + S1 (1536 bytes) = 1537 bytes. + if self.recv_buf.len() < 1 + HANDSHAKE_SIZE { + return None; + } + + let s0 = self.recv_buf[0]; + if s0 != 0x03 { + tracing::warn!(s0, "RTMP server version is not 3, continuing anyway"); + } + + // S1 is bytes 1..1537. We'll need it for C2. + // Move to waiting for S2. + self.state = HandshakeState::WaitingForS2; + + // Check if S2 is also already here. + if self.recv_buf.len() > HANDSHAKE_SIZE * 2 { + return Some(self.complete_handshake()); + } + None + }, + HandshakeState::WaitingForS2 => { + if self.recv_buf.len() <= HANDSHAKE_SIZE * 2 { + return None; + } + Some(self.complete_handshake()) + }, + HandshakeState::Complete => None, + } + } + + /// Validate S2 and produce C2, returning any leftover bytes. + fn complete_handshake(&mut self) -> (Vec, Vec) { + // C2 = echo of S1 (bytes 1..=HANDSHAKE_SIZE of recv_buf). + let s1 = &self.recv_buf[1..=HANDSHAKE_SIZE]; + let c2 = s1.to_vec(); + + // Bytes beyond S0(1) + S1(1536) + S2(1536) = 3073 are + // post-handshake protocol messages (e.g. WinAckSize, + // SetPeerBandwidth) that the server pipelined in the same + // TCP segment. Return them so the caller can forward them + // to the chunk decoder. + let handshake_total = 1 + HANDSHAKE_SIZE * 2; + let leftover = if self.recv_buf.len() > handshake_total { + self.recv_buf[handshake_total..].to_vec() + } else { + Vec::new() + }; + + self.state = HandshakeState::Complete; + // Free the receive buffer — no longer needed. + self.recv_buf = Vec::new(); + + (c2, leftover) + } +} + +/// Fill a buffer with pseudo-random bytes. +/// +/// Uses a simple xorshift64 PRNG seeded from the current timestamp to avoid +/// all-zero handshakes (which some servers may fingerprint). Cryptographic +/// strength is not required here. +fn fill_random(buf: &mut [u8]) { + // Seed from the current time. We mix in a fixed constant to avoid + // degenerate seeds (e.g. zero). + let mut state: u64 = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).map_or( + 0x517E_A45D_1234_5678, + |d| { + #[allow(clippy::cast_possible_truncation)] + // Truncation is intentional: we only need 64 bits of entropy for a PRNG seed. + let nanos = d.as_nanos() as u64; + nanos + }, + ); + if state == 0 { + state = 0x517E_A45D_1234_5678; + } + for byte in buf.iter_mut() { + state ^= state << 13; + state ^= state >> 7; + state ^= state << 17; + #[allow(clippy::cast_possible_truncation)] + { + *byte = state as u8; + } + } +} + +// --------------------------------------------------------------------------- +// Connection State +// --------------------------------------------------------------------------- + +/// RTMP connection states (publish-client subset). +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(super) enum RtmpConnectionState { + Handshaking, + Connecting, + Connected, + MediaStreamCreated, + PublishPending, + Publishing, + Disconnecting, +} + +impl fmt::Display for RtmpConnectionState { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Handshaking => f.write_str("Handshaking"), + Self::Connecting => f.write_str("Connecting"), + Self::Connected => f.write_str("Connected"), + Self::MediaStreamCreated => f.write_str("MediaStreamCreated"), + Self::PublishPending => f.write_str("PublishPending"), + Self::Publishing => f.write_str("Publishing"), + Self::Disconnecting => f.write_str("Disconnecting"), + } + } +} + +/// Events emitted by the connection state machine. +#[derive(Debug)] +pub(super) enum RtmpConnectionEvent { + StateChanged(RtmpConnectionState), + DisconnectedByPeer { reason: String }, +} + +// --------------------------------------------------------------------------- +// RtmpPublishClientConnection +// --------------------------------------------------------------------------- + +/// Sans-I/O RTMP publish client connection. +/// +/// Manages the full lifecycle from handshake through publish, providing +/// the same API surface as the previous `shiguredo_rtmp` library. +pub(super) struct RtmpPublishClientConnection { + url: RtmpUrl, + state: RtmpConnectionState, + handshake: Option, + encoder: ChunkEncoder, + decoder: ChunkDecoder, + send_buf: Vec, + events: VecDeque, + /// Server-assigned stream ID from createStream `_result`. + media_stream_id: u32, + /// Transaction ID counter for AMF0 commands. + next_transaction_id: f64, + /// Total bytes received (for ACK tracking). + total_bytes_received: u64, + /// Peer's requested ACK window size. + peer_ack_window_size: u32, + /// Byte count at which we last sent an ACK. + last_ack_sent_at: u64, + /// Chunk size we announce to the server. + local_chunk_size: u32, +} + +impl fmt::Debug for RtmpPublishClientConnection { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("RtmpPublishClientConnection") + .field("state", &self.state) + .field("url", &self.url.to_string()) + .finish_non_exhaustive() + } +} + +impl RtmpPublishClientConnection { + /// The chunk size we announce to the server (4096 bytes — matches + /// OBS/FFmpeg; the RTMP default of 128 is too small for video). + const LOCAL_CHUNK_SIZE: u32 = 4096; + + /// Maximum send buffer size (8 MB). If the TCP socket stalls and + /// the buffer exceeds this, we refuse to enqueue more media so the + /// caller can detect backpressure and disconnect gracefully. + const MAX_SEND_BUF: usize = 8 * 1024 * 1024; + + /// Create a new RTMP publish client. C0+C1 are queued in the send + /// buffer immediately. + pub fn new(url: RtmpUrl) -> Self { + let (handshake, c0c1) = Handshake::new(); + + Self { + url, + state: RtmpConnectionState::Handshaking, + handshake: Some(handshake), + encoder: ChunkEncoder::new(), + decoder: ChunkDecoder::new(), + send_buf: c0c1, + events: VecDeque::new(), + media_stream_id: 0, + next_transaction_id: 1.0, + total_bytes_received: 0, + peer_ack_window_size: 0, + last_ack_sent_at: 0, + local_chunk_size: Self::LOCAL_CHUNK_SIZE, + } + } + + /// Feed received bytes from the server. Drives the state machine + /// (handshake → connect → createStream → publish). + pub fn feed_recv_buf(&mut self, buf: &[u8]) -> Result<(), Error> { + // ── Handshake phase ───────────────────────────────────────── + // ACK sequence numbers are based on post-handshake bytes only + // (RTMP spec §5.4), so we defer the counter increment. + if let Some(ref mut hs) = self.handshake { + if let Some((c2, leftover)) = hs.feed(buf) { + self.send_buf.extend_from_slice(&c2); + + // Handshake complete — send the RTMP connect sequence. + self.handshake = None; + self.send_connect_sequence()?; + + // Forward any post-S2 bytes (e.g. WinAckSize, + // SetPeerBandwidth pipelined in the same TCP segment) + // to the chunk decoder so they aren't silently lost. + if !leftover.is_empty() { + self.total_bytes_received += leftover.len() as u64; + self.decoder.push(&leftover); + while let Some(msg) = self.decoder.decode_message()? { + self.handle_message(&msg)?; + } + self.maybe_send_ack(); + } + return Ok(()); + } + // Still handshaking, need more data. + return Ok(()); + } + + // ── Post-handshake: decode chunks ─────────────────────────── + self.total_bytes_received += buf.len() as u64; + self.decoder.push(buf); + while let Some(msg) = self.decoder.decode_message()? { + self.handle_message(&msg)?; + } + + // ── ACK tracking ──────────────────────────────────────────── + self.maybe_send_ack(); + + Ok(()) + } + + /// Bytes waiting to be sent to the server. + pub fn send_buf(&self) -> &[u8] { + &self.send_buf + } + + /// Mark `n` bytes as sent. + pub fn advance_send_buf(&mut self, n: usize) { + self.send_buf.drain(..n); + } + + /// Current connection state. + pub const fn state(&self) -> RtmpConnectionState { + self.state + } + + /// Send a video frame (only valid in `Publishing` state). + #[allow(clippy::cast_possible_truncation)] + pub fn send_video(&mut self, frame: &VideoFrame) -> Result<(), Error> { + if self.state != RtmpConnectionState::Publishing { + return Err(Error::new(format!("Cannot send video in state {}", self.state))); + } + if self.send_buf.len() > Self::MAX_SEND_BUF { + return Err(Error::new(format!( + "Send buffer exceeded {} bytes — backpressure (TCP stall?)", + Self::MAX_SEND_BUF + ))); + } + + // Build the FLV video tag payload. + let mut payload = Vec::with_capacity(5 + frame.data.len()); + + // FLV video header byte: frame_type(4 bits) | codec_id(4 bits) + let frame_type_nibble: u8 = match frame.frame_type { + VideoFrameType::KeyFrame => 1, + VideoFrameType::InterFrame => 2, + }; + let codec_nibble: u8 = match frame.codec { + VideoCodec::Avc => 7, + }; + payload.push((frame_type_nibble << 4) | codec_nibble); + + // AVC packet type + composition time offset (3 bytes, signed 24-bit) + if let Some(ref pkt_type) = frame.avc_packet_type { + payload.push(match pkt_type { + AvcPacketType::SequenceHeader => 0, + AvcPacketType::NalUnit => 1, + }); + let cto = frame.composition_timestamp_offset.millis(); + let cto_bytes = cto.to_be_bytes(); + // 24-bit signed: take lower 3 bytes of i32 + payload.extend_from_slice(&cto_bytes[1..4]); + } + + payload.extend_from_slice(&frame.data); + + let msg = OutboundMessage { + csid: csid_for_stream(self.media_stream_id), + timestamp: frame.timestamp.millis(), + msg_type_id: MSG_VIDEO, + stream_id: self.media_stream_id, + payload, + }; + self.encoder.encode_message(&msg, &mut self.send_buf); + Ok(()) + } + + /// Send an audio frame (only valid in `Publishing` state). + pub fn send_audio(&mut self, frame: &AudioFrame) -> Result<(), Error> { + if self.state != RtmpConnectionState::Publishing { + return Err(Error::new(format!("Cannot send audio in state {}", self.state))); + } + if self.send_buf.len() > Self::MAX_SEND_BUF { + return Err(Error::new(format!( + "Send buffer exceeded {} bytes — backpressure (TCP stall?)", + Self::MAX_SEND_BUF + ))); + } + + // Build the FLV audio tag payload. + let mut payload = Vec::with_capacity(2 + frame.data.len()); + + // FLV audio header byte: + // soundFormat(4) | soundRate(2) | soundSize(1) | soundType(1) + let format_nibble: u8 = match frame.format { + AudioFormat::Aac => 10, + }; + let rate_bits: u8 = match frame.sample_rate { + AudioSampleRate::Khz44 => 3, // 44 kHz + }; + let size_bit: u8 = u8::from(!frame.is_8bit_sample); // 0=8bit, 1=16bit + let type_bit: u8 = u8::from(frame.is_stereo); + payload.push((format_nibble << 4) | (rate_bits << 2) | (size_bit << 1) | type_bit); + + // AAC packet type: 0 = sequence header, 1 = raw + if matches!(frame.format, AudioFormat::Aac) { + payload.push(u8::from(!frame.is_aac_sequence_header)); + } + + payload.extend_from_slice(&frame.data); + + let msg = OutboundMessage { + csid: csid_for_stream(self.media_stream_id), + timestamp: frame.timestamp.millis(), + msg_type_id: MSG_AUDIO, + stream_id: self.media_stream_id, + payload, + }; + self.encoder.encode_message(&msg, &mut self.send_buf); + Ok(()) + } + + /// Retrieve the next event, if any. + pub fn next_event(&mut self) -> Option { + self.events.pop_front() + } + + // ------------------------------------------------------------------- + // Internal: connect sequence + // ------------------------------------------------------------------- + + /// Send the initial RTMP connect command sequence after handshake. + fn send_connect_sequence(&mut self) -> Result<(), Error> { + // 1. WinAckSize (server should ACK every 2.5 MB). + self.send_protocol_message(MSG_WIN_ACK_SIZE, &2_500_000u32.to_be_bytes()); + + // 2. SetChunkSize. + self.send_protocol_message(MSG_SET_CHUNK_SIZE, &self.local_chunk_size.to_be_bytes()); + self.encoder.set_chunk_size(self.local_chunk_size); + + // 3. connect command. + let tid = self.next_tid(); + let tc_url = self.url.tc_url(); + let app = self.url.app.clone(); + + let mut payload = Vec::with_capacity(256); + amf0_encode(&Amf0Value::String("connect".to_string()), &mut payload)?; + amf0_encode(&Amf0Value::Number(tid), &mut payload)?; + amf0_encode( + &Amf0Value::Object(vec![ + ("app".to_string(), Amf0Value::String(app)), + ("type".to_string(), Amf0Value::String("nonprivate".to_string())), + ("flashVer".to_string(), Amf0Value::String("FMLE/3.0".to_string())), + ("tcUrl".to_string(), Amf0Value::String(tc_url)), + ]), + &mut payload, + )?; + + let msg = OutboundMessage { + csid: CSID_COMMAND, + timestamp: 0, + msg_type_id: MSG_COMMAND_AMF0, + stream_id: 0, + payload, + }; + self.encoder.encode_message(&msg, &mut self.send_buf); + + self.set_state(RtmpConnectionState::Connecting); + Ok(()) + } + + // ------------------------------------------------------------------- + // Internal: message handling + // ------------------------------------------------------------------- + + /// Handle a fully assembled inbound RTMP message. + fn handle_message(&mut self, msg: &InboundMessage) -> Result<(), Error> { + match msg.msg_type_id { + MSG_SET_CHUNK_SIZE => { + if msg.payload.len() >= 4 { + let size = u32::from_be_bytes([ + msg.payload[0], + msg.payload[1], + msg.payload[2], + msg.payload[3], + ]) & 0x7FFF_FFFF; // high bit must be 0 + tracing::debug!(chunk_size = size, "Server SetChunkSize"); + self.decoder.set_chunk_size(size); + } + }, + MSG_ABORT => { + // Abort message for a chunk stream — clear partial state. + if msg.payload.len() >= 4 { + let abort_csid = u32::from_be_bytes([ + msg.payload[0], + msg.payload[1], + msg.payload[2], + msg.payload[3], + ]); + #[allow(clippy::cast_possible_truncation)] + let csid = abort_csid as u16; + if let Some(state) = self.decoder.csid_states.get_mut(&csid) { + state.payload.clear(); + state.bytes_remaining = 0; + } + } + }, + MSG_ACK => { + // Server acknowledgement — we don't enforce ACK windows + // on the send side, so just log it. + tracing::debug!("Server ACK received"); + }, + MSG_USER_CONTROL => self.handle_user_control(&msg.payload), + MSG_WIN_ACK_SIZE => { + if msg.payload.len() >= 4 { + let size = u32::from_be_bytes([ + msg.payload[0], + msg.payload[1], + msg.payload[2], + msg.payload[3], + ]); + tracing::debug!(window_size = size, "Server WinAckSize"); + self.peer_ack_window_size = size; + } + }, + MSG_SET_PEER_BANDWIDTH => { + if msg.payload.len() >= 5 { + let size = u32::from_be_bytes([ + msg.payload[0], + msg.payload[1], + msg.payload[2], + msg.payload[3], + ]); + tracing::debug!( + window_size = size, + limit_type = msg.payload[4], + "Server SetPeerBandwidth" + ); + // Respond with WinAckSize to acknowledge. + self.send_protocol_message(MSG_WIN_ACK_SIZE, &size.to_be_bytes()); + self.peer_ack_window_size = size; + } + }, + MSG_COMMAND_AMF0 => self.handle_command(&msg.payload)?, + MSG_AUDIO | MSG_VIDEO => { + // We're a publisher, not a subscriber — ignore inbound media. + }, + _ => { + tracing::debug!(msg_type = msg.msg_type_id, "Ignoring unknown RTMP message type"); + }, + } + Ok(()) + } + + /// Handle a User Control event message (type 4). + fn handle_user_control(&mut self, payload: &[u8]) { + if payload.len() < 2 { + return; + } + let event_type = u16::from_be_bytes([payload[0], payload[1]]); + + match event_type { + UC_STREAM_BEGIN => { + tracing::debug!("User control: StreamBegin"); + }, + UC_STREAM_EOF => { + tracing::debug!("User control: StreamEof"); + }, + UC_PING_REQUEST => { + // Respond with PingResponse (event type 7). + if payload.len() >= 6 { + let mut response = Vec::with_capacity(6); + response.extend_from_slice(&7u16.to_be_bytes()); // PingResponse + response.extend_from_slice(&payload[2..6]); // echo timestamp + self.send_protocol_message(MSG_USER_CONTROL, &response); + tracing::debug!("Responded to PingRequest"); + } + }, + _ => { + tracing::debug!(event_type, "User control event ignored"); + }, + } + } + + /// Handle an AMF0 command message. + fn handle_command(&mut self, payload: &[u8]) -> Result<(), Error> { + // Decode command name. + let (name_val, mut offset) = amf0_decode(payload)?; + let name = match &name_val { + Amf0Value::String(s) => s.as_str(), + _ => return Ok(()), // not a command + }; + + // Decode transaction ID. + let (tid_val, consumed) = amf0_decode(&payload[offset..])?; + offset += consumed; + let _tid = match &tid_val { + Amf0Value::Number(n) => *n, + _ => 0.0, + }; + + match name { + "_result" => self.handle_result(&payload[offset..])?, + "_error" => self.handle_error(&payload[offset..]), + "onStatus" => self.handle_on_status(&payload[offset..]), + _ => { + tracing::debug!(command = name, "Ignoring unknown RTMP command"); + }, + } + + Ok(()) + } + + /// Handle a `_result` response. + #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] + fn handle_result(&mut self, payload: &[u8]) -> Result<(), Error> { + match self.state { + RtmpConnectionState::Connecting => { + // connect _result — success. + tracing::info!("RTMP connect succeeded"); + self.set_state(RtmpConnectionState::Connected); + + // Send createStream. + let tid = self.next_tid(); + let mut cmd_payload = Vec::with_capacity(32); + amf0_encode(&Amf0Value::String("createStream".to_string()), &mut cmd_payload)?; + amf0_encode(&Amf0Value::Number(tid), &mut cmd_payload)?; + amf0_encode(&Amf0Value::Null, &mut cmd_payload)?; + + let msg = OutboundMessage { + csid: CSID_COMMAND, + timestamp: 0, + msg_type_id: MSG_COMMAND_AMF0, + stream_id: 0, + payload: cmd_payload, + }; + self.encoder.encode_message(&msg, &mut self.send_buf); + }, + RtmpConnectionState::Connected => { + // createStream _result — extract stream ID. + // The result payload is: Null (properties) + Number (stream_id). + let mut off = 0; + // Skip the Null/Object properties field. + if !payload.is_empty() { + let (_, consumed) = amf0_decode(payload)?; + off += consumed; + } + // Read the stream ID. + if off < payload.len() { + let (val, _) = amf0_decode(&payload[off..])?; + if let Amf0Value::Number(n) = val { + self.media_stream_id = n as u32; + tracing::info!(stream_id = self.media_stream_id, "createStream succeeded"); + } + } + + self.set_state(RtmpConnectionState::MediaStreamCreated); + + // Send publish command on the media stream's csid. + let tid = self.next_tid(); + let stream_name = self.url.stream_name.clone(); + let mut cmd_payload = Vec::with_capacity(64); + amf0_encode(&Amf0Value::String("publish".to_string()), &mut cmd_payload)?; + amf0_encode(&Amf0Value::Number(tid), &mut cmd_payload)?; + amf0_encode(&Amf0Value::Null, &mut cmd_payload)?; + amf0_encode(&Amf0Value::String(stream_name), &mut cmd_payload)?; + amf0_encode(&Amf0Value::String("live".to_string()), &mut cmd_payload)?; + + let msg = OutboundMessage { + csid: csid_for_stream(self.media_stream_id), + timestamp: 0, + msg_type_id: MSG_COMMAND_AMF0, + stream_id: self.media_stream_id, + payload: cmd_payload, + }; + self.encoder.encode_message(&msg, &mut self.send_buf); + + self.set_state(RtmpConnectionState::PublishPending); + }, + _ => { + tracing::debug!(state = %self.state, "Unexpected _result"); + }, + } + Ok(()) + } + + /// Handle a `_error` response. + fn handle_error(&mut self, payload: &[u8]) { + // Try to extract a description. + let desc = extract_info_description(payload).unwrap_or_else(|| "unknown error".to_string()); + tracing::warn!(description = %desc, state = %self.state, "RTMP _error"); + self.events.push_back(RtmpConnectionEvent::DisconnectedByPeer { reason: desc }); + self.set_state(RtmpConnectionState::Disconnecting); + } + + /// Handle an `onStatus` notification. + fn handle_on_status(&mut self, payload: &[u8]) { + // Skip Null (command object), then decode the info object. + let mut off = 0; + if !payload.is_empty() { + if let Ok((_, consumed)) = amf0_decode(payload) { + off += consumed; + } + } + + let code = if off < payload.len() { + if let Ok((val, _)) = amf0_decode(&payload[off..]) { + extract_object_field(&val, "code") + } else { + None + } + } else { + None + }; + + let code_str = code.as_deref().unwrap_or(""); + tracing::info!(code = code_str, state = %self.state, "onStatus"); + + match code_str { + "NetStream.Publish.Start" => { + self.set_state(RtmpConnectionState::Publishing); + }, + s if s.contains("Error") || s.contains("Failed") || s.contains("Rejected") => { + let desc = + extract_info_description(payload).unwrap_or_else(|| code_str.to_string()); + self.events.push_back(RtmpConnectionEvent::DisconnectedByPeer { reason: desc }); + self.set_state(RtmpConnectionState::Disconnecting); + }, + _ => { + // Other status codes (e.g. NetStream.Play.Start) — ignore. + }, + } + } + + // ------------------------------------------------------------------- + // Internal: helpers + // ------------------------------------------------------------------- + + /// Set state and emit a `StateChanged` event. + fn set_state(&mut self, new_state: RtmpConnectionState) { + if self.state != new_state { + tracing::debug!(from = %self.state, to = %new_state, "RTMP state transition"); + self.state = new_state; + self.events.push_back(RtmpConnectionEvent::StateChanged(new_state)); + } + } + + /// Allocate the next transaction ID. + fn next_tid(&mut self) -> f64 { + let tid = self.next_transaction_id; + self.next_transaction_id += 1.0; + tid + } + + /// Send a protocol control message on csid=2, stream_id=0. + fn send_protocol_message(&mut self, msg_type_id: u8, payload: &[u8]) { + let msg = OutboundMessage { + csid: CSID_PROTOCOL_CONTROL, + timestamp: 0, + msg_type_id, + stream_id: 0, + payload: payload.to_vec(), + }; + self.encoder.encode_message(&msg, &mut self.send_buf); + } + + /// Send an ACK if we've received enough bytes since the last one. + fn maybe_send_ack(&mut self) { + if self.peer_ack_window_size == 0 { + return; + } + let since_last = self.total_bytes_received - self.last_ack_sent_at; + if since_last >= u64::from(self.peer_ack_window_size) { + #[allow(clippy::cast_possible_truncation)] + let seq = self.total_bytes_received as u32; + self.send_protocol_message(MSG_ACK, &seq.to_be_bytes()); + self.last_ack_sent_at = self.total_bytes_received; + } + } +} + +/// Extract the "description" field from an AMF0 info object payload. +/// +/// The payload typically starts with Null (command object) then an Object +/// containing `code`, `level`, `description` fields. +fn extract_info_description(payload: &[u8]) -> Option { + let mut off = 0; + // Skip Null/command object. + if !payload.is_empty() { + let (_, consumed) = amf0_decode(payload).ok()?; + off += consumed; + } + if off >= payload.len() { + return None; + } + let (val, _) = amf0_decode(&payload[off..]).ok()?; + extract_object_field(&val, "description") +} + +/// Extract a string field from an AMF0 Object value. +fn extract_object_field(val: &Amf0Value, field: &str) -> Option { + if let Amf0Value::Object(props) = val { + for (key, v) in props { + if key == field { + if let Amf0Value::String(s) = v { + return Some(s.clone()); + } + } + } + } + None +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + + // ── URL parsing ───────────────────────────────────────────────── + + #[test] + fn parse_rtmp_url_basic() { + let url = RtmpUrl::parse("rtmp://live.example.com/app/stream_key").unwrap(); + assert_eq!(url.host, "live.example.com"); + assert_eq!(url.port, 1935); + assert_eq!(url.app, "app"); + assert_eq!(url.stream_name, "stream_key"); + assert!(!url.tls); + } + + #[test] + fn parse_rtmps_url_with_port() { + let url = RtmpUrl::parse("rtmps://live.twitch.tv:8443/app/key").unwrap(); + assert_eq!(url.host, "live.twitch.tv"); + assert_eq!(url.port, 8443); + assert_eq!(url.app, "app"); + assert_eq!(url.stream_name, "key"); + assert!(url.tls); + } + + #[test] + fn parse_rtmps_default_port() { + let url = RtmpUrl::parse("rtmps://live.twitch.tv/app/key").unwrap(); + assert_eq!(url.port, 443); + } + + #[test] + fn parse_rtmp_multi_segment_path() { + let url = RtmpUrl::parse("rtmp://host/live/extra/stream_key").unwrap(); + assert_eq!(url.app, "live/extra"); + assert_eq!(url.stream_name, "stream_key"); + } + + #[test] + fn parse_rtmp_single_segment_is_app() { + // When there's only one path segment, it's the app name. + // Stream name will be appended by the caller (resolve_rtmp_url). + let url = RtmpUrl::parse("rtmp://host/live2").unwrap(); + assert_eq!(url.app, "live2"); + assert_eq!(url.stream_name, ""); + } + + #[test] + fn parse_rtmp_invalid_scheme() { + assert!(RtmpUrl::parse("http://host/app/key").is_err()); + } + + #[test] + fn parse_rtmp_empty_host() { + assert!(RtmpUrl::parse("rtmp:///app/key").is_err()); + } + + #[test] + fn tc_url_omits_default_port() { + let url = RtmpUrl::parse("rtmp://a.rtmp.youtube.com/live2/key").unwrap(); + assert_eq!(url.tc_url(), "rtmp://a.rtmp.youtube.com/live2"); + } + + #[test] + fn tc_url_includes_custom_port() { + let url = RtmpUrl::parse("rtmp://host:9999/app/key").unwrap(); + assert_eq!(url.tc_url(), "rtmp://host:9999/app"); + } + + #[test] + fn url_from_str_works() { + let url: RtmpUrl = "rtmp://host/app/key".parse().unwrap(); + assert_eq!(url.host, "host"); + } + + // ── AMF0 ──────────────────────────────────────────────────────── + + #[test] + fn amf0_number_roundtrip() { + let val = Amf0Value::Number(42.5); + let mut buf = Vec::new(); + amf0_encode(&val, &mut buf).unwrap(); + let (decoded, consumed) = amf0_decode(&buf).unwrap(); + assert_eq!(decoded, val); + assert_eq!(consumed, buf.len()); + } + + #[test] + fn amf0_string_roundtrip() { + let val = Amf0Value::String("hello RTMP".to_string()); + let mut buf = Vec::new(); + amf0_encode(&val, &mut buf).unwrap(); + let (decoded, consumed) = amf0_decode(&buf).unwrap(); + assert_eq!(decoded, val); + assert_eq!(consumed, buf.len()); + } + + #[test] + fn amf0_boolean_roundtrip() { + for b in [true, false] { + let val = Amf0Value::Boolean(b); + let mut buf = Vec::new(); + amf0_encode(&val, &mut buf).unwrap(); + let (decoded, consumed) = amf0_decode(&buf).unwrap(); + assert_eq!(decoded, val); + assert_eq!(consumed, buf.len()); + } + } + + #[test] + fn amf0_null_roundtrip() { + let val = Amf0Value::Null; + let mut buf = Vec::new(); + amf0_encode(&val, &mut buf).unwrap(); + let (decoded, consumed) = amf0_decode(&buf).unwrap(); + assert_eq!(decoded, val); + assert_eq!(consumed, buf.len()); + } + + #[test] + fn amf0_object_roundtrip() { + let val = Amf0Value::Object(vec![ + ("app".to_string(), Amf0Value::String("live".to_string())), + ("version".to_string(), Amf0Value::Number(3.0)), + ("flag".to_string(), Amf0Value::Boolean(true)), + ]); + let mut buf = Vec::new(); + amf0_encode(&val, &mut buf).unwrap(); + let (decoded, consumed) = amf0_decode(&buf).unwrap(); + assert_eq!(decoded, val); + assert_eq!(consumed, buf.len()); + } + + // ── Chunk encoder ─────────────────────────────────────────────── + + #[test] + fn chunk_encoder_fmt0_basic() { + let mut enc = ChunkEncoder::new(); + let msg = OutboundMessage { + csid: 3, + timestamp: 100, + msg_type_id: MSG_COMMAND_AMF0, + stream_id: 0, + payload: vec![0xAA; 10], + }; + let mut out = Vec::new(); + enc.encode_message(&msg, &mut out); + + // Basic header: 1 byte (fmt=0, csid=3). + assert_eq!(out[0], 0x03); // fmt=0 (00) | csid=3 (000011) + // Message header: 11 bytes. + // Total header: 12 bytes + 10 payload = 22 bytes. + assert_eq!(out.len(), 12 + 10); + } + + #[test] + fn chunk_encoder_splits_at_chunk_size() { + let mut enc = ChunkEncoder::new(); + enc.set_chunk_size(10); + let msg = OutboundMessage { + csid: 3, + timestamp: 0, + msg_type_id: MSG_COMMAND_AMF0, + stream_id: 0, + payload: vec![0xBB; 25], // 3 chunks: 10 + 10 + 5 + }; + let mut out = Vec::new(); + enc.encode_message(&msg, &mut out); + + // First chunk: 12 (header) + 10 (data) = 22 + // Second chunk: 1 (fmt=3 header) + 10 (data) = 11 + // Third chunk: 1 (fmt=3 header) + 5 (data) = 6 + assert_eq!(out.len(), 22 + 11 + 6); + } + + #[test] + fn chunk_encoder_fmt_progression() { + let mut enc = ChunkEncoder::new(); + let mut out = Vec::new(); + + // First message: fmt=0 (12 bytes header). + let msg1 = OutboundMessage { + csid: 3, + timestamp: 100, + msg_type_id: MSG_AUDIO, + stream_id: 1, + payload: vec![0; 5], + }; + enc.encode_message(&msg1, &mut out); + assert_eq!(out[0] >> 6, 0); // fmt=0 + + // Second message: same stream_id, different length → fmt=1. + out.clear(); + let msg2 = OutboundMessage { + csid: 3, + timestamp: 120, + msg_type_id: MSG_AUDIO, + stream_id: 1, + payload: vec![0; 10], + }; + enc.encode_message(&msg2, &mut out); + assert_eq!(out[0] >> 6, 1); // fmt=1 + + // Third message: same length/type, different delta → fmt=2. + out.clear(); + let msg3 = OutboundMessage { + csid: 3, + timestamp: 150, + msg_type_id: MSG_AUDIO, + stream_id: 1, + payload: vec![0; 10], + }; + enc.encode_message(&msg3, &mut out); + assert_eq!(out[0] >> 6, 2); // fmt=2 + + // Fourth message: same delta → fmt=3. + out.clear(); + let msg4 = OutboundMessage { + csid: 3, + timestamp: 180, + msg_type_id: MSG_AUDIO, + stream_id: 1, + payload: vec![0; 10], + }; + enc.encode_message(&msg4, &mut out); + assert_eq!(out[0] >> 6, 3); // fmt=3 + } + + #[test] + fn chunk_encoder_extended_timestamp() { + let mut enc = ChunkEncoder::new(); + let msg = OutboundMessage { + csid: 3, + timestamp: 0x01FF_FFFF, // > 0xFFFFFF + msg_type_id: MSG_VIDEO, + stream_id: 1, + payload: vec![0; 5], + }; + let mut out = Vec::new(); + enc.encode_message(&msg, &mut out); + + // Timestamp field in header should be 0xFFFFFF. + assert_eq!(out[1], 0xFF); + assert_eq!(out[2], 0xFF); + assert_eq!(out[3], 0xFF); + // Extended timestamp (4 bytes) follows the 11-byte message header. + // Position 12..16 = extended timestamp. + let ext = u32::from_be_bytes([out[12], out[13], out[14], out[15]]); + assert_eq!(ext, 0x01FF_FFFF); + } + + #[test] + fn chunk_encoder_csid_assignment() { + // Protocol control → csid=2. + assert_eq!(CSID_PROTOCOL_CONTROL, 2); + // Commands on stream 0 → csid=3. + assert_eq!(csid_for_stream(0), 3); + // Media on stream 1 → csid=4. + assert_eq!(csid_for_stream(1), 4); + // Media on stream 2 → csid=5. + assert_eq!(csid_for_stream(2), 5); + } + + // ── Chunk decoder ─────────────────────────────────────────────── + + #[test] + fn chunk_decode_fmt0_single_chunk() { + // Encode a message, then decode it. + let mut enc = ChunkEncoder::new(); + let msg = OutboundMessage { + csid: 3, + timestamp: 42, + msg_type_id: MSG_COMMAND_AMF0, + stream_id: 0, + payload: vec![0x11, 0x22, 0x33], + }; + let mut wire = Vec::new(); + enc.encode_message(&msg, &mut wire); + + let mut dec = ChunkDecoder::new(); + dec.push(&wire); + let decoded = dec.decode_message().unwrap().unwrap(); + + assert_eq!(decoded.timestamp, 42); + assert_eq!(decoded.msg_type_id, MSG_COMMAND_AMF0); + assert_eq!(decoded.stream_id, 0); + assert_eq!(decoded.payload, vec![0x11, 0x22, 0x33]); + } + + #[test] + fn chunk_decode_multi_chunk_reassembly() { + let mut enc = ChunkEncoder::new(); + enc.set_chunk_size(5); + let payload = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; + let msg = OutboundMessage { + csid: 3, + timestamp: 0, + msg_type_id: MSG_AUDIO, + stream_id: 1, + payload: payload.clone(), + }; + let mut wire = Vec::new(); + enc.encode_message(&msg, &mut wire); + + let mut dec = ChunkDecoder::new(); + dec.set_chunk_size(5); + dec.push(&wire); + // decode_message() internally loops over continuation chunks, + // so a single call assembles the full multi-chunk message. + let decoded = loop { + if let Some(msg) = dec.decode_message().unwrap() { + break msg; + } + }; + assert_eq!(decoded.payload, payload); + } + + #[test] + fn chunk_decode_partial_reads() { + // Feed one byte at a time. + let mut enc = ChunkEncoder::new(); + let msg = OutboundMessage { + csid: 3, + timestamp: 100, + msg_type_id: MSG_VIDEO, + stream_id: 1, + payload: vec![0xAA, 0xBB, 0xCC], + }; + let mut wire = Vec::new(); + enc.encode_message(&msg, &mut wire); + + let mut dec = ChunkDecoder::new(); + for (i, &byte) in wire.iter().enumerate() { + dec.push(&[byte]); + let result = dec.decode_message().unwrap(); + if i < wire.len() - 1 { + assert!(result.is_none(), "Should not have a message yet at byte {i}"); + } else { + let decoded = result.unwrap(); + assert_eq!(decoded.payload, vec![0xAA, 0xBB, 0xCC]); + } + } + } + + // ── Handshake ─────────────────────────────────────────────────── + + #[test] + fn handshake_c0c1_length() { + let (_, c0c1) = Handshake::new(); + assert_eq!(c0c1.len(), 1 + HANDSHAKE_SIZE); + assert_eq!(c0c1[0], 0x03); // version + } + + #[test] + fn handshake_c1_not_all_zeros() { + let (_, c0c1) = Handshake::new(); + // Random portion (bytes 9..1537) should not be all zeros. + let random_portion = &c0c1[9..]; + assert!(random_portion.iter().any(|&b| b != 0), "C1 random data should not be all zeros"); + } + + #[test] + fn handshake_full_flow() { + let (mut hs, c0c1) = Handshake::new(); + + // Simulate server sending S0+S1+S2. + let mut server_response = Vec::new(); + server_response.push(0x03); // S0 + server_response.extend_from_slice(&vec![0xAA; HANDSHAKE_SIZE]); // S1 + // S2 = echo of C1. + server_response.extend_from_slice(&c0c1[1..=HANDSHAKE_SIZE]); // S2 + + let (c2, leftover) = hs.feed(&server_response).unwrap(); + assert_eq!(hs.state, HandshakeState::Complete); + assert_eq!(c2, vec![0xAA; HANDSHAKE_SIZE]); + assert!(leftover.is_empty()); + } + + #[test] + fn handshake_incremental_feed() { + let (mut hs, c0c1) = Handshake::new(); + + let mut server_response = Vec::new(); + server_response.push(0x03); + server_response.extend_from_slice(&vec![0xBB; HANDSHAKE_SIZE]); + server_response.extend_from_slice(&c0c1[1..=HANDSHAKE_SIZE]); + + // Feed in small increments. + let half = server_response.len() / 2; + assert!(hs.feed(&server_response[..half]).is_none()); + assert_ne!(hs.state, HandshakeState::Complete); + + let (c2, leftover) = hs.feed(&server_response[half..]).unwrap(); + assert_eq!(hs.state, HandshakeState::Complete); + assert_eq!(c2.len(), HANDSHAKE_SIZE); + assert!(leftover.is_empty()); + } + + #[test] + fn handshake_preserves_leftover_bytes() { + // Simulate a server that pipelines S0+S1+S2 plus initial protocol + // messages (e.g. WinAckSize) in the same TCP segment. The leftover + // bytes after the 3073-byte handshake must be returned so the caller + // can forward them to the chunk decoder. + let (mut hs, c0c1) = Handshake::new(); + + let extra = b"\x02\x00\x00\x00\x00\x00\x04\x05\x00\x00\x00\x00\x00\x26\x25\xa0"; + + let mut server_response = Vec::new(); + server_response.push(0x03); // S0 + server_response.extend_from_slice(&vec![0xCC; HANDSHAKE_SIZE]); // S1 + server_response.extend_from_slice(&c0c1[1..=HANDSHAKE_SIZE]); // S2 + server_response.extend_from_slice(extra); // extra post-handshake data + + let (c2, leftover) = hs.feed(&server_response).unwrap(); + assert_eq!(hs.state, HandshakeState::Complete); + assert_eq!(c2, vec![0xCC; HANDSHAKE_SIZE]); + assert_eq!(leftover, extra); + } + + // ── AvcSequenceHeader ─────────────────────────────────────────── + + #[test] + fn avc_sequence_header_to_bytes() { + let header = AvcSequenceHeader { + avc_profile_indication: 0x42, + profile_compatibility: 0xC0, + avc_level_indication: 0x1F, + length_size_minus_one: 3, + sps_list: vec![vec![0x67, 0x42, 0xC0, 0x1F]], + pps_list: vec![vec![0x68, 0xCE, 0x38, 0x80]], + }; + let bytes = header.to_bytes().unwrap(); + + assert_eq!(bytes[0], 1); // configurationVersion + assert_eq!(bytes[1], 0x42); // profile + assert_eq!(bytes[2], 0xC0); // compatibility + assert_eq!(bytes[3], 0x1F); // level + assert_eq!(bytes[4], 0xFF); // 111111 | 11 (length_size_minus_one=3) + assert_eq!(bytes[5] & 0x1F, 1); // numSPS = 1 + // SPS length (2 bytes) + SPS data (4 bytes). + assert_eq!(bytes[6], 0); + assert_eq!(bytes[7], 4); + assert_eq!(&bytes[8..12], &[0x67, 0x42, 0xC0, 0x1F]); + // numPPS = 1 + assert_eq!(bytes[12], 1); + // PPS length (2 bytes) + PPS data (4 bytes). + assert_eq!(bytes[13], 0); + assert_eq!(bytes[14], 4); + assert_eq!(&bytes[15..19], &[0x68, 0xCE, 0x38, 0x80]); + } + + #[test] + fn avc_sequence_header_no_sps_errors() { + let header = AvcSequenceHeader { + avc_profile_indication: 0x42, + profile_compatibility: 0xC0, + avc_level_indication: 0x1F, + length_size_minus_one: 3, + sps_list: vec![], + pps_list: vec![vec![0x68]], + }; + assert!(header.to_bytes().is_err()); + } + + // ── FLV header byte construction ──────────────────────────────── + + #[test] + fn flv_video_header_keyframe_avc() { + // KeyFrame (1) << 4 | AVC (7) = 0x17. + let frame_type: u8 = 1; + let codec: u8 = 7; + assert_eq!((frame_type << 4) | codec, 0x17); + } + + #[test] + fn flv_video_header_interframe_avc() { + // InterFrame (2) << 4 | AVC (7) = 0x27. + let frame_type: u8 = 2; + let codec: u8 = 7; + assert_eq!((frame_type << 4) | codec, 0x27); + } + + #[test] + fn flv_audio_header_aac() { + // AAC (10) << 4 | 44kHz (3) << 2 | 16bit (1) << 1 | stereo (1) = 0xAF. + let format: u8 = 10; + let rate: u8 = 3; + let size: u8 = 1; + let channels: u8 = 1; + assert_eq!((format << 4) | (rate << 2) | (size << 1) | channels, 0xAF); + } + + // ── State machine ─────────────────────────────────────────────── + + #[test] + fn connection_starts_in_handshaking() { + let url = RtmpUrl::parse("rtmp://127.0.0.1/live/key").unwrap(); + let conn = RtmpPublishClientConnection::new(url); + assert_eq!(conn.state(), RtmpConnectionState::Handshaking); + } + + #[test] + fn connection_c0c1_in_send_buf() { + let url = RtmpUrl::parse("rtmp://127.0.0.1/live/key").unwrap(); + let conn = RtmpPublishClientConnection::new(url); + let buf = conn.send_buf(); + assert_eq!(buf.len(), 1 + HANDSHAKE_SIZE); + assert_eq!(buf[0], 0x03); + } + + #[test] + fn connection_advance_send_buf() { + let url = RtmpUrl::parse("rtmp://127.0.0.1/live/key").unwrap(); + let mut conn = RtmpPublishClientConnection::new(url); + let initial_len = conn.send_buf().len(); + conn.advance_send_buf(10); + assert_eq!(conn.send_buf().len(), initial_len - 10); + } + + #[test] + fn connection_handshake_transitions_to_connecting() { + let url = RtmpUrl::parse("rtmp://127.0.0.1/live/key").unwrap(); + let mut conn = RtmpPublishClientConnection::new(url); + + // Get C0+C1 from send buf. + let c0c1 = conn.send_buf().to_vec(); + + // Simulate S0+S1+S2. + let mut server = Vec::new(); + server.push(0x03); // S0 + server.extend_from_slice(&vec![0xCC; HANDSHAKE_SIZE]); // S1 + server.extend_from_slice(&c0c1[1..=HANDSHAKE_SIZE]); // S2 = echo C1 + + conn.feed_recv_buf(&server).unwrap(); + assert_eq!(conn.state(), RtmpConnectionState::Connecting); + + // Send buf should have: C0+C1 + C2 + WinAckSize + SetChunkSize + connect + assert!(conn.send_buf().len() > 1 + HANDSHAKE_SIZE); + } + + #[test] + fn connection_send_video_before_publishing_errors() { + let url = RtmpUrl::parse("rtmp://127.0.0.1/live/key").unwrap(); + let mut conn = RtmpPublishClientConnection::new(url); + let frame = VideoFrame { + timestamp: RtmpTimestamp::from_millis(0), + composition_timestamp_offset: RtmpTimestampDelta::ZERO, + frame_type: VideoFrameType::KeyFrame, + codec: VideoCodec::Avc, + avc_packet_type: Some(AvcPacketType::NalUnit), + data: vec![0; 10], + }; + assert!(conn.send_video(&frame).is_err()); + } + + #[test] + fn connection_display_impl() { + let state = RtmpConnectionState::Publishing; + assert_eq!(format!("{state}"), "Publishing"); + } + + // ── Encode/decode roundtrip ───────────────────────────────────── + + #[test] + fn encode_decode_roundtrip_various_messages() { + let messages = vec![ + OutboundMessage { + csid: 2, + timestamp: 0, + msg_type_id: MSG_WIN_ACK_SIZE, + stream_id: 0, + payload: 2_500_000u32.to_be_bytes().to_vec(), + }, + OutboundMessage { + csid: 3, + timestamp: 100, + msg_type_id: MSG_COMMAND_AMF0, + stream_id: 0, + payload: vec![0x02, 0x00, 0x07, b'c', b'o', b'n', b'n', b'e', b'c', b't'], + }, + OutboundMessage { + csid: 4, + timestamp: 1000, + msg_type_id: MSG_VIDEO, + stream_id: 1, + payload: vec![0x17, 0x00, 0x00, 0x00, 0x00, 0xAA, 0xBB], + }, + ]; + + for orig in &messages { + let mut enc = ChunkEncoder::new(); + let mut wire = Vec::new(); + enc.encode_message(orig, &mut wire); + + let mut dec = ChunkDecoder::new(); + dec.push(&wire); + let decoded = dec.decode_message().unwrap().unwrap(); + + assert_eq!( + decoded.timestamp, orig.timestamp, + "timestamp mismatch for csid={}", + orig.csid + ); + assert_eq!( + decoded.msg_type_id, orig.msg_type_id, + "type mismatch for csid={}", + orig.csid + ); + assert_eq!( + decoded.stream_id, orig.stream_id, + "stream_id mismatch for csid={}", + orig.csid + ); + assert_eq!(decoded.payload, orig.payload, "payload mismatch for csid={}", orig.csid); + } + } + + // ── Basic header encoding ─────────────────────────────────────── + + #[test] + fn basic_header_1byte_form() { + let mut out = Vec::new(); + encode_basic_header(0, 2, &mut out); + assert_eq!(out.len(), 1); + assert_eq!(out[0], 0x02); // fmt=0, csid=2 + } + + #[test] + fn basic_header_2byte_form() { + let mut out = Vec::new(); + encode_basic_header(0, 64, &mut out); + assert_eq!(out.len(), 2); + assert_eq!(out[0], 0x00); // fmt=0, csid=0 (2-byte marker) + assert_eq!(out[1], 0); // 64 - 64 = 0 + } + + #[test] + fn basic_header_3byte_form() { + let mut out = Vec::new(); + encode_basic_header(0, 320, &mut out); + assert_eq!(out.len(), 3); + assert_eq!(out[0], 0x01); // fmt=0, csid=1 (3-byte marker) + let val = u16::from(out[1]) + u16::from(out[2]) * 256 + 64; + assert_eq!(val, 320); + } + + // ── Full server simulation (YouTube-like flow) ──────────────── + + /// Helper: build an RTMP chunk from scratch using our encoder, simulating + /// a server sending a message. Returns the raw bytes ready to feed into + /// a client connection's `feed_recv_buf`. + fn server_encode( + encoder: &mut ChunkEncoder, + csid: u16, + msg_type_id: u8, + stream_id: u32, + payload: Vec, + ) -> Vec { + let mut out = Vec::new(); + encoder.encode_message( + &OutboundMessage { csid, timestamp: 0, msg_type_id, stream_id, payload }, + &mut out, + ); + out + } + + /// Simulate the complete YouTube RTMP server flow from handshake + /// through to Publishing state. This catches regressions in the + /// state machine, AMF0 codec, and chunk encoder/decoder interop. + #[test] + fn full_youtube_server_simulation() { + let url = RtmpUrl::parse("rtmp://x.rtmp.youtube.com/live2/stream-key").unwrap(); + let mut conn = RtmpPublishClientConnection::new(url); + assert_eq!(conn.state(), RtmpConnectionState::Handshaking); + + // ── Step 1: client sends C0+C1 ────────────────────────────── + let c0c1 = conn.send_buf().to_vec(); + assert_eq!(c0c1.len(), 1 + HANDSHAKE_SIZE); + conn.advance_send_buf(c0c1.len()); + + // ── Step 2: server sends S0+S1+S2 (no leftover bytes) ─────── + let mut s0s1s2 = Vec::with_capacity(1 + HANDSHAKE_SIZE * 2); + s0s1s2.push(0x03); // S0 + s0s1s2.extend_from_slice(&vec![0xBB; HANDSHAKE_SIZE]); // S1 + s0s1s2.extend_from_slice(&c0c1[1..=HANDSHAKE_SIZE]); // S2 = echo C1 + + conn.feed_recv_buf(&s0s1s2).unwrap(); + assert_eq!(conn.state(), RtmpConnectionState::Connecting); + + // Send buf now has: C2 + WinAckSize + SetChunkSize + connect + assert!(conn.send_buf().len() > HANDSHAKE_SIZE); + conn.advance_send_buf(conn.send_buf().len()); // simulate flush + + // ── Step 3: server sends WinAckSize + SetPeerBandwidth ────── + let mut srv_enc = ChunkEncoder::new(); + let win_ack = server_encode( + &mut srv_enc, + 2, + MSG_WIN_ACK_SIZE, + 0, + 2_500_000u32.to_be_bytes().to_vec(), + ); + let mut set_bw_payload = 59_768_832u32.to_be_bytes().to_vec(); + set_bw_payload.push(2); // limit_type = Dynamic + let set_bw = server_encode(&mut srv_enc, 2, MSG_SET_PEER_BANDWIDTH, 0, set_bw_payload); + + let mut server_msg = Vec::new(); + server_msg.extend_from_slice(&win_ack); + server_msg.extend_from_slice(&set_bw); + conn.feed_recv_buf(&server_msg).unwrap(); + // Still Connecting — waiting for _result + assert_eq!(conn.state(), RtmpConnectionState::Connecting); + + // Client should have queued a WinAckSize response to SetPeerBandwidth + assert!(!conn.send_buf().is_empty()); + conn.advance_send_buf(conn.send_buf().len()); + + // ── Step 4: server sends connect _result ──────────────────── + let mut result_payload = Vec::new(); + amf0_encode(&Amf0Value::String("_result".to_string()), &mut result_payload).unwrap(); + amf0_encode(&Amf0Value::Number(1.0), &mut result_payload).unwrap(); + amf0_encode( + &Amf0Value::Object(vec![ + ("fmsVer".to_string(), Amf0Value::String("FMS/3,5,7,7009".to_string())), + ("capabilities".to_string(), Amf0Value::Number(31.0)), + ]), + &mut result_payload, + ) + .unwrap(); + amf0_encode( + &Amf0Value::Object(vec![ + ("level".to_string(), Amf0Value::String("status".to_string())), + ( + "code".to_string(), + Amf0Value::String("NetConnection.Connect.Success".to_string()), + ), + ("description".to_string(), Amf0Value::String("Connection succeeded".to_string())), + ("objectEncoding".to_string(), Amf0Value::Number(0.0)), + ]), + &mut result_payload, + ) + .unwrap(); + let result_msg = server_encode(&mut srv_enc, 3, MSG_COMMAND_AMF0, 0, result_payload); + + conn.feed_recv_buf(&result_msg).unwrap(); + // After _result → Connected → auto-sends createStream + assert_eq!(conn.state(), RtmpConnectionState::Connected); + conn.advance_send_buf(conn.send_buf().len()); + + // ── Step 5: server sends createStream _result ─────────────── + let mut cs_payload = Vec::new(); + amf0_encode(&Amf0Value::String("_result".to_string()), &mut cs_payload).unwrap(); + amf0_encode(&Amf0Value::Number(2.0), &mut cs_payload).unwrap(); + amf0_encode(&Amf0Value::Null, &mut cs_payload).unwrap(); + amf0_encode(&Amf0Value::Number(1.0), &mut cs_payload).unwrap(); // stream_id=1 + let cs_msg = server_encode(&mut srv_enc, 3, MSG_COMMAND_AMF0, 0, cs_payload); + + conn.feed_recv_buf(&cs_msg).unwrap(); + // MediaStreamCreated → auto-sends publish → PublishPending + assert_eq!(conn.state(), RtmpConnectionState::PublishPending); + assert_eq!(conn.media_stream_id, 1); + conn.advance_send_buf(conn.send_buf().len()); + + // ── Step 6: server sends onStatus(NetStream.Publish.Start) ── + let mut status_payload = Vec::new(); + amf0_encode(&Amf0Value::String("onStatus".to_string()), &mut status_payload).unwrap(); + amf0_encode(&Amf0Value::Number(0.0), &mut status_payload).unwrap(); + amf0_encode(&Amf0Value::Null, &mut status_payload).unwrap(); + amf0_encode( + &Amf0Value::Object(vec![ + ("level".to_string(), Amf0Value::String("status".to_string())), + ("code".to_string(), Amf0Value::String("NetStream.Publish.Start".to_string())), + ("description".to_string(), Amf0Value::String("Publishing stream-key".to_string())), + ]), + &mut status_payload, + ) + .unwrap(); + let status_msg = server_encode(&mut srv_enc, 4, MSG_COMMAND_AMF0, 1, status_payload); + + conn.feed_recv_buf(&status_msg).unwrap(); + assert_eq!(conn.state(), RtmpConnectionState::Publishing); + + // ── Step 7: verify we can send media ──────────────────────── + let video = VideoFrame { + timestamp: RtmpTimestamp::from_millis(0), + composition_timestamp_offset: RtmpTimestampDelta::ZERO, + frame_type: VideoFrameType::KeyFrame, + codec: VideoCodec::Avc, + avc_packet_type: Some(AvcPacketType::NalUnit), + data: vec![0x00, 0x00, 0x01, 0x67, 0x42], + }; + conn.send_video(&video).unwrap(); + assert!(!conn.send_buf().is_empty()); + } +} diff --git a/crates/nodes/src/video/openh264.rs b/crates/nodes/src/video/openh264.rs index 5b02fd5f..2b7bdd97 100644 --- a/crates/nodes/src/video/openh264.rs +++ b/crates/nodes/src/video/openh264.rs @@ -11,7 +11,9 @@ use async_trait::async_trait; use bytes::Bytes; -use openh264::encoder::{BitRate, EncoderConfig, FrameRate, FrameType, RateControlMode}; +use openh264::encoder::{ + BitRate, EncoderConfig, FrameRate, FrameType, IntraFramePeriod, RateControlMode, +}; use openh264::formats::YUVSlices; use schemars::JsonSchema; use serde::Deserialize; @@ -34,6 +36,7 @@ use super::H264_CONTENT_TYPE; const H264_DEFAULT_BITRATE_KBPS: u32 = 2000; const H264_DEFAULT_MAX_FRAME_RATE: f32 = 30.0; +const H264_DEFAULT_GOP_SIZE: u32 = 60; // --------------------------------------------------------------------------- // Configuration @@ -50,6 +53,16 @@ pub struct OpenH264EncoderConfig { pub bitrate_kbps: u32, /// Maximum frame rate in Hz. Must be greater than zero. pub max_frame_rate: f32, + /// GOP size: number of frames between IDR (keyframe) insertions. + /// + /// 0 = let the encoder decide (OpenH264 "auto" mode — may produce very + /// few keyframes). For RTMP streaming to platforms like YouTube Live or + /// Twitch, set this to `2 × max_frame_rate` (e.g. 60 for 30fps) to get + /// a keyframe every 2 seconds, which is within the 2–4 s range most CDNs + /// require. + /// + /// Defaults to 60 (≈ 2 s at 30 fps). + pub gop_size: u32, } impl Default for OpenH264EncoderConfig { @@ -57,6 +70,7 @@ impl Default for OpenH264EncoderConfig { Self { bitrate_kbps: H264_DEFAULT_BITRATE_KBPS, max_frame_rate: H264_DEFAULT_MAX_FRAME_RATE, + gop_size: H264_DEFAULT_GOP_SIZE, } } } @@ -205,6 +219,7 @@ impl OpenH264Encoder { .bitrate(BitRate::from_bps(config.bitrate_kbps.saturating_mul(1000))) .max_frame_rate(FrameRate::from_hz(config.max_frame_rate)) .rate_control_mode(RateControlMode::Bitrate) + .intra_frame_period(IntraFramePeriod::from_num_frames(config.gop_size)) .skip_frames(false); let encoder = openh264::encoder::Encoder::with_api_config( @@ -408,7 +423,11 @@ mod tests { enc_inputs.insert("in".to_string(), enc_input_rx); let (enc_context, enc_sender, mut enc_state_rx) = create_test_context(enc_inputs, 10); - let encoder_config = OpenH264EncoderConfig { bitrate_kbps: 2000, max_frame_rate: 30.0 }; + let encoder_config = OpenH264EncoderConfig { + bitrate_kbps: 2000, + max_frame_rate: 30.0, + ..Default::default() + }; let encoder = OpenH264EncoderNode::new(encoder_config).unwrap(); let enc_handle = tokio::spawn(async move { Box::new(encoder).run(enc_context).await }); @@ -539,6 +558,7 @@ mod tests { let result = OpenH264EncoderNode::new(OpenH264EncoderConfig { bitrate_kbps: 0, max_frame_rate: 30.0, + ..Default::default() }); assert!(result.is_err(), "bitrate_kbps=0 should be rejected"); } @@ -548,6 +568,7 @@ mod tests { let result = OpenH264EncoderNode::new(OpenH264EncoderConfig { bitrate_kbps: 2000, max_frame_rate: -1.0, + ..Default::default() }); assert!(result.is_err(), "negative max_frame_rate should be rejected"); } @@ -557,6 +578,7 @@ mod tests { let result = OpenH264EncoderNode::new(OpenH264EncoderConfig { bitrate_kbps: 2000, max_frame_rate: 0.0, + ..Default::default() }); assert!(result.is_err(), "zero max_frame_rate should be rejected"); } @@ -566,6 +588,7 @@ mod tests { let result = OpenH264EncoderNode::new(OpenH264EncoderConfig { bitrate_kbps: 2000, max_frame_rate: f32::NAN, + ..Default::default() }); assert!(result.is_err(), "NaN max_frame_rate should be rejected"); } @@ -575,12 +598,14 @@ mod tests { let result = OpenH264EncoderNode::new(OpenH264EncoderConfig { bitrate_kbps: 2000, max_frame_rate: f32::INFINITY, + ..Default::default() }); assert!(result.is_err(), "INFINITY max_frame_rate should be rejected"); let result = OpenH264EncoderNode::new(OpenH264EncoderConfig { bitrate_kbps: 2000, max_frame_rate: f32::NEG_INFINITY, + ..Default::default() }); assert!(result.is_err(), "NEG_INFINITY max_frame_rate should be rejected"); } @@ -590,6 +615,7 @@ mod tests { let result = OpenH264EncoderNode::new(OpenH264EncoderConfig { bitrate_kbps: 500_001, max_frame_rate: 30.0, + ..Default::default() }); assert!(result.is_err(), "bitrate_kbps above 500_000 should be rejected"); @@ -597,6 +623,7 @@ mod tests { let result = OpenH264EncoderNode::new(OpenH264EncoderConfig { bitrate_kbps: 500_000, max_frame_rate: 30.0, + ..Default::default() }); assert!(result.is_ok(), "bitrate_kbps=500_000 should be accepted"); } diff --git a/samples/pipelines/dynamic/moq_to_rtmp_composite.yml b/samples/pipelines/dynamic/moq_to_rtmp_composite.yml new file mode 100644 index 00000000..dac56f2e --- /dev/null +++ b/samples/pipelines/dynamic/moq_to_rtmp_composite.yml @@ -0,0 +1,155 @@ +# SPDX-FileCopyrightText: © 2025 StreamKit Contributors +# +# SPDX-License-Identifier: MPL-2.0 + +# MoQ-to-RTMP compositing pipeline. +# +# Receives audio and video from a WebTransport publisher via MoQ, +# composites the video (main + PiP + logo overlay), re-encodes to +# H.264 (OpenH264) + AAC, and publishes the result to an RTMP endpoint +# such as YouTube Live or Twitch. +# +# Requires: +# - aac-encoder native plugin (just install-plugin aac-encoder) +# - Set the SKIT_RTMP_STREAM_KEY env var before starting the server, +# or replace stream_key_env with a literal stream_key value. + +name: MoQ to RTMP (Composited) +description: | + Receives audio+video via MoQ peer, composites with PiP and logo overlay, + re-encodes to H.264+AAC, and publishes to an RTMP endpoint (e.g. YouTube Live). +mode: dynamic +client: + gateway_path: /moq/rtmp-out + publish: + broadcast: input + tracks: + - kind: audio + source: microphone + - kind: video + source: camera + watch: + broadcast: monitor + audio: false + video: true + +nodes: + # ── MoQ input ────────────────────────────────────────────────────── + moq_peer: + kind: transport::moq::peer + params: + gateway_path: /moq/rtmp-out + input_broadcasts: + - input + output_broadcast: monitor + allow_reconnect: true + needs: + - vp9_monitor + + # ── Audio path: Opus → PCM → AAC ────────────────────────────────── + opus_decoder: + kind: audio::opus::decoder + needs: + in: moq_peer.audio/data + + aac_encoder: + kind: plugin::native::aac_encoder + params: + bitrate: 128000 + needs: opus_decoder + + # ── Video path: decode → composite → encode ──────────────────────── + vp9_decoder: + kind: video::vp9::decoder + needs: + # @moq/publish uses "video/hd" as the track name (not "video/data") + in: moq_peer.video/hd + + colorbars_pip: + kind: video::colorbars + params: + width: 1280 + height: 720 + fps: 30 + pixel_format: rgba8 + draw_time: true + + compositor: + kind: video::compositor + params: + width: 1280 + height: 720 + num_inputs: 2 + layers: + in_0: + rect: + x: 920 + y: 20 + width: 320 + height: 240 + aspect_fit: true + opacity: 1.0 + z_index: 1 + mirror_horizontal: true + mirror_vertical: false + crop_zoom: 1.8 + crop_x: 0.5 + crop_y: 0.4 + crop_shape: circle + in_1: + rect: + x: 0 + y: 0 + width: 1280 + height: 720 + opacity: 1 + z_index: 0 + image_overlays: + - id: logo + asset_path: samples/images/system/streamkit-logo.png + rect: + x: 1190 + y: 630 + width: 46 + height: 80 + opacity: 0.7 + z_index: 3 + needs: + - vp9_decoder + - colorbars_pip + + pixel_convert: + kind: video::pixel_convert + params: + output_format: nv12 + needs: compositor + + h264_encoder: + kind: video::openh264::encoder + params: + bitrate_kbps: 2500 + max_frame_rate: 30.0 + gop_size: 60 # keyframe every 2s at 30 fps — explicit for RTMP clarity + needs: pixel_convert + + # ── Monitor output (view composited result via MoQ) ──────────────── + vp9_monitor: + kind: video::vp9::encoder + needs: pixel_convert + + # ── RTMP output ──────────────────────────────────────────────────── + rtmp_publish: + kind: transport::rtmp::publish + params: + url: "rtmps://x.rtmps.youtube.com/live2" + # Twitch: + # url: "rtmps://ingest.global-contribute.live-video.net/app" + # + # Read the stream key from the SKIT_RTMP_STREAM_KEY environment variable. + # Alternatively, use `stream_key: "your-key-here"` for a literal value. + # The env var name is fully user-controlled — use any name you like + # (e.g. SKIT_TWITCH_KEY, SKIT_YT_KEY) to support multiple RTMP outputs. + stream_key_env: "SKIT_RTMP_STREAM_KEY" + needs: + video: h264_encoder + audio: aac_encoder