diff --git a/benchmark/logic/parametrized_select.js b/benchmark/logic/parametrized_select.js index caef6876c..52cb4e4f8 100644 --- a/benchmark/logic/parametrized_select.js +++ b/benchmark/logic/parametrized_select.js @@ -15,7 +15,7 @@ function selectWithRows(number) { utils.prepareDatabase(client, utils.tableSchemaBasic, next); }, async function insert(next) { - utils.insertSimple(client, 10, next); + utils.insertSimple(client, number, next); }, async function query(next) { await utils.queryWithRowCheck(client, number, iterCnt, next); diff --git a/lib/client.js b/lib/client.js index dfff6fc8c..dacf6b324 100644 --- a/lib/client.js +++ b/lib/client.js @@ -29,6 +29,11 @@ const { HostMap } = require("./host.js"); // eslint-disable-next-line no-unused-vars const { QueryOptions } = require("./query-options.js"); +// Initialize the direct-poll bridge once per process. +// This sets up the Tokio reactor thread and the wake mechanism used by all +// bridged async Rust functions (session queries, paging, etc.). +rust.initPollBridge(); + /** * Represents a database client that maintains multiple connections to the cluster nodes, providing methods to * execute CQL statements. diff --git a/src/casync.rs b/src/casync.rs new file mode 100644 index 000000000..7f9292a3a --- /dev/null +++ b/src/casync.rs @@ -0,0 +1,388 @@ +use std::cell::RefCell; +use std::collections::HashMap; +use std::ffi::CString; +use std::future::Future; +use std::marker::PhantomData; +use std::pin::Pin; +use std::ptr; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, Mutex}; +use std::task::{Context, Poll, Wake, Waker}; + +// While check_status macro is doc(hidden), it implements a simple checks that convert c errors into Rust Results +// Implementation: https://github.com/napi-rs/napi-rs/blob/f2178312d0e3e07beecc19836b91716a229107d3/crates/napi/src/error.rs#L35 +use napi::bindgen_prelude::{ToNapiValue, check_status}; +use napi::threadsafe_function::ThreadsafeFunctionCallMode; +use napi::{Env, Error, Result, Status, sys}; +use napi_derive::napi; + +use crate::errors::{ConvertedError, ConvertedResult, JsResult, with_custom_error_sync}; +use crate::napi_helpers::{DeferredPtr, ResolveOrReject}; + +/// JsPromise — lightweight wrapper over the promise pointer that indicates the type used to resolve the promise +/// The promise can be either resolved with type T or rejected with any error value (`ConvertedError` when used with `submit_future`). +pub struct JsPromise(sys::napi_value, PhantomData); + +impl ToNapiValue for JsPromise { + /// # Safety + /// No constrains on safety. The unsafe is required by the trait. + unsafe fn to_napi_value(_: sys::napi_env, val: Self) -> Result { + Ok(val.0) + } +} + +type SettleCallback = Box; +type BridgedFuture = Pin + Send>>; + +struct FutureEntry { + future: BridgedFuture, + /// Raw deferred handle — resolved/rejected in `poll_woken` on the + /// main thread where we have a valid `napi_env`. + deferred: DeferredPtr, + waker: Waker, +} + +/// No argument no return value, weak ThreadSafeFunction type. +type Tsfn = napi::threadsafe_function::ThreadsafeFunction<(), (), (), Status, false, true>; + +/// Single Thread safe function, coalesced wake signals +struct WakerBridge { + woken_ids: Arc>>, + signaled: Arc, + /// The TSFN lives here (behind a Mutex) so it's reachable from any + /// thread — including the Tokio worker thread that fires wakers. + tsfn: Mutex>, +} + +impl WakerBridge { + fn new() -> Self { + Self { + woken_ids: Arc::new(Mutex::new(Vec::new())), + signaled: Arc::new(AtomicBool::new(false)), + tsfn: Mutex::new(None), + } + } + + /// Set the TSFN after creation (called once from `init_poll_bridge`). + fn set_tsfn(&self, tsfn: Tsfn) { + *self.tsfn.lock().unwrap() = Some(tsfn); + } + + /// Signal the TSFN if not already signaled. + fn signal(&self) { + if !self.signaled.swap(true, Ordering::AcqRel) { + let guard = self.tsfn.lock().unwrap(); + if let Some(ref tsfn) = *guard { + tsfn.call((), ThreadsafeFunctionCallMode::NonBlocking); + } // Else branches can happen only during shutdown + } + } + + /// Called from any thread by a Waker. + fn wake(&self, future_id: u64) { + let mut ids = self.woken_ids.lock().unwrap(); + ids.push(future_id); + self.signal(); + } +} + +/// Per-future waker internals +struct WakerInner { + future_id: u64, + bridge: Arc, +} + +impl Wake for WakerInner { + fn wake(self: Arc) { + self.bridge.wake(self.future_id); + } + + fn wake_by_ref(self: &Arc) { + self.bridge.wake(self.future_id); + } +} + +/// FutureRegistry — thread-local, lives on the Node main thread +struct FutureRegistry { + futures: HashMap, + next_id: u64, + bridge: Arc, + tokio_rt: Option, +} + +static INITIALIZED: AtomicBool = AtomicBool::new(false); + +impl FutureRegistry { + fn new() -> Self { + let bridge = Arc::new(WakerBridge::new()); + Self { + futures: HashMap::new(), + next_id: 0, + bridge, + tokio_rt: None, + } + } + + fn insert(&mut self, env: &Env, future: BridgedFuture, deferred: DeferredPtr) -> Result { + let was_empty = self.futures.is_empty(); + + let id = self.next_id; + self.next_id += 1; + + let waker = Waker::from(Arc::new(WakerInner { + future_id: id, + bridge: Arc::clone(&self.bridge), + })); + + self.futures.insert( + id, + FutureEntry { + future, + deferred, + waker, + }, + ); + + // If this is the first outstanding future, ref the TSFN so Node + // keeps its event loop alive until all futures have settled. + if was_empty { + let guard = self.bridge.tsfn.lock().unwrap(); + if let Some(ref tsfn) = *guard { + // SAFETY: Env guarantees a valid `napi_env` for the current call. + unsafe { check_status!(sys::napi_ref_threadsafe_function(env.raw(), tsfn.raw()))? }; + } // Else branches can happen only during shutdown + } + + // Schedule the mandatory first poll. + self.bridge.wake(id); + + Ok(id) + } + + /// Called on the main thread when the TSFN fires. + /// `raw_env` is valid only for this invocation (from the TSFN callback). + fn poll_woken(&mut self, env: Env) { + self.bridge.signaled.store(false, Ordering::Release); + + let woken: Vec = { + let mut ids = self.bridge.woken_ids.lock().unwrap(); + std::mem::take(&mut *ids) + }; + + // Take-and-process: remove entries before polling so that a polled + // future can register *new* futures without hitting RefCell deadlock. + let entries: Vec<(u64, FutureEntry)> = woken + .iter() + .filter_map(|&id| self.futures.remove(&id).map(|e| (id, e))) + .collect(); + + // Enter the Tokio runtime context so tokio::net, tokio::time, etc. + // register with the reactor when polled. + let _guard = self.tokio_rt.as_ref().map(|rt| rt.enter()); + + for (id, mut entry) in entries { + let mut cx = Context::from_waker(&entry.waker); + match entry.future.as_mut().poll(&mut cx) { + Poll::Ready(settle_fn) => { + settle_fn(env, entry.deferred); + } + Poll::Pending => { + self.futures.insert(id, entry); + } + } + } + + // If every future has settled, unref the TSFN so Node can exit + // naturally. The check happens *after* all polls so that a future + // completing synchronously and submitting a new future in its settle + // callback won't cause a premature unref. + if self.futures.is_empty() { + let guard = self.bridge.tsfn.lock().unwrap(); + if let Some(ref tsfn) = *guard { + // SAFETY: Env guarantees a valid `napi_env` for the current call. + // `tsfn.raw()` is live because we hold the Mutex lock. + let status = unsafe { + check_status!(sys::napi_unref_threadsafe_function(env.raw(), tsfn.raw())) + }; + if let Err(e) = status { + // We should fail here only in extreme cases (e.g. TSFN already unrefed, env invalid, etc.) — panic is warranted. + panic!( + "Failed to unref TSFN in poll_woken. This may indicate either a bug in the driver or a severe runtime error.\nRoot cause:\n {}", + e.reason + ); + } + } + } + } + + // This function is registered in the startup to be called during node cleanup process. + fn shutdown(&mut self) { + self.futures.clear(); + *self.bridge.tsfn.lock().unwrap() = None; + if let Some(rt) = self.tokio_rt.take() { + rt.shutdown_background(); + } + } +} + +thread_local! { + static REGISTRY: RefCell = RefCell::new(FutureRegistry::new()); +} + +fn create_promise(env: &Env) -> Result<(DeferredPtr, sys::napi_value)> { + let mut deferred = ptr::null_mut(); + let mut promise = ptr::null_mut(); + // SAFETY: `raw_env` is taken from Env, which is guaranteed to be valid for the lifetime of the current napi call. + unsafe { + check_status!(sys::napi_create_promise( + env.raw(), + &mut deferred, + &mut promise + ))? + }; + // SAFETY: deferred is assigned to valid value in napi_create_promise call, that have just succeeded. + // This promise had no chance to be resolved yet. + let deferred_ptr = unsafe { DeferredPtr::new(deferred) }; + Ok((deferred_ptr, promise)) +} + +fn reject_with_reason(env: Env, deferred: DeferredPtr, reason: &str) -> Result<()> { + // We can unwrap in the second place, because the only case when Cstring::new can fail is when the string contains a null byte. + let c_reason = CString::new(reason).unwrap_or_else(|_| { + CString::new("[Unknown error] Error message contained illegal null byte").unwrap() + }); + let mut msg: sys::napi_value = std::ptr::null_mut(); + let mut error: sys::napi_value = std::ptr::null_mut(); + + // SAFETY: Env guarantees that raw pointer is a valid main-thread env. + // Remaining arguments are created in this function and are valid for the whole duration. + unsafe { + check_status!(sys::napi_create_string_utf8( + env.raw(), + c_reason.as_ptr(), + c_reason.to_bytes().len() as isize, + &mut msg, + ))?; + check_status!(sys::napi_create_error( + env.raw(), + ptr::null_mut(), + msg, + &mut error + ))?; + deferred.resolve(env, error, ResolveOrReject::Reject)?; + } + Ok(()) +} + +#[napi(no_export)] +fn noop_callback() { + // No-op callback for creating the ThreadsafeFunction. +} + +/// Initialize the direct-poll bridge. Must be called once before any +/// bridged async function. This function must be called only once. +/// +/// Creates a dedicated `multi_thread(1)` Tokio runtime whose single worker +/// thread drives the reactor (epoll/kqueue). A single weak TSFN is used +/// as the cross-thread wake mechanism — ABI-stable, cross-platform, no +/// direct libuv dependency. +#[napi] +pub fn init_poll_bridge(env: Env) -> JsResult<()> { + with_custom_error_sync(|| { + let rt = tokio::runtime::Builder::new_multi_thread() + .worker_threads(1) + .enable_all() + .build()?; + + // Create the TSFN from any c callback. This callback will be replaced in the build_callback step, + // but we still need to provide c function, to use napi-rs callback builder. + // We could do this directly through node-api interface, but here napi-rs simplifies this process. + // We also have to use callback witch matching type, to ensure everything runs correctly. + let noop_fn = env.create_function::<(), ()>("pollBridgeNoop", noop_callback_c_callback)?; + + let tsfn = noop_fn + .build_threadsafe_function::<()>() + // We will manually ref/unref this tsfn based on whether we have outstanding futures. + .weak::() + .build_callback(|ctx| { + let raw_env = ctx.env; + REGISTRY.with(|r| { + r.borrow_mut().poll_woken(raw_env); + }); + Ok(()) + })?; + + REGISTRY.with(|r| { + let mut reg = r.borrow_mut(); + reg.tokio_rt = Some(rt); + reg.bridge.set_tsfn(tsfn); + }); + + // Cleanup hook — shut down the runtime when Node exits. + env.add_env_cleanup_hook((), |_| { + REGISTRY.with(|r| { + r.borrow_mut().shutdown(); + }); + })?; + + if INITIALIZED.swap(true, Ordering::SeqCst) { + return Err(Error::from_reason( + "init_poll_bridge can only be called once", + )); + } + + Ok(()) + }) +} + +/// Submit a typed Rust future to be polled directly by the Node event loop. +/// +/// Future can return a typed value `T` on success +/// or a `ConvertedError` on failure. Both `T` and `ConvertedError` are converted to JS values via +/// `ToNapiValue` on the main thread when the future settles. +pub fn submit_future(env: &Env, fut: F) -> ConvertedResult> +where + F: Future> + Send + 'static, + T: napi::bindgen_prelude::ToNapiValue + Send + 'static, +{ + // This is a driver error, so panic is warranted here. There is no reasonable way to recover. + assert!( + INITIALIZED.load(Ordering::Relaxed), + "init_poll_bridge must be called before submit_future. This is a bug in the driver." + ); + + let (deferred, promise) = create_promise(env)?; + + let boxed: BridgedFuture = Box::pin(async move { + let result = fut.await; + Box::new(move |env: Env, deferred: DeferredPtr| unsafe { + // SAFETY: This closure is only ever invoked from `poll_woken`, which runs + // on the Node main thread inside the TSFN callback - the only place where + // `env` is a valid napi_env. `deferred` is consumed exactly once here, + // satisfying the napi contract that each deferred is resolved or rejected + // exactly once. `to_napi_value` receives the same valid `env`. + let (js_val, resolve) = match result { + Ok(val) => (T::to_napi_value(env.raw(), val), ResolveOrReject::Resolve), + Err(err) => ( + ConvertedError::to_napi_value(env.raw(), err), + ResolveOrReject::Reject, + ), + }; + + let status = match js_val { + Ok(v) => deferred.resolve(env, v, resolve), + Err(e) => reject_with_reason(env, deferred, &e.reason), + }; + + if let Err(e) = status { + panic!( + "Failed to settle promise in TSFN callback. This may indicate either a bug in the driver or a severe runtime error.\nRoot cause:\n {}", + e.reason + ); + } + }) as SettleCallback + }); + + REGISTRY.with(|r| r.borrow_mut().insert(env, boxed, deferred))?; + Ok(JsPromise(promise, PhantomData)) +} diff --git a/src/errors.rs b/src/errors.rs index ef332f261..e30d83bdb 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -130,6 +130,12 @@ where } } +impl std::fmt::Display for ConvertedError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}: {}", self.name, self.msg) + } +} + impl ToNapiValue for ConvertedError { /// # Safety /// @@ -147,19 +153,6 @@ impl ToNapiValue for ConvertedError { } } -/// Allows to run a block of code that returns Result, -/// with automatic conversion to JsResult. This allows to use the `?` operator, -/// while still returning JsResult from the function. -/// Version for async functions -pub(crate) async fn with_custom_error_async(code: C) -> JsResult -where - C: AsyncFnOnce() -> In, - In: IntoConvertedResult, -{ - let c = code().await; - c.into_converted_result().into() -} - /// Allows to run a block of code that returns Result, /// with automatic conversion to JsResult. This allows to use the `?` operator, /// while still returning JsResult from the function. diff --git a/src/lib.rs b/src/lib.rs index b178c3f8f..409662b65 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,8 +2,10 @@ extern crate napi_derive; // Link other files +pub mod casync; pub mod errors; pub mod metadata; +pub mod napi_helpers; pub mod options; pub mod paging; pub mod requests; diff --git a/src/napi_helpers.rs b/src/napi_helpers.rs new file mode 100644 index 000000000..5bce47889 --- /dev/null +++ b/src/napi_helpers.rs @@ -0,0 +1,50 @@ +use std::{marker::PhantomData, rc::Rc}; + +use napi::{Env, Result, bindgen_prelude::check_status, sys}; + +/// Wrapper over napi_deferred pointer, that ensures safe usage of the pointer and prevents double resolve/reject. +pub(crate) struct DeferredPtr { + ptr: sys::napi_deferred, + // We want to block DeferredPtr from being Send or Sync, + // as we can use napi_deferred pointer can be used only in the main nodejs thread. + _not_send_sync: PhantomData>, +} + +pub(crate) enum ResolveOrReject { + Resolve, + Reject, +} + +impl DeferredPtr { + /// # Safety + /// The pointer must not have been resolved or rejected yet, and must point to a valid napi_deferred. + pub(crate) unsafe fn new(ptr: sys::napi_deferred) -> Self { + Self { + ptr, + _not_send_sync: PhantomData, + } + } + + /// # Safety + /// Valid pointer to value must be provided + pub(crate) unsafe fn resolve( + self, + env: Env, + value: sys::napi_value, + mode: ResolveOrReject, + ) -> Result<()> { + // We can use the napi_deferred only once, as per napi documentation, + // any calls to resolve it, will free the value: https://nodejs.org/api/n-api.html#promises + // While there is no specification what happens if the call fails, it's safer to assume + // the pointer is no longer valid, and we are in non-recoverable state. + + // SAFETY: Constraints of this class ensure validity of the deref pointer, + // and Env ensures validity of the napi_env. + // Caller ensures validity of the value pointer. + if let ResolveOrReject::Resolve = mode { + unsafe { check_status!(sys::napi_resolve_deferred(env.raw(), self.ptr, value)) } + } else { + unsafe { check_status!(sys::napi_reject_deferred(env.raw(), self.ptr, value)) } + } + } +} diff --git a/src/session.rs b/src/session.rs index 190dca995..d1451d2c3 100644 --- a/src/session.rs +++ b/src/session.rs @@ -1,23 +1,22 @@ pub mod config; use std::sync::Arc; -use config::SessionOptions; -use scylla::client::caching_session::CachingSession; -use scylla::response::{PagingState, PagingStateResponse}; -use scylla::statement::batch::Batch; -use scylla::statement::{Consistency, SerialConsistency, Statement}; - +use crate::casync::{JsPromise, submit_future}; use crate::errors::{ - ConvertedError, ConvertedResult, JsResult, make_js_error, with_custom_error_async, - with_custom_error_sync, + ConvertedError, ConvertedResult, JsResult, make_js_error, with_custom_error_sync, }; use crate::paging::{PagingResult, PagingResultWithExecutor, PagingStateWrapper}; use crate::requests::request::{QueryOptionsObj, QueryOptionsWrapper}; -use crate::session::config::configure_session_builder; +use crate::session::config::{SessionOptions, configure_session_builder}; use crate::types::encoded_data::EncodedValuesWrapper; use crate::types::type_wrappers::ComplexType; use crate::utils::bigint_to_i64; use crate::{requests::request::PreparedStatementWrapper, result::QueryResultWrapper}; +use napi::bindgen_prelude::Env; +use scylla::client::caching_session::CachingSession; +use scylla::response::{PagingState, PagingStateResponse}; +use scylla::statement::batch::Batch; +use scylla::statement::{Consistency, SerialConsistency, Statement}; const DEFAULT_CACHE_SIZE: u32 = 512; @@ -28,7 +27,7 @@ pub struct BatchWrapper { #[napi] pub struct SessionWrapper { - pub(crate) inner: CachingSession, + pub(crate) inner: Arc, } /// This object allows executing queries for following pages of the result, @@ -55,58 +54,52 @@ impl QueryExecutor { } } -impl QueryExecutor { - async fn fetch_next_page_internal( - &self, - session: &SessionWrapper, - paging_state: Option<&PagingStateWrapper>, - ) -> ConvertedResult { - let paging_state = paging_state - .map(|e| e.inner.clone()) - .unwrap_or(PagingState::start()); - - let (result, paging_state_response) = if self.is_prepared { - session - .inner - .execute_single_page( - Statement::clone(self.statement.as_ref()), - self.params.as_ref(), - paging_state, - ) - .await - } else { - session - .inner - .get_session() - .query_single_page( - Statement::clone(self.statement.as_ref()), - self.params.as_ref(), - paging_state, - ) - .await - }?; - - Ok(PagingResult { - result: QueryResultWrapper::from_query(result)?, - paging_state: match paging_state_response { - PagingStateResponse::HasMorePages { state } => { - Some(PagingStateWrapper { inner: state }) - } - PagingStateResponse::NoMorePages => None, - }, - }) - } -} #[napi] impl QueryExecutor { #[napi(ts_return_type = "Promise")] - pub async fn fetch_next_page( + pub fn fetch_next_page( &self, + env: Env, session: &SessionWrapper, paging_state: Option<&PagingStateWrapper>, - ) -> JsResult { - with_custom_error_async(async || self.fetch_next_page_internal(session, paging_state).await) - .await + ) -> JsResult> { + with_custom_error_sync(|| { + let params = Arc::clone(&self.params); + let statement = Arc::clone(&self.statement); + let is_prepared = self.is_prepared; + let session_inner = Arc::clone(&session.inner); + let paging_state_inner = paging_state.map(|p| p.inner.clone()); + submit_future(&env, async move { + let paging_state = paging_state_inner.unwrap_or(PagingState::start()); + let (result, paging_state_response) = if is_prepared { + session_inner + .execute_single_page( + Statement::clone(statement.as_ref()), + params.as_ref(), + paging_state, + ) + .await + } else { + session_inner + .get_session() + .query_single_page( + Statement::clone(statement.as_ref()), + params.as_ref(), + paging_state, + ) + .await + }?; + Ok(PagingResult { + result: QueryResultWrapper::from_query(result)?, + paging_state: match paging_state_response { + PagingStateResponse::HasMorePages { state } => { + Some(PagingStateWrapper { inner: state }) + } + PagingStateResponse::NoMorePages => None, + }, + }) + }) + }) } } @@ -114,15 +107,21 @@ impl QueryExecutor { impl SessionWrapper { /// Creates session based on the provided session options. #[napi(ts_return_type = "Promise")] - pub async fn create_session(options: SessionOptions) -> JsResult { - with_custom_error_async(async || { - let cache_size = options.cache_size.unwrap_or(DEFAULT_CACHE_SIZE) as usize; - let builder = configure_session_builder(options)?; - let session = builder.build().await?; - let session: CachingSession = CachingSession::from(session, cache_size); - ConvertedResult::Ok(SessionWrapper { inner: session }) + pub fn create_session( + env: Env, + options: SessionOptions, + ) -> JsResult> { + with_custom_error_sync(|| { + submit_future(&env, async move { + let cache_size = options.cache_size.unwrap_or(DEFAULT_CACHE_SIZE) as usize; + let builder = configure_session_builder(options)?; + let session = builder.build().await?; + let session: CachingSession = CachingSession::from(session, cache_size); + Ok(SessionWrapper { + inner: Arc::new(session), + }) + }) }) - .await } /// Returns the name of the current keyspace @@ -143,43 +142,43 @@ impl SessionWrapper { /// -- each value must be tuple of its ComplexType and the value itself. /// If the provided types will not be correct, this query will fail. #[napi(ts_return_type = "Promise")] - pub async fn query_unpaged_encoded( + pub fn query_unpaged_encoded( &self, + env: Env, query: String, params: Vec, options: &QueryOptionsWrapper, - ) -> JsResult { - with_custom_error_async(async || { - let statement: Statement = - self.apply_statement_options(query.into(), &options.options)?; - let query_result = self - .inner - .get_session() - .query_unpaged(statement, params) - .await?; - QueryResultWrapper::from_query(query_result) + ) -> JsResult> { + with_custom_error_sync(|| { + let statement = self.apply_statement_options(query.into(), &options.options)?; + let inner = Arc::clone(&self.inner); + submit_future(&env, async move { + let query_result = inner.get_session().query_unpaged(statement, params).await?; + QueryResultWrapper::from_query(query_result) + }) }) - .await } /// Prepares a statement through rust driver for a given session /// Return expected types for the prepared statement #[napi(ts_return_type = "Promise>")] - pub async fn prepare_statement( + pub fn prepare_statement( &self, + env: Env, statement: String, - ) -> JsResult>> { - with_custom_error_async(async || { - let statement: Statement = statement.into(); - let w = PreparedStatementWrapper { - prepared: self - .inner - .add_prepared_statement(&statement) // TODO: change for add_prepared_statement_to_owned after it is made public - .await?, - }; - ConvertedResult::Ok(w.get_expected_types()) + ) -> JsResult>>> { + with_custom_error_sync(|| { + let inner = Arc::clone(&self.inner); + submit_future(&env, async move { + let statement: Statement = statement.into(); + let w = PreparedStatementWrapper { + prepared: inner + .add_prepared_statement(&statement) // TODO: change for add_prepared_statement_to_owned after it is made public + .await?, + }; + Ok(w.get_expected_types()) + }) }) - .await } /// Execute a given prepared statement against the database with provided parameters. @@ -193,33 +192,39 @@ impl SessionWrapper { /// Currently `execute_unpaged` from rust driver is used, so no paging is done /// and there is no support for any query options #[napi(ts_return_type = "Promise")] - pub async fn execute_prepared_unpaged_encoded( + pub fn execute_prepared_unpaged_encoded( &self, + env: Env, query: String, params: Vec, options: &QueryOptionsWrapper, - ) -> JsResult { - with_custom_error_async(async || { + ) -> JsResult> { + with_custom_error_sync(|| { let query = self.apply_statement_options(query.into(), &options.options)?; - QueryResultWrapper::from_query(self.inner.execute_unpaged(query, params).await?) + let inner = Arc::clone(&self.inner); + submit_future(&env, async move { + QueryResultWrapper::from_query(inner.execute_unpaged(query, params).await?) + }) }) - .await } /// Executes all statements in the provided batch. Those statements can be either prepared or unprepared. /// /// Returns a wrapper of the result provided by the rust driver #[napi(ts_return_type = "Promise")] - pub async fn batch_encoded( + pub fn batch_encoded( &self, + env: Env, batch: &BatchWrapper, params: Vec>, - ) -> JsResult { - with_custom_error_async(async || { - let res = self.inner.batch(&batch.inner, params).await?; - QueryResultWrapper::from_query(res) + ) -> JsResult> { + with_custom_error_sync(|| { + let batch = batch.inner.clone(); + let inner = Arc::clone(&self.inner); + submit_future(&env, async move { + QueryResultWrapper::from_query(inner.batch(&batch, params).await?) + }) }) - .await } /// Query a single page of a prepared statement @@ -228,27 +233,42 @@ impl SessionWrapper { /// For the following pages you need to provide page state /// received from the previous page #[napi(ts_return_type = "Promise")] - pub async fn query_single_page_encoded( + pub fn query_single_page_encoded( &self, + env: Env, query: String, params: Vec, options: &QueryOptionsWrapper, paging_state: Option<&PagingStateWrapper>, - ) -> JsResult { - with_custom_error_async(async || { + ) -> JsResult> { + with_custom_error_sync(|| { let statement = Arc::new(self.apply_statement_options(query.into(), &options.options)?); - let params = Arc::new(params); - - let executor = QueryExecutor::new(statement, params, false); - - let res = executor - .fetch_next_page_internal(self, paging_state) - .await?; - - ConvertedResult::Ok(res.with_executor(executor)) + let paging_state_inner = paging_state.map(|p| p.inner.clone()); + let inner = Arc::clone(&self.inner); + submit_future(&env, async move { + let paging_state = paging_state_inner.unwrap_or(PagingState::start()); + let (result, paging_state_response) = inner + .get_session() + .query_single_page( + Statement::clone(statement.as_ref()), + params.as_ref(), + paging_state, + ) + .await?; + let paging_result = PagingResult { + result: QueryResultWrapper::from_query(result)?, + paging_state: match paging_state_response { + PagingStateResponse::HasMorePages { state } => { + Some(PagingStateWrapper { inner: state }) + } + PagingStateResponse::NoMorePages => None, + }, + }; + let executor = QueryExecutor::new(statement, params, false); + Ok(paging_result.with_executor(executor)) + }) }) - .await } /// Execute a single page of a prepared statement @@ -257,27 +277,41 @@ impl SessionWrapper { /// For the following pages you need to provide page state /// received from the previous page #[napi(ts_return_type = "Promise")] - pub async fn execute_single_page_encoded( + pub fn execute_single_page_encoded( &self, + env: Env, query: String, params: Vec, options: &QueryOptionsWrapper, paging_state: Option<&PagingStateWrapper>, - ) -> JsResult { - with_custom_error_async(async || { + ) -> JsResult> { + with_custom_error_sync(|| { let statement = Arc::new(self.apply_statement_options(query.into(), &options.options)?); - let params = Arc::new(params); - - let executor = QueryExecutor::new(statement, params, true); - - let res = executor - .fetch_next_page_internal(self, paging_state) - .await?; - - ConvertedResult::Ok(res.with_executor(executor)) + let paging_state_inner = paging_state.map(|p| p.inner.clone()); + let inner = Arc::clone(&self.inner); + submit_future(&env, async move { + let paging_state = paging_state_inner.unwrap_or(PagingState::start()); + let (result, paging_state_response) = inner + .execute_single_page( + Statement::clone(statement.as_ref()), + params.as_ref(), + paging_state, + ) + .await?; + let paging_result = PagingResult { + result: QueryResultWrapper::from_query(result)?, + paging_state: match paging_state_response { + PagingStateResponse::HasMorePages { state } => { + Some(PagingStateWrapper { inner: state }) + } + PagingStateResponse::NoMorePages => None, + }, + }; + let executor = QueryExecutor::new(statement, params, true); + Ok(paging_result.with_executor(executor)) + }) }) - .await } /// Creates object representing batch of statements. diff --git a/src/tests/casync_tests.rs b/src/tests/casync_tests.rs new file mode 100644 index 000000000..bc5f5cbd0 --- /dev/null +++ b/src/tests/casync_tests.rs @@ -0,0 +1,145 @@ +use std::time::Duration; + +use napi::bindgen_prelude::*; + +use crate::casync::{JsPromise, submit_future}; +use crate::errors::{ConvertedError, JsResult, with_custom_error_sync}; + +// --------------------------------------------------------------------------- +// Resolve paths +// --------------------------------------------------------------------------- + +/// Resolves with 42 on the very first poll (no Pending). +/// Tests the synchronous-completion fast path. +#[napi] +pub fn tests_casync_resolve_immediate(env: Env) -> JsResult> { + with_custom_error_sync(|| submit_future(&env, async move { Ok::(42) })) +} + +/// Resolves with `millis` after sleeping for `millis` milliseconds. +/// The sleep causes the future to return Pending on the first poll; the Tokio +/// reactor fires the waker from its worker thread when the timer expires, +/// exercising the cross-thread waker → TSFN → poll_woken path. +#[napi] +pub fn tests_casync_resolve_delayed(env: Env, millis: u32) -> JsResult> { + with_custom_error_sync(|| { + submit_future(&env, async move { + tokio::time::sleep(Duration::from_millis(millis as u64)).await; + Ok::(millis as i32) + }) + }) +} + +/// Resolves with a String value. +/// Tests a different ToNapiValue type so that type erasure in BoxFuture does +/// not silently confuse return types. +#[napi] +pub fn tests_casync_resolve_string(env: Env) -> JsResult> { + with_custom_error_sync(|| { + submit_future(&env, async move { + Ok::("hello from async".to_string()) + }) + }) +} + +/// Resolves with a bool. +#[napi] +pub fn tests_casync_resolve_bool(env: Env, value: bool) -> JsResult> { + with_custom_error_sync(|| submit_future(&env, async move { Ok::(value) })) +} + +// --------------------------------------------------------------------------- +// Reject paths +// --------------------------------------------------------------------------- + +/// Rejects with a ConvertedError produced from a real scylla error. +/// The JS side can assert `.message` and `.name` on the rejection value. +#[napi] +pub fn tests_casync_reject(env: Env) -> JsResult> { + with_custom_error_sync(|| { + submit_future(&env, async move { + Err::(scylla::errors::BadKeyspaceName::Empty.into()) + }) + }) +} + +/// Rejects after a delay, exercising the waker path on the error branch. +#[napi] +pub fn tests_casync_reject_delayed(env: Env, millis: u32) -> JsResult> { + with_custom_error_sync(|| { + submit_future(&env, async move { + tokio::time::sleep(Duration::from_millis(millis as u64)).await; + Err::(scylla::errors::BadKeyspaceName::Empty.into()) + }) + }) +} + +/// Rejects with a ConvertedError whose message contains an interior null byte. +/// This exercises the CString::new fallback in reject_with_reason — the error +/// is produced by a type whose Display output contains '\0'. Because the normal +/// ConvertedError::to_napi_value path uses napi-rs string APIs (not CString), +/// the null byte only matters when that path itself fails, causing reject_with_reason +/// to be called. We trigger that by making T::to_napi_value fail: the future +/// succeeds (Ok variant), but the value cannot be converted, so the settle +/// callback falls through to reject_with_reason. +/// +/// More practically this test validates that a ConvertedError with a null byte +/// does NOT crash the process — the promise is simply rejected with a fallback +/// message. +#[napi] +pub fn tests_casync_reject_null_byte(env: Env) -> JsResult> { + /// An error whose Display contains an interior null byte. + struct NullByteError; + + impl std::fmt::Display for NullByteError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // The \0 makes CString::new fail when reject_with_reason is called. + write!(f, "error with\0null byte") + } + } + + impl std::fmt::Debug for NullByteError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "NullByteError") + } + } + + impl std::error::Error for NullByteError {} + + with_custom_error_sync(|| { + submit_future(&env, async move { + Err::(NullByteError.into()) + }) + }) +} + +// --------------------------------------------------------------------------- +// Waker path +// --------------------------------------------------------------------------- + +/// Submits a future that is woken multiple times before its second poll. +/// Uses tokio::sync::Notify: a spawned task calls notify_one() twice in quick +/// succession. The first notification wakes the future; the second fires while +/// the waker may still be queued, exercising the coalesced-wake path in +/// WakerBridge::signal (the signaled AtomicBool prevents duplicate TSFN calls). +/// The promise must still resolve exactly once with the correct value. +#[napi] +pub fn tests_casync_multi_wake(env: Env) -> JsResult> { + with_custom_error_sync(|| { + submit_future(&env, async move { + // This future is polled inside rt.enter(), so tokio::spawn is valid here. + let notify = std::sync::Arc::new(tokio::sync::Notify::new()); + let notify2 = std::sync::Arc::clone(¬ify); + + tokio::spawn(async move { + // Fire two notifications back-to-back. The waker may fire twice + // before poll_woken runs, which tests the coalescing in WakerBridge. + notify2.notify_one(); + notify2.notify_one(); + }); + + notify.notified().await; + Ok::(99) + }) + }) +} diff --git a/src/tests/mod.rs b/src/tests/mod.rs index 6817428c7..7d573dd5b 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -1,3 +1,4 @@ +pub mod casync_tests; pub mod js_results_tests; pub mod option_tests; pub mod socket_addr_tests; diff --git a/test/unit/casync-tests.js b/test/unit/casync-tests.js new file mode 100644 index 000000000..e542a9083 --- /dev/null +++ b/test/unit/casync-tests.js @@ -0,0 +1,151 @@ +"use strict"; + +const { assert } = require("chai"); +const rust = require("../../index"); +const helper = require("../test-helper"); + +// init_poll_bridge is called by lib/client.js at module load time. +// Require it here to ensure the bridge is ready before any test runs. +require("../../lib/client"); + +describe("casync bridge", function () { + // --------------------------------------------------------------------------- + // Resolve paths + // --------------------------------------------------------------------------- + + describe("resolve", function () { + it("should resolve immediately with a numeric value", async function () { + const result = await rust.testsCasyncResolveImmediate(); + assert.strictEqual(result, 42); + }); + + it("should resolve after a delay", async function () { + const result = await rust.testsCasyncResolveDelayed(50); + assert.strictEqual(result, 50); + }); + + it("should resolve with a string value", async function () { + const result = await rust.testsCasyncResolveString(); + assert.strictEqual(result, "hello from async"); + }); + + it("should resolve with a boolean value", async function () { + assert.strictEqual(await rust.testsCasyncResolveBool(true), true); + assert.strictEqual(await rust.testsCasyncResolveBool(false), false); + }); + }); + + // --------------------------------------------------------------------------- + // Reject paths + // --------------------------------------------------------------------------- + + describe("reject", function () { + it("should reject with the correct error message and name", async function () { + try { + await rust.testsCasyncReject(); + assert.fail("Promise should have been rejected"); + } catch (e) { + helper.assertInstanceOf(e, Error); + assert.strictEqual(e.message, "Keyspace name is empty"); + assert.strictEqual(e.name, "BadKeyspaceName"); + } + }); + + it("should reject with the correct error after a delay", async function () { + try { + await rust.testsCasyncRejectDelayed(30); + assert.fail("Promise should have been rejected"); + } catch (e) { + helper.assertInstanceOf(e, Error); + assert.strictEqual(e.message, "Keyspace name is empty"); + assert.strictEqual(e.name, "BadKeyspaceName"); + } + }); + + it("should reject cleanly even when the error message contains a null byte", async function () { + // The promise must reject (not crash) when ConvertedError::msg has \0. + try { + await rust.testsCasyncRejectNullByte(); + assert.fail("Promise should have been rejected"); + } catch (e) { + helper.assertInstanceOf(e, Error); + // The message may be truncated or replaced — the important thing + // is that the process did not crash and the promise was rejected. + } + }); + }); + + // --------------------------------------------------------------------------- + // Concurrency + // --------------------------------------------------------------------------- + + describe("concurrency", function () { + it("should resolve many concurrent futures", async function () { + const N = 50; + const results = await Promise.all( + Array.from({ length: N }, () => + rust.testsCasyncResolveImmediate(), + ), + ); + assert.strictEqual(results.length, N); + results.forEach((v) => assert.strictEqual(v, 42)); + }); + + it("should correctly resolve a mix of delayed and immediate futures", async function () { + const [delayed, immediate] = await Promise.all([ + rust.testsCasyncResolveDelayed(20), + rust.testsCasyncResolveImmediate(), + ]); + assert.strictEqual(delayed, 20); + assert.strictEqual(immediate, 42); + }); + + it("should handle a mix of resolving and rejecting futures", async function () { + const N = 20; + const promises = Array.from({ length: N }, (_, i) => + i % 2 === 0 + ? rust.testsCasyncResolveImmediate().then((v) => ({ + ok: true, + value: v, + })) + : rust.testsCasyncReject().then( + () => assert.fail("Should not resolve"), + (e) => ({ ok: false, error: e }), + ), + ); + + const results = await Promise.all(promises); + results.forEach((r, i) => { + if (i % 2 === 0) { + assert.isTrue(r.ok); + assert.strictEqual(r.value, 42); + } else { + assert.isFalse(r.ok); + assert.strictEqual(r.error.name, "BadKeyspaceName"); + } + }); + }); + }); + + // --------------------------------------------------------------------------- + // Waker correctness + // --------------------------------------------------------------------------- + + describe("waker", function () { + it("should resolve exactly once even when the waker fires multiple times", async function () { + // The future notifies twice before being polled — the coalescing + // AtomicBool in WakerBridge must prevent double-resolution. + const result = await rust.testsCasyncMultiWake(); + assert.strictEqual(result, 99); + }); + + it("should resolve multiple multi-wake futures concurrently", async function () { + const results = await Promise.all([ + rust.testsCasyncMultiWake(), + rust.testsCasyncMultiWake(), + rust.testsCasyncMultiWake(), + ]); + results.forEach((v) => assert.strictEqual(v, 99)); + }); + }); +});