Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 147 additions & 18 deletions crates/plugin-native/src/wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@ use streamkit_core::control::NodeControlMessage;
use streamkit_core::telemetry::TelemetryEvent;
use streamkit_core::types::Packet;
use streamkit_core::{
InputPin, NodeContext, NodeState, NodeStateUpdate, OutputPin, ProcessorNode, StopReason,
StreamKitError,
AudioFramePool, InputPin, NodeContext, NodeState, NodeStateUpdate, OutputPin, ProcessorNode,
StopReason, StreamKitError, VideoFramePool,
};
use streamkit_plugin_sdk_native::{
conversions,
types::{CNativePluginAPI, CPacket, CPluginHandle, CResult},
types::{
CAllocAudioResult, CAllocVideoResult, CNativePluginAPI, CNodeCallbacks, CPacket,
CPluginHandle, CResult,
},
};
use tracing::{error, info, warn};

Expand Down Expand Up @@ -282,6 +285,8 @@ impl NativeNodeWrapper {
let (merged_tx, mut merged_rx) =
tokio::sync::mpsc::channel::<(usize, Packet)>(context.batch_size.max(1));
let cancellation_token = context.cancellation_token.clone();
let video_pool = context.video_pool.clone();
let audio_pool = context.audio_pool.clone();

for (pin_name, mut rx) in inputs.drain() {
let pin_cstr = CString::new(pin_name.as_str()).map_err(|e| {
Expand Down Expand Up @@ -432,18 +437,18 @@ impl NativeNodeWrapper {
telemetry_tx,
session_id,
node_id,
video_pool,
audio_pool,
};

let callback_data = (&raw mut callback_ctx).cast::<c_void>();
let node_callbacks = build_node_callbacks(callback_data);

// Call plugin's flush function
tracing::info!("Calling api.flush()");
let result = (api.flush)(
handle,
output_callback_shim,
callback_data,
Some(telemetry_callback_shim),
callback_data,
&raw const node_callbacks,
);
tracing::info!(success = result.success, "Flush returned");

Expand Down Expand Up @@ -488,6 +493,8 @@ impl NativeNodeWrapper {
let session_id = context.session_id.clone();
let node_id = node_name.clone();
let pin_cstr = Arc::clone(&input_pin_cstrs[pin_index]);
let video_pool = video_pool.clone();
let audio_pool = audio_pool.clone();
let (outputs, error) = tokio::task::spawn_blocking(move || {
let Some(handle) = state.begin_call() else {
return (Vec::new(), None);
Expand All @@ -505,19 +512,19 @@ impl NativeNodeWrapper {
telemetry_tx,
session_id,
node_id,
video_pool,
audio_pool,
};

let callback_data = (&raw mut callback_ctx).cast::<c_void>();
let node_callbacks = build_node_callbacks(callback_data);

// Call plugin's process function (BLOCKING - but we're in spawn_blocking)
let result = (api.process_packet)(
handle,
pin_cstr.as_ptr(),
&raw const packet_repr.packet,
output_callback_shim,
callback_data,
Some(telemetry_callback_shim),
callback_data,
&raw const node_callbacks,
);

// Check for errors
Expand Down Expand Up @@ -612,6 +619,8 @@ impl NativeNodeWrapper {
}

let node_name = context.output_sender.node_name().to_string();
let video_pool = context.video_pool.clone();
let audio_pool = context.audio_pool.clone();

tracing::info!(node = %node_name, "Native source plugin wrapper starting");

Expand Down Expand Up @@ -845,6 +854,8 @@ impl NativeNodeWrapper {
let telemetry_tx = context.telemetry_tx.clone();
let session_id = context.session_id.clone();
let node_id = node_name.clone();
let video_pool = video_pool.clone();
let audio_pool = audio_pool.clone();
let outcome = tokio::task::spawn_blocking(move || {
let Some(handle) = state.begin_call() else {
return TickOutcome {
Expand All @@ -863,17 +874,14 @@ impl NativeNodeWrapper {
telemetry_tx,
session_id,
node_id,
video_pool,
audio_pool,
};

let callback_data = (&raw mut callback_ctx).cast::<c_void>();
let node_callbacks = build_node_callbacks(callback_data);

let result = tick_fn(
handle,
output_callback_shim,
callback_data,
Some(telemetry_callback_shim),
callback_data,
);
let result = tick_fn(handle, &raw const node_callbacks);

// Extract error string while pointers are still valid.
let error_msg = if result.result.success {
Expand Down Expand Up @@ -1032,6 +1040,43 @@ struct CallbackContext {
telemetry_tx: Option<tokio::sync::mpsc::Sender<TelemetryEvent>>,
session_id: Option<String>,
node_id: String,
video_pool: Option<Arc<VideoFramePool>>,
audio_pool: Option<Arc<AudioFramePool>>,
}

/// Free any pool-allocated `buffer_handle` embedded in a raw [`CPacket`].
///
/// This is the safety net for error paths in [`output_callback_shim`]: if
/// `packet_from_c` is never called (e.g. invalid pin name) or if it fails
/// before reclaiming the handle, the pooled buffer would leak because the
/// SDK already marked it as consumed (suppressing `Drop`).
///
/// # Safety
///
/// `c_packet` must be a valid, non-null pointer to a [`CPacket`].
unsafe fn free_packet_buffer_handle(c_packet: *const CPacket) {
use streamkit_core::frame_pool::{PooledSamples, PooledVideoData};
use streamkit_plugin_sdk_native::types::CPacketType;

let pkt = &*c_packet;
if pkt.data.is_null() {
return;
}
match pkt.packet_type {
CPacketType::RawVideo => {
let frame = &*pkt.data.cast::<streamkit_plugin_sdk_native::types::CVideoFrame>();
if !frame.buffer_handle.is_null() {
drop(Box::from_raw(frame.buffer_handle.cast::<PooledVideoData>()));
}
},
CPacketType::RawAudio => {
let frame = &*pkt.data.cast::<streamkit_plugin_sdk_native::types::CAudioFrame>();
if !frame.buffer_handle.is_null() {
drop(Box::from_raw(frame.buffer_handle.cast::<PooledSamples>()));
}
},
_ => {},
}
}

/// C callback function for sending output packets
Expand All @@ -1054,6 +1099,8 @@ extern "C" fn output_callback_shim(
Ok(s) => s,
Err(e) => {
ctx.error = Some(format!("Invalid pin name: {e}"));
// Free any pooled buffer the plugin already consumed.
unsafe { free_packet_buffer_handle(c_packet) };
return CResult::error(std::ptr::null());
},
};
Expand All @@ -1062,6 +1109,8 @@ extern "C" fn output_callback_shim(
let packet = match unsafe { conversions::packet_from_c(c_packet) } {
Ok(p) => p,
Err(e) => {
// packet_from_c already frees the buffer_handle on its own error
// paths (Critical #1), so no extra cleanup needed here.
ctx.error = Some(format!("Failed to convert packet: {e}"));
return CResult::error(std::ptr::null());
},
Expand Down Expand Up @@ -1151,3 +1200,83 @@ extern "C" fn telemetry_callback_shim(

CResult::success()
}

// ── Frame pool allocation shims (v6) ─────────────────────────────────────

/// Allocate a video buffer from the host's frame pool.
extern "C" fn alloc_video_shim(min_bytes: usize, user_data: *mut c_void) -> CAllocVideoResult {
use streamkit_core::frame_pool::PooledVideoData;

if user_data.is_null() {
return CAllocVideoResult::null();
}

let ctx = unsafe { &*user_data.cast::<CallbackContext>() };
let Some(pool) = ctx.video_pool.as_ref() else {
return CAllocVideoResult::null();
};

let mut pooled: PooledVideoData = pool.get(min_bytes);
let data_ptr = pooled.as_mut_ptr();
let len = pooled.len();
let handle = Box::into_raw(Box::new(pooled)).cast::<c_void>();

CAllocVideoResult { data: data_ptr, len, handle, free_fn: Some(free_video_buffer) }
}

/// Free a video buffer without sending it (error/discard path).
extern "C" fn free_video_buffer(handle: *mut c_void) {
use streamkit_core::frame_pool::PooledVideoData;

if !handle.is_null() {
// SAFETY: handle was created by alloc_video_shim via Box::into_raw.
let _ = unsafe { Box::from_raw(handle.cast::<PooledVideoData>()) };
}
}

/// Allocate an audio buffer from the host's frame pool.
extern "C" fn alloc_audio_shim(min_samples: usize, user_data: *mut c_void) -> CAllocAudioResult {
use streamkit_core::frame_pool::PooledSamples;

if user_data.is_null() {
return CAllocAudioResult::null();
}

let ctx = unsafe { &*user_data.cast::<CallbackContext>() };
let Some(pool) = ctx.audio_pool.as_ref() else {
return CAllocAudioResult::null();
};

let mut pooled: PooledSamples = pool.get(min_samples);
let data_ptr = pooled.as_mut_ptr();
let sample_count = pooled.len();
let handle = Box::into_raw(Box::new(pooled)).cast::<c_void>();

CAllocAudioResult { data: data_ptr, sample_count, handle, free_fn: Some(free_audio_buffer) }
}

/// Free an audio buffer without sending it (error/discard path).
extern "C" fn free_audio_buffer(handle: *mut c_void) {
use streamkit_core::frame_pool::PooledSamples;

if !handle.is_null() {
let _ = unsafe { Box::from_raw(handle.cast::<PooledSamples>()) };
}
}

/// Build a `CNodeCallbacks` struct from a `CallbackContext` pointer.
///
/// The returned struct borrows `callback_data` — it must not outlive the
/// `CallbackContext`.
fn build_node_callbacks(callback_data: *mut c_void) -> CNodeCallbacks {
CNodeCallbacks {
struct_size: std::mem::size_of::<CNodeCallbacks>(),
output_callback: output_callback_shim,
output_user_data: callback_data,
telemetry_callback: Some(telemetry_callback_shim),
telemetry_user_data: callback_data,
alloc_video: Some(alloc_video_shim),
alloc_audio: Some(alloc_audio_shim),
alloc_user_data: callback_data,
}
}
32 changes: 16 additions & 16 deletions examples/plugins/gain-native/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading