Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions crates/factor-outbound-pg/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ anyhow = { workspace = true }
bytes = {workspace = true }
chrono = { workspace = true }
deadpool-postgres = { version = "0.14", features = ["rt_tokio_1"] }
futures = { workspace = true }
moka = { version = "0.12", features = ["sync"] }
native-tls = "0.2"
postgres-native-tls = "0.5"
Expand All @@ -20,6 +21,7 @@ spin-factor-otel = { path = "../factor-otel" }
spin-factor-outbound-networking = { path = "../factor-outbound-networking" }
spin-factors = { path = "../factors" }
spin-resource-table = { path = "../table" }
spin-wasi-async = { path = "../wasi-async" }
spin-world = { path = "../world" }
tokio = { workspace = true, features = ["rt-multi-thread"] }
tokio-postgres = { version = "0.7", features = ["with-chrono-0_4", "with-serde_json-1", "with-uuid-1"] }
Expand Down
67 changes: 67 additions & 0 deletions crates/factor-outbound-pg/src/allowed_hosts.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
use std::sync::Arc;

use spin_factor_outbound_networking::config::allowed_hosts::OutboundAllowedHosts;
use spin_world::spin::postgres4_2_0::postgres::{self as v4};

/// Encapsulates checking of a PostgreSQL address/connection string against
/// an allow-list.
///
/// This is broken out as a distinct object to allow it to be synchronously retrieved
/// within a P3 Accessor block and then asynchronously queried outside the block.
#[derive(Clone)]
pub(crate) struct AllowedHostChecker {
allowed_hosts: Arc<OutboundAllowedHosts>,
}

impl AllowedHostChecker {
pub fn new(allowed_hosts: OutboundAllowedHosts) -> Self {
Self {
allowed_hosts: Arc::new(allowed_hosts),
}
}
#[allow(clippy::result_large_err)]
pub async fn ensure_address_allowed(&self, address: &str) -> Result<(), v4::Error> {
fn conn_failed(message: impl Into<String>) -> v4::Error {
v4::Error::ConnectionFailed(message.into())
}
fn err_other(err: anyhow::Error) -> v4::Error {
v4::Error::Other(err.to_string())
}

let config = address
.parse::<tokio_postgres::Config>()
.map_err(|e| conn_failed(e.to_string()))?;

for (i, host) in config.get_hosts().iter().enumerate() {
match host {
tokio_postgres::config::Host::Tcp(address) => {
let ports = config.get_ports();
// The port we use is either:
// * The port at the same index as the host
// * The first port if there is only one port
let port =
ports
.get(i)
.or_else(|| if ports.len() == 1 { ports.get(1) } else { None });
let port_str = port.map(|p| format!(":{p}")).unwrap_or_default();
let url = format!("{address}{port_str}");
if !self
.allowed_hosts
.check_url(&url, "postgres")
.await
.map_err(err_other)?
{
return Err(conn_failed(format!(
"address postgres://{url} is not permitted"
)));
}
}
#[cfg(unix)]
tokio_postgres::config::Host::Unix(_) => {
return Err(conn_failed("Unix sockets are not supported on WebAssembly"));
}
}
}
Ok(())
}
}
97 changes: 81 additions & 16 deletions crates/factor-outbound-pg/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
use std::sync::Arc;

use anyhow::{Context, Result};
use native_tls::TlsConnector;
use postgres_native_tls::MakeTlsConnector;
use spin_world::async_trait;
use spin_world::spin::postgres4_0_0::postgres::{
use spin_world::spin::postgres4_2_0::postgres::{
self as v4, Column, DbValue, ParameterValue, RowSet,
};
use tokio_postgres::types::ToSql;
use tokio_postgres::{config::SslMode, NoTls, Row};

use crate::types::{convert_data_type, convert_entry, to_sql_parameter};
use crate::types::{
as_sql_parameter_refs, convert_data_type, convert_entry, to_sql_parameter, to_sql_parameters,
};

/// Max connections in a given address' connection pool
const CONNECTION_POOL_SIZE: usize = 64;
Expand Down Expand Up @@ -40,7 +44,7 @@ impl Default for PooledTokioClientFactory {

#[async_trait]
impl ClientFactory for PooledTokioClientFactory {
type Client = deadpool_postgres::Object;
type Client = Arc<deadpool_postgres::Object>;

async fn get_client(&self, address: &str) -> Result<Self::Client> {
let pool = self
Expand All @@ -49,7 +53,7 @@ impl ClientFactory for PooledTokioClientFactory {
.map_err(ArcError)
.context("establishing PostgreSQL connection pool")?;

Ok(pool.get().await?)
Ok(Arc::new(pool.get().await?))
}
}

Expand Down Expand Up @@ -85,7 +89,7 @@ fn create_connection_pool(address: &str) -> Result<deadpool_postgres::Pool> {
}

#[async_trait]
pub trait Client: Send + Sync + 'static {
pub trait Client: Clone + Send + Sync + 'static {
async fn execute(
&self,
statement: String,
Expand All @@ -97,6 +101,18 @@ pub trait Client: Send + Sync + 'static {
statement: String,
params: Vec<ParameterValue>,
) -> Result<RowSet, v4::Error>;

async fn query_async(
&self,
statement: String,
params: Vec<ParameterValue>,
) -> Result<
(
tokio::sync::oneshot::Receiver<Vec<v4::Column>>,
tokio::sync::mpsc::Receiver<Result<v4::Row, v4::Error>>,
),
v4::Error,
>;
}

/// Extract weak-typed error data for WIT purposes
Expand Down Expand Up @@ -142,7 +158,7 @@ fn query_failed(e: tokio_postgres::error::Error) -> v4::Error {
}

#[async_trait]
impl Client for deadpool_postgres::Object {
impl Client for Arc<deadpool_postgres::Object> {
async fn execute(
&self,
statement: String,
Expand Down Expand Up @@ -170,16 +186,8 @@ impl Client for deadpool_postgres::Object {
statement: String,
params: Vec<ParameterValue>,
) -> Result<RowSet, v4::Error> {
let params = params
.iter()
.map(to_sql_parameter)
.collect::<Result<Vec<_>>>()
.map_err(|e| v4::Error::BadParameter(format!("{e:?}")))?;

let params_refs: Vec<&(dyn ToSql + Sync)> = params
.iter()
.map(|b| b.as_ref() as &(dyn ToSql + Sync))
.collect();
let params = to_sql_parameters(params)?;
let params_refs = as_sql_parameter_refs(&params);

let results = self
.as_ref()
Expand All @@ -203,6 +211,63 @@ impl Client for deadpool_postgres::Object {

Ok(RowSet { columns, rows })
}

async fn query_async(
&self,
statement: String,
params: Vec<ParameterValue>,
) -> Result<
(
tokio::sync::oneshot::Receiver<Vec<v4::Column>>,
tokio::sync::mpsc::Receiver<Result<v4::Row, v4::Error>>,
),
v4::Error,
> {
let params = to_sql_parameters(params)?;
let params_refs = as_sql_parameter_refs(&params);

let stm = self
.as_ref()
.query_raw(&statement, params_refs)
.await
.map_err(query_failed)?;

let (rows_tx, rows_rx) = tokio::sync::mpsc::channel(1000);
let (cols_tx, cols_rx) = tokio::sync::oneshot::channel();
let mut cols_tx_opt = Some(cols_tx);

let mut stm = Box::pin(stm);

tokio::spawn(async move {
use futures::StreamExt;
loop {
let Some(row) = stm.next().await else {
break;
};
// TODO: figure out how to deal with errors here - I think there is like a FutureReader<Error> pattern?
let row = match row {
Ok(r) => r,
Err(e) => {
let err = query_failed(e);
rows_tx.send(Err(err)).await.unwrap();
break;
}
};
if let Some(cols_tx) = cols_tx_opt.take() {
cols_tx.send(infer_columns(&row)).unwrap();
}
match convert_row(&row) {
Ok(row) => rows_tx.send(Ok(row)).await.unwrap(),
Err(e) => {
let err = v4::Error::QueryFailed(v4::QueryError::Text(format!("{e:?}")));
rows_tx.send(Err(err)).await.unwrap();
}
}
}
});

Ok((cols_rx, rows_rx))
}
}

fn infer_columns(row: &Row) -> Vec<Column> {
Expand Down
Loading