From 0902d1781cf02aa30d40362233e8786cbc5ff38b Mon Sep 17 00:00:00 2001 From: itowlson Date: Fri, 14 Nov 2025 13:21:53 +1300 Subject: [PATCH] Async PostgreSQL API Signed-off-by: itowlson --- Cargo.lock | 11 ++ crates/factor-outbound-pg/Cargo.toml | 2 + .../factor-outbound-pg/src/allowed_hosts.rs | 67 +++++++ crates/factor-outbound-pg/src/client.rs | 97 ++++++++-- crates/factor-outbound-pg/src/host.rs | 126 ++++++++----- crates/factor-outbound-pg/src/lib.rs | 31 ++-- crates/factor-outbound-pg/src/types.rs | 25 ++- .../factor-outbound-pg/src/types/convert.rs | 2 +- .../factor-outbound-pg/src/types/interval.rs | 2 +- .../factor-outbound-pg/tests/factor_test.rs | 23 ++- crates/wasi-async/Cargo.toml | 10 + crates/wasi-async/src/future.rs | 30 +++ crates/wasi-async/src/lib.rs | 2 + crates/wasi-async/src/stream.rs | 45 +++++ crates/world/src/conversions.rs | 6 +- crates/world/src/lib.rs | 1 + wit/deps/spin-postgres@4.2.0/postgres.wit | 171 ++++++++++++++++++ wit/world.wit | 2 +- 18 files changed, 571 insertions(+), 82 deletions(-) create mode 100644 crates/factor-outbound-pg/src/allowed_hosts.rs create mode 100644 crates/wasi-async/Cargo.toml create mode 100644 crates/wasi-async/src/future.rs create mode 100644 crates/wasi-async/src/lib.rs create mode 100644 crates/wasi-async/src/stream.rs create mode 100644 wit/deps/spin-postgres@4.2.0/postgres.wit diff --git a/Cargo.lock b/Cargo.lock index 0ccc6f3d74..7506dfcb46 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8736,6 +8736,7 @@ dependencies = [ "bytes", "chrono", "deadpool-postgres", + "futures", "moka", "native-tls", "postgres-native-tls", @@ -8749,6 +8750,7 @@ dependencies = [ "spin-factors", "spin-factors-test", "spin-resource-table", + "spin-wasi-async", "spin-world", "tokio", "tokio-postgres", @@ -9431,6 +9433,15 @@ dependencies = [ "vaultrs", ] +[[package]] +name = "spin-wasi-async" +version = "3.7.0-pre0" +dependencies = [ + "anyhow", + "spin-core", + "tokio", +] + [[package]] name = "spin-world" version = "3.7.0-pre0" diff --git a/crates/factor-outbound-pg/Cargo.toml b/crates/factor-outbound-pg/Cargo.toml index 6436d55268..c42d5fd63c 100644 --- a/crates/factor-outbound-pg/Cargo.toml +++ b/crates/factor-outbound-pg/Cargo.toml @@ -9,6 +9,7 @@ anyhow = { workspace = true } bytes = {workspace = true } chrono = { workspace = true } deadpool-postgres = { version = "0.14", features = ["rt_tokio_1"] } +futures = { workspace = true } moka = { version = "0.12", features = ["sync"] } native-tls = "0.2" postgres-native-tls = "0.5" @@ -20,6 +21,7 @@ spin-factor-otel = { path = "../factor-otel" } spin-factor-outbound-networking = { path = "../factor-outbound-networking" } spin-factors = { path = "../factors" } spin-resource-table = { path = "../table" } +spin-wasi-async = { path = "../wasi-async" } spin-world = { path = "../world" } tokio = { workspace = true, features = ["rt-multi-thread"] } tokio-postgres = { version = "0.7", features = ["with-chrono-0_4", "with-serde_json-1", "with-uuid-1"] } diff --git a/crates/factor-outbound-pg/src/allowed_hosts.rs b/crates/factor-outbound-pg/src/allowed_hosts.rs new file mode 100644 index 0000000000..7612185b62 --- /dev/null +++ b/crates/factor-outbound-pg/src/allowed_hosts.rs @@ -0,0 +1,67 @@ +use std::sync::Arc; + +use spin_factor_outbound_networking::config::allowed_hosts::OutboundAllowedHosts; +use spin_world::spin::postgres4_2_0::postgres::{self as v4}; + +/// Encapsulates checking of a PostgreSQL address/connection string against +/// an allow-list. +/// +/// This is broken out as a distinct object to allow it to be synchronously retrieved +/// within a P3 Accessor block and then asynchronously queried outside the block. +#[derive(Clone)] +pub(crate) struct AllowedHostChecker { + allowed_hosts: Arc, +} + +impl AllowedHostChecker { + pub fn new(allowed_hosts: OutboundAllowedHosts) -> Self { + Self { + allowed_hosts: Arc::new(allowed_hosts), + } + } + #[allow(clippy::result_large_err)] + pub async fn ensure_address_allowed(&self, address: &str) -> Result<(), v4::Error> { + fn conn_failed(message: impl Into) -> v4::Error { + v4::Error::ConnectionFailed(message.into()) + } + fn err_other(err: anyhow::Error) -> v4::Error { + v4::Error::Other(err.to_string()) + } + + let config = address + .parse::() + .map_err(|e| conn_failed(e.to_string()))?; + + for (i, host) in config.get_hosts().iter().enumerate() { + match host { + tokio_postgres::config::Host::Tcp(address) => { + let ports = config.get_ports(); + // The port we use is either: + // * The port at the same index as the host + // * The first port if there is only one port + let port = + ports + .get(i) + .or_else(|| if ports.len() == 1 { ports.get(1) } else { None }); + let port_str = port.map(|p| format!(":{p}")).unwrap_or_default(); + let url = format!("{address}{port_str}"); + if !self + .allowed_hosts + .check_url(&url, "postgres") + .await + .map_err(err_other)? + { + return Err(conn_failed(format!( + "address postgres://{url} is not permitted" + ))); + } + } + #[cfg(unix)] + tokio_postgres::config::Host::Unix(_) => { + return Err(conn_failed("Unix sockets are not supported on WebAssembly")); + } + } + } + Ok(()) + } +} diff --git a/crates/factor-outbound-pg/src/client.rs b/crates/factor-outbound-pg/src/client.rs index 6a8bab7da9..c9cff04b95 100644 --- a/crates/factor-outbound-pg/src/client.rs +++ b/crates/factor-outbound-pg/src/client.rs @@ -1,14 +1,18 @@ +use std::sync::Arc; + use anyhow::{Context, Result}; use native_tls::TlsConnector; use postgres_native_tls::MakeTlsConnector; use spin_world::async_trait; -use spin_world::spin::postgres4_0_0::postgres::{ +use spin_world::spin::postgres4_2_0::postgres::{ self as v4, Column, DbValue, ParameterValue, RowSet, }; use tokio_postgres::types::ToSql; use tokio_postgres::{config::SslMode, NoTls, Row}; -use crate::types::{convert_data_type, convert_entry, to_sql_parameter}; +use crate::types::{ + as_sql_parameter_refs, convert_data_type, convert_entry, to_sql_parameter, to_sql_parameters, +}; /// Max connections in a given address' connection pool const CONNECTION_POOL_SIZE: usize = 64; @@ -40,7 +44,7 @@ impl Default for PooledTokioClientFactory { #[async_trait] impl ClientFactory for PooledTokioClientFactory { - type Client = deadpool_postgres::Object; + type Client = Arc; async fn get_client(&self, address: &str) -> Result { let pool = self @@ -49,7 +53,7 @@ impl ClientFactory for PooledTokioClientFactory { .map_err(ArcError) .context("establishing PostgreSQL connection pool")?; - Ok(pool.get().await?) + Ok(Arc::new(pool.get().await?)) } } @@ -85,7 +89,7 @@ fn create_connection_pool(address: &str) -> Result { } #[async_trait] -pub trait Client: Send + Sync + 'static { +pub trait Client: Clone + Send + Sync + 'static { async fn execute( &self, statement: String, @@ -97,6 +101,18 @@ pub trait Client: Send + Sync + 'static { statement: String, params: Vec, ) -> Result; + + async fn query_async( + &self, + statement: String, + params: Vec, + ) -> Result< + ( + tokio::sync::oneshot::Receiver>, + tokio::sync::mpsc::Receiver>, + ), + v4::Error, + >; } /// Extract weak-typed error data for WIT purposes @@ -142,7 +158,7 @@ fn query_failed(e: tokio_postgres::error::Error) -> v4::Error { } #[async_trait] -impl Client for deadpool_postgres::Object { +impl Client for Arc { async fn execute( &self, statement: String, @@ -170,16 +186,8 @@ impl Client for deadpool_postgres::Object { statement: String, params: Vec, ) -> Result { - let params = params - .iter() - .map(to_sql_parameter) - .collect::>>() - .map_err(|e| v4::Error::BadParameter(format!("{e:?}")))?; - - let params_refs: Vec<&(dyn ToSql + Sync)> = params - .iter() - .map(|b| b.as_ref() as &(dyn ToSql + Sync)) - .collect(); + let params = to_sql_parameters(params)?; + let params_refs = as_sql_parameter_refs(¶ms); let results = self .as_ref() @@ -203,6 +211,63 @@ impl Client for deadpool_postgres::Object { Ok(RowSet { columns, rows }) } + + async fn query_async( + &self, + statement: String, + params: Vec, + ) -> Result< + ( + tokio::sync::oneshot::Receiver>, + tokio::sync::mpsc::Receiver>, + ), + v4::Error, + > { + let params = to_sql_parameters(params)?; + let params_refs = as_sql_parameter_refs(¶ms); + + let stm = self + .as_ref() + .query_raw(&statement, params_refs) + .await + .map_err(query_failed)?; + + let (rows_tx, rows_rx) = tokio::sync::mpsc::channel(1000); + let (cols_tx, cols_rx) = tokio::sync::oneshot::channel(); + let mut cols_tx_opt = Some(cols_tx); + + let mut stm = Box::pin(stm); + + tokio::spawn(async move { + use futures::StreamExt; + loop { + let Some(row) = stm.next().await else { + break; + }; + // TODO: figure out how to deal with errors here - I think there is like a FutureReader pattern? + let row = match row { + Ok(r) => r, + Err(e) => { + let err = query_failed(e); + rows_tx.send(Err(err)).await.unwrap(); + break; + } + }; + if let Some(cols_tx) = cols_tx_opt.take() { + cols_tx.send(infer_columns(&row)).unwrap(); + } + match convert_row(&row) { + Ok(row) => rows_tx.send(Ok(row)).await.unwrap(), + Err(e) => { + let err = v4::Error::QueryFailed(v4::QueryError::Text(format!("{e:?}"))); + rows_tx.send(Err(err)).await.unwrap(); + } + } + } + }); + + Ok((cols_rx, rows_rx)) + } } fn infer_columns(row: &Row) -> Vec { diff --git a/crates/factor-outbound-pg/src/host.rs b/crates/factor-outbound-pg/src/host.rs index 1c642120cd..8740bb5b5f 100644 --- a/crates/factor-outbound-pg/src/host.rs +++ b/crates/factor-outbound-pg/src/host.rs @@ -1,7 +1,8 @@ use anyhow::Result; -use spin_core::wasmtime::component::Resource; +use spin_core::wasmtime; +use spin_core::wasmtime::component::{Accessor, FutureReader, Resource, StreamReader}; use spin_world::spin::postgres3_0_0::postgres::{self as v3}; -use spin_world::spin::postgres4_0_0::postgres::{self as v4}; +use spin_world::spin::postgres4_2_0::postgres::{self as v4}; use spin_world::v1::postgres as v1; use spin_world::v1::rdbms_types as v1_types; use spin_world::v2::postgres::{self as v2}; @@ -10,9 +11,14 @@ use tracing::field::Empty; use tracing::instrument; use tracing::Level; +use crate::allowed_hosts::AllowedHostChecker; use crate::client::{Client, ClientFactory}; use crate::InstanceState; +// Declare some types to make Clippy less mad +pub type RowStream = StreamReader>; +pub type ColumnsFuture = FutureReader>; + impl InstanceState { async fn open_connection( &mut self, @@ -38,50 +44,15 @@ impl InstanceState { .ok_or_else(|| v4::Error::ConnectionFailed("no connection found".into())) } + fn allowed_host_checker(&self) -> AllowedHostChecker { + self.allowed_host_checker.clone() + } + #[allow(clippy::result_large_err)] async fn ensure_address_allowed(&self, address: &str) -> Result<(), v4::Error> { - fn conn_failed(message: impl Into) -> v4::Error { - v4::Error::ConnectionFailed(message.into()) - } - fn err_other(err: anyhow::Error) -> v4::Error { - v4::Error::Other(err.to_string()) - } - - let config = address - .parse::() - .map_err(|e| conn_failed(e.to_string()))?; - - for (i, host) in config.get_hosts().iter().enumerate() { - match host { - tokio_postgres::config::Host::Tcp(address) => { - let ports = config.get_ports(); - // The port we use is either: - // * The port at the same index as the host - // * The first port if there is only one port - let port = - ports - .get(i) - .or_else(|| if ports.len() == 1 { ports.get(1) } else { None }); - let port_str = port.map(|p| format!(":{p}")).unwrap_or_default(); - let url = format!("{address}{port_str}"); - if !self - .allowed_hosts - .check_url(&url, "postgres") - .await - .map_err(err_other)? - { - return Err(conn_failed(format!( - "address postgres://{url} is not permitted" - ))); - } - } - #[cfg(unix)] - tokio_postgres::config::Host::Unix(_) => { - return Err(conn_failed("Unix sockets are not supported on WebAssembly")); - } - } - } - Ok(()) + self.allowed_host_checker + .ensure_address_allowed(address) + .await } } @@ -182,6 +153,73 @@ impl v4::HostConnection for InstanceState { } } +impl spin_world::spin::postgres4_2_0::postgres::HostConnectionWithStore + for crate::PgFactorData +{ + #[instrument(name = "spin_outbound_pg.open_async", skip(accessor, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))] + async fn open_async( + accessor: &Accessor, + address: String, + ) -> Result, v4::Error> { + spin_factor_outbound_networking::record_address_fields(&address); + + // A merry dance to avoid doing the async allow check under the accessor + let allowed_host_checker = accessor.with(|mut access| { + let host = access.get(); + host.allowed_host_checker() + }); + + allowed_host_checker + .ensure_address_allowed(&address) + .await?; + + let cf = accessor.with(|mut access| { + let host = access.get(); + host.client_factory.clone() + }); + let client = cf + .get_client(&address) + .await + .map_err(|e| v4::Error::ConnectionFailed(format!("{e:?}")))?; + let rsrc = accessor.with(|mut access| { + let host = access.get(); + host.connections + .push(client) + .map_err(|_| v4::Error::ConnectionFailed("too many connections".into())) + .map(wasmtime::component::Resource::new_own) + }); + rsrc + } + + #[instrument(name = "spin_outbound_pg.query", skip(accessor, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))] + async fn query_async( + accessor: &Accessor, + connection: Resource, + statement: String, + params: Vec, + ) -> Result<(ColumnsFuture, RowStream), v4::Error> { + use wasmtime::AsContextMut; + + let client = accessor.with(|mut access| { + let host = access.get(); + host.connections.get(connection.rep()).unwrap().clone() + }); + + let (col_rx, row_rx) = client.query_async(statement, params).await?; + + let row_producer = spin_wasi_async::stream::producer(row_rx); + let col_producer = spin_wasi_async::future::producer(col_rx); + + let (fr, sr) = accessor.with(|mut access| { + let fr = FutureReader::new(access.as_context_mut(), col_producer); + let sr = StreamReader::new(access.as_context_mut(), row_producer); + (fr, sr) + }); + + Ok((fr, sr)) + } +} + impl v2_types::Host for InstanceState { fn convert_error(&mut self, error: v2::Error) -> Result { Ok(error) diff --git a/crates/factor-outbound-pg/src/lib.rs b/crates/factor-outbound-pg/src/lib.rs index 4a8baaf170..9f1e6d236b 100644 --- a/crates/factor-outbound-pg/src/lib.rs +++ b/crates/factor-outbound-pg/src/lib.rs @@ -1,15 +1,14 @@ +mod allowed_hosts; pub mod client; mod host; mod types; +use allowed_hosts::AllowedHostChecker; use client::ClientFactory; use spin_factor_otel::OtelFactorState; -use spin_factor_outbound_networking::{ - config::allowed_hosts::OutboundAllowedHosts, OutboundNetworkingFactor, -}; +use spin_factor_outbound_networking::OutboundNetworkingFactor; use spin_factors::{ - anyhow, ConfigureAppContext, Factor, FactorData, PrepareContext, RuntimeFactors, - SelfInstanceBuilder, + anyhow, ConfigureAppContext, Factor, PrepareContext, RuntimeFactors, SelfInstanceBuilder, }; use std::sync::Arc; @@ -23,13 +22,13 @@ impl Factor for OutboundPgFactor { type InstanceBuilder = InstanceState; fn init(&mut self, ctx: &mut impl spin_factors::InitContext) -> anyhow::Result<()> { - ctx.link_bindings(spin_world::v1::postgres::add_to_linker::<_, FactorData>)?; - ctx.link_bindings(spin_world::v2::postgres::add_to_linker::<_, FactorData>)?; + ctx.link_bindings(spin_world::v1::postgres::add_to_linker::<_, PgFactorData>)?; + ctx.link_bindings(spin_world::v2::postgres::add_to_linker::<_, PgFactorData>)?; ctx.link_bindings( - spin_world::spin::postgres3_0_0::postgres::add_to_linker::<_, FactorData>, + spin_world::spin::postgres3_0_0::postgres::add_to_linker::<_, PgFactorData>, )?; ctx.link_bindings( - spin_world::spin::postgres4_0_0::postgres::add_to_linker::<_, FactorData>, + spin_world::spin::postgres4_2_0::postgres::add_to_linker::<_, PgFactorData>, )?; Ok(()) } @@ -51,7 +50,7 @@ impl Factor for OutboundPgFactor { let otel = OtelFactorState::from_prepare_context(&mut ctx)?; Ok(InstanceState { - allowed_hosts, + allowed_host_checker: AllowedHostChecker::new(allowed_hosts), client_factory: ctx.app_state().clone(), connections: Default::default(), otel, @@ -74,10 +73,20 @@ impl OutboundPgFactor { } pub struct InstanceState { - allowed_hosts: OutboundAllowedHosts, + allowed_host_checker: AllowedHostChecker, client_factory: Arc, connections: spin_resource_table::Table, otel: OtelFactorState, } impl SelfInstanceBuilder for InstanceState {} + +pub struct PgFactorData(OutboundPgFactor); + +impl spin_core::wasmtime::component::HasData for PgFactorData { + type Data<'a> = &'a mut InstanceState; +} + +impl spin_core::wasmtime::component::HasData for InstanceState { + type Data<'a> = &'a mut InstanceState; +} diff --git a/crates/factor-outbound-pg/src/types.rs b/crates/factor-outbound-pg/src/types.rs index c3b586892f..b146cb27ba 100644 --- a/crates/factor-outbound-pg/src/types.rs +++ b/crates/factor-outbound-pg/src/types.rs @@ -1,4 +1,5 @@ -use spin_world::spin::postgres4_0_0::postgres::{DbDataType, DbValue, ParameterValue}; +use anyhow::Result; +use spin_world::spin::postgres4_2_0::postgres::{self as v4, DbDataType, DbValue, ParameterValue}; use tokio_postgres::types::{FromSql, Type}; use tokio_postgres::{types::ToSql, Row}; @@ -162,3 +163,25 @@ pub fn to_sql_parameter(value: &ParameterValue) -> anyhow::Result Ok(Box::new(PgNull)), } } + +// The logic for "vector of ParameterValue to vector of &dyn ToSql" is +// used in multiple places, but needs to be broken into two functions +// because the return value of the first (the Vec) needs to be kept +// around to provide an owner for the refs. +#[allow(clippy::result_large_err)] +pub fn to_sql_parameters( + params: Vec, +) -> Result>, v4::Error> { + params + .iter() + .map(to_sql_parameter) + .collect::>>() + .map_err(|e| v4::Error::BadParameter(format!("{e:?}"))) +} + +pub fn as_sql_parameter_refs(params: &[Box]) -> Vec<&(dyn ToSql + Sync)> { + params + .iter() + .map(|b| b.as_ref() as &(dyn ToSql + Sync)) + .collect() +} diff --git a/crates/factor-outbound-pg/src/types/convert.rs b/crates/factor-outbound-pg/src/types/convert.rs index ce806245af..303cc53b13 100644 --- a/crates/factor-outbound-pg/src/types/convert.rs +++ b/crates/factor-outbound-pg/src/types/convert.rs @@ -2,7 +2,7 @@ //! the tokio_postgres driver. use anyhow::{anyhow, Context}; -use spin_world::spin::postgres4_0_0::postgres::{self as v4}; +use spin_world::spin::postgres4_2_0::postgres::{self as v4}; use super::decimal::RangeableDecimal; diff --git a/crates/factor-outbound-pg/src/types/interval.rs b/crates/factor-outbound-pg/src/types/interval.rs index f494070a47..a87bdbde05 100644 --- a/crates/factor-outbound-pg/src/types/interval.rs +++ b/crates/factor-outbound-pg/src/types/interval.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use spin_world::spin::postgres4_0_0::postgres::{self as v4}; +use spin_world::spin::postgres4_2_0::postgres::{self as v4}; use tokio_postgres::types::{FromSql, ToSql, Type}; #[derive(Debug)] diff --git a/crates/factor-outbound-pg/tests/factor_test.rs b/crates/factor-outbound-pg/tests/factor_test.rs index 364e62a7f4..663d7cc0a4 100644 --- a/crates/factor-outbound-pg/tests/factor_test.rs +++ b/crates/factor-outbound-pg/tests/factor_test.rs @@ -7,10 +7,10 @@ use spin_factor_variables::VariablesFactor; use spin_factors::{anyhow, RuntimeFactors}; use spin_factors_test::{toml, TestEnvironment}; use spin_world::async_trait; -use spin_world::spin::postgres4_0_0::postgres::Error as PgError; -use spin_world::spin::postgres4_0_0::postgres::HostConnection; -use spin_world::spin::postgres4_0_0::postgres::{self as v2}; -use spin_world::spin::postgres4_0_0::postgres::{ParameterValue, RowSet}; +use spin_world::spin::postgres4_2_0::postgres::Error as PgError; +use spin_world::spin::postgres4_2_0::postgres::HostConnection; +use spin_world::spin::postgres4_2_0::postgres::{self as v2}; +use spin_world::spin::postgres4_2_0::postgres::{ParameterValue, RowSet}; #[derive(RuntimeFactors)] struct TestFactors { @@ -107,6 +107,7 @@ async fn exercise_query() -> anyhow::Result<()> { // TODO: We can expand this mock to track calls and simulate return values #[derive(Default)] pub struct MockClientFactory {} +#[derive(Clone)] pub struct MockClient {} #[async_trait] @@ -137,4 +138,18 @@ impl Client for MockClient { rows: vec![], }) } + + async fn query_async( + &self, + _statement: String, + _params: Vec, + ) -> Result< + ( + tokio::sync::oneshot::Receiver>, + tokio::sync::mpsc::Receiver>, + ), + v2::Error, + > { + panic!("not implemented"); + } } diff --git a/crates/wasi-async/Cargo.toml b/crates/wasi-async/Cargo.toml new file mode 100644 index 0000000000..d352b7c57d --- /dev/null +++ b/crates/wasi-async/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "spin-wasi-async" +version.workspace = true +authors.workspace = true +edition.workspace = true + +[dependencies] +anyhow = { workspace = true } +spin-core = { path = "../core" } +tokio = { workspace = true } diff --git a/crates/wasi-async/src/future.rs b/crates/wasi-async/src/future.rs new file mode 100644 index 0000000000..b761d6a8ef --- /dev/null +++ b/crates/wasi-async/src/future.rs @@ -0,0 +1,30 @@ +use spin_core::wasmtime; + +pub fn producer(rx: tokio::sync::oneshot::Receiver) -> FutureProducer { + FutureProducer { rx } +} + +pub struct FutureProducer { + rx: tokio::sync::oneshot::Receiver, +} + +impl wasmtime::component::FutureProducer for FutureProducer { + type Item = T; + + fn poll_produce( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + _store: wasmtime::StoreContextMut, + _finish: bool, + ) -> std::task::Poll>> { + use std::future::Future; + use std::task::Poll; + + let pinned_rx = std::pin::Pin::new(&mut self.get_mut().rx); + match pinned_rx.poll(cx) { + Poll::Ready(Err(e)) => Poll::Ready(Err(anyhow::anyhow!("{e:#}"))), + Poll::Ready(Ok(cols)) => Poll::Ready(Ok(Some(cols))), + Poll::Pending => Poll::Pending, + } + } +} diff --git a/crates/wasi-async/src/lib.rs b/crates/wasi-async/src/lib.rs new file mode 100644 index 0000000000..fab6801fab --- /dev/null +++ b/crates/wasi-async/src/lib.rs @@ -0,0 +1,2 @@ +pub mod future; +pub mod stream; diff --git a/crates/wasi-async/src/stream.rs b/crates/wasi-async/src/stream.rs new file mode 100644 index 0000000000..a7bb86f019 --- /dev/null +++ b/crates/wasi-async/src/stream.rs @@ -0,0 +1,45 @@ +use spin_core::wasmtime; + +pub fn producer(rx: tokio::sync::mpsc::Receiver) -> StreamProducer { + StreamProducer { rx } +} + +pub struct StreamProducer { + rx: tokio::sync::mpsc::Receiver, +} + +impl wasmtime::component::StreamProducer for StreamProducer { + type Item = T; + + type Buffer = Option; + + fn poll_produce<'a>( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + store: wasmtime::StoreContextMut<'a, D>, + mut destination: wasmtime::component::Destination<'a, Self::Item, Self::Buffer>, + finish: bool, + ) -> std::task::Poll> { + use std::task::Poll; + use wasmtime::component::StreamResult; + + if finish { + return Poll::Ready(Ok(StreamResult::Cancelled)); + } + + let remaining = destination.remaining(store); + if remaining.is_some_and(|r| r == 0) { + return Poll::Ready(Ok(StreamResult::Completed)); + } + + let recv = self.get_mut().rx.poll_recv(cx); + match recv { + Poll::Ready(None) => Poll::Ready(Ok(StreamResult::Dropped)), + Poll::Pending => Poll::Pending, + Poll::Ready(Some(row)) => { + destination.set_buffer(Some(row)); + Poll::Ready(Ok(StreamResult::Completed)) + } + } + } +} diff --git a/crates/world/src/conversions.rs b/crates/world/src/conversions.rs index 8177b623b2..53bc787725 100644 --- a/crates/world/src/conversions.rs +++ b/crates/world/src/conversions.rs @@ -3,7 +3,7 @@ use super::*; mod rdbms_types { use super::*; use spin::postgres3_0_0::postgres as pg3; - use spin::postgres4_0_0::postgres as pg4; + use spin::postgres4_2_0::postgres as pg4; impl From for v1::rdbms_types::Column { fn from(value: v2::rdbms_types::Column) -> Self { @@ -15,7 +15,7 @@ mod rdbms_types { } impl From for v1::rdbms_types::Column { - fn from(value: spin::postgres4_0_0::postgres::Column) -> Self { + fn from(value: pg4::Column) -> Self { v1::rdbms_types::Column { name: value.name, data_type: value.data_type.into(), @@ -422,7 +422,7 @@ mod rdbms_types { mod postgres { use super::*; use spin::postgres3_0_0::postgres as pg3; - use spin::postgres4_0_0::postgres as pg4; + use spin::postgres4_2_0::postgres as pg4; impl From for v1::postgres::RowSet { fn from(value: pg4::RowSet) -> v1::postgres::RowSet { diff --git a/crates/world/src/lib.rs b/crates/world/src/lib.rs index ac2bea9531..1898d81e2d 100644 --- a/crates/world/src/lib.rs +++ b/crates/world/src/lib.rs @@ -36,6 +36,7 @@ wasmtime::component::bindgen!({ "fermyon:spin/variables@2.0.0.error" => v2::variables::Error, "spin:postgres/postgres@3.0.0.error" => spin::postgres3_0_0::postgres::Error, "spin:postgres/postgres@4.0.0.error" => spin::postgres4_0_0::postgres::Error, + "spin:postgres/postgres@4.2.0.error" => spin::postgres4_2_0::postgres::Error, "spin:sqlite/sqlite.error" => spin::sqlite::sqlite::Error, "wasi:config/store@0.2.0-draft-2024-09-27.error" => wasi::config::store::Error, "wasi:keyvalue/store.error" => wasi::keyvalue::store::Error, diff --git a/wit/deps/spin-postgres@4.2.0/postgres.wit b/wit/deps/spin-postgres@4.2.0/postgres.wit new file mode 100644 index 0000000000..d962d08f58 --- /dev/null +++ b/wit/deps/spin-postgres@4.2.0/postgres.wit @@ -0,0 +1,171 @@ +package spin:postgres@4.2.0; + +interface postgres { + /// Errors related to interacting with a database. + variant error { + connection-failed(string), + bad-parameter(string), + query-failed(query-error), + value-conversion-failed(string), + other(string) + } + + variant query-error { + /// An error occurred but we do not have structured info for it + text(string), + /// Postgres returned a structured database error + db-error(db-error), + } + + record db-error { + /// Stringised version of the error. This is primarily to facilitate migration of older code. + as-text: string, + severity: string, + code: string, + message: string, + detail: option, + /// Any error information provided by Postgres and not captured above. + extras: list>, + } + + /// Data types for a database column + variant db-data-type { + boolean, + int8, + int16, + int32, + int64, + floating32, + floating64, + str, + binary, + date, + time, + datetime, + timestamp, + uuid, + jsonb, + decimal, + range-int32, + range-int64, + range-decimal, + array-int32, + array-int64, + array-decimal, + array-str, + interval, + other(string), + } + + /// Database values + variant db-value { + boolean(bool), + int8(s8), + int16(s16), + int32(s32), + int64(s64), + floating32(f32), + floating64(f64), + str(string), + binary(list), + date(tuple), // (year, month, day) + time(tuple), // (hour, minute, second, nanosecond) + /// Date-time types are always treated as UTC (without timezone info). + /// The instant is represented as a (year, month, day, hour, minute, second, nanosecond) tuple. + datetime(tuple), + /// Unix timestamp (seconds since epoch) + timestamp(s64), + uuid(string), + jsonb(list), + decimal(string), // I admit defeat. Base 10 + range-int32(tuple>, option>>), + range-int64(tuple>, option>>), + range-decimal(tuple>, option>>), + array-int32(list>), + array-int64(list>), + array-decimal(list>), + array-str(list>), + interval(interval), + db-null, + unsupported(list), + } + + /// Values used in parameterized queries + variant parameter-value { + boolean(bool), + int8(s8), + int16(s16), + int32(s32), + int64(s64), + floating32(f32), + floating64(f64), + str(string), + binary(list), + date(tuple), // (year, month, day) + time(tuple), // (hour, minute, second, nanosecond) + /// Date-time types are always treated as UTC (without timezone info). + /// The instant is represented as a (year, month, day, hour, minute, second, nanosecond) tuple. + datetime(tuple), + /// Unix timestamp (seconds since epoch) + timestamp(s64), + uuid(string), + jsonb(list), + decimal(string), // base 10 + range-int32(tuple>, option>>), + range-int64(tuple>, option>>), + range-decimal(tuple>, option>>), + array-int32(list>), + array-int64(list>), + array-decimal(list>), + array-str(list>), + interval(interval), + db-null, + } + + record interval { + micros: s64, + days: s32, + months: s32, + } + + /// A database column + record column { + name: string, + data-type: db-data-type, + } + + /// A database row + type row = list; + + /// A set of database rows + record row-set { + columns: list, + rows: list, + } + + /// For range types, indicates if each bound is inclusive or exclusive + enum range-bound-kind { + inclusive, + exclusive, + } + + /// A connection to a postgres database. + resource connection { + /// Open a connection to the Postgres instance at `address`. + open: static func(address: string) -> result; + + /// Open a connection to the Postgres instance at `address`. + @since(version = 4.2.0) + open-async: static async func(address: string) -> result; + + /// Query the database. + query: func(statement: string, params: list) -> result; + + /// Query the database. + @since(version = 4.2.0) + query-async: async func(statement: string, params: list) -> result>, stream>>, error>; + + /// Execute command to the database. + execute: func(statement: string, params: list) -> result; + } +} diff --git a/wit/world.wit b/wit/world.wit index ec7de2aef8..0922fac6bc 100644 --- a/wit/world.wit +++ b/wit/world.wit @@ -20,7 +20,7 @@ world platform { include fermyon:spin/platform@2.0.0; include wasi:keyvalue/imports@0.2.0-draft2; import spin:postgres/postgres@3.0.0; - import spin:postgres/postgres@4.0.0; + import spin:postgres/postgres@4.2.0; import spin:sqlite/sqlite@3.0.0; import wasi:config/store@0.2.0-draft-2024-09-27; }