From b22699013c4a798d302c900e7163dd64258cef59 Mon Sep 17 00:00:00 2001 From: Petr Portnov Date: Sun, 16 Mar 2025 22:11:34 +0300 Subject: [PATCH] wip: migrate to thiserror --- Cargo.lock | 21 +++ Cargo.toml | 2 +- Makefile | 2 +- src/admin.rs | 92 +++++------ src/auth.rs | 20 +++ src/client.rs | 385 +++++++++++++++++++------------------------- src/cmd_args.rs | 13 +- src/daemon/lib.rs | 2 +- src/errors.rs | 298 +++++++++++++++++++++++++++------- src/jwt_auth.rs | 75 ++++----- src/lib.rs | 6 +- src/main.rs | 2 - src/messages.rs | 347 +++++++++++++++++---------------------- src/scram_client.rs | 79 ++++----- src/scram_server.rs | 2 +- src/server.rs | 292 +++++++++++++++------------------ src/stats/socket.rs | 42 ++--- 17 files changed, 849 insertions(+), 831 deletions(-) create mode 100644 src/auth.rs diff --git a/Cargo.lock b/Cargo.lock index 64f66e94..a1417c0a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -977,6 +977,7 @@ dependencies = [ "socket2", "stringprep", "syslog", + "thiserror", "tikv-jemallocator", "tokio", "tokio-native-tls", @@ -1476,6 +1477,26 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "thiserror" +version = "2.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "thread_local" version = "1.1.8" diff --git a/Cargo.toml b/Cargo.toml index 7e00d19f..a7617bd2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -50,7 +50,7 @@ serde-toml-merge = { version = "0.3.8"} jwt = { version = "0.16.0", features = ["openssl"] } openssl = { version = "0.10.71"} iota = { version = "0.2.3" } - +thiserror = "2.0" [replace] 'deadpool:0.10.0' = { path = 'patches/deadpool' } diff --git a/Makefile b/Makefile index 509e6ccd..fd21eff6 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,5 @@ .DEFAULT_GOAL := build +.PHONY: build install test build: cargo build --release @@ -9,4 +10,3 @@ install: build test: cargo test - ./tests/tests.sh \ No newline at end of file diff --git a/src/admin.rs b/src/admin.rs index e717f253..9a6ed209 100644 --- a/src/admin.rs +++ b/src/admin.rs @@ -5,18 +5,18 @@ use log::{debug, error, info}; use nix::sys::signal::{self, Signal}; use nix::unistd::Pid; use std::collections::HashMap; +use std::marker::Unpin; /// Admin database. use std::sync::atomic::Ordering; +use tokio::io::AsyncWrite; use tokio::time::Instant; use crate::config::{get_config, reload_config, VERSION}; -use crate::errors::Error; +use crate::errors::{Error, ProtocolSyncError, ServerError}; use crate::messages::*; use crate::pool::get_all_pools; use crate::pool::ClientServerMap; use crate::stats::client::{CLIENT_STATE_ACTIVE, CLIENT_STATE_IDLE}; -#[cfg(target_os = "linux")] -use crate::stats::get_socket_states_count; use crate::stats::server::{SERVER_STATE_ACTIVE, SERVER_STATE_IDLE}; use crate::stats::{ get_client_stats, get_server_stats, CANCEL_CONNECTION_COUNTER, PLAIN_CONNECTION_COUNTER, @@ -42,15 +42,15 @@ pub async fn handle_admin( client_server_map: ClientServerMap, ) -> Result<(), Error> where - T: tokio::io::AsyncWrite + std::marker::Unpin, + T: AsyncWrite + Unpin, { - let code = query.get_u8() as char; - - if code != 'Q' { - return Err(Error::ProtocolSyncError(format!( - "Invalid code, expected 'Q' but got '{}'", - code - ))); + let code = query.get_u8(); + if code != b'Q' { + return Err(ProtocolSyncError::InvalidCode { + expected: b'Q', + actual: code, + } + .into()); } let len = query.get_i32() as usize; @@ -110,7 +110,7 @@ where /// Column-oriented statistics. async fn show_lists(stream: &mut T) -> Result<(), Error> where - T: tokio::io::AsyncWrite + std::marker::Unpin, + T: AsyncWrite + Unpin, { let client_stats = get_client_stats(); let server_stats = get_server_stats(); @@ -206,13 +206,13 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, &res).await + Ok(write_all_half(stream, &res).await?) } /// Show PgDoorman version. async fn show_version(stream: &mut T) -> Result<(), Error> where - T: tokio::io::AsyncWrite + std::marker::Unpin, + T: AsyncWrite + Unpin, { let mut res = BytesMut::new(); @@ -224,13 +224,13 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, &res).await + Ok(write_all_half(stream, &res).await?) } /// Show utilization of connection pools for each pool. async fn show_pools(stream: &mut T) -> Result<(), Error> where - T: tokio::io::AsyncWrite + std::marker::Unpin, + T: AsyncWrite + Unpin, { let pool_lookup = PoolStats::construct_pool_lookup(); let mut res = BytesMut::new(); @@ -245,13 +245,13 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, &res).await + Ok(write_all_half(stream, &res).await?) } /// Show extended utilization of connection pools for each pool. async fn show_pools_extended(stream: &mut T) -> Result<(), Error> where - T: tokio::io::AsyncWrite + std::marker::Unpin, + T: AsyncWrite + Unpin, { let pool_lookup = PoolStats::construct_pool_lookup(); let mut res = BytesMut::new(); @@ -268,13 +268,13 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, &res).await + Ok(write_all_half(stream, &res).await?) } /// Show all available options. async fn show_help(stream: &mut T) -> Result<(), Error> where - T: tokio::io::AsyncWrite + std::marker::Unpin, + T: AsyncWrite + Unpin, { let mut res = BytesMut::new(); @@ -307,13 +307,13 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, &res).await + Ok(write_all_half(stream, &res).await?) } /// Show databases. async fn show_databases(stream: &mut T) -> Result<(), Error> where - T: tokio::io::AsyncWrite + std::marker::Unpin, + T: AsyncWrite + Unpin, { // Columns let columns = vec![ @@ -361,14 +361,14 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, &res).await + Ok(write_all_half(stream, &res).await?) } /// Ignore any SET commands the client sends. /// This is common initialization done by ORMs. async fn ignore_set(stream: &mut T) -> Result<(), Error> where - T: tokio::io::AsyncWrite + std::marker::Unpin, + T: AsyncWrite + Unpin, { custom_protocol_response_ok(stream, "SET").await } @@ -376,7 +376,7 @@ where /// Reload the configuration file without restarting the process. async fn reload(stream: &mut T, client_server_map: ClientServerMap) -> Result<(), Error> where - T: tokio::io::AsyncWrite + std::marker::Unpin, + T: AsyncWrite + Unpin, { info!("Reloading config"); @@ -393,13 +393,13 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, &res).await + Ok(write_all_half(stream, &res).await?) } /// Shows current configuration. async fn show_config(stream: &mut T) -> Result<(), Error> where - T: tokio::io::AsyncWrite + std::marker::Unpin, + T: AsyncWrite + Unpin, { let config = &get_config(); let config: HashMap = config.into(); @@ -439,13 +439,13 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, &res).await + Ok(write_all_half(stream, &res).await?) } /// Show stats. async fn show_stats(stream: &mut T) -> Result<(), Error> where - T: tokio::io::AsyncWrite + std::marker::Unpin, + T: AsyncWrite + Unpin, { let pool_lookup = PoolStats::construct_pool_lookup(); let mut res = BytesMut::new(); @@ -461,13 +461,13 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, &res).await + Ok(write_all_half(stream, &res).await?) } /// Show currently connected clients async fn show_clients(stream: &mut T) -> Result<(), Error> where - T: tokio::io::AsyncWrite + std::marker::Unpin, + T: AsyncWrite + Unpin, { let columns = vec![ ("client_id", DataType::Text), @@ -517,12 +517,12 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, &res).await + Ok(write_all_half(stream, &res).await?) } async fn show_connections(stream: &mut T) -> Result<(), Error> where - T: tokio::io::AsyncWrite + std::marker::Unpin, + T: AsyncWrite + Unpin, { let columns = vec![ ("total", DataType::Numeric), @@ -556,12 +556,13 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, &res).await + Ok(write_all_half(stream, &res).await?) } + /// Show currently connected servers async fn show_servers(stream: &mut T) -> Result<(), Error> where - T: tokio::io::AsyncWrite + std::marker::Unpin, + T: AsyncWrite + Unpin, { let columns = vec![ ("server_id", DataType::Text), @@ -627,13 +628,13 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, &res).await + Ok(write_all_half(stream, &res).await?) } /// Send response packets for shutdown. async fn shutdown(stream: &mut T) -> Result<(), Error> where - T: tokio::io::AsyncWrite + std::marker::Unpin, + T: AsyncWrite + Unpin, { let mut res = BytesMut::new(); @@ -655,13 +656,13 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, &res).await + Ok(write_all_half(stream, &res).await?) } /// Show Users. async fn show_users(stream: &mut T) -> Result<(), Error> where - T: tokio::io::AsyncWrite + std::marker::Unpin, + T: AsyncWrite + Unpin, { let mut res = BytesMut::new(); @@ -684,20 +685,19 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, &res).await + Ok(write_all_half(stream, &res).await?) } #[cfg(target_os = "linux")] async fn show_sockets(stream: &mut T) -> Result<(), Error> where - T: tokio::io::AsyncWrite + std::marker::Unpin, + T: AsyncWrite + Unpin, { + use crate::stats::get_socket_states_count; + let mut res = BytesMut::new(); - let sockets_info = match get_socket_states_count(std::process::id()) { - Ok(info) => info, - Err(_) => return Err(Error::ServerError), - }; + let sockets_info = get_socket_states_count(std::process::id()).map_err(ServerError::from)?; res.put(row_description(&vec![ // tcp @@ -747,5 +747,5 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, &res).await + Ok(write_all_half(stream, &res).await?) } diff --git a/src/auth.rs b/src/auth.rs new file mode 100644 index 00000000..be960c20 --- /dev/null +++ b/src/auth.rs @@ -0,0 +1,20 @@ +use std::fmt::{self, Display}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum AuthMethod { + Sasl, + ClearPassword, + Jwt, + Md5, +} + +impl Display for AuthMethod { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self { + Self::Sasl => "SASL", + Self::ClearPassword => "clear password", + Self::Jwt => "JWT", + Self::Md5 => "MD5-encrypted password", + }) + } +} diff --git a/src/client.rs b/src/client.rs index 21062570..eeb7225b 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,16 +1,20 @@ -use crate::errors::{ClientIdentifier, Error}; +use crate::errors::{ + ClientBadStartupError, ClientError, ClientGeneralError, ClientIdentifier, Error, + HbaForbiddenError, ProtocolSyncError, TlsError, +}; /// Handle clients by pretending to be a PostgreSQL server. use bytes::{Buf, BufMut, BytesMut}; use log::{debug, error, info, warn}; use once_cell::sync::Lazy; use std::collections::{HashMap, VecDeque}; use std::ffi::CStr; +use std::marker::Unpin; use std::ops::DerefMut; use std::str; use std::sync::atomic::Ordering; use std::sync::{atomic::AtomicUsize, Arc}; use std::time::Instant; -use tokio::io::{split, AsyncReadExt, BufReader, ReadHalf, WriteHalf}; +use tokio::io::{split, AsyncRead, AsyncReadExt, AsyncWrite, BufReader, ReadHalf, WriteHalf}; use tokio::net::TcpStream; use tokio::sync::broadcast::Receiver; use tokio::sync::mpsc::Sender; @@ -310,9 +314,9 @@ pub async fn client_entrypoint( // Client probably disconnected rejecting our plain text connection. Ok((ClientConnectionType::Tls, _)) - | Ok((ClientConnectionType::CancelQuery, _)) => Err(Error::ProtocolSyncError( - "Bad postgres client (plain)".into(), - )), + | Ok((ClientConnectionType::CancelQuery, _)) => { + Err(ProtocolSyncError::BadClient { tls: false }.into()) + } Err(err) => Err(err), } @@ -399,25 +403,21 @@ pub async fn client_entrypoint( /// Handle the first message the client sends. async fn get_startup(stream: &mut S) -> Result<(ClientConnectionType, BytesMut), Error> where - S: tokio::io::AsyncRead + std::marker::Unpin + tokio::io::AsyncWrite, + S: AsyncRead + AsyncWrite + Unpin, { // Get startup message length. - let len = match stream.read_i32().await { - Ok(len) => len, - Err(_) => return Err(Error::ClientBadStartup), - }; + let len = stream.read_i32().await.map_err(ClientBadStartupError::Io)?; // Get the rest of the message. let mut startup = vec![0u8; len as usize - 4]; - match stream.read_exact(&mut startup).await { - Ok(_) => (), - Err(_) => return Err(Error::ClientBadStartup), - }; + stream + .read_exact(&mut startup) + .await + .map_err(ClientBadStartupError::Io)?; let mut bytes = BytesMut::from(&startup[..]); - let code = bytes.get_i32(); - match code { + match bytes.get_i32() { // Client is requesting SSL (TLS). SSL_REQUEST_CODE => Ok((ClientConnectionType::Tls, bytes)), @@ -437,10 +437,7 @@ where // Something else, probably something is wrong, and it's not our fault, // e.g. badly implemented Postgres client. - _ => Err(Error::ProtocolSyncError(format!( - "Unexpected startup code: {}", - code - ))), + code => Err(ProtocolSyncError::UnexpectedStartupCode(code).into()), } } @@ -469,56 +466,48 @@ pub async fn startup_tls( } }; - let mut stream = match tls_acceptor.accept(stream).await { - Ok(stream) => stream, - - // TLS negotiation failed. - Err(err) => { - error!("TLS negotiation failed: {:?}", err); - return Err(Error::TlsError); - } - }; + let mut stream = tls_acceptor.accept(stream).await.map_err(TlsError::from)?; // TLS negotiation successful. // Continue with regular startup using encrypted connection. - match get_startup::>(&mut stream).await { - // Got good startup message, proceeding like normal except we - // are encrypted now. - Ok((ClientConnectionType::Startup, bytes)) => { - let (read, write) = split(stream); - - Client::startup( - read, - write, - addr, - bytes, - client_server_map, - shutdown, - admin_only, - true, - ) - .await - } - - Ok((ClientConnectionType::CancelQuery, bytes)) => { - CANCEL_CONNECTION_COUNTER.fetch_add(1, Ordering::Relaxed); - let (read, write) = split(stream); - Client::cancel(read, write, addr, bytes, client_server_map, shutdown).await - } + Ok( + match get_startup::>(&mut stream).await? { + // Got good startup message, proceeding like normal except we + // are encrypted now. + (ClientConnectionType::Startup, bytes) => { + let (read, write) = split(stream); + + Client::startup( + read, + write, + addr, + bytes, + client_server_map, + shutdown, + admin_only, + true, + ) + .await? + } - // Bad Postgres client. - Ok((ClientConnectionType::Tls, _)) => { - Err(Error::ProtocolSyncError("Bad postgres client (tls)".into())) - } + (ClientConnectionType::CancelQuery, bytes) => { + CANCEL_CONNECTION_COUNTER.fetch_add(1, Ordering::Relaxed); + let (read, write) = split(stream); + Client::cancel(read, write, addr, bytes, client_server_map, shutdown).await? + } - Err(err) => Err(err), - } + // Bad Postgres client. + (ClientConnectionType::Tls, _) => { + return Err(ProtocolSyncError::BadClient { tls: true }.into()) + } + }, + ) } impl Client where - S: tokio::io::AsyncRead + std::marker::Unpin, - T: tokio::io::AsyncWrite + std::marker::Unpin, + S: AsyncRead + Unpin, + T: AsyncWrite + Unpin, { pub fn is_admin(&self) -> bool { self.admin @@ -540,14 +529,7 @@ where let parameters = parse_startup(bytes.clone())?; // This parameter is mandatory by the protocol. - let username = match parameters.get("user") { - Some(user) => user, - None => { - return Err(Error::ClientError( - "Missing user parameter on client startup".into(), - )) - } - }; + let username = parameters.get("user").ok_or(ClientError::NoUserParam)?; let pool_name = parameters.get("database").unwrap_or(username); @@ -583,11 +565,11 @@ where if !addr_in_hba(addr.ip()) { error_response_terminal(&mut write, "hba forbidden for this ip address", "28000") .await?; - return Err(Error::HbaForbiddenError(format!( - "hba forbidden client: {} from address: {:?}", - client_identifier, - addr.ip() - ))); + return Err(HbaForbiddenError { + client: client_identifier, + address: addr, + } + .into()); } // Generate random backend ID and secret key @@ -611,12 +593,14 @@ where ); if password_hash != password_response { - let error = Error::ClientGeneralError("Invalid password".into(), client_identifier); + let error = ClientGeneralError::InvalidPassword { + id: client_identifier, + }; - warn!("{}", error); + warn!("{error}"); wrong_password(&mut write, username).await?; - return Err(error); + return Err(error.into()); } (false, generate_server_parameters_for_admin()) @@ -637,10 +621,11 @@ where ) .await?; - return Err(Error::ClientGeneralError( - "Invalid pool name".into(), - client_identifier, - )); + return Err(ClientGeneralError::InvalidPoolName { + id: client_identifier, + pool_name: pool_name.clone(), + } + .into()); } }; let pool_password = pool.settings.user.password.clone(); @@ -910,7 +895,7 @@ where Ok(message) => message, Err(err) => return self.process_error(err).await, }; - if message[0] as char == 'X' { + if message[0] == b'X' { self.stats.disconnect(); return Ok(()); } @@ -1027,7 +1012,7 @@ where current_pool.address.stats.error(); self.stats.checkout_error(); - if message[0] as char == 'S' { + if message[0] == b'S' { self.reset_buffered_state(); } @@ -1076,7 +1061,7 @@ where if current_pool.settings.sync_server_parameters { server.sync_parameters(&self.server_parameters).await?; } - server.set_flush_wait_code(' '); + server.set_flush_wait_code(b' '); let mut initial_message = Some(message); @@ -1112,11 +1097,11 @@ where // Safe to unwrap because we know this message has a certain length and has the code // This reads the first byte without advancing the internal pointer and mutating the bytes - let code = *message.first().unwrap() as char; + let code = *message.first().unwrap(); match code { // Query - 'Q' => { + b'Q' => { self.send_and_receive_loop(Some(&message), server).await?; self.stats.query(); server.stats.query( @@ -1141,7 +1126,7 @@ where } // Terminate - 'X' => { + b'X' => { // принудительно закрываем чтобы не допустить длинную транзакцию server.checkin_cleanup().await?; self.stats.disconnect(); @@ -1151,31 +1136,31 @@ where // Parse // The query with placeholders is here, e.g. `SELECT * FROM users WHERE email = $1 AND active = $2`. - 'P' => { + b'P' => { self.buffer_parse(message, current_pool)?; } // Bind - 'B' => { + b'B' => { self.buffer_bind(message).await?; } // Describe // Command a client can issue to describe a previously prepared named statement. - 'D' => { + b'D' => { self.buffer_describe(message).await?; } // Execute // Execute a prepared statement prepared in `P` and bound in `B`. - 'E' => { + b'E' => { self.extended_protocol_data_buffer .push_back(ExtendedProtocolData::create_new_execute(message)); } // Close // Close the prepared statement. - 'C' => { + b'C' => { let close: Close = (&message).try_into()?; self.extended_protocol_data_buffer @@ -1184,7 +1169,7 @@ where // Sync // Frontend (client) is asking for the query result now. - 'S' | 'H' => { + b'S' | b'H' => { // Prepared statements can arrive like this // 1. Without named describe // Client: Parse, with name, query and params @@ -1208,13 +1193,13 @@ where // RowDescription // ReadyForQuery // Iterate over our extended protocol data that we've buffered - let mut async_wait_code = ' '; + let mut async_wait_code = b' '; while let Some(protocol_data) = self.extended_protocol_data_buffer.pop_front() { match protocol_data { ExtendedProtocolData::Parse { data, metadata } => { - async_wait_code = '1'; + async_wait_code = b'1'; debug!("Have parse in extended buffer"); let (parse, hash) = match metadata { Some(metadata) => metadata, @@ -1259,7 +1244,7 @@ where } } ExtendedProtocolData::Bind { data, metadata } => { - async_wait_code = '2'; + async_wait_code = b'2'; // This is using a prepared statement if let Some(client_given_name) = metadata { self.ensure_prepared_statement_is_on_server( @@ -1273,7 +1258,7 @@ where self.buffer.put(&data[..]); } ExtendedProtocolData::Describe { data, metadata } => { - async_wait_code = 'T'; + async_wait_code = b'T'; // This is using a prepared statement if let Some(client_given_name) = metadata { self.ensure_prepared_statement_is_on_server( @@ -1287,7 +1272,7 @@ where self.buffer.put(&data[..]); } ExtendedProtocolData::Execute { data } => { - async_wait_code = 'C'; + async_wait_code = b'C'; self.buffer.put(&data[..]) } ExtendedProtocolData::Close { data, close } => { @@ -1311,11 +1296,11 @@ where // Add the sync message self.buffer.put(&message[..]); - if code == 'H' { + if code == b'H' { server.set_flush_wait_code(async_wait_code); debug!("Client requested flush, going async"); } else { - server.set_flush_wait_code(' ') + server.set_flush_wait_code(b' ') } self.send_and_receive_loop(None, server).await?; @@ -1340,7 +1325,7 @@ where self.client_last_messages_in_tx .put(&self.response_message_queue_buffer[..]); self.client_last_messages_in_tx = set_messages_right_place( - self.client_last_messages_in_tx.to_vec(), + &self.client_last_messages_in_tx, )?; self.response_message_queue_buffer.clear(); } @@ -1364,7 +1349,7 @@ where .as_str(), false, ); - return Err(err); + return Err(err.into()); } self.response_message_queue_buffer.clear(); @@ -1372,7 +1357,7 @@ where } // CopyData - 'd' => { + b'd' => { self.buffer.put(&message[..]); // Want to limit buffer size @@ -1385,7 +1370,7 @@ where // CopyDone or CopyFail // Copy is done, successfully or not. - 'c' | 'f' => { + b'c' | b'f' => { // We may already have some copy data in the buffer, add this message to buffer self.buffer.put(&message[..]); @@ -1411,7 +1396,7 @@ where .as_str(), false, ); - return Err(err); + return Err(err.into()); } }; @@ -1479,36 +1464,27 @@ where pool: &ConnectionPool, server: &mut Server, ) -> Result<(), Error> { - match self.prepared_statements.get(&client_name) { - Some((parse, hash)) => { - debug!("Prepared statement `{}` found in cache", client_name); - // In this case we want to send the parse message to the server - // since pgcat is initiating the prepared statement on this specific server - match self - .register_parse_to_server_cache(true, hash, parse, pool, server) - .await - { - Ok(_) => (), - Err(err) => match err { - Error::PreparedStatementError => { - debug!("Removed {} from client cache", client_name); - self.prepared_statements.remove(&client_name); - } + let Some((parse, hash)) = self.prepared_statements.get(&client_name) else { + return Err(ClientError::PreparedStatementNotFound(client_name).into()); + }; - _ => { - return Err(err); - } - }, + debug!("Prepared statement {client_name:?} found in cache"); + // In this case we want to send the parse message to the server + // since pgcat is initiating the prepared statement on this specific server + if let Err(e) = self + .register_parse_to_server_cache(true, hash, parse, pool, server) + .await + { + match e { + Error::NoPreparedStatement => { + debug!("Removed {client_name:?} from client cache"); + self.prepared_statements.remove(&client_name); + } + e => { + return Err(e); } } - - None => { - return Err(Error::ClientError(format!( - "prepared statement `{}` not found", - client_name - ))) - } - }; + } Ok(()) } @@ -1554,14 +1530,8 @@ where let hash = parse.get_hash(); // Add the statement to the cache or check if we already have it - let new_parse = match pool.register_parse_to_cache(hash, &parse) { - Some(parse) => parse, - None => { - return Err(Error::ClientError(format!( - "Could not store Prepared statement `{}`", - client_given_name - ))) - } + let Some(new_parse) = pool.register_parse_to_cache(hash, &parse) else { + return Err(ClientError::PreparesStatementStore(client_given_name).into()); }; debug!( @@ -1575,7 +1545,7 @@ where self.extended_protocol_data_buffer .push_back(ExtendedProtocolData::create_new_parse( new_parse.as_ref().try_into()?, - Some((new_parse.clone(), hash)), + Some((new_parse, hash)), )); Ok(()) @@ -1584,7 +1554,7 @@ where /// Rewrite the Bind (F) message to use the prepared statement name /// saved in the client cache. async fn buffer_bind(&mut self, message: BytesMut) -> Result<(), Error> { - // Avoid parsing if prepared statements not enabled + // Avoid parsing if prepared statements are not enabled if !self.prepared_statements_enabled { debug!("Anonymous bind message"); self.extended_protocol_data_buffer @@ -1594,43 +1564,32 @@ where let client_given_name = Bind::get_name(&message)?; - match self.prepared_statements.get(&client_given_name) { - Some((rewritten_parse, _)) => { - let message = Bind::rename(message, &rewritten_parse.name)?; + let Some((rewritten_parse, _)) = self.prepared_statements.get(&client_given_name) else { + debug!("Got bind for unknown prepared statement {client_given_name:?}"); - debug!( - "Rewrote bind `{}` to `{}`", - client_given_name, rewritten_parse.name - ); + error_response( + &mut self.write, + &format!("prepared statement {client_given_name:?} does not exist"), + "58000", + ) + .await?; - self.extended_protocol_data_buffer.push_back( - ExtendedProtocolData::create_new_bind(message, Some(client_given_name)), - ); + return Err(ClientError::PreparedStatementNotFound(client_given_name).into()); + }; + let message = Bind::rename(message, &rewritten_parse.name)?; - Ok(()) - } - None => { - debug!( - "Got bind for unknown prepared statement {:?}", - client_given_name - ); + debug!( + "Rewrote bind {client_given_name:?} to {:?}", + rewritten_parse.name + ); - error_response( - &mut self.write, - &format!( - "prepared statement \"{}\" does not exist", - client_given_name - ), - "58000", - ) - .await?; + self.extended_protocol_data_buffer + .push_back(ExtendedProtocolData::create_new_bind( + message, + Some(client_given_name), + )); - Err(Error::ClientError(format!( - "Prepared statement `{}` doesn't exist", - client_given_name - ))) - } - } + Ok(()) } /// Rewrite the Describe (F) message to use the prepared statement name @@ -1655,45 +1614,33 @@ where } let client_given_name = describe.statement_name.clone(); + let Some((rewritten_parse, _)) = self.prepared_statements.get(&client_given_name) else { + debug!("Got describe for unknown prepared statement {describe:?}"); - match self.prepared_statements.get(&client_given_name) { - Some((rewritten_parse, _)) => { - let describe = describe.rename(&rewritten_parse.name); - - debug!( - "Rewrote describe `{}` to `{}`", - client_given_name, describe.statement_name - ); + error_response( + &mut self.write, + &format!("prepared statement {client_given_name:?} does not exist"), + "58000", + ) + .await?; - self.extended_protocol_data_buffer.push_back( - ExtendedProtocolData::create_new_describe( - describe.try_into()?, - Some(client_given_name), - ), - ); + return Err(ClientError::PreparedStatementNotFound(client_given_name).into()); + }; - Ok(()) - } + let describe = describe.rename(&rewritten_parse.name); - None => { - debug!("Got describe for unknown prepared statement {:?}", describe); + debug!( + "Rewrote describe {client_given_name:?} to {:?}", + describe.statement_name + ); - error_response( - &mut self.write, - &format!( - "prepared statement \"{}\" does not exist", - client_given_name - ), - "58000", - ) - .await?; + self.extended_protocol_data_buffer + .push_back(ExtendedProtocolData::create_new_describe( + describe.try_into()?, + Some(client_given_name), + )); - Err(Error::ClientError(format!( - "Prepared statement `{}` doesn't exist", - client_given_name - ))) - } - } + Ok(()) } fn reset_buffered_state(&mut self) { @@ -1724,13 +1671,13 @@ where ) .await?; - Err(Error::ClientError(format!( - "Invalid pool name {{ username: {}, pool_name: {}, application_name: {}, virtual pool id: {} }}", - self.pool_name, - self.username, - self.server_parameters.get_application_name(), - virtual_pool_id - ))) + Err(ClientError::InvalidPoolName { + username: self.username.clone(), + pool_name: self.pool_name.clone(), + application_name: self.server_parameters.get_application_name().clone(), + virtual_pool_id, + } + .into()) } } } @@ -1779,20 +1726,20 @@ where if !self.response_message_queue_buffer.is_empty() { response.put(&self.response_message_queue_buffer[..]); - response = set_messages_right_place(response.to_vec())?; + response = set_messages_right_place(&response)?; self.response_message_queue_buffer.clear(); } self.stats.active_write(); match write_all_flush(&mut self.write, &response).await { Ok(_) => self.stats.active_idle(), - Err(err_write) => { + Err(err) => { server.wait_available().await; server.mark_bad( - format!("flush to client {} {:?}", self.addr, err_write).as_str(), + format!("flush to client {} {:?}", self.addr, err).as_str(), true, ); - return Err(err_write); + return Err(err.into()); } }; @@ -1814,7 +1761,7 @@ where .await?; Err(err) } - Error::CurrentMemoryUsage => { + Error::MemoryLimitReached => { error_response( &mut self.write, format!("could not read message, temporary out of memory - {}", err).as_str(), diff --git a/src/cmd_args.rs b/src/cmd_args.rs index c74fd9a5..408777c8 100644 --- a/src/cmd_args.rs +++ b/src/cmd_args.rs @@ -11,7 +11,7 @@ pub struct Args { #[arg(short, long, default_value_t = tracing::Level::INFO, env)] pub log_level: Level, - #[clap(short='F', long, value_enum, default_value_t=LogFormat::Text, env)] + #[clap(short = 'F', long, value_enum, default_value_t = LogFormat::Text, env)] pub log_format: LogFormat, #[arg( @@ -37,3 +37,14 @@ pub enum LogFormat { Structured, Debug, } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn cmd_parses() { + use clap::CommandFactory; + Args::command().debug_assert(); + } +} diff --git a/src/daemon/lib.rs b/src/daemon/lib.rs index d3ad1254..31527f30 100644 --- a/src/daemon/lib.rs +++ b/src/daemon/lib.rs @@ -332,8 +332,8 @@ impl Daemonize { } fn execute_child(self) -> Result { + set_current_dir(&self.directory).map_err(|_| ErrorKind::ChangeDirectory(errno()))?; unsafe { - set_current_dir(&self.directory).map_err(|_| ErrorKind::ChangeDirectory(errno()))?; set_sid()?; libc::umask(self.umask.inner); diff --git a/src/errors.rs b/src/errors.rs index 79231262..887298f5 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -1,43 +1,245 @@ //! Errors. +use std::{ffi::NulError, io, net::SocketAddr}; + +use md5::digest::{InvalidLength as InvalidMd5Length, MacError}; +use openssl::error::ErrorStack; + +use crate::{auth::AuthMethod, stats::socket::SocketInfoError}; + /// Various errors. -#[derive(Debug, PartialEq, Clone)] +#[derive(Debug, thiserror::Error)] pub enum Error { + #[error("socket error ocurred: {0}")] SocketError(String), + #[error(transparent)] + Socket(#[from] SocketError), + #[error("error reading {0} from {1}")] ClientSocketError(String, ClientIdentifier), - ClientGeneralError(String, ClientIdentifier), - ClientBadStartup, - ProtocolSyncError(String), - BadQuery(String), - ServerError, - ServerMessageParserError(String), + #[error(transparent)] + ClientGeneral(#[from] ClientGeneralError), + #[error(transparent)] + ClientBadStartup(#[from] ClientBadStartupError), + #[error(transparent)] + ProtocolSync(#[from] ProtocolSyncError), + #[error(transparent)] + Server(#[from] ServerError), + #[error(transparent)] + ServerMessageParse(#[from] ServerMessageParseError), + #[error("Error reading {0} on server startup {1}")] ServerStartupError(String, ServerIdentifier), - ServerAuthError(String, ServerIdentifier), + #[error(transparent)] + ServerAuth(#[from] ServerAuthError), + #[error("TODO")] BadConfig(String), - AllServersDown, - QueryWaitTimeout, - ClientError(String), - TlsError, + #[error(transparent)] + Client(#[from] ClientError), + #[error(transparent)] + Tls(#[from] TlsError), + #[error("TODO")] StatementTimeout, - DNSCachedError(String), + #[error("shutting down")] ShuttingDown, - ParseBytesError(String), + #[error(transparent)] + ParseBytes(#[from] ParseBytesError), + #[error("TODO")] AuthError(String), - UnsupportedStatement, - QueryError(String), + #[error(transparent)] + QueryError(#[from] NulError), + #[error("TODO")] ScramClientError(String), + #[error("TODO")] ScramServerError(String), - HbaForbiddenError(String), - PreparedStatementError, + // the error is boxed since it is huge + #[error(transparent)] + HbaForbidden(#[from] Box), + #[error("prepated statement not found")] + NoPreparedStatement, + #[error("max message size")] MaxMessageSize, - CurrentMemoryUsage, - JWTPubKey(String), - JWTPrivKey(String), - JWTValidate(String), + #[error("memory limit reached")] + MemoryLimitReached, + #[error(transparent)] + JwtPubKey(#[from] JwtPubKeyError), + #[error(transparent)] + JwtValidate(#[from] JwtValidateError), + #[error("proxy timeout")] ProxyTimeout, } -#[derive(Clone, PartialEq, Debug)] +#[derive(Debug, thiserror::Error)] +pub enum SocketError { + #[error("failed to flush socket")] + Flush(#[source] io::Error), + #[error("failed to write to socket")] + Write(#[source] io::Error), +} + +#[derive(Debug, thiserror::Error)] +pub enum ProtocolSyncError { + #[error("unexpected startup code {0}")] + UnexpectedStartupCode(i32), + #[error("SCRAM")] + Scram, + #[error("bad Postges client ({})", if *tls { "TLS" } else { "plain" })] + BadClient { tls: bool }, + #[error("invalid code, expected {expected} but got {actual}")] + InvalidCode { expected: u8, actual: u8 }, + #[error("unprocessed message code {0} from server backend while startup")] + UnprocessedCode(u8), + #[error("server {server} unknown transaction state {transaction_state}")] + UnknownTransactionState { + // TODO: something smarter + server: String, + transaction_state: u8, + }, +} + +#[derive(Debug, thiserror::Error)] +pub enum ClientBadStartupError { + #[error(transparent)] + Io(#[from] io::Error), + #[error("no parameters were specified")] + NoParams, + #[error("numbers of parameter keys and values don't match")] + UnevenParams, + #[error("user parameter is not specified")] + UserUnspecified, +} + +#[derive(Debug, thiserror::Error)] +pub enum ClientGeneralError { + #[error("invalid pool name {pool_name:?} for {id}")] + InvalidPoolName { + id: ClientIdentifier, + pool_name: String, + }, + #[error("invalid password for {id}")] + InvalidPassword { id: ClientIdentifier }, +} + +#[derive(Debug, thiserror::Error)] +pub enum ServerAuthError { + #[error("invalid authentication code {code} for {id}")] + InvalidAuthCode { id: ServerIdentifier, code: i32 }, + #[error("unsupported authentication method {method} for {id}")] + UnsupportedMethod { + id: ServerIdentifier, + method: AuthMethod, + }, + #[error(transparent)] + JwtPrivKey(#[from] JwtPrivKeyError), + #[error("authentication method {method} failed")] + Io { + method: AuthMethod, + #[source] + error: io::Error, + }, +} + +#[derive(Debug, thiserror::Error)] +pub enum JwtPrivKeyError { + #[error(transparent)] + OpenSsl(#[from] ErrorStack), + #[error(transparent)] + Io(#[from] io::Error), + #[error(transparent)] + Jwt(#[from] jwt::Error), +} + +#[derive(Debug, thiserror::Error)] +pub enum JwtValidateError { + #[error("no expiration")] + NoExpiration, + #[error("expiration")] + Expiration, + #[error("not before")] + NotBefore, + #[error(transparent)] + Jwt(#[from] jwt::Error), +} + +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +pub struct TlsError(#[from] native_tls::Error); + +#[derive(Debug, thiserror::Error)] +pub enum ServerError { + #[error(transparent)] + InvalidMd5Length(#[from] InvalidMd5Length), + #[error(transparent)] + Mac(#[from] MacError), + #[error("unsupported SCRAM version {0:?}")] + UnsupportedScramVersion(String), + #[error(transparent)] + SocketInfo(#[from] SocketInfoError), + #[error("error message is empty")] + EmptyErrorMessage, + #[error("internal server error")] + Internal, +} + +#[derive(Debug, thiserror::Error)] +pub enum ServerMessageParseError { + #[error("failed to read i32 value from server message")] + InvalidI32, + #[error("message `len` is less than 4")] + LenSmallerThan4(usize), + #[error("cursor {cursor} exceeds message length {message}")] + CursorOverflow { cursor: usize, message: usize }, + #[error( + "message length {message} at cursor {cursor} exceeds received message length {received}" + )] + LenOverlow { + received: usize, + cursor: usize, + message: usize, + }, +} + +#[derive(Debug, thiserror::Error)] +pub enum ClientError { + #[error("missing user parameter on client startup")] + NoUserParam, + #[error("Invalid pool name {{ username: {username}, pool_name: {pool_name}, application_name: {application_name}, virtual pool id: {virtual_pool_id} }}")] + InvalidPoolName { + username: String, + pool_name: String, + application_name: String, + virtual_pool_id: u16, + }, + #[error("prepared statement {0:?} does not exist")] + PreparedStatementNotFound(String), + #[error("failed to store prepated statemtn {0:?}")] + PreparesStatementStore(String), +} + +#[derive(Debug, thiserror::Error)] +#[error("hba forbidden client {client} from address: {address}")] +pub struct HbaForbiddenError { + pub client: ClientIdentifier, + pub address: SocketAddr, +} + +#[derive(Debug, thiserror::Error)] +pub enum JwtPubKeyError { + #[error(transparent)] + Io(#[from] io::Error), + #[error(transparent)] + OpenSsl(#[from] ErrorStack), + #[error("key is not loaded")] + KeyNotLoaded, +} + +#[derive(Debug, thiserror::Error)] +pub enum ParseBytesError { + #[error(transparent)] + Io(#[from] io::Error), + #[error("string is not nul-terminated")] + NoNul, +} + +#[derive(Debug, Clone, PartialEq, Eq)] pub struct ClientIdentifier { pub addr: String, pub application_name: String, @@ -63,15 +265,20 @@ impl ClientIdentifier { impl std::fmt::Display for ClientIdentifier { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let Self { + addr, + application_name, + username, + pool_name, + } = self; write!( f, - "{{ {}@{}/{}?application_name={} }}", - self.username, self.addr, self.pool_name, self.application_name + "{{ {username}@{addr}/{pool_name}?application_name={application_name} }}", ) } } -#[derive(Clone, PartialEq, Debug)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct ServerIdentifier { pub username: String, pub database: String, @@ -88,42 +295,13 @@ impl ServerIdentifier { impl std::fmt::Display for ServerIdentifier { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!( - f, - "{{ username: {}, database: {} }}", - self.username, self.database - ) - } -} - -impl std::fmt::Display for Error { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match &self { - &Error::ClientSocketError(error, client_identifier) => write!( - f, - "Error reading {} from client {}", - error, client_identifier - ), - &Error::ClientGeneralError(error, client_identifier) => { - write!(f, "{} {}", error, client_identifier) - } - &Error::ServerStartupError(error, server_identifier) => write!( - f, - "Error reading {} on server startup {}", - error, server_identifier, - ), - &Error::ServerAuthError(error, server_identifier) => { - write!(f, "{} for {}", error, server_identifier,) - } - - // The rest can use Debug. - err => write!(f, "{:?}", err), - } + let Self { username, database } = self; + write!(f, "{{ username: {username}, database: {database} }}") } } -impl From for Error { - fn from(err: std::ffi::NulError) -> Self { - Error::QueryError(err.to_string()) +impl From for Error { + fn from(value: HbaForbiddenError) -> Self { + Self::from(Box::new(value)) } } diff --git a/src/jwt_auth.rs b/src/jwt_auth.rs index 0d72f832..cdfce593 100644 --- a/src/jwt_auth.rs +++ b/src/jwt_auth.rs @@ -1,4 +1,4 @@ -use crate::errors::Error; +use crate::errors::{Error, JwtPrivKeyError, JwtPubKeyError, JwtValidateError}; use jwt::{Header, PKeyWithDigest, RegisteredClaims, SignWithKey, Token, VerifyWithKey}; use once_cell::sync::Lazy; use openssl::hash::MessageDigest; @@ -38,63 +38,46 @@ pub fn new_claims(username: String, duration: Duration) -> PreferredUsernameClai } impl PreferredUsernameClaims { - fn validate(&self) -> Result<(), Error> { + fn validate(&self) -> Result<(), JwtValidateError> { let now = SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() .as_secs(); if let Some(val) = self.default_claims.not_before { if now < val { - return Err(Error::JWTValidate("not before".to_string())); + return Err(JwtValidateError::NotBefore); } } - if let Some(val) = self.default_claims.expiration { - if now > val { - return Err(Error::JWTValidate("expiration".to_string())); - } - } else { - return Err(Error::JWTValidate("empty expiration".to_string())); + + let Some(expiration) = self.default_claims.expiration else { + return Err(JwtValidateError::NoExpiration); + }; + if now > expiration { + return Err(JwtValidateError::Expiration); } + Ok(()) } } pub async fn sign_with_jwt_priv_key( claims: PreferredUsernameClaims, - key_filename: String, -) -> Result { - let priv_key_data = match fs::read_to_string(key_filename.clone()) { - Ok(data) => data, - Err(err) => return Err(Error::JWTPrivKey(err.to_string())), - }; - let priv_key_rsa = match Rsa::private_key_from_pem(priv_key_data.as_bytes()) { - Ok(rsa) => rsa, - Err(err) => return Err(Error::JWTPrivKey(err.to_string())), - }; - let priv_key = match PKey::from_rsa(priv_key_rsa) { - Ok(data) => data, - Err(err) => return Err(Error::JWTPrivKey(err.to_string())), - }; + key_filename: &str, +) -> Result { + let priv_key_data = fs::read_to_string(key_filename)?; + let priv_key_rsa = Rsa::private_key_from_pem(priv_key_data.as_bytes())?; + let priv_key = PKey::from_rsa(priv_key_rsa)?; let rs256_priv_key = PKeyWithDigest { digest: MessageDigest::sha256(), key: priv_key, }; - let data = match claims.sign_with_key(&rs256_priv_key) { - Ok(data) => data, - Err(err) => return Err(Error::JWTPrivKey(err.to_string())), - }; - Ok(data) + + Ok(claims.sign_with_key(&rs256_priv_key)?) } -pub async fn load_jwt_pub_key(key_filename: String) -> Result<(), Error> { - let pub_key_data = match fs::read_to_string(key_filename.clone()) { - Ok(data) => data, - Err(err) => return Err(Error::JWTPubKey(err.to_string())), - }; - let pub_key = match PKey::public_key_from_pem(pub_key_data.as_ref()) { - Ok(key) => key, - Err(err) => return Err(Error::JWTPubKey(err.to_string())), - }; +pub async fn load_jwt_pub_key(key_filename: String) -> Result<(), JwtPubKeyError> { + let pub_key_data = fs::read_to_string(key_filename.clone())?; + let pub_key = PKey::public_key_from_pem(pub_key_data.as_ref())?; let rs256_public_key = PKeyWithDigest { digest: MessageDigest::sha256(), key: pub_key, @@ -109,15 +92,13 @@ pub async fn get_user_name_from_jwt( input_token: String, ) -> Result { let read_guard = KEYS.read().await; - let pub_key = match read_guard.get(&key_filename) { - Some(key) => key, - None => return Err(Error::JWTPubKey("key is not loaded".to_string())), - }; + let pub_key = read_guard + .get(&key_filename) + .ok_or(JwtPubKeyError::KeyNotLoaded)?; + let token: Token = - match VerifyWithKey::verify_with_key(input_token.as_str(), pub_key) { - Ok(token) => token, - Err(err) => return Err(Error::JWTValidate(err.to_string())), - }; + VerifyWithKey::verify_with_key(input_token.as_str(), pub_key) + .map_err(JwtValidateError::from)?; let (_, claim) = token.into(); claim.validate()?; Ok(claim.username) @@ -174,9 +155,7 @@ mod tests { .unwrap() .as_secs(); claims.default_claims.expiration = Some(now + 2); - let token = match sign_with_jwt_priv_key(claims, "./tests/data/jwt/private.pem".to_string()) - .await - { + let token = match sign_with_jwt_priv_key(claims, "./tests/data/jwt/private.pem").await { Ok(token) => token, Err(err) => panic!("{:?}", err), }; diff --git a/src/lib.rs b/src/lib.rs index f0309996..0e038fbe 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ pub mod admin; +pub mod auth; pub mod client; pub mod cmd_args; pub mod config; @@ -33,8 +34,5 @@ pub fn format_duration(duration: &chrono::Duration) -> String { let days = duration.num_days().to_string(); - format!( - "{}d {}:{}:{}.{}", - days, hours, minutes, seconds, milliseconds - ) + format!("{days}d {hours}:{minutes}:{seconds}.{milliseconds}") } diff --git a/src/main.rs b/src/main.rs index 5e1203f0..2332ab3c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -46,8 +46,6 @@ use tokio::signal::windows as win_signal; use tokio::sync::broadcast; use tokio::{runtime::Builder, sync::mpsc}; -extern crate exitcode; - use pg_doorman::config::{get_config, reload_config, VERSION}; use pg_doorman::core_affinity; use pg_doorman::daemon; diff --git a/src/messages.rs b/src/messages.rs index a44173ea..1603bc82 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -4,12 +4,15 @@ use bytes::{Buf, BufMut, BytesMut}; use log::error; use md5::{Digest, Md5}; use socket2::{SockRef, TcpKeepalive}; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::net::{TcpStream, UnixStream}; use crate::client::PREPARED_STATEMENT_COUNTER; use crate::config::get_config; -use crate::errors::Error; +use crate::errors::{ + ClientBadStartupError, Error, ParseBytesError, ProtocolSyncError, ServerMessageParseError, + SocketError, +}; use crate::constants::{AUTHENTICATION_CLEAR_PASSWORD, MESSAGE_TERMINATOR, SCRAM_SHA_256}; use crate::errors::Error::ProxyTimeout; @@ -20,6 +23,7 @@ use std::ffi::CString; use std::fmt::{Display, Formatter}; use std::hash::{Hash, Hasher}; use std::io::{BufRead, Cursor}; +use std::marker::Unpin; use std::mem; use std::str::FromStr; use std::sync::atomic::{AtomicI64, Ordering}; @@ -62,7 +66,7 @@ impl From<&DataType> for i32 { /// Tell the client that authentication handshake completed successfully. pub async fn auth_ok(stream: &mut S) -> Result<(), Error> where - S: tokio::io::AsyncWrite + std::marker::Unpin, + S: AsyncWrite + Unpin, { let mut auth_ok = BytesMut::with_capacity(9); @@ -76,7 +80,7 @@ where /// Generate md5 password challenge. pub async fn md5_challenge(stream: &mut S) -> Result<[u8; 4], Error> where - S: tokio::io::AsyncWrite + std::marker::Unpin, + S: AsyncWrite + Unpin, { // let mut rng = rand::thread_rng(); let salt: [u8; 4] = [ @@ -98,7 +102,7 @@ where pub async fn plain_password_challenge(stream: &mut S) -> Result<(), Error> where - S: tokio::io::AsyncWrite + std::marker::Unpin, + S: AsyncWrite + Unpin, { let mut res = BytesMut::new(); res.put_u8(b'R'); @@ -112,7 +116,7 @@ where /// Generate scram password challenge. pub async fn scram_start_challenge(stream: &mut S) -> Result<(), Error> where - S: tokio::io::AsyncWrite + std::marker::Unpin, + S: AsyncWrite + Unpin, { let mut res = BytesMut::new(); res.put_u8(b'R'); @@ -127,7 +131,7 @@ where pub async fn scram_server_response(stream: &mut S, code: i32, data: &str) -> Result<(), Error> where - S: tokio::io::AsyncWrite + std::marker::Unpin, + S: AsyncWrite + Unpin, { let mut res = BytesMut::new(); res.put_u8(b'R'); @@ -139,7 +143,7 @@ where pub async fn read_password(stream: &mut S) -> Result, Error> where - S: tokio::io::AsyncRead + std::marker::Unpin, + S: AsyncRead + Unpin, { let code = match stream.read_u8().await { Ok(p) => p, @@ -149,11 +153,12 @@ where )) } }; - if code as char != 'p' { - return Err(Error::ProtocolSyncError(format!( - "Expected p, got {}", - code as char - ))); + if code != b'p' { + return Err(ProtocolSyncError::InvalidCode { + expected: b'p', + actual: code, + } + .into()); }; let len = match stream.read_i32().await { Ok(len) => len, @@ -175,7 +180,7 @@ pub async fn backend_key_data( secret_key: i32, ) -> Result<(), Error> where - S: tokio::io::AsyncWrite + std::marker::Unpin, + S: AsyncWrite + Unpin, { let mut key_data = BytesMut::from(&b"K"[..]); key_data.put_i32(12); @@ -199,7 +204,7 @@ pub fn simple_query(query: &str) -> BytesMut { /// Tell the client we're ready for another query. pub async fn send_ready_for_query(stream: &mut S) -> Result<(), Error> where - S: tokio::io::AsyncWrite + std::marker::Unpin, + S: AsyncWrite + Unpin, { write_all(stream, ready_for_query(false)).await } @@ -208,7 +213,7 @@ where /// This tells the server which user we are and what database we want. pub async fn startup(stream: &mut S, user: String, database: &str) -> Result<(), Error> where - S: tokio::io::AsyncWrite + std::marker::Unpin, + S: AsyncWrite + Unpin, { let mut bytes = BytesMut::with_capacity(25); @@ -261,15 +266,15 @@ pub async fn ssl_request(stream: &mut TcpStream) -> Result<(), Error> { } /// Parse the params the server sends as a key/value format. -pub fn parse_params(mut bytes: BytesMut) -> Result, Error> { - let mut result = HashMap::new(); +pub fn parse_params(mut bytes: BytesMut) -> Result, ClientBadStartupError> { + // TODO: don't create temporary buffer and aggregate directly into map. let mut buf = Vec::new(); let mut tmp = String::new(); while bytes.has_remaining() { let mut c = bytes.get_u8(); - // Null-terminated C-strings. + // Nul-terminated C-strings. while c != 0 { tmp.push(c as char); c = bytes.get_u8(); @@ -281,32 +286,29 @@ pub fn parse_params(mut bytes: BytesMut) -> Result, Erro } } - // Expect pairs of name and value - // and at least one pair to be present. - if buf.len() % 2 != 0 || buf.len() < 2 { - return Err(Error::ClientBadStartup); + if buf.is_empty() { + return Err(ClientBadStartupError::NoParams); } - let mut i = 0; - while i < buf.len() { - let name = buf[i].clone(); - let value = buf[i + 1].clone(); - let _ = result.insert(name, value); - i += 2; + let chunks = buf.chunks_exact(2); + if !chunks.remainder().is_empty() { + return Err(ClientBadStartupError::UnevenParams); } - Ok(result) + Ok(chunks + .map(|pair| (pair[0].clone(), pair[1].clone())) + .collect()) } /// Parse StartupMessage parameters. /// e.g. user, database, application_name, etc. -pub fn parse_startup(bytes: BytesMut) -> Result, Error> { +pub fn parse_startup(bytes: BytesMut) -> Result, ClientBadStartupError> { let result = parse_params(bytes)?; // Minimum required parameters // I want to have the user at the very minimum, according to the protocol spec. if !result.contains_key("user") { - return Err(Error::ClientBadStartup); + return Err(ClientBadStartupError::UserUnspecified); } Ok(result) @@ -350,7 +352,7 @@ pub async fn md5_password( salt: &[u8], ) -> Result<(), Error> where - S: tokio::io::AsyncWrite + std::marker::Unpin, + S: AsyncWrite + Unpin, { let password = md5_hash_password(user, password, salt); @@ -365,7 +367,7 @@ where pub async fn md5_password_with_hash(stream: &mut S, hash: &str, salt: &[u8]) -> Result<(), Error> where - S: tokio::io::AsyncWrite + std::marker::Unpin, + S: AsyncWrite + Unpin, { let password = md5_hash_second_pass(hash, salt); let mut message = BytesMut::with_capacity(password.len() as usize + 5); @@ -381,7 +383,7 @@ where /// This tells the client we're ready for the next query. pub async fn custom_protocol_response_ok(stream: &mut S, message: &str) -> Result<(), Error> where - S: tokio::io::AsyncWrite + std::marker::Unpin, + S: AsyncWrite + Unpin, { let mut res = BytesMut::with_capacity(25); @@ -399,7 +401,7 @@ where pub async fn error_response(stream: &mut S, message: &str, code: &str) -> Result<(), Error> where - S: tokio::io::AsyncWrite + std::marker::Unpin, + S: AsyncWrite + Unpin, { error_response_terminal(stream, message, code).await?; send_ready_for_query(stream).await @@ -411,7 +413,7 @@ pub async fn error_response_terminal( code: &str, ) -> Result<(), Error> where - S: tokio::io::AsyncWrite + std::marker::Unpin, + S: AsyncWrite + Unpin, { let mut error = BytesMut::new(); @@ -441,12 +443,12 @@ where res.put_i32(error.len() as i32 + 4); res.put(error); - write_all_flush(stream, &res).await + Ok(write_all_flush(stream, &res).await?) } pub async fn wrong_password(stream: &mut S, user: &str) -> Result<(), Error> where - S: tokio::io::AsyncWrite + std::marker::Unpin, + S: AsyncWrite + Unpin, { let mut error = BytesMut::new(); @@ -464,7 +466,7 @@ where // The short error message. error.put_u8(b'M'); - error.put_slice(format!("password authentication failed for user \"{}\"\0", user).as_bytes()); + error.put_slice(format!("password authentication failed for user {user:?}\0").as_bytes()); // No more fields follow. error.put_u8(0); @@ -489,7 +491,7 @@ pub fn row_description(columns: &Vec<(&str, DataType)>) -> BytesMut { for (name, data_type) in columns { // Column name - row_desc.put_slice(format!("{}\0", name).as_bytes()); + row_desc.put_slice(format!("{name}\0").as_bytes()); // Doesn't belong to any table row_desc.put_i32(0); @@ -559,7 +561,7 @@ pub fn data_row_nullable(row: &Vec>) -> BytesMut { data_row.put_i32(column.len() as i32); data_row.put_slice(column); } else { - data_row.put_i32(-1_i32); + data_row.put_i32(-1); } } @@ -572,7 +574,7 @@ pub fn data_row_nullable(row: &Vec>) -> BytesMut { /// Create a CommandComplete message. pub fn command_complete(command: &str) -> BytesMut { - let cmd = BytesMut::from(format!("{}\0", command).as_bytes()); + let cmd = BytesMut::from(format!("{command}\0").as_bytes()); let mut res = BytesMut::new(); res.put_u8(b'C'); res.put_i32(cmd.len() as i32 + 4); @@ -586,8 +588,8 @@ pub fn notify(message: &str, details: String) -> BytesMut { notify_cmd.put_slice("SNOTICE\0".as_bytes()); notify_cmd.put_slice("C00000\0".as_bytes()); - notify_cmd.put_slice(format!("M{}\0", message).as_bytes()); - notify_cmd.put_slice(format!("D{}\0", details).as_bytes()); + notify_cmd.put_slice(format!("M{message}\0").as_bytes()); + notify_cmd.put_slice(format!("D{details}\0").as_bytes()); // this extra byte says that is the end of the package notify_cmd.put_u8(0); @@ -652,11 +654,7 @@ pub fn ready_for_query(in_transaction: bool) -> BytesMut { bytes.put_u8(b'Z'); bytes.put_i32(5); - if in_transaction { - bytes.put_u8(b'T'); - } else { - bytes.put_u8(b'I'); - } + bytes.put_u8(if in_transaction { b'T' } else { b'I' }); bytes } @@ -664,7 +662,7 @@ pub fn ready_for_query(in_transaction: bool) -> BytesMut { /// Write all data in the buffer to the TcpStream. pub async fn write_all(stream: &mut S, buf: BytesMut) -> Result<(), Error> where - S: tokio::io::AsyncWrite + std::marker::Unpin, + S: AsyncWrite + Unpin, { match stream.write_all(&buf).await { Ok(_) => Ok(()), @@ -676,42 +674,26 @@ where } /// Write all the data in the buffer to the TcpStream, write owned half (see mpsc). -pub async fn write_all_half(stream: &mut S, buf: &BytesMut) -> Result<(), Error> +pub async fn write_all_half(stream: &mut S, buf: &BytesMut) -> Result<(), SocketError> where - S: tokio::io::AsyncWrite + std::marker::Unpin, + S: AsyncWrite + Unpin, { - match stream.write_all(buf).await { - Ok(_) => Ok(()), - Err(err) => Err(Error::SocketError(format!( - "Error writing to socket: {:?}", - err - ))), - } + stream.write_all(buf).await.map_err(SocketError::Write) } -pub async fn write_all_flush(stream: &mut S, buf: &[u8]) -> Result<(), Error> +pub async fn write_all_flush(stream: &mut S, buf: &[u8]) -> Result<(), SocketError> where - S: tokio::io::AsyncWrite + std::marker::Unpin, + S: AsyncWrite + Unpin, { - match stream.write_all(buf).await { - Ok(_) => match stream.flush().await { - Ok(_) => Ok(()), - Err(err) => Err(Error::SocketError(format!( - "Error flushing socket: {:?}", - err - ))), - }, - Err(err) => Err(Error::SocketError(format!( - "Error writing to socket: {:?}", - err - ))), - } + stream.write_all(buf).await.map_err(SocketError::Write)?; + stream.flush().await.map_err(SocketError::Flush)?; + Ok(()) } /// Read header. pub async fn read_message_header(stream: &mut S) -> Result<(u8, i32), Error> where - S: tokio::io::AsyncRead + std::marker::Unpin, + S: AsyncRead + Unpin, { let code = match stream.read_u8().await { Ok(code) => code, @@ -737,7 +719,7 @@ where /// Read a message data from the socket. pub async fn read_message_data(stream: &mut S, code: u8, len: i32) -> Result where - S: tokio::io::AsyncRead + std::marker::Unpin, + S: AsyncRead + Unpin, { let mut bytes = BytesMut::with_capacity(len as usize + 1); @@ -775,7 +757,7 @@ where /// Read a complete message from the socket. pub async fn read_message(stream: &mut S, max_memory_usage: u64) -> Result where - S: tokio::io::AsyncRead + std::marker::Unpin, + S: AsyncRead + Unpin, { let (code, len) = read_message_header(stream).await?; if len > MAX_MESSAGE_SIZE { @@ -790,8 +772,7 @@ where }; let current_memory = CURRENT_MEMORY.load(Ordering::Relaxed); if current_memory > max_memory_usage as i64 { - error!("reached memory limit while processing code '{}' message len '{}' current memory usage: '{}' maximum memory usage: '{}'", - code as char, len, current_memory, max_memory_usage); + error!("reached memory limit while processing code '{code:#x}' message len '{len}' current memory usage: '{current_memory}' maximum memory usage: '{max_memory_usage}'"); proxy_copy_data_with_timeout( Duration::from_millis(get_config().general.proxy_copy_data_timeout), stream, @@ -799,7 +780,7 @@ where len as usize - mem::size_of::(), ) .await?; - return Err(Error::CurrentMemoryUsage); + return Err(Error::MemoryLimitReached); } CURRENT_MEMORY.fetch_add(len as i64, Ordering::Relaxed); let bytes = read_message_data(stream, code, len).await?; @@ -814,22 +795,18 @@ pub async fn proxy_copy_data_with_timeout( len: usize, ) -> Result where - R: tokio::io::AsyncRead + std::marker::Unpin, - W: tokio::io::AsyncWrite + std::marker::Unpin, + R: AsyncRead + Unpin, + W: AsyncWrite + Unpin, { - match timeout(duration, proxy_copy_data(read, write, len)).await { - Ok(res) => match res { - Ok(len) => Ok(len), - Err(err) => Err(err), - }, - Err(_) => Err(ProxyTimeout), - } + timeout(duration, proxy_copy_data(read, write, len)) + .await + .map_err(|_| ProxyTimeout)? } pub async fn proxy_copy_data(read: &mut R, write: &mut W, len: usize) -> Result where - R: tokio::io::AsyncRead + std::marker::Unpin, - W: tokio::io::AsyncWrite + std::marker::Unpin, + R: AsyncRead + Unpin, + W: AsyncWrite + Unpin, { const MAX_BUFFER_CHUNK: usize = 4096; // гарантия того что вызовы read из // буфферизированного stream 8kb будет быстрым. @@ -910,23 +887,14 @@ pub fn configure_unix_socket(stream: &UnixStream) { let sock_ref = SockRef::from(stream); let conf = get_config(); - match sock_ref.set_linger(Some(Duration::from_secs(conf.general.tcp_so_linger))) { - Ok(_) => {} - Err(err) => error!("Could not configure unix_so_linger for socket: {}", err), + if let Err(e) = sock_ref.set_linger(Some(Duration::from_secs(conf.general.tcp_so_linger))) { + error!("Could not configure unix_so_linger for socket: {e}") } - match sock_ref.set_send_buffer_size(conf.general.unix_socket_buffer_size) { - Ok(_) => {} - Err(err) => error!( - "Could not configure set_send_buffer_size for socket: {}", - err - ), + if let Err(e) = sock_ref.set_send_buffer_size(conf.general.unix_socket_buffer_size) { + error!("Could not configure set_send_buffer_size for socket: {e}",) } - match sock_ref.set_recv_buffer_size(conf.general.unix_socket_buffer_size) { - Ok(_) => {} - Err(err) => error!( - "Could not configure set_recv_buffer_size for socket: {}", - err - ), + if let Err(e) = sock_ref.set_recv_buffer_size(conf.general.unix_socket_buffer_size) { + error!("Could not configure set_recv_buffer_size for socket: {e}",) } } @@ -934,61 +902,51 @@ pub fn configure_tcp_socket(stream: &TcpStream) { let sock_ref = SockRef::from(stream); let conf = get_config(); - match sock_ref.set_linger(Some(Duration::from_secs(conf.general.tcp_so_linger))) { - Ok(_) => {} - Err(err) => error!("Could not configure tcp_so_linger for socket: {}", err), - } + if let Err(e) = sock_ref.set_linger(Some(Duration::from_secs(conf.general.tcp_so_linger))) { + error!("Could not configure tcp_so_linger for socket: {e}"); + }; - match sock_ref.set_nodelay(conf.general.tcp_no_delay) { - Ok(_) => {} - Err(err) => error!("Could not configure no delay for socket: {}", err), + if let Err(e) = sock_ref.set_nodelay(conf.general.tcp_no_delay) { + error!("Could not configure no delay for socket: {}", e) } - match sock_ref.set_keepalive(true) { - Ok(_) => { - match sock_ref.set_tcp_keepalive( - &TcpKeepalive::new() - .with_interval(Duration::from_secs(conf.general.tcp_keepalives_interval)) - .with_retries(conf.general.tcp_keepalives_count) - .with_time(Duration::from_secs(conf.general.tcp_keepalives_idle)), - ) { - Ok(_) => (), - Err(err) => error!("Could not configure tcp_keepalive for socket: {}", err), - } - } - Err(err) => error!("Could not configure socket: {}", err), + if let Err(err) = sock_ref.set_keepalive(true) { + error!("Could not configure socket: {err}") + } else if let Err(err) = sock_ref.set_tcp_keepalive( + &TcpKeepalive::new() + .with_interval(Duration::from_secs(conf.general.tcp_keepalives_interval)) + .with_retries(conf.general.tcp_keepalives_count) + .with_time(Duration::from_secs(conf.general.tcp_keepalives_idle)), + ) { + error!("Could not configure tcp_keepalive for socket: {err}") } } pub trait BytesMutReader { - fn read_string(&mut self) -> Result; + fn read_string(&mut self) -> Result; } impl BytesMutReader for Cursor<&BytesMut> { /// Should only be used when reading strings from the message protocol. /// Can be used to read multiple strings from the same message which are separated by the null byte - fn read_string(&mut self) -> Result { + fn read_string(&mut self) -> Result { let mut buf = vec![]; - match self.read_until(b'\0', &mut buf) { - Ok(_) => Ok(String::from_utf8_lossy(&buf[..buf.len() - 1]).to_string()), - Err(err) => Err(Error::ParseBytesError(err.to_string())), - } + self.read_until(b'\0', &mut buf)?; + Ok(String::from_utf8_lossy(&buf[..buf.len() - 1]).to_string()) } } impl BytesMutReader for BytesMut { /// Should only be used when reading strings from the message protocol. /// Can be used to read multiple strings from the same message which are separated by the null byte - fn read_string(&mut self) -> Result { - let null_index = self.iter().position(|&byte| byte == b'\0'); + fn read_string(&mut self) -> Result { + let index = self + .iter() + .position(|&byte| byte == b'\0') + .ok_or(ParseBytesError::NoNul)?; - match null_index { - Some(index) => { - let string_bytes = self.split_to(index + 1); - Ok(String::from_utf8_lossy(&string_bytes[..string_bytes.len() - 1]).to_string()) - } - None => Err(Error::ParseBytesError("Could not read string".to_string())), - } + let string_bytes = self.split_to(index + 1); + Ok(String::from_utf8_lossy(&string_bytes[..string_bytes.len() - 1]).to_string()) } } @@ -1127,7 +1085,7 @@ impl Parse { } /// Gets the name of the prepared statement from the buffer - pub fn get_name(buf: &BytesMut) -> Result { + pub fn get_name(buf: &BytesMut) -> Result { let mut cursor = Cursor::new(buf); // Skip the code and length cursor.advance(mem::size_of::() + mem::size_of::()); @@ -1286,7 +1244,7 @@ impl TryFrom for BytesMut { impl Bind { /// Gets the name of the prepared statement from the buffer - pub fn get_name(buf: &BytesMut) -> Result { + pub fn get_name(buf: &BytesMut) -> Result { let mut cursor = Cursor::new(buf); // Skip the code and length cursor.advance(mem::size_of::() + mem::size_of::()); @@ -1581,12 +1539,7 @@ impl PgErrorMsg { let msg_content = match String::from_utf8_lossy(&msg_part[1..]).parse() { Ok(c) => c, - Err(err) => { - return Err(Error::ServerMessageParserError(format!( - "could not parse server message field. err {:?}", - err - ))) - } + Err(infallible) => match infallible {}, }; match &msg_part[0] { @@ -1648,7 +1601,9 @@ impl PgErrorMsg { } } -pub fn set_messages_right_place(in_msg: Vec) -> Result { +// TODO: this can be rewritten using parser such as `peg` +// for ease of maintenance and potentially better performance +pub fn set_messages_right_place(in_msg: &[u8]) -> Result { let in_msg_len = in_msg.len(); let mut cursor = 0; let mut count_parse_complete = 0; @@ -1658,101 +1613,93 @@ pub fn set_messages_right_place(in_msg: Vec) -> Result { // count parse message. loop { if cursor > in_msg_len { - return Err(Error::ServerMessageParserError( - "Cursor is more than total message size".to_string(), - )); + return Err(ServerMessageParseError::CursorOverflow { + cursor, + message: in_msg_len, + }); } if cursor == in_msg_len { break; } - match in_msg[cursor] as char { - '1' => count_parse_complete += 1, - '3' => count_stmt_close += 1, + match in_msg[cursor] { + b'1' => count_parse_complete += 1, + b'3' => count_stmt_close += 1, _ => (), } cursor += 1; if cursor + 4 > in_msg_len { - return Err(Error::ServerMessageParserError( - "Can't read i32 from server message".to_string(), - )); + return Err(ServerMessageParseError::InvalidI32); } - let len_ref = match <[u8; 4]>::try_from(&in_msg[cursor..cursor + 4]) { + let len_bytes = match <[u8; 4]>::try_from(&in_msg[cursor..cursor + 4]) { Ok(len_ref) => len_ref, - _ => { - return Err(Error::ServerMessageParserError( - "Can't convert i32 from server message".to_string(), - )) - } + Err(_) => return Err(ServerMessageParseError::InvalidI32), }; - let mut len = i32::from_be_bytes(len_ref) as usize; - if len < 4 { - return Err(Error::ServerMessageParserError( - "Message len less than 4".to_string(), - )); - } - len -= 4; + let len = i32::from_be_bytes(len_bytes) as usize; + let len = len + .checked_sub(4) + .ok_or(ServerMessageParseError::LenSmallerThan4(len))?; cursor += 4; if cursor + len > in_msg_len { - return Err(Error::ServerMessageParserError( - "Message len more than server message size".to_string(), - )); + return Err(ServerMessageParseError::LenOverlow { + received: in_msg_len, + cursor, + message: len, + }); } cursor += len; } if count_stmt_close == 0 && count_parse_complete == 0 { - result.put(&in_msg[..]); + result.put(in_msg); return Ok(result); } cursor = 0; - let mut prev_msg: char = ' '; + let mut prev_msg = b' '; loop { if cursor == in_msg_len { return Ok(result); } - match in_msg[cursor] as char { - '1' => { - if count_parse_complete == 0 || prev_msg == '1' { + match in_msg[cursor] { + b'1' => { + if count_parse_complete == 0 || prev_msg == b'1' { // ParseComplete: ignore. cursor += 5; continue; } count_parse_complete -= 1; } - '2' | 't' => { - if (prev_msg != '1') && (prev_msg != '2') && count_parse_complete > 0 { + b'2' | b't' => { + if prev_msg != b'1' && prev_msg != b'2' && count_parse_complete > 0 { // BindComplete, just add before ParseComplete. result.put(parse_complete()); count_parse_complete -= 1; } } - '3' => { + b'3' => { if count_stmt_close == 1 { cursor += 5; continue; } } - 'Z' => { + b'Z' => { if count_stmt_close == 1 { result.put(close_complete()) } } _ => {} }; - prev_msg = in_msg[cursor] as char; + prev_msg = in_msg[cursor]; cursor += 1; // code let len_ref = match <[u8; 4]>::try_from(&in_msg[cursor..cursor + 4]) { Ok(len_ref) => len_ref, - _ => { - return Err(Error::ServerMessageParserError( - "Can't convert i32 from server message".to_string(), - )) - } + _ => return Err(ServerMessageParseError::InvalidI32), }; - let mut len = i32::from_be_bytes(len_ref) as usize; - len -= 4; + let len = i32::from_be_bytes(len_ref) as usize; + let len = len + .checked_sub(4) + .ok_or(ServerMessageParseError::LenSmallerThan4(len))?; cursor += 4; result.put(&in_msg[cursor - 5..cursor + len]); cursor += len; @@ -1790,23 +1737,17 @@ mod tests { let mut in_msg = parse_complete(); assert_eq!( parse_complete().len(), - set_messages_right_place(in_msg.to_vec()) - .expect("parsing") - .len() + set_messages_right_place(&in_msg).expect("parsing").len() ); in_msg.put(flush()); assert_eq!( parse_complete().len() + flush().len(), - set_messages_right_place(in_msg.to_vec()) - .expect("parsing") - .len() + set_messages_right_place(&in_msg).expect("parsing").len() ); in_msg.put(ready_for_query(true)); assert_eq!( parse_complete().len() + flush().len() + ready_for_query(true).len(), - set_messages_right_place(in_msg.to_vec()) - .expect("parsing") - .len() + set_messages_right_place(&in_msg).expect("parsing").len() ); } @@ -1861,7 +1802,7 @@ mod tests { in_msg.put(row_description()); // t in_msg.put(command_complete("2")); // C in_msg.put(ready_for_query(false)); // Z - let out_msg = set_messages_right_place(in_msg.to_vec()).expect("parse"); + let out_msg = set_messages_right_place(&in_msg).expect("parse"); println!("112tC2tCZ"); assert_eq!(show_headers(out_msg), "12tC12tCZ".to_string()); } diff --git a/src/scram_client.rs b/src/scram_client.rs index 111dd5e1..c7c1dc48 100644 --- a/src/scram_client.rs +++ b/src/scram_client.rs @@ -12,7 +12,7 @@ use sha2::{Digest, Sha256}; use std::fmt::Write; use crate::constants::*; -use crate::errors::Error; +use crate::errors::{Error, ProtocolSyncError, ServerError}; /// Normalize a password string. Postgres /// passwords don't have to be UTF-8. @@ -79,12 +79,12 @@ impl ScramSha256 { let server_message = Message::parse(message)?; if !server_message.nonce.starts_with(&self.nonce) { - return Err(Error::ProtocolSyncError("SCRAM".to_string())); + return Err(ProtocolSyncError::Scram.into()); } let salt = match general_purpose::STANDARD.decode(&server_message.salt) { Ok(salt) => salt, - Err(_) => return Err(Error::ProtocolSyncError("SCRAM".to_string())), + Err(_) => return Err(ProtocolSyncError::Scram.into()), }; let salted_password = Self::hi( @@ -96,10 +96,8 @@ impl ScramSha256 { // Save for verification of final server message. self.salted_password = salted_password; - let mut hmac = match Hmac::::new_from_slice(&salted_password) { - Ok(hmac) => hmac, - Err(_) => return Err(Error::ServerError), - }; + let mut hmac = + Hmac::::new_from_slice(&salted_password).map_err(ServerError::from)?; hmac.update(b"Client Key"); @@ -117,14 +115,12 @@ impl ScramSha256 { self.message.clear(); // Start writing the client reply. - match write!( + write!( &mut self.message, "c={},r={}", cbind_input, server_message.nonce - ) { - Ok(_) => (), - Err(_) => return Err(Error::ServerError), - }; + ) + .map_err(|_| ServerError::Internal)?; let auth_message = format!( "n=,r={},{},{}", @@ -133,10 +129,7 @@ impl ScramSha256 { String::from_utf8_lossy(&self.message[..]) ); - let mut hmac = match Hmac::::new_from_slice(&stored_key) { - Ok(hmac) => hmac, - Err(_) => return Err(Error::ServerError), - }; + let mut hmac = Hmac::::new_from_slice(&stored_key).map_err(ServerError::from)?; hmac.update(auth_message.as_bytes()); // Save the auth message for server final message verification. @@ -150,14 +143,12 @@ impl ScramSha256 { *proof ^= signature; } - match write!( + write!( &mut self.message, ",p={}", general_purpose::STANDARD.encode(&*client_proof) - ) { - Ok(_) => (), - Err(_) => return Err(Error::ServerError), - }; + ) + .map_err(|_| ServerError::Internal)?; Ok(self.message.clone()) } @@ -166,28 +157,20 @@ impl ScramSha256 { pub fn finish(&mut self, message: &BytesMut) -> Result<(), Error> { let final_message = FinalMessage::parse(message)?; - let verifier = match general_purpose::STANDARD.decode(final_message.value) { - Ok(verifier) => verifier, - Err(_) => return Err(Error::ProtocolSyncError("SCRAM".to_string())), - }; + let verifier = general_purpose::STANDARD + .decode(final_message.value) + .map_err(|_| ProtocolSyncError::Scram)?; - let mut hmac = match Hmac::::new_from_slice(&self.salted_password) { - Ok(hmac) => hmac, - Err(_) => return Err(Error::ServerError), - }; + let mut hmac = + Hmac::::new_from_slice(&self.salted_password).map_err(ServerError::from)?; hmac.update(b"Server Key"); let server_key = hmac.finalize().into_bytes(); - let mut hmac = match Hmac::::new_from_slice(&server_key) { - Ok(hmac) => hmac, - Err(_) => return Err(Error::ServerError), - }; + let mut hmac = Hmac::::new_from_slice(&server_key).map_err(ServerError::from)?; hmac.update(self.auth_message.as_bytes()); - match hmac.verify_slice(&verifier) { - Ok(_) => Ok(()), - Err(_) => Err(Error::ServerError), - } + hmac.verify_slice(&verifier).map_err(ServerError::from)?; + Ok(()) } /// Hash the password with the salt i-times. @@ -224,20 +207,18 @@ struct Message { impl Message { /// Parse the server SASL challenge. fn parse(message: &BytesMut) -> Result { - let parts = String::from_utf8_lossy(&message[..]) - .split(',') - .map(|s| s.to_string()) - .collect::>(); + let message = String::from_utf8_lossy(&message[..]); + let parts = message.split(',').collect::>(); - if parts.len() != 3 { - return Err(Error::ProtocolSyncError("SCRAM".to_string())); - } + let [nonce, salt, iterations] = parts[..] else { + return Err(ProtocolSyncError::Scram.into()); + }; - let nonce = str::replace(&parts[0], "r=", ""); - let salt = str::replace(&parts[1], "s=", ""); - let iterations = match str::replace(&parts[2], "i=", "").parse::() { + let nonce = str::replace(nonce, "r=", ""); + let salt = str::replace(salt, "s=", ""); + let iterations = match str::replace(iterations, "i=", "").parse::() { Ok(iterations) => iterations, - Err(_) => return Err(Error::ProtocolSyncError("SCRAM".to_string())), + Err(_) => return Err(ProtocolSyncError::Scram.into()), }; Ok(Message { @@ -257,7 +238,7 @@ impl FinalMessage { /// Parse the server final validation message. pub fn parse(message: &BytesMut) -> Result { if !message.starts_with(b"v=") || message.len() < 4 { - return Err(Error::ProtocolSyncError("SCRAM".to_string())); + return Err(ProtocolSyncError::Scram.into()); } Ok(FinalMessage { diff --git a/src/scram_server.rs b/src/scram_server.rs index ed6b1df6..bc877a05 100644 --- a/src/scram_server.rs +++ b/src/scram_server.rs @@ -265,7 +265,7 @@ pub fn prepare_server_first_response( let key = rng.gen::<[u8; 18]>(); // bytes 18 -> base64 24 (( 4*(18/3) )) let nonce = client_nonce.to_owned() + &*general_purpose::STANDARD.encode(key); - let server_first_bare = format!("r={},s={},i={}", nonce, server_salt, server_iteration,); + let server_first_bare = format!("r={},s={},i={}", nonce, server_salt, server_iteration); ServerFirstMessage { nonce, client_first_bare: client_first_bare.to_string(), diff --git a/src/server.rs b/src/server.rs index afbc1f89..771e5033 100644 --- a/src/server.rs +++ b/src/server.rs @@ -5,6 +5,7 @@ use log::{error, info, warn}; use lru::LruCache; use once_cell::sync::Lazy; use std::collections::{HashMap, HashSet, VecDeque}; +use std::marker::Unpin; use std::mem; use std::num::NonZeroUsize; use std::os::fd::AsRawFd; @@ -13,9 +14,10 @@ use std::time::{Duration, SystemTime}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, BufStream}; use tokio::net::{TcpStream, UnixStream}; +use crate::auth::AuthMethod; use crate::config::{get_config, Address, User}; use crate::constants::*; -use crate::errors::{Error, ServerIdentifier}; +use crate::errors::{Error, ProtocolSyncError, ServerAuthError, ServerError, ServerIdentifier}; use crate::messages::BytesMutReader; use crate::messages::*; use crate::pool::{ClientServerMap, CANCELED_PIDS}; @@ -312,7 +314,7 @@ pub struct Server { /// Is the server in copy-in or copy-out modes in_copy_mode: bool, - flush_wait_code: char, + flush_wait_code: u8, /// Is the server broken? We'll remote it from the pool if so. bad: bool, @@ -404,13 +406,13 @@ impl Server { mut client_server_parameters: Option<&mut ServerParameters>, ) -> Result where - C: tokio::io::AsyncWrite + std::marker::Unpin, + C: AsyncWrite + Unpin, { loop { self.stats.wait_reading(); let (code_u8, message_len) = read_message_header(&mut self.stream).await?; // if message server is too big. - if message_len > self.max_message_size && code_u8 as char == 'D' { + if message_len > self.max_message_size && code_u8 == b'D' { // send current buffer + header. self.buffer.put_u8(code_u8); self.buffer.put_i32(message_len); @@ -467,27 +469,25 @@ impl Server { // Buffer the message we'll forward to the client later. self.buffer.put(&message[..]); - let code = message.get_u8() as char; + let code = message.get_u8(); let _len = message.get_i32(); match code { // ReadyForQuery - 'Z' => { - let transaction_state = message.get_u8() as char; - - match transaction_state { + b'Z' => { + match message.get_u8() { // In transaction. - 'T' => { + b'T' => { self.in_transaction = true; } // Idle, transaction over. - 'I' => { + b'I' => { self.in_transaction = false; } // Some error occurred, the transaction was rolled back. - 'E' => { + b'E' => { if let Ok(msg) = PgErrorMsg::parse(&message) { error!( "Server error (in tx) {} (severity: {} code: {} message: {})", @@ -498,13 +498,13 @@ impl Server { } // Something totally unexpected, this is not a Postgres server we know. - _ => { - let err = Error::ProtocolSyncError(format!( - "Server {}: unknown transaction state: {}", - self, transaction_state - )); + transaction_state => { + let err = ProtocolSyncError::UnknownTransactionState { + server: self.to_string(), + transaction_state, + }; self.mark_bad(err.to_string().as_str(), true); - return Err(err); + return Err(err.into()); } }; @@ -514,7 +514,7 @@ impl Server { } // ErrorResponse - 'E' => { + b'E' => { if let Ok(msg) = PgErrorMsg::parse(&message) { error!( "Server {}: {} ({}) - {}", @@ -537,7 +537,7 @@ impl Server { } // CommandComplete - 'C' => { + b'C' => { if self.in_copy_mode { self.in_copy_mode = false; } @@ -545,13 +545,13 @@ impl Server { if message.len() == 4 && message.to_vec().eq(COMMAND_COMPLETE_BY_SET) { self.cleanup_state.needs_cleanup_set = true; } - if self.flush_wait_code == 'C' { + if self.flush_wait_code == b'C' { self.data_available = false; break; } } - 'S' => { + b'S' => { let key = message.read_string().unwrap(); let value = message.read_string().unwrap(); @@ -569,7 +569,7 @@ impl Server { } // DataRow - 'D' => { + b'D' => { // More data is available after this message, this is not the end of the reply. self.data_available = true; @@ -580,20 +580,20 @@ impl Server { } // CopyInResponse: copy is starting from client to server. - 'G' => { + b'G' => { self.in_copy_mode = true; break; } // CopyOutResponse: copy is starting from the server to the client. - 'H' => { + b'H' => { self.in_copy_mode = true; self.data_available = true; break; } // CopyData - 'd' => { + b'd' => { // Don't flush yet, buffer until we reach limit if self.buffer.len() >= 8196 { break; @@ -602,11 +602,11 @@ impl Server { // CopyDone // Buffer until ReadyForQuery shows up, so don't exit the loop yet. - 'c' => (), + b'c' => (), // NoData // https://www.postgresql.org/docs/current/protocol-flow.html - 'n' => { + b'n' => { if self.is_async() { self.data_available = false; self.set_flush_wait_code(code); @@ -693,7 +693,7 @@ impl Server { #[inline(always)] pub fn is_async(&self) -> bool { - self.flush_wait_code != ' ' + self.flush_wait_code != b' ' } pub async fn send_and_flush(&mut self, messages: &BytesMut) -> Result<(), Error> { self.stats.data_sent(messages.len()); @@ -710,7 +710,7 @@ impl Server { self.stats.wait_idle(); error!("Terminating server {} because of: {:?}", self, err); self.mark_bad("flush to server error", true); - Err(err) + Err(err.into()) } } } @@ -796,7 +796,7 @@ impl Server { /// Switch to async mode, flushing messages as soon /// as we receive them without buffering or waiting for "ReadyForQuery". #[inline(always)] - pub fn set_flush_wait_code(&mut self, wait: char) { + pub fn set_flush_wait_code(&mut self, wait: u8) { self.flush_wait_code = wait } @@ -872,7 +872,7 @@ impl Server { // If it's not there, something went bad, I'm guessing bad syntax or permissions error // on the server. if !self.has_prepared_statement(&parse.name) { - Err(Error::PreparedStatementError) + Err(Error::NoPreparedStatement) } else { Ok(()) } @@ -955,7 +955,7 @@ impl Server { bytes.put_i32(process_id); bytes.put_i32(secret_key); - write_all_flush(&mut stream, &bytes).await + Ok(write_all_flush(&mut stream, &bytes).await?) } // Marks a connection as needing cleanup at checkin @@ -1014,7 +1014,7 @@ impl Server { loop { let code = match stream.read_u8().await { - Ok(code) => code as char, + Ok(code) => code, Err(err) => { return Err(Error::ServerStartupError( format!( @@ -1041,7 +1041,7 @@ impl Server { match code { // Authentication - 'R' => { + b'R' => { // Determine which kind of authentication is required, if any. let auth_code = match stream.read_i32().await { Ok(auth_code) => auth_code, @@ -1058,10 +1058,11 @@ impl Server { SASL => { match scram_client_auth { None => { - return Err(Error::ServerAuthError( - "server wants sasl auth, but it is not configured".into(), - server_identifier, - )); + return Err(ServerAuthError::UnsupportedMethod { + id: server_identifier, + method: AuthMethod::Sasl, + } + .into()); } Some(_) => { let sasl_len = (len - 8) as usize; @@ -1105,7 +1106,10 @@ impl Server { write_all_flush(&mut stream, &res).await?; } else { error!("Unsupported SCRAM version: {}", sasl_type); - return Err(Error::ServerError); + return Err(ServerError::UnsupportedScramVersion( + sasl_type.to_string(), + ) + .into()); } } } @@ -1159,134 +1163,103 @@ impl Server { } /* SASL end */ AUTHENTICATION_CLEAR_PASSWORD => { - if user.server_username.is_none() || user.server_password.is_none() { + let (Some(server_username), Some(server_password)) = + (&user.server_username, &user.server_password) + else { error!( "authentication on server {}@{} with clear auth is not configured", server_identifier.username, server_identifier.database, ); - return Err(Error::ServerAuthError( - "server wants clear password authentication, but auth for this server is not configured".into(), - server_identifier, - )); - } - let server_password = - as Clone>::clone(&user.server_password) - .unwrap() - .clone(); - let server_username = - as Clone>::clone(&user.server_username) - .unwrap() - .clone(); - if server_password.starts_with(JWT_PRIV_KEY_PASSWORD_PREFIX) { - // generate password - let claims = new_claims(server_username, Duration::from_secs(120)); - let token = match sign_with_jwt_priv_key( - claims, - server_password - .strip_prefix(JWT_PRIV_KEY_PASSWORD_PREFIX) - .unwrap() - .to_string(), - ) + return Err(ServerAuthError::UnsupportedMethod { + id: server_identifier, + method: AuthMethod::ClearPassword, + } + .into()); + }; + + let Some(jwt_priv_key) = + server_password.strip_prefix(JWT_PRIV_KEY_PASSWORD_PREFIX) + else { + return Err(ServerAuthError::UnsupportedMethod { + id: server_identifier, + method: AuthMethod::ClearPassword, + } + .into()); + }; + + // generate password + let claims = + new_claims(server_username.clone(), Duration::from_secs(120)); + let token = sign_with_jwt_priv_key(claims, jwt_priv_key) .await - { - Ok(token) => token, - Err(err) => { - return Err(Error::ServerAuthError( - err.to_string(), - server_identifier, - )) - } - }; - let mut password_response = BytesMut::new(); - password_response.put_u8(b'p'); - password_response.put_i32(token.len() as i32 + 4 + 1); - password_response.put_slice(token.as_bytes()); - password_response.put_u8(b'\0'); - match stream.try_write(&password_response) { - Ok(_) => (), - Err(err) => { - return Err(Error::ServerAuthError( - format!( - "jwt authentication on the server failed: {:?}", - err - ), - server_identifier, - )); - } + .map_err(ServerAuthError::JwtPrivKey)?; + + let mut password_response = BytesMut::new(); + password_response.put_u8(b'p'); + password_response.put_i32(token.len() as i32 + 4 + 1); + password_response.put_slice(token.as_bytes()); + password_response.put_u8(b'\0'); + stream.try_write(&password_response).map_err(|e| { + ServerAuthError::Io { + method: AuthMethod::Jwt, + error: e, } - } else { - return Err(Error::ServerAuthError( - "plain password is not supported".into(), - server_identifier, - )); - } + })?; } MD5_ENCRYPTED_PASSWORD => { - if user.server_username.is_none() || user.server_password.is_none() { + let (Some(server_username), Some(server_password)) = + (&user.server_username, &user.server_password) + else { error!( "authentication for server {}@{} with md5 auth is not configured", server_identifier.username, server_identifier.database, ); - return Err(Error::ServerAuthError( - "server wants md5 authentication, but auth for this server is not configured".into(), - server_identifier, - )); - } else { - let server_username = - as Clone>::clone(&user.server_username) - .unwrap() - .clone(); - let server_password = - as Clone>::clone(&user.server_password) - .unwrap() - .clone(); - let mut salt = BytesMut::with_capacity(4); - match stream.read_buf(&mut salt).await { - Ok(_) => (), - Err(err) => { - return Err(Error::ServerAuthError( - format!("md5 authentication on the server: {:?}", err), - server_identifier, - )); - } + return Err(ServerAuthError::UnsupportedMethod { + id: server_identifier, + method: AuthMethod::Md5, } - let password_hash = md5_hash_password( - server_username.as_str(), - server_password.as_str(), - salt.as_mut(), - ); - let mut password_response = BytesMut::new(); - password_response.put_u8(b'p'); - password_response.put_i32(password_hash.len() as i32 + 4); - password_response.put_slice(&password_hash); - match stream.try_write(&password_response) { - Ok(_) => (), - Err(err) => { - return Err(Error::ServerAuthError( - format!( - "md5 authentication on the server failed: {:?}", - err - ), - server_identifier, - )); - } + .into()); + }; + + let mut salt = BytesMut::with_capacity(4); + stream + .read_buf(&mut salt) + .await + .map_err(|e| ServerAuthError::Io { + method: AuthMethod::Md5, + error: e, + })?; + let password_hash = md5_hash_password( + server_username.as_str(), + server_password.as_str(), + salt.as_mut(), + ); + let mut password_response = BytesMut::new(); + password_response.put_u8(b'p'); + password_response.put_i32(password_hash.len() as i32 + 4); + password_response.put_slice(&password_hash); + stream.try_write(&password_response).map_err(|e| { + ServerAuthError::Io { + method: AuthMethod::Md5, + error: e, } - } + })?; } - _ => { + code => { error!("this type of authentication on the server {}@{} is not supported, auth code: {}", server_identifier.username, server_identifier.database, auth_code); - return Err(Error::ServerAuthError( - "authentication on the server is not supported".into(), - server_identifier, - )); + return Err(ServerAuthError::InvalidAuthCode { + id: server_identifier, + code, + } + .into()); } } } // ErrorResponse - 'E' => { + b'E' => { let error_code = match stream.read_u8().await { Ok(error_code) => error_code, Err(_) => { @@ -1299,7 +1272,7 @@ impl Server { match error_code { // No error message is present in the message. - MESSAGE_TERMINATOR => (), + MESSAGE_TERMINATOR => return Err(ServerError::EmptyErrorMessage.into()), // An error message will be present. _ => { @@ -1338,12 +1311,10 @@ impl Server { }; } }; - - return Err(Error::ServerError); } // Notice - 'N' => { + b'N' => { let mut bytes = BytesMut::with_capacity(len as usize - 4); bytes.resize(len as usize - mem::size_of::(), b'0'); match stream.read_exact(&mut bytes[..]).await { @@ -1364,7 +1335,7 @@ impl Server { } // ParameterStatus - 'S' => { + b'S' => { let mut bytes = BytesMut::with_capacity(len as usize - 4); bytes.resize(len as usize - mem::size_of::(), b'0'); @@ -1388,7 +1359,7 @@ impl Server { } // BackendKeyData - 'K' => { + b'K' => { // The frontend must save these values if it wishes to be able to issue CancelRequest messages later. // See: . process_id = match stream.read_i32().await { @@ -1413,7 +1384,7 @@ impl Server { } // ReadyForQuery - 'Z' => { + b'Z' => { let mut idle = vec![0u8; len as usize - 4]; match stream.read_exact(&mut idle).await { @@ -1437,7 +1408,7 @@ impl Server { in_copy_mode: false, data_available: false, bad: false, - flush_wait_code: ' ', + flush_wait_code: b' ', cleanup_state: CleanupState::new(), client_server_map, connected_at: chrono::offset::Utc::now().naive_utc(), @@ -1462,15 +1433,12 @@ impl Server { // We have an unexpected message from the server during this exchange. // Means we implemented the protocol wrong or we're not talking to a Postgres server. - _ => { + code => { error!( "An unprocessed message code from server backend while startup: {}", code ); - return Err(Error::ProtocolSyncError(format!( - "An unprocessed message code from server backend while startup: {}", - code - ))); + return Err(ProtocolSyncError::UnprocessedCode(code).into()); } }; } @@ -1558,7 +1526,7 @@ async fn create_tcp_stream_inner( ssl_request(&mut stream).await?; let response = match stream.read_u8().await { - Ok(response) => response as char, + Ok(response) => response, Err(err) => { return Err(Error::SocketError(format!( "Server socket error: {:?}", @@ -1569,17 +1537,17 @@ async fn create_tcp_stream_inner( match response { // Server supports TLS - 'S' => { + b'S' => { error!("Connection to server via tls is not supported"); return Err(Error::SocketError("Server TLS is unsupported".to_string())); } // Server does not support TLS - 'N' => StreamInner::TCPPlain { stream }, + b'N' => StreamInner::TCPPlain { stream }, // Something else? m => { - return Err(Error::SocketError(format!("Unknown message: {}", { m }))); + return Err(Error::SocketError(format!("Unknown message: {:#x}", m))); } } } else { diff --git a/src/stats/socket.rs b/src/stats/socket.rs index f2ea7f95..768e97d0 100644 --- a/src/stats/socket.rs +++ b/src/stats/socket.rs @@ -12,11 +12,14 @@ use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6}; use std::path::Path; use std::{fs, mem, ptr, slice}; -#[derive(Debug)] -pub enum SocketInfoErr { - Io(std::io::Error), - Nix(nix::errno::Errno), - Convert(std::num::TryFromIntError), +#[derive(Debug, thiserror::Error)] +pub enum SocketInfoError { + #[error(transparent)] + Io(#[from] std::io::Error), + #[error(transparent)] + Nix(#[from] nix::errno::Errno), + #[error(transparent)] + Convert(#[from] std::num::TryFromIntError), } const FD_DIR: &str = "fd"; @@ -166,16 +169,6 @@ impl Display for SocketAddr { } } -impl Display for SocketInfoErr { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - SocketInfoErr::Io(io_error) => write!(f, "{}", io_error), - SocketInfoErr::Nix(n_error) => write!(f, "{}", n_error), - SocketInfoErr::Convert(int_error) => write!(f, "{}", int_error), - } - } -} - impl Display for TcpStateCount { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { let mut str_buf: Vec = Vec::new(); @@ -278,24 +271,7 @@ impl Display for SocketStateCount { } } -impl From for SocketInfoErr { - fn from(err: nix::errno::Errno) -> Self { - SocketInfoErr::Nix(err) - } -} -impl From for SocketInfoErr { - fn from(err: std::io::Error) -> Self { - SocketInfoErr::Io(err) - } -} - -impl From for SocketInfoErr { - fn from(err: std::num::TryFromIntError) -> Self { - SocketInfoErr::Convert(err) - } -} - -pub fn get_socket_states_count(pid: u32) -> Result { +pub fn get_socket_states_count(pid: u32) -> Result { let mut result: SocketStateCount = SocketStateCount { ..Default::default() };