diff --git a/build.rs b/build.rs index 6ff03ad..c0ad45d 100644 --- a/build.rs +++ b/build.rs @@ -7,7 +7,6 @@ fn main() -> Result<(), Box> { .compile_well_known_types(true) .compile(&["proto/spiffe/workload/workload.proto"], &["proto"])?; - tonic_build::configure() - .compile(&["proto/examples/helloworld.proto"], &["proto/examples"])?; + tonic_build::configure().compile(&["proto/examples/helloworld.proto"], &["proto/examples"])?; Ok(()) } diff --git a/examples/spiffe-grpc/client.rs b/examples/spiffe-grpc/client.rs index a181b0c..160639c 100644 --- a/examples/spiffe-grpc/client.rs +++ b/examples/spiffe-grpc/client.rs @@ -1,9 +1,9 @@ +use hyper::Uri; use spiffe_rs::spiffeid; use spiffe_rs::spiffetls; use spiffe_rs::workloadapi; -use std::sync::Arc; -use hyper::Uri; use std::io; +use std::sync::Arc; use tokio::net::TcpStream; use tokio_rustls::TlsConnector; use tonic::transport::{Channel, Endpoint}; @@ -24,7 +24,8 @@ async fn main() -> Result<(), Box> { let source = Arc::new(workloadapi::X509Source::new(&ctx, Vec::new()).await?); let server_id = spiffeid::require_from_string("spiffe://example.org/server"); let authorizer = spiffetls::tlsconfig::authorize_id(server_id); - let mut tls_config = spiffetls::tlsconfig::mtls_client_config(source.as_ref(), source.clone(), authorizer)?; + let mut tls_config = + spiffetls::tlsconfig::mtls_client_config(source.as_ref(), source.clone(), authorizer)?; tls_config.alpn_protocols = vec![b"h2".to_vec()]; let connector = TlsConnector::from(Arc::new(tls_config)); @@ -33,16 +34,16 @@ async fn main() -> Result<(), Box> { .connect_with_connector(service_fn(move |uri: Uri| { let connector = connector.clone(); async move { - let authority = uri - .authority() - .map(|auth| auth.as_str()) - .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "missing authority"))?; + let authority = uri.authority().map(|auth| auth.as_str()).ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidInput, "missing authority") + })?; let stream = TcpStream::connect(authority).await?; let server_name = rustls::ServerName::try_from("example.org") .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?; - connector.connect(server_name, stream).await.map_err(|err| { - io::Error::new(io::ErrorKind::Other, err) - }) + connector + .connect(server_name, stream) + .await + .map_err(|err| io::Error::new(io::ErrorKind::Other, err)) } })) .await?; diff --git a/examples/spiffe-grpc/server.rs b/examples/spiffe-grpc/server.rs index 50ea383..256109e 100644 --- a/examples/spiffe-grpc/server.rs +++ b/examples/spiffe-grpc/server.rs @@ -1,16 +1,16 @@ use spiffe_rs::spiffeid; use spiffe_rs::spiffetls; use spiffe_rs::workloadapi; +use std::pin::Pin; use std::sync::Arc; -use tokio::net::TcpListener; +use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio::net::TcpListener; use tokio_rustls::TlsAcceptor; use tokio_stream::wrappers::ReceiverStream; +use tonic::transport::server::Connected; use tonic::transport::Server; use tonic::{Request, Response, Status}; -use tonic::transport::server::Connected; -use std::pin::Pin; -use std::task::{Context, Poll}; pub mod helloworld { tonic::include_proto!("helloworld"); @@ -24,7 +24,10 @@ struct GreeterService; #[tonic::async_trait] impl Greeter for GreeterService { - async fn say_hello(&self, request: Request) -> Result, Status> { + async fn say_hello( + &self, + request: Request, + ) -> Result, Status> { let name = request.into_inner().name; let reply = HelloReply { message: format!("Hello {}", name), @@ -41,7 +44,8 @@ async fn main() -> Result<(), Box> { let source = Arc::new(workloadapi::X509Source::new(&ctx, Vec::new()).await?); let client_id = spiffeid::require_from_string("spiffe://example.org/client"); let authorizer = spiffetls::tlsconfig::authorize_id(client_id); - let mut tls_config = spiffetls::tlsconfig::mtls_server_config(source.as_ref(), source.clone(), authorizer)?; + let mut tls_config = + spiffetls::tlsconfig::mtls_server_config(source.as_ref(), source.clone(), authorizer)?; tls_config.alpn_protocols = vec![b"h2".to_vec()]; let acceptor = TlsAcceptor::from(Arc::new(tls_config)); @@ -110,17 +114,11 @@ impl AsyncWrite for TlsIo { Pin::new(&mut self.0).poll_write(cx, data) } - fn poll_flush( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.0).poll_flush(cx) } - fn poll_shutdown( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.0).poll_shutdown(cx) } } diff --git a/examples/spiffe-http/client.rs b/examples/spiffe-http/client.rs index 35dff11..fec1e55 100644 --- a/examples/spiffe-http/client.rs +++ b/examples/spiffe-http/client.rs @@ -15,7 +15,8 @@ async fn main() -> Result<(), Box> { let source = Arc::new(workloadapi::X509Source::new(&ctx, Vec::new()).await?); let server_id = spiffeid::require_from_string("spiffe://example.org/server"); let authorizer = spiffetls::tlsconfig::authorize_id(server_id); - let tls_config = spiffetls::tlsconfig::mtls_client_config(source.as_ref(), source.clone(), authorizer)?; + let tls_config = + spiffetls::tlsconfig::mtls_client_config(source.as_ref(), source.clone(), authorizer)?; let connector = TlsConnector::from(Arc::new(tls_config)); let stream = TcpStream::connect("127.0.0.1:8443").await?; diff --git a/examples/spiffe-http/server.rs b/examples/spiffe-http/server.rs index b42dedf..9a45227 100644 --- a/examples/spiffe-http/server.rs +++ b/examples/spiffe-http/server.rs @@ -15,7 +15,8 @@ async fn main() -> Result<(), Box> { let source = Arc::new(workloadapi::X509Source::new(&ctx, Vec::new()).await?); let client_id = spiffeid::require_from_string("spiffe://example.org/client"); let authorizer = spiffetls::tlsconfig::authorize_id(client_id); - let tls_config = spiffetls::tlsconfig::mtls_server_config(source.as_ref(), source.clone(), authorizer)?; + let tls_config = + spiffetls::tlsconfig::mtls_server_config(source.as_ref(), source.clone(), authorizer)?; let acceptor = TlsAcceptor::from(Arc::new(tls_config)); let listener = TcpListener::bind("127.0.0.1:8443").await?; diff --git a/examples/spiffe-jwt-using-proxy/server.rs b/examples/spiffe-jwt-using-proxy/server.rs index ed2f6bd..b1a0b67 100644 --- a/examples/spiffe-jwt-using-proxy/server.rs +++ b/examples/spiffe-jwt-using-proxy/server.rs @@ -36,7 +36,8 @@ async fn main() -> Result<(), Box> { return; } }; - let service = service_fn(move |req| handle_request(req, jwt_source.clone(), audience.clone())); + let service = + service_fn(move |req| handle_request(req, jwt_source.clone(), audience.clone())); if let Err(err) = hyper::server::conn::Http::new() .serve_connection(tls, service) .await diff --git a/examples/spiffe-jwt/server.rs b/examples/spiffe-jwt/server.rs index 6ad88bd..acdac39 100644 --- a/examples/spiffe-jwt/server.rs +++ b/examples/spiffe-jwt/server.rs @@ -36,7 +36,8 @@ async fn main() -> Result<(), Box> { return; } }; - let service = service_fn(move |req| handle_request(req, jwt_source.clone(), audience.clone())); + let service = + service_fn(move |req| handle_request(req, jwt_source.clone(), audience.clone())); if let Err(err) = hyper::server::conn::Http::new() .serve_connection(tls, service) .await diff --git a/examples/spiffe-tls/client.rs b/examples/spiffe-tls/client.rs index 810ae73..2251822 100644 --- a/examples/spiffe-tls/client.rs +++ b/examples/spiffe-tls/client.rs @@ -11,7 +11,8 @@ async fn main() -> Result<(), Box> { let server_id = spiffeid::require_from_string("spiffe://example.org/server"); let authorizer = spiffetls::tlsconfig::authorize_id(server_id); let server_name = rustls::ServerName::try_from("example.org")?; - let mut stream = spiffetls::dial(&ctx, "127.0.0.1:55555", server_name, authorizer, Vec::new()).await?; + let mut stream = + spiffetls::dial(&ctx, "127.0.0.1:55555", server_name, authorizer, Vec::new()).await?; stream.write_all(b"Hello server")?; stream.flush()?; diff --git a/src/bundle/jwtbundle/mod.rs b/src/bundle/jwtbundle/mod.rs index e3b82d6..1282019 100644 --- a/src/bundle/jwtbundle/mod.rs +++ b/src/bundle/jwtbundle/mod.rs @@ -66,8 +66,8 @@ impl Bundle { /// Loads a JWT bundle from a JSON file (JWKS). pub fn load(trust_domain: TrustDomain, path: &str) -> Result { - let bytes = - fs::read(path).map_err(|err| wrap_error(format!("unable to read JWT bundle: {}", err)))?; + let bytes = fs::read(path) + .map_err(|err| wrap_error(format!("unable to read JWT bundle: {}", err)))?; Bundle::parse(trust_domain, &bytes) } @@ -82,15 +82,15 @@ impl Bundle { /// Parses a JWT bundle from JSON bytes (JWKS). pub fn parse(trust_domain: TrustDomain, bytes: &[u8]) -> Result { - let jwks: JwkDocument = - serde_json::from_slice(bytes).map_err(|err| wrap_error(format!("unable to parse JWKS: {}", err)))?; + let jwks: JwkDocument = serde_json::from_slice(bytes) + .map_err(|err| wrap_error(format!("unable to parse JWKS: {}", err)))?; let bundle = Bundle::new(trust_domain); let keys = jwks.keys.unwrap_or_default(); for (idx, key) in keys.iter().enumerate() { let key_id = key.key_id().unwrap_or_default(); - let jwt_key = key - .to_jwt_key() - .map_err(|err| wrap_error(format!("error adding authority {} of JWKS: {}", idx, err)))?; + let jwt_key = key.to_jwt_key().map_err(|err| { + wrap_error(format!("error adding authority {} of JWKS: {}", idx, err)) + })?; if let Err(err) = bundle.add_jwt_authority(key_id, jwt_key) { return Err(wrap_error(format!( "error adding authority {} of JWKS: {}", diff --git a/src/bundle/spiffebundle/mod.rs b/src/bundle/spiffebundle/mod.rs index a2b5c4a..a559789 100644 --- a/src/bundle/spiffebundle/mod.rs +++ b/src/bundle/spiffebundle/mod.rs @@ -5,13 +5,13 @@ use crate::internal::jwtutil; use crate::internal::x509util; use crate::spiffeid::TrustDomain; use base64::Engine; +use oid_registry::{OID_EC_P256, OID_KEY_TYPE_EC_PUBLIC_KEY, OID_NIST_EC_P384, OID_NIST_EC_P521}; use serde::Serialize; use std::collections::HashMap; use std::fs; use std::io::Read; use std::sync::RwLock; use std::time::Duration; -use oid_registry::{OID_EC_P256, OID_KEY_TYPE_EC_PUBLIC_KEY, OID_NIST_EC_P384, OID_NIST_EC_P521}; use x509_parser::prelude::X509Certificate; const X509_SVID_USE: &str = "x509-svid"; @@ -70,8 +70,8 @@ impl Bundle { /// Loads a SPIFFE bundle from a JSON file (JWKS). pub fn load(trust_domain: TrustDomain, path: &str) -> Result { - let bytes = - fs::read(path).map_err(|err| wrap_error(format!("unable to read SPIFFE bundle: {}", err)))?; + let bytes = fs::read(path) + .map_err(|err| wrap_error(format!("unable to read SPIFFE bundle: {}", err)))?; Bundle::parse(trust_domain, &bytes) } @@ -86,8 +86,8 @@ impl Bundle { /// Parses a SPIFFE bundle from JSON bytes (JWKS). pub fn parse(trust_domain: TrustDomain, bytes: &[u8]) -> Result { - let jwks: JwkDocument = - serde_json::from_slice(bytes).map_err(|err| wrap_error(format!("unable to parse JWKS: {}", err)))?; + let jwks: JwkDocument = serde_json::from_slice(bytes) + .map_err(|err| wrap_error(format!("unable to parse JWKS: {}", err)))?; let bundle = Bundle::new(trust_domain); if let Some(hint) = jwks.spiffe_refresh_hint { bundle.set_refresh_hint(Duration::from_secs(hint as u64)); @@ -96,18 +96,18 @@ impl Bundle { bundle.set_sequence_number(seq); } - let keys = jwks.keys.ok_or_else(|| wrap_error("no authorities found"))?; + let keys = jwks + .keys + .ok_or_else(|| wrap_error("no authorities found"))?; for (idx, key) in keys.iter().enumerate() { match key.use_field.as_deref() { Some(X509_SVID_USE) => { - let cert = key - .x509_certificate_der() - .ok_or_else(|| { - wrap_error(format!( - "expected a single certificate in {} entry {}; got 0", - X509_SVID_USE, idx - )) - })?; + let cert = key.x509_certificate_der().ok_or_else(|| { + wrap_error(format!( + "expected a single certificate in {} entry {}; got 0", + X509_SVID_USE, idx + )) + })?; if let Some(count) = key.x5c.as_ref().map(|x| x.len()) { if count != 1 { return Err(wrap_error(format!( @@ -120,9 +120,9 @@ impl Bundle { } Some(JWT_SVID_USE) => { let key_id = key.key_id().unwrap_or_default(); - let jwt_key = key - .to_jwt_key() - .map_err(|err| wrap_error(format!("error adding authority {} of JWKS: {}", idx, err)))?; + let jwt_key = key.to_jwt_key().map_err(|err| { + wrap_error(format!("error adding authority {} of JWKS: {}", idx, err)) + })?; if let Err(err) = bundle.add_jwt_authority(key_id, jwt_key) { return Err(wrap_error(format!( "error adding authority {} of JWKS: {}", @@ -379,7 +379,10 @@ impl Bundle { } /// Returns the X.509 bundle for the given trust domain if it matches. - pub fn get_x509_bundle_for_trust_domain(&self, trust_domain: TrustDomain) -> Result { + pub fn get_x509_bundle_for_trust_domain( + &self, + trust_domain: TrustDomain, + ) -> Result { if self.trust_domain != trust_domain { return Err(wrap_error(format!( "no X.509 bundle for trust domain \"{}\"", @@ -390,7 +393,10 @@ impl Bundle { } /// Returns the JWT bundle for the given trust domain if it matches. - pub fn get_jwt_bundle_for_trust_domain(&self, trust_domain: TrustDomain) -> Result { + pub fn get_jwt_bundle_for_trust_domain( + &self, + trust_domain: TrustDomain, + ) -> Result { if self.trust_domain != trust_domain { return Err(wrap_error(format!( "no JWT bundle for trust domain \"{}\"", @@ -496,7 +502,10 @@ impl Set { } /// Returns the X.509 bundle for the given trust domain. - pub fn get_x509_bundle_for_trust_domain(&self, trust_domain: TrustDomain) -> Result { + pub fn get_x509_bundle_for_trust_domain( + &self, + trust_domain: TrustDomain, + ) -> Result { let guard = self .bundles .read() @@ -511,7 +520,10 @@ impl Set { } /// Returns the JWT bundle for the given trust domain. - pub fn get_jwt_bundle_for_trust_domain(&self, trust_domain: TrustDomain) -> Result { + pub fn get_jwt_bundle_for_trust_domain( + &self, + trust_domain: TrustDomain, + ) -> Result { let guard = self .bundles .read() @@ -620,7 +632,9 @@ fn ec_public_key_parameters(cert: &X509Certificate<'_>) -> Result<(String, Vec std::result::Result>, Str let mut remaining = bytes; let mut certs = Vec::new(); while !remaining.is_empty() { - let (rest, _cert) = x509_parser::parse_x509_certificate(remaining) - .map_err(|err| err.to_string())?; + let (rest, _cert) = + x509_parser::parse_x509_certificate(remaining).map_err(|err| err.to_string())?; let consumed = remaining .len() .checked_sub(rest.len()) diff --git a/src/federation/mod.rs b/src/federation/mod.rs index 5d201d6..a0d6fde 100644 --- a/src/federation/mod.rs +++ b/src/federation/mod.rs @@ -37,6 +37,12 @@ pub trait FetchOption { } /// Sets the authentication method to SPIFFE-TLS. +/// +/// When this option is used, the HTTPS connection is verified using +/// SPIFFE-TLS: +/// - the server must present an X.509-SVID whose SPIFFE ID matches +/// `endpoint_id` +/// - the server certificate chain must verify against `bundle_source` pub fn with_spiffe_auth( bundle_source: Arc, endpoint_id: spiffeid::ID, @@ -63,12 +69,20 @@ pub fn with_web_pki_roots(roots: RootCertStore) -> impl FetchOption { "cannot use both SPIFFE and Web PKI authentication", )); } - options.auth_method = AuthMethod::WebPki { roots: roots.clone() }; + options.auth_method = AuthMethod::WebPki { + roots: roots.clone(), + }; Ok(()) }) } /// Fetches a SPIFFE bundle from the given URL. +/// +/// Notes on the HTTP implementation: +/// - Only `http://` and `https://` are supported. +/// - The request is a minimal HTTP/1.1 GET over a plain TCP stream. +/// - The response parser expects a `200` status and splits at `\r\n\r\n`. +/// Chunked transfer encoding and redirects are not supported. pub fn fetch_bundle( trust_domain: TrustDomain, url: &str, @@ -95,6 +109,15 @@ pub trait BundleWatcher: Send + Sync { } /// Watches a SPIFFE bundle at the given URL for updates. +/// +/// This repeatedly calls [`fetch_bundle`] and: +/// - calls [`BundleWatcher::on_update`] only when the bundle contents change +/// (as determined by `Bundle::equal`) +/// - passes fetch errors to [`BundleWatcher::on_error`] +/// - uses the bundle `refresh_hint` (if present) as input to +/// [`BundleWatcher::next_refresh`] +/// +/// The loop exits when `ctx` is cancelled. pub async fn watch_bundle( ctx: &Context, trust_domain: TrustDomain, @@ -208,7 +231,10 @@ impl Service> for BundleHandler { type Error = hyper::Error; type Future = std::future::Ready>; - fn poll_ready(&mut self, _cx: &mut TaskContext<'_>) -> Poll> { + fn poll_ready( + &mut self, + _cx: &mut TaskContext<'_>, + ) -> Poll> { Poll::Ready(Ok(())) } @@ -219,7 +245,10 @@ impl Service> for BundleHandler { .body(Body::from("method is not allowed")) .unwrap() } else { - match self.source.get_bundle_for_trust_domain(self.trust_domain.clone()) { + match self + .source + .get_bundle_for_trust_domain(self.trust_domain.clone()) + { Ok(bundle) => match bundle.marshal() { Ok(body) => Response::builder() .status(StatusCode::OK) @@ -293,9 +322,7 @@ fn fetch_url(url: &Url, options: &FetchOptions) -> Result> { .map_err(|err| wrap_error(format!("unable to create TLS connection: {}", err)))?; HttpStream::Tls(rustls::StreamOwned::new(conn, tcp)) } - "http" => HttpStream::Plain( - TcpStream::connect(&addr).map_err(|err| wrap_error(err))?, - ), + "http" => HttpStream::Plain(TcpStream::connect(&addr).map_err(|err| wrap_error(err))?), scheme => { return Err(wrap_error(format!("unsupported URL scheme: {}", scheme))); } @@ -316,7 +343,9 @@ fn fetch_url(url: &Url, options: &FetchOptions) -> Result> { stream.flush().map_err(|err| wrap_error(err))?; let mut response = Vec::new(); - stream.read_to_end(&mut response).map_err(|err| wrap_error(err))?; + stream + .read_to_end(&mut response) + .map_err(|err| wrap_error(err))?; parse_http_body(&response) } @@ -360,8 +389,12 @@ fn parse_http_body(response: &[u8]) -> Result> { .ok_or_else(|| wrap_error("invalid HTTP response"))?; let status_line = String::from_utf8_lossy(status_line).trim().to_string(); let mut parts = status_line.split_whitespace(); - let _proto = parts.next().ok_or_else(|| wrap_error("invalid HTTP response"))?; - let status = parts.next().ok_or_else(|| wrap_error("invalid HTTP response"))?; + let _proto = parts + .next() + .ok_or_else(|| wrap_error("invalid HTTP response"))?; + let status = parts + .next() + .ok_or_else(|| wrap_error("invalid HTTP response"))?; if status != "200" { return Err(wrap_error(format!("unexpected HTTP status {}", status))); } diff --git a/src/internal/jwk.rs b/src/internal/jwk.rs index 2ecd2e9..1239731 100644 --- a/src/internal/jwk.rs +++ b/src/internal/jwk.rs @@ -39,10 +39,7 @@ impl JwkKeyEntry { pub fn to_jwt_key(&self) -> Result { match self.kty.as_str() { "EC" => { - let crv = self - .crv - .as_ref() - .ok_or_else(|| "missing crv".to_string())?; + let crv = self.crv.as_ref().ok_or_else(|| "missing crv".to_string())?; let x = self .x .as_ref() diff --git a/src/lib.rs b/src/lib.rs index 3de902d..cc04b2a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,9 +1,9 @@ -pub mod spiffeid; -pub mod logger; pub mod bundle; +pub mod federation; +pub mod logger; +pub mod spiffeid; +pub mod spiffetls; pub mod svid; pub mod workloadapi; -pub mod spiffetls; -pub mod federation; pub(crate) mod internal; diff --git a/src/spiffeid/errors.rs b/src/spiffeid/errors.rs index 65146e1..73d0a53 100644 --- a/src/spiffeid/errors.rs +++ b/src/spiffeid/errors.rs @@ -10,7 +10,9 @@ pub enum Error { #[error("trust domain characters are limited to lowercase letters, numbers, dots, dashes, and underscores")] BadTrustDomainChar, /// The path contains an invalid character. - #[error("path segment characters are limited to letters, numbers, dots, dashes, and underscores")] + #[error( + "path segment characters are limited to letters, numbers, dots, dashes, and underscores" + )] BadPathSegmentChar, /// The path contains a dot segment (`.` or `..`). #[error("path cannot contain dot segments")] diff --git a/src/spiffeid/id.rs b/src/spiffeid/id.rs index 7ae972c..ce3f6aa 100644 --- a/src/spiffeid/id.rs +++ b/src/spiffeid/id.rs @@ -146,7 +146,9 @@ impl ID { /// Returns the trust domain of the SPIFFE ID. pub fn trust_domain(&self) -> TrustDomain { if self.is_zero() { - return TrustDomain { name: String::new() }; + return TrustDomain { + name: String::new(), + }; } TrustDomain { name: self.id[SCHEME_PREFIX.len()..self.path_idx].to_string(), diff --git a/src/spiffeid/matcher.rs b/src/spiffeid/matcher.rs index 48caef2..c988f63 100644 --- a/src/spiffeid/matcher.rs +++ b/src/spiffeid/matcher.rs @@ -1,4 +1,4 @@ -use crate::spiffeid::{ID, TrustDomain}; +use crate::spiffeid::{TrustDomain, ID}; /// An error that occurred during SPIFFE ID matching. #[derive(Debug, Clone, PartialEq, Eq)] diff --git a/src/spiffeid/mod.rs b/src/spiffeid/mod.rs index 94f2fd6..090bf44 100644 --- a/src/spiffeid/mod.rs +++ b/src/spiffeid/mod.rs @@ -7,7 +7,7 @@ mod require; mod trustdomain; pub use errors::{Error, Result}; -pub use id::{ID, SpiffeUrl}; +pub use id::{SpiffeUrl, ID}; pub use matcher::{match_any, match_id, match_member_of, match_one_of, Matcher, MatcherError}; pub use path::{format_path, join_path_segments, validate_path, validate_path_segment}; pub use require::{ diff --git a/src/spiffeid/require.rs b/src/spiffeid/require.rs index 10a595a..ce3423e 100644 --- a/src/spiffeid/require.rs +++ b/src/spiffeid/require.rs @@ -1,6 +1,6 @@ use crate::spiffeid::{ - format_path, join_path_segments, trust_domain_from_string, trust_domain_from_uri, ID, Result, - TrustDomain, + format_path, join_path_segments, trust_domain_from_string, trust_domain_from_uri, Result, + TrustDomain, ID, }; use url::Url; diff --git a/src/spiffeid/trustdomain.rs b/src/spiffeid/trustdomain.rs index e16e35e..1db9812 100644 --- a/src/spiffeid/trustdomain.rs +++ b/src/spiffeid/trustdomain.rs @@ -1,6 +1,6 @@ use crate::spiffeid::charset::is_backcompat_trust_domain_char; use crate::spiffeid::id::make_id; -use crate::spiffeid::{Error, ID, Result}; +use crate::spiffeid::{Error, Result, ID}; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use std::cmp::Ordering; use url::Url; @@ -79,12 +79,15 @@ impl TrustDomain { /// Unmarshals the trust domain name from a byte slice. pub fn unmarshal_text(&mut self, text: &[u8]) -> Result<()> { if text.is_empty() { - *self = TrustDomain { name: String::new() }; + *self = TrustDomain { + name: String::new(), + }; return Ok(()); } - let parsed = trust_domain_from_string(std::str::from_utf8(text).map_err(|e| { - Error::Other(format!("invalid trust domain text: {}", e)) - })?)?; + let parsed = trust_domain_from_string( + std::str::from_utf8(text) + .map_err(|e| Error::Other(format!("invalid trust domain text: {}", e)))?, + )?; *self = parsed; Ok(()) } @@ -98,7 +101,9 @@ impl std::fmt::Display for TrustDomain { impl Default for TrustDomain { fn default() -> Self { - TrustDomain { name: String::new() } + TrustDomain { + name: String::new(), + } } } @@ -122,7 +127,9 @@ impl<'de> Deserialize<'de> for TrustDomain { { let s = String::deserialize(deserializer)?; if s.is_empty() { - Ok(TrustDomain { name: String::new() }) + Ok(TrustDomain { + name: String::new(), + }) } else { trust_domain_from_string(&s).map_err(serde::de::Error::custom) } diff --git a/src/spiffetls/dial.rs b/src/spiffetls/dial.rs index dabc4e0..2fe531a 100644 --- a/src/spiffetls/dial.rs +++ b/src/spiffetls/dial.rs @@ -52,7 +52,14 @@ pub async fn dial( authorizer: tlsconfig::Authorizer, options: Vec>, ) -> Result { - dial_with_mode(ctx, addr, server_name, crate::spiffetls::mtls_client(authorizer), options).await + dial_with_mode( + ctx, + addr, + server_name, + crate::spiffetls::mtls_client(authorizer), + options, + ) + .await } /// Dials a SPIFFE-TLS server with the given mode. @@ -85,12 +92,18 @@ pub async fn dial_with_mode( let tls_config = match m.mode { crate::spiffetls::mode::ClientMode::Tls => { - let bundle = m.bundle.ok_or_else(|| crate::spiffetls::wrap_error("missing bundle source"))?; + let bundle = m + .bundle + .ok_or_else(|| crate::spiffetls::wrap_error("missing bundle source"))?; tlsconfig::tls_client_config(bundle, m.authorizer.clone())? } crate::spiffetls::mode::ClientMode::Mtls => { - let svid = m.svid.ok_or_else(|| crate::spiffetls::wrap_error("missing svid source"))?; - let bundle = m.bundle.ok_or_else(|| crate::spiffetls::wrap_error("missing bundle source"))?; + let svid = m + .svid + .ok_or_else(|| crate::spiffetls::wrap_error("missing svid source"))?; + let bundle = m + .bundle + .ok_or_else(|| crate::spiffetls::wrap_error("missing bundle source"))?; tlsconfig::mtls_client_config_with_options( svid.as_ref(), bundle, @@ -99,7 +112,9 @@ pub async fn dial_with_mode( )? } crate::spiffetls::mode::ClientMode::MtlsWeb => { - let svid = m.svid.ok_or_else(|| crate::spiffetls::wrap_error("missing svid source"))?; + let svid = m + .svid + .ok_or_else(|| crate::spiffetls::wrap_error("missing svid source"))?; tlsconfig::mtls_web_client_config_with_options( svid.as_ref(), m.roots, @@ -110,8 +125,9 @@ pub async fn dial_with_mode( let tcp = TcpStream::connect(addr).map_err(|err| crate::spiffetls::wrap_error(err))?; let tls_config = apply_base_client_config(tls_config, config.base_client_config); - let conn = rustls::ClientConnection::new(Arc::new(tls_config), server_name) - .map_err(|err| crate::spiffetls::wrap_error(format!("unable to create client connection: {}", err)))?; + let conn = rustls::ClientConnection::new(Arc::new(tls_config), server_name).map_err(|err| { + crate::spiffetls::wrap_error(format!("unable to create client connection: {}", err)) + })?; Ok(ClientStream { inner: rustls::StreamOwned::new(conn, tcp), source, @@ -139,8 +155,9 @@ fn peer_id_from_certs(certs: Option<&[rustls::Certificate]>) -> Result { let cert = certs .first() .ok_or_else(|| crate::spiffetls::wrap_error("no peer certificates"))?; - let (_rest, parsed) = x509_parser::parse_x509_certificate(&cert.0) - .map_err(|err| crate::spiffetls::wrap_error(format!("invalid peer certificate: {}", err)))?; + let (_rest, parsed) = x509_parser::parse_x509_certificate(&cert.0).map_err(|err| { + crate::spiffetls::wrap_error(format!("invalid peer certificate: {}", err)) + })?; let san = parsed .subject_alternative_name() .map_err(|_| crate::spiffetls::wrap_error("invalid peer certificate: invalid URI SAN"))? diff --git a/src/spiffetls/listen.rs b/src/spiffetls/listen.rs index 3334456..bba6846 100644 --- a/src/spiffetls/listen.rs +++ b/src/spiffetls/listen.rs @@ -14,9 +14,13 @@ pub struct Listener { impl Listener { /// Accepts a new connection and performs the TLS handshake. pub fn accept(&self) -> Result { - let (sock, _addr) = self.inner.accept().map_err(|err| crate::spiffetls::wrap_error(err))?; - let conn = rustls::ServerConnection::new(self.config.clone()) - .map_err(|err| crate::spiffetls::wrap_error(format!("unable to create server connection: {}", err)))?; + let (sock, _addr) = self + .inner + .accept() + .map_err(|err| crate::spiffetls::wrap_error(err))?; + let conn = rustls::ServerConnection::new(self.config.clone()).map_err(|err| { + crate::spiffetls::wrap_error(format!("unable to create server connection: {}", err)) + })?; Ok(ServerStream { inner: rustls::StreamOwned::new(conn, sock), }) @@ -24,13 +28,18 @@ impl Listener { /// Returns the local address the listener is bound to. pub fn local_addr(&self) -> Result { - self.inner.local_addr().map_err(|err| crate::spiffetls::wrap_error(err)) + self.inner + .local_addr() + .map_err(|err| crate::spiffetls::wrap_error(err)) } /// Closes the listener and its underlying source if it was created by the listener. pub async fn close(self) -> Result<()> { if let Some(source) = self.source { - source.close().await.map_err(|err| crate::spiffetls::Error(err.to_string()))?; + source + .close() + .await + .map_err(|err| crate::spiffetls::Error(err.to_string()))?; } Ok(()) } @@ -71,7 +80,13 @@ pub async fn listen( authorizer: tlsconfig::Authorizer, options: Vec>, ) -> Result { - listen_with_mode(ctx, addr, crate::spiffetls::mtls_server(authorizer), options).await + listen_with_mode( + ctx, + addr, + crate::spiffetls::mtls_server(authorizer), + options, + ) + .await } /// Listens for SPIFFE-TLS connections with the given mode. @@ -103,12 +118,18 @@ pub async fn listen_with_mode( let tls_config = match m.mode { crate::spiffetls::mode::ServerMode::Tls => { - let svid = m.svid.ok_or_else(|| crate::spiffetls::wrap_error("missing svid source"))?; + let svid = m + .svid + .ok_or_else(|| crate::spiffetls::wrap_error("missing svid source"))?; tlsconfig::tls_server_config_with_options(svid.as_ref(), &config.tls_options)? } crate::spiffetls::mode::ServerMode::Mtls => { - let svid = m.svid.ok_or_else(|| crate::spiffetls::wrap_error("missing svid source"))?; - let bundle = m.bundle.ok_or_else(|| crate::spiffetls::wrap_error("missing bundle source"))?; + let svid = m + .svid + .ok_or_else(|| crate::spiffetls::wrap_error("missing svid source"))?; + let bundle = m + .bundle + .ok_or_else(|| crate::spiffetls::wrap_error("missing bundle source"))?; tlsconfig::mtls_server_config_with_options( svid.as_ref(), bundle, @@ -117,8 +138,12 @@ pub async fn listen_with_mode( )? } crate::spiffetls::mode::ServerMode::MtlsWeb => { - let bundle = m.bundle.ok_or_else(|| crate::spiffetls::wrap_error("missing bundle source"))?; - let cert = m.web_cert.ok_or_else(|| crate::spiffetls::wrap_error("missing web cert"))?; + let bundle = m + .bundle + .ok_or_else(|| crate::spiffetls::wrap_error("missing bundle source"))?; + let cert = m + .web_cert + .ok_or_else(|| crate::spiffetls::wrap_error("missing web cert"))?; tlsconfig::mtls_web_server_config(cert, bundle, m.authorizer.clone())? } }; diff --git a/src/spiffetls/mode.rs b/src/spiffetls/mode.rs index c17b986..e06640f 100644 --- a/src/spiffetls/mode.rs +++ b/src/spiffetls/mode.rs @@ -1,3 +1,13 @@ +//! Dial/listen “modes” for `spiffetls` helpers. +//! +//! A mode is a small builder object that tells [`crate::spiffetls::dial`] / +//! [`crate::spiffetls::listen`] how to construct `rustls` configs: +//! +//! - Whether to do **TLS**, **SPIFFE mTLS**, or **web mTLS** +//! - Where to get the X.509-SVID and bundle information (from a Workload API +//! [`workloadapi::X509Source`], or from "raw" provided sources) +//! - Which SPIFFE IDs are allowed (via a [`tlsconfig::Authorizer`]) + use crate::bundle::x509bundle; use crate::spiffetls::tlsconfig; use crate::svid::x509svid; @@ -43,6 +53,9 @@ pub(crate) enum ServerMode { MtlsWeb, } +/// Dial a server using SPIFFE-TLS (server authenticated by SPIFFE ID). +/// +/// The client does not present an X.509-SVID. pub fn tls_client(authorizer: tlsconfig::Authorizer) -> DialMode { DialMode { mode: ClientMode::Tls, @@ -56,7 +69,11 @@ pub fn tls_client(authorizer: tlsconfig::Authorizer) -> DialMode { } } -pub fn tls_client_with_source(authorizer: tlsconfig::Authorizer, source: Arc) -> DialMode { +/// Like [`tls_client`], but uses an existing [`workloadapi::X509Source`]. +pub fn tls_client_with_source( + authorizer: tlsconfig::Authorizer, + source: Arc, +) -> DialMode { DialMode { mode: ClientMode::Tls, source_unneeded: false, @@ -69,6 +86,11 @@ pub fn tls_client_with_source(authorizer: tlsconfig::Authorizer, source: Arc>, @@ -85,6 +107,9 @@ pub fn tls_client_with_source_options( } } +/// Like [`tls_client`], but bypasses the Workload API entirely. +/// +/// The provided `bundle` is used to verify the server certificate chain. pub fn tls_client_with_raw_config( authorizer: tlsconfig::Authorizer, bundle: Arc, @@ -101,6 +126,7 @@ pub fn tls_client_with_raw_config( } } +/// Dial a server using SPIFFE mTLS (both sides authenticate via SPIFFE IDs). pub fn mtls_client(authorizer: tlsconfig::Authorizer) -> DialMode { DialMode { mode: ClientMode::Mtls, @@ -114,7 +140,11 @@ pub fn mtls_client(authorizer: tlsconfig::Authorizer) -> DialMode { } } -pub fn mtls_client_with_source(authorizer: tlsconfig::Authorizer, source: Arc) -> DialMode { +/// Like [`mtls_client`], but uses an existing [`workloadapi::X509Source`]. +pub fn mtls_client_with_source( + authorizer: tlsconfig::Authorizer, + source: Arc, +) -> DialMode { DialMode { mode: ClientMode::Mtls, source_unneeded: false, @@ -127,6 +157,8 @@ pub fn mtls_client_with_source(authorizer: tlsconfig::Authorizer, source: Arc>, @@ -143,6 +175,10 @@ pub fn mtls_client_with_source_options( } } +/// Like [`mtls_client`], but bypasses the Workload API entirely. +/// +/// The provided `svid` is presented as the client certificate chain, and the +/// provided `bundle` is used to verify the server certificate chain. pub fn mtls_client_with_raw_config( authorizer: tlsconfig::Authorizer, svid: Arc, @@ -160,6 +196,10 @@ pub fn mtls_client_with_raw_config( } } +/// Dial a “web” server using Web PKI verification, while presenting an X.509-SVID. +/// +/// - The server is verified using `roots` (or system roots if `None`). +/// - No SPIFFE ID authorization is performed for the server identity. pub fn mtls_web_client(roots: Option) -> DialMode { DialMode { mode: ClientMode::MtlsWeb, @@ -173,6 +213,7 @@ pub fn mtls_web_client(roots: Option) -> DialMode { } } +/// Like [`mtls_web_client`], but uses an existing [`workloadapi::X509Source`]. pub fn mtls_web_client_with_source( roots: Option, source: Arc, @@ -189,6 +230,8 @@ pub fn mtls_web_client_with_source( } } +/// Like [`mtls_web_client`], but configures the internal [`workloadapi::X509Source`] +/// via source options. pub fn mtls_web_client_with_source_options( roots: Option, options: Vec>, @@ -205,6 +248,9 @@ pub fn mtls_web_client_with_source_options( } } +/// Like [`mtls_web_client`], but bypasses the Workload API entirely. +/// +/// The provided `svid` is presented as the client certificate chain. pub fn mtls_web_client_with_raw_config( roots: Option, svid: Arc, @@ -221,6 +267,7 @@ pub fn mtls_web_client_with_raw_config( } } +/// Listen using TLS (present an X.509-SVID, do not authenticate clients). pub fn tls_server() -> ListenMode { ListenMode { mode: ServerMode::Tls, @@ -234,6 +281,7 @@ pub fn tls_server() -> ListenMode { } } +/// Like [`tls_server`], but uses an existing [`workloadapi::X509Source`]. pub fn tls_server_with_source(source: Arc) -> ListenMode { ListenMode { mode: ServerMode::Tls, @@ -247,7 +295,11 @@ pub fn tls_server_with_source(source: Arc) -> ListenMod } } -pub fn tls_server_with_source_options(options: Vec>) -> ListenMode { +/// Like [`tls_server`], but configures the internal [`workloadapi::X509Source`] +/// via source options. +pub fn tls_server_with_source_options( + options: Vec>, +) -> ListenMode { ListenMode { mode: ServerMode::Tls, source_unneeded: false, @@ -260,6 +312,9 @@ pub fn tls_server_with_source_options(options: Vec) -> ListenMode { ListenMode { mode: ServerMode::Tls, @@ -273,6 +328,7 @@ pub fn tls_server_with_raw_config(svid: Arc) } } +/// Listen using SPIFFE mTLS (require and authorize client certificates). pub fn mtls_server(authorizer: tlsconfig::Authorizer) -> ListenMode { ListenMode { mode: ServerMode::Mtls, @@ -286,7 +342,11 @@ pub fn mtls_server(authorizer: tlsconfig::Authorizer) -> ListenMode { } } -pub fn mtls_server_with_source(authorizer: tlsconfig::Authorizer, source: Arc) -> ListenMode { +/// Like [`mtls_server`], but uses an existing [`workloadapi::X509Source`]. +pub fn mtls_server_with_source( + authorizer: tlsconfig::Authorizer, + source: Arc, +) -> ListenMode { ListenMode { mode: ServerMode::Mtls, source_unneeded: false, @@ -299,6 +359,8 @@ pub fn mtls_server_with_source(authorizer: tlsconfig::Authorizer, source: Arc>, @@ -315,6 +377,10 @@ pub fn mtls_server_with_source_options( } } +/// Like [`mtls_server`], but bypasses the Workload API entirely. +/// +/// The provided `svid` is presented as the server certificate chain and +/// `bundle` is used to verify/authorize client certificates. pub fn mtls_server_with_raw_config( authorizer: tlsconfig::Authorizer, svid: Arc, @@ -332,6 +398,10 @@ pub fn mtls_server_with_raw_config( } } +/// Listen using Web PKI identity, while requiring and authorizing SPIFFE clients. +/// +/// The server presents `cert` (typically a public TLS certificate chain) but +/// authenticates clients via SPIFFE mTLS using `authorizer`. pub fn mtls_web_server(authorizer: tlsconfig::Authorizer, cert: tlsconfig::WebCert) -> ListenMode { ListenMode { mode: ServerMode::MtlsWeb, @@ -345,6 +415,7 @@ pub fn mtls_web_server(authorizer: tlsconfig::Authorizer, cert: tlsconfig::WebCe } } +/// Like [`mtls_web_server`], but uses an existing [`workloadapi::X509Source`]. pub fn mtls_web_server_with_source( authorizer: tlsconfig::Authorizer, cert: tlsconfig::WebCert, @@ -362,6 +433,8 @@ pub fn mtls_web_server_with_source( } } +/// Like [`mtls_web_server`], but configures the internal [`workloadapi::X509Source`] +/// via source options. pub fn mtls_web_server_with_source_options( authorizer: tlsconfig::Authorizer, cert: tlsconfig::WebCert, @@ -379,6 +452,9 @@ pub fn mtls_web_server_with_source_options( } } +/// Like [`mtls_web_server`], but bypasses the Workload API entirely. +/// +/// The provided `bundle` is used to verify/authorize client certificates. pub fn mtls_web_server_with_raw_config( authorizer: tlsconfig::Authorizer, cert: tlsconfig::WebCert, diff --git a/src/spiffetls/peerid.rs b/src/spiffetls/peerid.rs index aa4e0c8..bc9bd4d 100644 --- a/src/spiffetls/peerid.rs +++ b/src/spiffetls/peerid.rs @@ -11,8 +11,9 @@ pub fn peer_id_from_stream(certs: Option<&[rustls::Certificate]>) -> Result let cert = certs .first() .ok_or_else(|| crate::spiffetls::wrap_error("no peer certificates"))?; - let (_rest, parsed) = x509_parser::parse_x509_certificate(&cert.0) - .map_err(|err| crate::spiffetls::wrap_error(format!("invalid peer certificate: {}", err)))?; + let (_rest, parsed) = x509_parser::parse_x509_certificate(&cert.0).map_err(|err| { + crate::spiffetls::wrap_error(format!("invalid peer certificate: {}", err)) + })?; let san = parsed .subject_alternative_name() .map_err(|_| crate::spiffetls::wrap_error("invalid peer certificate: invalid URI SAN"))? diff --git a/src/spiffetls/tlsconfig.rs b/src/spiffetls/tlsconfig.rs index 5e13a19..8f9a739 100644 --- a/src/spiffetls/tlsconfig.rs +++ b/src/spiffetls/tlsconfig.rs @@ -1,36 +1,70 @@ +//! Helpers for building `rustls` configurations that enforce SPIFFE identities. +//! +//! This module wires together three steps that are easy to conflate: +//! +//! - **Chain verification**: verify the peer certificate chain against an X.509 +//! bundle source (trust domain authorities). +//! - **SPIFFE ID extraction**: parse the peer SPIFFE ID from the leaf SVID. +//! - **Authorization**: decide whether the extracted ID is acceptable for the +//! connection. +//! +//! The [`Authorizer`] type represents the final authorization step. + use crate::bundle::x509bundle; use crate::spiffeid; use crate::spiffeid::ID; use crate::svid::x509svid; use rustls::client::{ServerCertVerified, ServerCertVerifier}; use rustls::server::{ClientCertVerified, ClientCertVerifier}; -use rustls::{Certificate, ClientConfig, Error as RustlsError, PrivateKey, RootCertStore, ServerConfig}; +use rustls::{ + Certificate, ClientConfig, Error as RustlsError, PrivateKey, RootCertStore, ServerConfig, +}; use std::sync::Arc; +/// Authorization callback used by SPIFFE-TLS verification. +/// +/// The callback receives: +/// - **`id`**: the SPIFFE ID extracted from the peer X.509-SVID leaf. +/// - **`chains`**: the verified chain(s), where each chain is a list of DER +/// certificates (leaf first). Some verifiers may provide multiple candidate +/// chains; most callers can ignore this parameter. +/// +/// Return `Ok(())` to authorize the peer, otherwise return an error to fail the +/// TLS handshake. pub type Authorizer = Arc>]) -> std::result::Result<(), super::Error> + Send + Sync>; +/// Optional tracing hooks invoked when fetching an SVID for a TLS config. +/// +/// This is intended for diagnostics (e.g. logging), not for modifying the +/// handshake. #[derive(Clone, Default)] pub struct Trace { pub get_certificate: Option>, pub got_certificate: Option>, } +/// Input to [`Trace::get_certificate`]. #[derive(Clone, Default)] pub struct GetCertificateInfo; +/// Input to [`Trace::got_certificate`]. #[derive(Clone, Default)] pub struct GotCertificateInfo { + /// The leaf certificate that will be used (if any). pub cert: Option, + /// An error string if fetching the SVID failed. pub err: Option, } +/// Optional configuration applied when constructing `rustls` configs. #[derive(Clone, Default)] pub struct TlsOption { trace: std::option::Option, } impl TlsOption { + /// Enables [`Trace`] hooks for SVID retrieval. pub fn with_trace(trace: Trace) -> Self { Self { trace: std::option::Option::Some(trace), @@ -40,32 +74,44 @@ impl TlsOption { #[derive(Clone)] pub struct WebCert { + /// Certificate chain to present (leaf first). pub certs: Vec, + /// Private key corresponding to the leaf certificate. pub key: PrivateKey, } +/// Authorizes any valid SPIFFE ID. pub fn authorize_any() -> Authorizer { adapt_matcher(spiffeid::match_any()) } +/// Authorizes only the given SPIFFE ID. pub fn authorize_id(allowed: ID) -> Authorizer { adapt_matcher(spiffeid::match_id(allowed)) } +/// Authorizes if the peer matches any ID in `allowed`. pub fn authorize_one_of(allowed: &[ID]) -> Authorizer { adapt_matcher(spiffeid::match_one_of(allowed)) } +/// Authorizes if the peer ID is a member of the given trust domain. pub fn authorize_member_of(allowed: spiffeid::TrustDomain) -> Authorizer { adapt_matcher(spiffeid::match_member_of(allowed)) } +/// Adapts a [`spiffeid::Matcher`] into an [`Authorizer`]. +/// +/// The resulting authorizer ignores verified chains and only evaluates the +/// matcher against the extracted SPIFFE ID. pub fn adapt_matcher(matcher: spiffeid::Matcher) -> Authorizer { - Arc::new(move |actual, _chains| { - matcher(actual).map_err(|err| super::wrap_error(err)) - }) + Arc::new(move |actual, _chains| matcher(actual).map_err(|err| super::wrap_error(err))) } +/// Builds a `rustls` client config that authenticates the server via SPIFFE-TLS. +/// +/// This configuration **does not** present a client certificate. Use +/// [`mtls_client_config`] for mutual authentication. pub fn tls_client_config( bundle_source: Arc, authorizer: Authorizer, @@ -77,6 +123,7 @@ pub fn tls_client_config( .with_no_client_auth()) } +/// Builds a `rustls` client config for SPIFFE mTLS (client presents an X.509-SVID). pub fn mtls_client_config( svid_source: &dyn x509svid::Source, bundle_source: Arc, @@ -85,6 +132,7 @@ pub fn mtls_client_config( mtls_client_config_with_options(svid_source, bundle_source, authorizer, &[]) } +/// Like [`mtls_client_config`], but allows passing [`TlsOption`] (e.g. tracing). pub fn mtls_client_config_with_options( svid_source: &dyn x509svid::Source, bundle_source: Arc, @@ -100,6 +148,7 @@ pub fn mtls_client_config_with_options( .map_err(|err| super::wrap_error(format!("unable to set client auth cert: {}", err))) } +/// Builds a `rustls` server config for SPIFFE mTLS (server requires client certs). pub fn mtls_server_config( svid_source: &dyn x509svid::Source, bundle_source: Arc, @@ -108,6 +157,7 @@ pub fn mtls_server_config( mtls_server_config_with_options(svid_source, bundle_source, authorizer, &[]) } +/// Like [`mtls_server_config`], but allows passing [`TlsOption`] (e.g. tracing). pub fn mtls_server_config_with_options( svid_source: &dyn x509svid::Source, bundle_source: Arc, @@ -123,10 +173,13 @@ pub fn mtls_server_config_with_options( .map_err(|err| super::wrap_error(format!("unable to set server cert: {}", err))) } +/// Builds a `rustls` server config that presents an X.509-SVID but does not +/// authenticate clients. pub fn tls_server_config(svid_source: &dyn x509svid::Source) -> super::Result { tls_server_config_with_options(svid_source, &[]) } +/// Like [`tls_server_config`], but allows passing [`TlsOption`] (e.g. tracing). pub fn tls_server_config_with_options( svid_source: &dyn x509svid::Source, opts: &[TlsOption], @@ -139,6 +192,11 @@ pub fn tls_server_config_with_options( .map_err(|err| super::wrap_error(format!("unable to set server cert: {}", err))) } +/// Builds a `rustls` client config that presents an X.509-SVID and verifies the +/// server using Web PKI roots. +/// +/// This is intended for talking to conventional HTTPS servers that use Web PKI +/// identities, while still authenticating the client with an X.509-SVID. pub fn mtls_web_client_config( svid_source: &dyn x509svid::Source, roots: std::option::Option, @@ -146,6 +204,7 @@ pub fn mtls_web_client_config( mtls_web_client_config_with_options(svid_source, roots, &[]) } +/// Like [`mtls_web_client_config`], but allows passing [`TlsOption`] (e.g. tracing). pub fn mtls_web_client_config_with_options( svid_source: &dyn x509svid::Source, roots: std::option::Option, @@ -161,6 +220,11 @@ pub fn mtls_web_client_config_with_options( Ok(config) } +/// Builds a `rustls` server config for “web mTLS”: +/// +/// - The server presents a Web PKI-style certificate chain (`cert`). +/// - The server requires a client certificate and enforces a SPIFFE ID policy +/// using `bundle_source` + `authorizer`. pub fn mtls_web_server_config( cert: WebCert, bundle_source: Arc, @@ -174,7 +238,13 @@ pub fn mtls_web_server_config( .map_err(|err| super::wrap_error(format!("unable to set server cert: {}", err))) } -pub fn webpki_client_config(roots: std::option::Option) -> super::Result { +/// Builds a Web PKI `rustls` client config (no client certificate). +/// +/// This is *not* SPIFFE-aware; it exists to support federation bundle fetching +/// over HTTPS when SPIFFE authentication is not configured. +pub fn webpki_client_config( + roots: std::option::Option, +) -> super::Result { let mut config = ClientConfig::builder() .with_safe_defaults() .with_root_certificates(roots.unwrap_or_else(system_roots)) @@ -264,8 +334,7 @@ impl SpiffeServerVerifier { let raw = raw_chain(end_entity, intermediates); let (id, chains) = x509svid::parse_and_verify(&raw, self.bundle_source.as_ref(), &[]) .map_err(|err| RustlsError::General(format!("spiffe verification failed: {}", err)))?; - (self.authorizer)(&id, &chains) - .map_err(|err| RustlsError::General(err.to_string()))?; + (self.authorizer)(&id, &chains).map_err(|err| RustlsError::General(err.to_string()))?; Ok(ServerCertVerified::assertion()) } } @@ -308,8 +377,7 @@ impl SpiffeClientVerifier { let raw = raw_chain(end_entity, intermediates); let (id, chains) = x509svid::parse_and_verify(&raw, self.bundle_source.as_ref(), &[]) .map_err(|err| RustlsError::General(format!("spiffe verification failed: {}", err)))?; - (self.authorizer)(&id, &chains) - .map_err(|err| RustlsError::General(err.to_string()))?; + (self.authorizer)(&id, &chains).map_err(|err| RustlsError::General(err.to_string()))?; Ok(ClientCertVerified::assertion()) } } diff --git a/src/svid/jwtsvid.rs b/src/svid/jwtsvid.rs index f28a639..0b23780 100644 --- a/src/svid/jwtsvid.rs +++ b/src/svid/jwtsvid.rs @@ -3,16 +3,18 @@ use crate::bundle::jwtbundle::JwtKey; use crate::spiffeid::ID; use base64::engine::general_purpose::URL_SAFE_NO_PAD; use base64::Engine; -use p256::ecdsa::{signature::Verifier, Signature as P256Signature, VerifyingKey as P256VerifyingKey}; +use p256::ecdsa::{ + signature::Verifier, Signature as P256Signature, VerifyingKey as P256VerifyingKey, +}; use p384::ecdsa::{Signature as P384Signature, VerifyingKey as P384VerifyingKey}; use p521::ecdsa::{Signature as P521Signature, VerifyingKey as P521VerifyingKey}; +use pkcs8::AssociatedOid; use rsa::pkcs1v15::{Signature as RsaSignature, VerifyingKey as RsaVerifyingKey}; use rsa::pss::{Signature as RsaPssSignature, VerifyingKey as RsaPssVerifyingKey}; +use rsa::signature::digest::FixedOutputReset; use rsa::RsaPublicKey; use serde_json::{Map, Value}; use sha2::{Digest, Sha256, Sha384, Sha512}; -use pkcs8::AssociatedOid; -use rsa::signature::digest::FixedOutputReset; use std::collections::HashMap; use std::time::{Duration, SystemTime}; @@ -103,29 +105,34 @@ pub fn parse_and_validate( bundles: &dyn jwtbundle::Source, audience: &[String], ) -> Result { - parse(token, audience, |header, signing_input, signature, trust_domain| { - let key_id = header - .kid - .as_deref() - .ok_or_else(|| wrap_error("token header missing key id"))?; - let bundle = bundles - .get_jwt_bundle_for_trust_domain(trust_domain.clone()) - .map_err(|_| { - wrap_error(format!("no bundle found for trust domain \"{}\"", trust_domain)) - })?; - let authority = bundle - .find_jwt_authority(key_id) - .ok_or_else(|| { + parse( + token, + audience, + |header, signing_input, signature, trust_domain| { + let key_id = header + .kid + .as_deref() + .ok_or_else(|| wrap_error("token header missing key id"))?; + let bundle = bundles + .get_jwt_bundle_for_trust_domain(trust_domain.clone()) + .map_err(|_| { + wrap_error(format!( + "no bundle found for trust domain \"{}\"", + trust_domain + )) + })?; + let authority = bundle.find_jwt_authority(key_id).ok_or_else(|| { wrap_error(format!( "no JWT authority \"{}\" found for trust domain \"{}\"", key_id, trust_domain )) })?; - verify_signature(&header.alg, &authority, signing_input, signature).map_err(|_| { + verify_signature(&header.alg, &authority, signing_input, signature).map_err(|_| { wrap_error("unable to get claims from token: go-jose/go-jose: error in cryptographic primitive") })?; - Ok(()) - }) + Ok(()) + }, + ) } /// Parses a JWT SVID token without validating its signature. @@ -133,7 +140,11 @@ pub fn parse_and_validate( /// **WARNING**: This should only be used if the token has already been validated /// by other means. pub fn parse_insecure(token: &str, audience: &[String]) -> Result { - parse(token, audience, |_header, _signing_input, _signature, _td| Ok(())) + parse( + token, + audience, + |_header, _signing_input, _signature, _td| Ok(()), + ) } fn parse(token: &str, audience: &[String], verify: F) -> Result @@ -154,20 +165,22 @@ where .decode(parts[2].as_bytes()) .map_err(|_| wrap_error("unable to parse JWT token"))?; - let header: Header = - serde_json::from_slice(&header_bytes).map_err(|_| wrap_error("unable to parse JWT token"))?; + let header: Header = serde_json::from_slice(&header_bytes) + .map_err(|_| wrap_error("unable to parse JWT token"))?; if !is_allowed_alg(&header.alg) { return Err(wrap_error("unable to parse JWT token")); } if let Some(typ) = header.typ.as_deref() { if typ != "JWT" && typ != "JOSE" { - return Err(wrap_error("token header type not equal to either JWT or JOSE")); + return Err(wrap_error( + "token header type not equal to either JWT or JOSE", + )); } } - let claims: Map = - serde_json::from_slice(&payload_bytes).map_err(|_| wrap_error("unable to parse JWT token"))?; + let claims: Map = serde_json::from_slice(&payload_bytes) + .map_err(|_| wrap_error("unable to parse JWT token"))?; let subject = claims .get("sub") .and_then(|v| v.as_str()) @@ -182,7 +195,12 @@ where .map_err(|err| wrap_error(format!("token has an invalid subject claim: {}", err)))?; let trust_domain = id.trust_domain(); - verify(&header, &format!("{}.{}", parts[0], parts[1]), &signature, &trust_domain)?; + verify( + &header, + &format!("{}.{}", parts[0], parts[1]), + &signature, + &trust_domain, + )?; validate_claims(expiry, &aud, audience)?; @@ -233,9 +251,15 @@ fn is_allowed_alg(alg: &str) -> bool { fn verify_signature(alg: &str, key: &JwtKey, signing_input: &str, signature: &[u8]) -> Result<()> { match (alg, key) { - ("RS256", JwtKey::Rsa { n, e }) => verify_rsa_pkcs1::(n, e, signing_input, signature), - ("RS384", JwtKey::Rsa { n, e }) => verify_rsa_pkcs1::(n, e, signing_input, signature), - ("RS512", JwtKey::Rsa { n, e }) => verify_rsa_pkcs1::(n, e, signing_input, signature), + ("RS256", JwtKey::Rsa { n, e }) => { + verify_rsa_pkcs1::(n, e, signing_input, signature) + } + ("RS384", JwtKey::Rsa { n, e }) => { + verify_rsa_pkcs1::(n, e, signing_input, signature) + } + ("RS512", JwtKey::Rsa { n, e }) => { + verify_rsa_pkcs1::(n, e, signing_input, signature) + } ("PS256", JwtKey::Rsa { n, e }) => verify_rsa_pss::(n, e, signing_input, signature), ("PS384", JwtKey::Rsa { n, e }) => verify_rsa_pss::(n, e, signing_input, signature), ("PS512", JwtKey::Rsa { n, e }) => verify_rsa_pss::(n, e, signing_input, signature), @@ -246,12 +270,7 @@ fn verify_signature(alg: &str, key: &JwtKey, signing_input: &str, signature: &[u } } -fn verify_rsa_pkcs1( - n: &[u8], - e: &[u8], - signing_input: &str, - signature: &[u8], -) -> Result<()> +fn verify_rsa_pkcs1(n: &[u8], e: &[u8], signing_input: &str, signature: &[u8]) -> Result<()> where D: Digest + AssociatedOid, { @@ -264,12 +283,7 @@ where Ok(()) } -fn verify_rsa_pss( - n: &[u8], - e: &[u8], - signing_input: &str, - signature: &[u8], -) -> Result<()> +fn verify_rsa_pss(n: &[u8], e: &[u8], signing_input: &str, signature: &[u8]) -> Result<()> where D: Digest + FixedOutputReset, { @@ -288,45 +302,30 @@ fn rsa_public_key(n: &[u8], e: &[u8]) -> Result { RsaPublicKey::new(n, e).map_err(|_| wrap_error("invalid RSA key")) } -fn verify_ecdsa_p256( - x: &[u8], - y: &[u8], - signing_input: &str, - signature: &[u8], -) -> Result<()> { +fn verify_ecdsa_p256(x: &[u8], y: &[u8], signing_input: &str, signature: &[u8]) -> Result<()> { let public_key = ecdsa_public_key(x, y)?; - let key = P256VerifyingKey::from_sec1_bytes(&public_key) - .map_err(|_| wrap_error("invalid EC key"))?; + let key = + P256VerifyingKey::from_sec1_bytes(&public_key).map_err(|_| wrap_error("invalid EC key"))?; let sig = P256Signature::from_slice(signature).map_err(|_| wrap_error("invalid signature"))?; key.verify(signing_input.as_bytes(), &sig) .map_err(|_| wrap_error("invalid signature"))?; Ok(()) } -fn verify_ecdsa_p384( - x: &[u8], - y: &[u8], - signing_input: &str, - signature: &[u8], -) -> Result<()> { +fn verify_ecdsa_p384(x: &[u8], y: &[u8], signing_input: &str, signature: &[u8]) -> Result<()> { let public_key = ecdsa_public_key(x, y)?; - let key = P384VerifyingKey::from_sec1_bytes(&public_key) - .map_err(|_| wrap_error("invalid EC key"))?; + let key = + P384VerifyingKey::from_sec1_bytes(&public_key).map_err(|_| wrap_error("invalid EC key"))?; let sig = P384Signature::from_slice(signature).map_err(|_| wrap_error("invalid signature"))?; key.verify(signing_input.as_bytes(), &sig) .map_err(|_| wrap_error("invalid signature"))?; Ok(()) } -fn verify_ecdsa_p521( - x: &[u8], - y: &[u8], - signing_input: &str, - signature: &[u8], -) -> Result<()> { +fn verify_ecdsa_p521(x: &[u8], y: &[u8], signing_input: &str, signature: &[u8]) -> Result<()> { let public_key = ecdsa_public_key(x, y)?; - let key = P521VerifyingKey::from_sec1_bytes(&public_key) - .map_err(|_| wrap_error("invalid EC key"))?; + let key = + P521VerifyingKey::from_sec1_bytes(&public_key).map_err(|_| wrap_error("invalid EC key"))?; let sig = P521Signature::from_slice(signature).map_err(|_| wrap_error("invalid signature"))?; key.verify(signing_input.as_bytes(), &sig) .map_err(|_| wrap_error("invalid signature"))?; diff --git a/src/svid/x509svid.rs b/src/svid/x509svid.rs index 9ffe2b7..4d766cf 100644 --- a/src/svid/x509svid.rs +++ b/src/svid/x509svid.rs @@ -54,17 +54,15 @@ impl SVID { pub fn load(cert_file: &str, key_file: &str) -> Result { let cert_bytes = fs::read(cert_file) .map_err(|err| wrap_error(format!("cannot read certificate file: {}", err)))?; - let key_bytes = - fs::read(key_file).map_err(|err| wrap_error(format!("cannot read key file: {}", err)))?; + let key_bytes = fs::read(key_file) + .map_err(|err| wrap_error(format!("cannot read key file: {}", err)))?; SVID::parse(&cert_bytes, &key_bytes) } /// Parses an X.509 SVID from PEM encoded bytes. pub fn parse(cert_bytes: &[u8], key_bytes: &[u8]) -> Result { - let certs = - pemutil::parse_certificates(cert_bytes).map_err(|err| { - wrap_error(format!("cannot parse PEM encoded certificate: {}", err)) - })?; + let certs = pemutil::parse_certificates(cert_bytes) + .map_err(|err| wrap_error(format!("cannot parse PEM encoded certificate: {}", err)))?; let key = parse_private_key_pem(key_bytes) .map_err(|err| wrap_error(format!("cannot parse PEM encoded private key: {}", err)))?; new_svid(certs, key) @@ -100,7 +98,9 @@ impl SVID { return Err(wrap_error("no certificates to marshal")); } if self.private_key.is_empty() { - return Err(wrap_error("cannot marshal private key: missing private key")); + return Err(wrap_error( + "cannot marshal private key: missing private key", + )); } let mut certs = Vec::new(); for cert in &self.certificates { @@ -215,8 +215,7 @@ fn validate_certificates(certs: &[Vec]) -> Result { "leaf certificate must not have CA flag set to true".to_string(), )); } - validate_leaf_key_usage(&leaf) - .map_err(|err| Error(err.to_string()))?; + validate_leaf_key_usage(&leaf).map_err(|err| Error(err.to_string()))?; for cert_bytes in certs.iter().skip(1) { let cert = parse_certificate(cert_bytes)?; @@ -283,7 +282,9 @@ fn id_from_cert(cert: &X509Certificate<'_>) -> Result { return Err(Error("certificate contains no URI SAN".to_string())); } if uris.len() > 1 { - return Err(Error("certificate contains more than one URI SAN".to_string())); + return Err(Error( + "certificate contains more than one URI SAN".to_string(), + )); } ID::from_string(uris.remove(0)).map_err(|err| Error(err.to_string())) } @@ -298,8 +299,8 @@ fn parse_raw_certificates(bytes: &[u8]) -> std::result::Result>, Str let mut remaining = bytes; let mut certs = Vec::new(); while !remaining.is_empty() { - let (rest, _cert) = x509_parser::parse_x509_certificate(remaining) - .map_err(|err| err.to_string())?; + let (rest, _cert) = + x509_parser::parse_x509_certificate(remaining).map_err(|err| err.to_string())?; let consumed = remaining .len() .checked_sub(rest.len()) @@ -347,7 +348,9 @@ fn validate_private_key(key_bytes: &[u8], cert_bytes: &[u8]) -> Result<()> { if n == modulus && e == exponent { return Ok(()); } - return Err(Error("leaf certificate does not match private key".to_string())); + return Err(Error( + "leaf certificate does not match private key".to_string(), + )); } } @@ -358,7 +361,9 @@ fn validate_private_key(key_bytes: &[u8], cert_bytes: &[u8]) -> Result<()> { if bytes == ec.data() { return Ok(()); } - return Err(Error("leaf certificate does not match private key".to_string())); + return Err(Error( + "leaf certificate does not match private key".to_string(), + )); } } @@ -369,7 +374,9 @@ fn validate_private_key(key_bytes: &[u8], cert_bytes: &[u8]) -> Result<()> { if bytes == ec.data() { return Ok(()); } - return Err(Error("leaf certificate does not match private key".to_string())); + return Err(Error( + "leaf certificate does not match private key".to_string(), + )); } } @@ -380,7 +387,9 @@ fn validate_private_key(key_bytes: &[u8], cert_bytes: &[u8]) -> Result<()> { if bytes == ec.data() { return Ok(()); } - return Err(Error("leaf certificate does not match private key".to_string())); + return Err(Error( + "leaf certificate does not match private key".to_string(), + )); } } @@ -404,7 +413,8 @@ fn verify_chain( let now = now .duration_since(SystemTime::UNIX_EPOCH) .map_err(|_| "invalid time".to_string())?; - let now = ASN1Time::from_timestamp(now.as_secs() as i64).map_err(|_| "invalid time".to_string())?; + let now = + ASN1Time::from_timestamp(now.as_secs() as i64).map_err(|_| "invalid time".to_string())?; let parsed = certs .iter() diff --git a/src/workloadapi/addr.rs b/src/workloadapi/addr.rs index 4d13b36..ca39a15 100644 --- a/src/workloadapi/addr.rs +++ b/src/workloadapi/addr.rs @@ -1,6 +1,6 @@ use crate::workloadapi::{wrap_error, Result}; -use std::net::IpAddr; use std::env; +use std::net::IpAddr; use url::Url; #[allow(non_upper_case_globals)] @@ -72,12 +72,12 @@ fn parse_target_from_url(url: &Url) -> Result { let host = url .host_str() .ok_or_else(|| wrap_error("workload endpoint tcp socket URI must include a host"))?; - let ip: IpAddr = host - .parse() - .map_err(|_| wrap_error("workload endpoint tcp socket URI host component must be an IP:port"))?; - let port = url - .port() - .ok_or_else(|| wrap_error("workload endpoint tcp socket URI host component must include a port"))?; + let ip: IpAddr = host.parse().map_err(|_| { + wrap_error("workload endpoint tcp socket URI host component must be an IP:port") + })?; + let port = url.port().ok_or_else(|| { + wrap_error("workload endpoint tcp socket URI host component must include a port") + })?; return Ok(format!("{}:{}", ip, port)); } diff --git a/src/workloadapi/backoff.rs b/src/workloadapi/backoff.rs index 8c49df0..982dd7c 100644 --- a/src/workloadapi/backoff.rs +++ b/src/workloadapi/backoff.rs @@ -1,14 +1,35 @@ +//! Backoff strategies used for Workload API watch retries. +//! +//! The Workload API watch RPCs are long-lived streams. When a stream errors or +//! disconnects, the client retries with a delay determined by a [`Backoff`]. +//! Implementations are expected to be cheap and deterministic (no sleeping) and +//! to be reset after a successful receive. + use std::time::Duration; +/// Creates independent [`Backoff`] instances with the same policy. +/// +/// A strategy is stored in client configuration and used to construct a fresh +/// backoff state for each watch loop. pub trait BackoffStrategy: Send + Sync { + /// Returns a new backoff state machine. fn new_backoff(&self) -> Box; } +/// A stateful backoff timer. +/// +/// Implementations return the next delay to wait before retrying, and can be +/// reset after a successful attempt. pub trait Backoff: Send { + /// Returns the delay to wait before the next retry. fn next(&mut self) -> Duration; + /// Resets the backoff to its initial state (e.g. after a successful call). fn reset(&mut self); } +/// A simple linear backoff strategy (`delay = initial * n`, capped). +/// +/// This is the crate default for Workload API watches. #[derive(Default)] pub struct LinearBackoffStrategy; @@ -25,6 +46,13 @@ pub struct LinearBackoff { } impl LinearBackoff { + /// Creates a linear backoff with defaults: + /// + /// - **initial**: 1s + /// - **max**: 30s + /// + /// The first call to [`Backoff::next`] returns 1s, then 2s, etc., capped at + /// 30s. pub fn new() -> Self { Self { initial_delay: Duration::from_secs(1), diff --git a/src/workloadapi/bundlesource.rs b/src/workloadapi/bundlesource.rs index 010c18c..fb1650c 100644 --- a/src/workloadapi/bundlesource.rs +++ b/src/workloadapi/bundlesource.rs @@ -47,7 +47,8 @@ impl BundleSource { } }); - let watcher = Watcher::new(ctx, config.watcher, Some(x509_handler), Some(jwt_handler)).await?; + let watcher = + Watcher::new(ctx, config.watcher, Some(x509_handler), Some(jwt_handler)).await?; Ok(BundleSource { watcher, x509_authorities, @@ -61,7 +62,10 @@ impl BundleSource { self.watcher.close().await } - pub fn get_bundle_for_trust_domain(&self, trust_domain: TrustDomain) -> Result { + pub fn get_bundle_for_trust_domain( + &self, + trust_domain: TrustDomain, + ) -> Result { self.check_closed()?; let x509 = self .x509_authorities @@ -90,7 +94,10 @@ impl BundleSource { Ok(bundle) } - pub fn get_x509_bundle_for_trust_domain(&self, trust_domain: TrustDomain) -> Result { + pub fn get_x509_bundle_for_trust_domain( + &self, + trust_domain: TrustDomain, + ) -> Result { self.check_closed()?; let x509 = self .x509_authorities @@ -103,10 +110,16 @@ impl BundleSource { trust_domain )) })?; - Ok(x509bundle::Bundle::from_x509_authorities(trust_domain, &x509)) + Ok(x509bundle::Bundle::from_x509_authorities( + trust_domain, + &x509, + )) } - pub fn get_jwt_bundle_for_trust_domain(&self, trust_domain: TrustDomain) -> Result { + pub fn get_jwt_bundle_for_trust_domain( + &self, + trust_domain: TrustDomain, + ) -> Result { self.check_closed()?; let jwt = self .jwt_authorities @@ -132,28 +145,39 @@ impl BundleSource { fn check_closed(&self) -> Result<()> { if self.closed.load(std::sync::atomic::Ordering::SeqCst) { - return Err(crate::workloadapi::Error::new("bundlesource: source is closed")); + return Err(crate::workloadapi::Error::new( + "bundlesource: source is closed", + )); } Ok(()) } } impl spiffebundle::Source for BundleSource { - fn get_bundle_for_trust_domain(&self, trust_domain: TrustDomain) -> spiffebundle::Result { + fn get_bundle_for_trust_domain( + &self, + trust_domain: TrustDomain, + ) -> spiffebundle::Result { self.get_bundle_for_trust_domain(trust_domain) .map_err(|err| spiffebundle::Error::new(err.to_string())) } } impl x509bundle::Source for BundleSource { - fn get_x509_bundle_for_trust_domain(&self, trust_domain: TrustDomain) -> x509bundle::Result { + fn get_x509_bundle_for_trust_domain( + &self, + trust_domain: TrustDomain, + ) -> x509bundle::Result { self.get_x509_bundle_for_trust_domain(trust_domain) .map_err(|err| x509bundle::Error::new(err.to_string())) } } impl jwtbundle::Source for BundleSource { - fn get_jwt_bundle_for_trust_domain(&self, trust_domain: TrustDomain) -> jwtbundle::Result { + fn get_jwt_bundle_for_trust_domain( + &self, + trust_domain: TrustDomain, + ) -> jwtbundle::Result { self.get_jwt_bundle_for_trust_domain(trust_domain) .map_err(|err| jwtbundle::Error::new(err.to_string())) } diff --git a/src/workloadapi/client.rs b/src/workloadapi/client.rs index 0cb096c..2d712ab 100644 --- a/src/workloadapi/client.rs +++ b/src/workloadapi/client.rs @@ -7,15 +7,15 @@ use crate::workloadapi::proto::{ JwtBundlesRequest, JwtBundlesResponse, JwtsvidRequest, JwtsvidResponse, ValidateJwtsvidRequest, X509BundlesRequest, X509BundlesResponse, X509svidRequest, X509svidResponse, }; -use crate::workloadapi::{target_from_address, wrap_error, Backoff, Error, Result}; use crate::workloadapi::{option::ClientConfig, Context}; -use tower::service_fn; +use crate::workloadapi::{target_from_address, wrap_error, Backoff, Error, Result}; use std::collections::HashSet; use std::sync::Arc; use tokio::net::UnixStream; use tonic::metadata::MetadataValue; use tonic::transport::{Channel, Endpoint}; use tonic::{Code, Request, Status}; +use tower::service_fn; /// A client for the SPIFFE Workload API. /// @@ -39,9 +39,8 @@ impl Client { let address = match config.address.clone() { Some(addr) => addr, - None => crate::workloadapi::get_default_address().ok_or_else(|| { - wrap_error("workload endpoint socket address is not configured") - })?, + None => crate::workloadapi::get_default_address() + .ok_or_else(|| wrap_error("workload endpoint socket address is not configured"))?, }; let target = target_from_address(&address)?; let channel = connect_channel(&target, &config.dial_options).await?; @@ -58,8 +57,12 @@ impl Client { pub async fn fetch_x509_svid(&self, ctx: &Context) -> Result { let mut client = self.inner.clone(); let request = with_header(Request::new(X509svidRequest {})); - let mut stream = cancelable(ctx, client.fetch_x509svid(request)).await?.into_inner(); - let response = cancelable(ctx, stream.message()).await?.ok_or_else(|| wrap_error("stream closed"))?; + let mut stream = cancelable(ctx, client.fetch_x509svid(request)) + .await? + .into_inner(); + let response = cancelable(ctx, stream.message()) + .await? + .ok_or_else(|| wrap_error("stream closed"))?; let svids = parse_x509_svids(response, true)?; Ok(svids .into_iter() @@ -71,8 +74,12 @@ impl Client { pub async fn fetch_x509_svids(&self, ctx: &Context) -> Result> { let mut client = self.inner.clone(); let request = with_header(Request::new(X509svidRequest {})); - let mut stream = cancelable(ctx, client.fetch_x509svid(request)).await?.into_inner(); - let response = cancelable(ctx, stream.message()).await?.ok_or_else(|| wrap_error("stream closed"))?; + let mut stream = cancelable(ctx, client.fetch_x509svid(request)) + .await? + .into_inner(); + let response = cancelable(ctx, stream.message()) + .await? + .ok_or_else(|| wrap_error("stream closed"))?; parse_x509_svids(response, false) } @@ -80,16 +87,27 @@ impl Client { pub async fn fetch_x509_bundles(&self, ctx: &Context) -> Result { let mut client = self.inner.clone(); let request = with_header(Request::new(X509BundlesRequest {})); - let mut stream = cancelable(ctx, client.fetch_x509_bundles(request)).await?.into_inner(); - let resp = cancelable(ctx, stream.message()).await?.ok_or_else(|| wrap_error("stream closed"))?; + let mut stream = cancelable(ctx, client.fetch_x509_bundles(request)) + .await? + .into_inner(); + let resp = cancelable(ctx, stream.message()) + .await? + .ok_or_else(|| wrap_error("stream closed"))?; parse_x509_bundles_response(resp) } /// Watches for X.509 bundle updates from the Workload API. - pub async fn watch_x509_bundles(&self, ctx: &Context, watcher: Arc) -> Result<()> { + pub async fn watch_x509_bundles( + &self, + ctx: &Context, + watcher: Arc, + ) -> Result<()> { let mut backoff = self.config.backoff_strategy.new_backoff(); loop { - if let Err(err) = self.watch_x509_bundles_once(ctx, watcher.clone(), &mut *backoff).await { + if let Err(err) = self + .watch_x509_bundles_once(ctx, watcher.clone(), &mut *backoff) + .await + { watcher.on_x509_bundles_watch_error(err.clone()); if let Some(err) = self.handle_watch_error(ctx, err, &mut *backoff).await { return Err(err); @@ -99,11 +117,18 @@ impl Client { } /// Fetches the X.509 context (SVIDs and bundles) from the Workload API. - pub async fn fetch_x509_context(&self, ctx: &Context) -> Result { + pub async fn fetch_x509_context( + &self, + ctx: &Context, + ) -> Result { let mut client = self.inner.clone(); let request = with_header(Request::new(X509svidRequest {})); - let mut stream = cancelable(ctx, client.fetch_x509svid(request)).await?.into_inner(); - let response = cancelable(ctx, stream.message()).await?.ok_or_else(|| wrap_error("stream closed"))?; + let mut stream = cancelable(ctx, client.fetch_x509svid(request)) + .await? + .into_inner(); + let response = cancelable(ctx, stream.message()) + .await? + .ok_or_else(|| wrap_error("stream closed"))?; parse_x509_context(response) } @@ -115,7 +140,10 @@ impl Client { ) -> Result<()> { let mut backoff = self.config.backoff_strategy.new_backoff(); loop { - if let Err(err) = self.watch_x509_context_once(ctx, watcher.clone(), &mut *backoff).await { + if let Err(err) = self + .watch_x509_context_once(ctx, watcher.clone(), &mut *backoff) + .await + { watcher.on_x509_context_watch_error(err.clone()); if let Some(err) = self.handle_watch_error(ctx, err, &mut *backoff).await { return Err(err); @@ -125,7 +153,11 @@ impl Client { } /// Fetches a single JWT SVID from the Workload API. - pub async fn fetch_jwt_svid(&self, ctx: &Context, params: jwtsvid::Params) -> Result { + pub async fn fetch_jwt_svid( + &self, + ctx: &Context, + params: jwtsvid::Params, + ) -> Result { let mut client = self.inner.clone(); let audience = params.audience_list(); let request = with_header(Request::new(JwtsvidRequest { @@ -141,7 +173,11 @@ impl Client { } /// Fetches multiple JWT SVIDs from the Workload API. - pub async fn fetch_jwt_svids(&self, ctx: &Context, params: jwtsvid::Params) -> Result> { + pub async fn fetch_jwt_svids( + &self, + ctx: &Context, + params: jwtsvid::Params, + ) -> Result> { let mut client = self.inner.clone(); let audience = params.audience_list(); let request = with_header(Request::new(JwtsvidRequest { @@ -156,16 +192,27 @@ impl Client { pub async fn fetch_jwt_bundles(&self, ctx: &Context) -> Result { let mut client = self.inner.clone(); let request = with_header(Request::new(JwtBundlesRequest {})); - let mut stream = cancelable(ctx, client.fetch_jwt_bundles(request)).await?.into_inner(); - let resp = cancelable(ctx, stream.message()).await?.ok_or_else(|| wrap_error("stream closed"))?; + let mut stream = cancelable(ctx, client.fetch_jwt_bundles(request)) + .await? + .into_inner(); + let resp = cancelable(ctx, stream.message()) + .await? + .ok_or_else(|| wrap_error("stream closed"))?; parse_jwt_bundles(resp) } /// Watches for JWT bundle updates from the Workload API. - pub async fn watch_jwt_bundles(&self, ctx: &Context, watcher: Arc) -> Result<()> { + pub async fn watch_jwt_bundles( + &self, + ctx: &Context, + watcher: Arc, + ) -> Result<()> { let mut backoff = self.config.backoff_strategy.new_backoff(); loop { - if let Err(err) = self.watch_jwt_bundles_once(ctx, watcher.clone(), &mut *backoff).await { + if let Err(err) = self + .watch_jwt_bundles_once(ctx, watcher.clone(), &mut *backoff) + .await + { watcher.on_jwt_bundles_watch_error(err.clone()); if let Some(err) = self.handle_watch_error(ctx, err, &mut *backoff).await { return Err(err); @@ -175,7 +222,12 @@ impl Client { } /// Validates a JWT SVID token using the Workload API. - pub async fn validate_jwt_svid(&self, ctx: &Context, token: &str, audience: &str) -> Result { + pub async fn validate_jwt_svid( + &self, + ctx: &Context, + token: &str, + audience: &str, + ) -> Result { let mut client = self.inner.clone(); let request = with_header(Request::new(ValidateJwtsvidRequest { svid: token.to_string(), @@ -191,7 +243,10 @@ impl Client { err: Error, backoff: &mut dyn Backoff, ) -> Option { - let status = err.status().cloned().unwrap_or_else(|| Status::unknown(err.to_string())); + let status = err + .status() + .cloned() + .unwrap_or_else(|| Status::unknown(err.to_string())); match status.code() { Code::Cancelled => return Some(err), Code::InvalidArgument => { @@ -225,10 +280,16 @@ impl Client { ) -> Result<()> { let mut client = self.inner.clone(); let request = with_header(Request::new(X509svidRequest {})); - let mut stream = cancelable(ctx, client.fetch_x509svid(request)).await?.into_inner(); - self.config.log.debugf(format_args!("Watching X.509 contexts")); + let mut stream = cancelable(ctx, client.fetch_x509svid(request)) + .await? + .into_inner(); + self.config + .log + .debugf(format_args!("Watching X.509 contexts")); loop { - let resp = cancelable(ctx, stream.message()).await?.ok_or_else(|| wrap_error("stream closed"))?; + let resp = cancelable(ctx, stream.message()) + .await? + .ok_or_else(|| wrap_error("stream closed"))?; backoff.reset(); match parse_x509_context(resp) { Ok(context) => watcher.on_x509_context_update(context), @@ -250,10 +311,14 @@ impl Client { ) -> Result<()> { let mut client = self.inner.clone(); let request = with_header(Request::new(JwtBundlesRequest {})); - let mut stream = cancelable(ctx, client.fetch_jwt_bundles(request)).await?.into_inner(); + let mut stream = cancelable(ctx, client.fetch_jwt_bundles(request)) + .await? + .into_inner(); self.config.log.debugf(format_args!("Watching JWT bundles")); loop { - let resp = cancelable(ctx, stream.message()).await?.ok_or_else(|| wrap_error("stream closed"))?; + let resp = cancelable(ctx, stream.message()) + .await? + .ok_or_else(|| wrap_error("stream closed"))?; backoff.reset(); match parse_jwt_bundles(resp) { Ok(bundles) => watcher.on_jwt_bundles_update(bundles), @@ -275,17 +340,24 @@ impl Client { ) -> Result<()> { let mut client = self.inner.clone(); let request = with_header(Request::new(X509BundlesRequest {})); - let mut stream = cancelable(ctx, client.fetch_x509_bundles(request)).await?.into_inner(); - self.config.log.debugf(format_args!("Watching X.509 bundles")); + let mut stream = cancelable(ctx, client.fetch_x509_bundles(request)) + .await? + .into_inner(); + self.config + .log + .debugf(format_args!("Watching X.509 bundles")); loop { - let resp = cancelable(ctx, stream.message()).await?.ok_or_else(|| wrap_error("stream closed"))?; + let resp = cancelable(ctx, stream.message()) + .await? + .ok_or_else(|| wrap_error("stream closed"))?; backoff.reset(); match parse_x509_bundles_response(resp) { Ok(bundles) => watcher.on_x509_bundles_update(bundles), Err(err) => { - self.config - .log - .errorf(format_args!("Failed to parse X.509 bundle response: {}", err)); + self.config.log.errorf(format_args!( + "Failed to parse X.509 bundle response: {}", + err + )); watcher.on_x509_bundles_watch_error(err); } } @@ -300,7 +372,10 @@ fn with_header(mut request: Request) -> Request { request } -async fn connect_channel(target: &str, options: &[Arc]) -> Result { +async fn connect_channel( + target: &str, + options: &[Arc], +) -> Result { if let Ok(url) = url::Url::parse(target) { if url.scheme() == "unix" { let path = unix_path_from_url(&url)?; @@ -331,7 +406,9 @@ async fn connect_channel(target: &str, options: &[Arc Result { if url.cannot_be_a_base() { - return Err(wrap_error("workload endpoint unix socket URI must not be opaque")); + return Err(wrap_error( + "workload endpoint unix socket URI must not be opaque", + )); } let host = url.host_str().unwrap_or(""); let raw_path = if host.is_empty() { @@ -342,7 +419,9 @@ fn unix_path_from_url(url: &url::Url) -> Result { format!("/{host}{}", url.path()) }; if raw_path.is_empty() || raw_path == "/" { - return Err(wrap_error("workload endpoint unix socket URI must include a path")); + return Err(wrap_error( + "workload endpoint unix socket URI must include a path", + )); } Ok(std::path::PathBuf::from(raw_path)) } @@ -392,7 +471,8 @@ fn parse_x509_bundles(resp: X509svidResponse) -> Result { let td = ID::from_string(&svid.spiffe_id) .map_err(|err| wrap_error(err))? .trust_domain(); - bundles.push(x509bundle::Bundle::parse_raw(td, &svid.bundle).map_err(|err| wrap_error(err))?); + bundles + .push(x509bundle::Bundle::parse_raw(td, &svid.bundle).map_err(|err| wrap_error(err))?); } for (td_id, bundle) in resp.federated_bundles { let td = spiffeid::trust_domain_from_string(&td_id).map_err(|err| wrap_error(err))?; @@ -410,7 +490,11 @@ fn parse_x509_bundles_response(resp: X509BundlesResponse) -> Result Result> { +fn parse_jwt_svids( + resp: JwtsvidResponse, + audience: &[String], + first_only: bool, +) -> Result> { let mut svids = resp.svids; if svids.is_empty() { return Err(wrap_error("there were no SVIDs in the response")); @@ -425,7 +509,8 @@ fn parse_jwt_svids(resp: JwtsvidResponse, audience: &[String], first_only: bool) if !svid.hint.is_empty() && !seen.insert(svid.hint.clone()) { continue; } - let mut parsed = jwtsvid::parse_insecure(&svid.svid, audience).map_err(|err| wrap_error(err))?; + let mut parsed = + jwtsvid::parse_insecure(&svid.svid, audience).map_err(|err| wrap_error(err))?; parsed.hint = svid.hint; out.push(parsed); } diff --git a/src/workloadapi/convenience.rs b/src/workloadapi/convenience.rs index 255b235..86956d6 100644 --- a/src/workloadapi/convenience.rs +++ b/src/workloadapi/convenience.rs @@ -60,7 +60,11 @@ where result } -pub async fn fetch_jwt_svid(ctx: &Context, params: jwtsvid::Params, options: I) -> Result +pub async fn fetch_jwt_svid( + ctx: &Context, + params: jwtsvid::Params, + options: I, +) -> Result where I: IntoIterator>, { @@ -70,7 +74,11 @@ where result } -pub async fn fetch_jwt_svids(ctx: &Context, params: jwtsvid::Params, options: I) -> Result> +pub async fn fetch_jwt_svids( + ctx: &Context, + params: jwtsvid::Params, + options: I, +) -> Result> where I: IntoIterator>, { @@ -118,7 +126,12 @@ where result } -pub async fn validate_jwt_svid(ctx: &Context, token: &str, audience: &str, options: I) -> Result +pub async fn validate_jwt_svid( + ctx: &Context, + token: &str, + audience: &str, + options: I, +) -> Result where I: IntoIterator>, { diff --git a/src/workloadapi/jwtsource.rs b/src/workloadapi/jwtsource.rs index 7d1e948..2615a05 100644 --- a/src/workloadapi/jwtsource.rs +++ b/src/workloadapi/jwtsource.rs @@ -82,7 +82,11 @@ impl JWTSource { self.bundles .read() .ok() - .and_then(|guard| guard.as_ref().and_then(|b| b.get_jwt_bundle_for_trust_domain(trust_domain).ok())) + .and_then(|guard| { + guard + .as_ref() + .and_then(|b| b.get_jwt_bundle_for_trust_domain(trust_domain).ok()) + }) .ok_or_else(|| crate::workloadapi::Error::new("jwtsource: no JWT bundle found")) } @@ -98,7 +102,9 @@ impl JWTSource { fn check_closed(&self) -> Result<()> { if self.closed.load(std::sync::atomic::Ordering::SeqCst) { - return Err(crate::workloadapi::Error::new("jwtsource: source is closed")); + return Err(crate::workloadapi::Error::new( + "jwtsource: source is closed", + )); } Ok(()) } diff --git a/src/workloadapi/option.rs b/src/workloadapi/option.rs index cfaab9d..d04ff5e 100644 --- a/src/workloadapi/option.rs +++ b/src/workloadapi/option.rs @@ -36,9 +36,7 @@ pub trait DialOption: Send + Sync { /// Sets the address of the Workload API endpoint. pub fn with_addr(addr: impl Into) -> Arc { - Arc::new(WithAddr { - addr: addr.into(), - }) + Arc::new(WithAddr { addr: addr.into() }) } /// Sets the gRPC dial options. @@ -57,23 +55,38 @@ pub fn with_backoff_strategy(strategy: Arc) -> Arc } /// Sets an existing Workload API client to be used by the source. +/// +/// Use this when you want to share a single [`Client`] across multiple sources +/// (e.g. X.509 and JWT), or when you manage the client lifecycle yourself. pub fn with_client(client: Arc) -> Arc { Arc::new(WithClient { client }) } /// Sets the options for the Workload API client created by the source. +/// +/// This is used when the source is responsible for creating the underlying +/// [`Client`]. If you already have a client instance, prefer [`with_client`]. pub fn with_client_options(options: Vec>) -> Arc { Arc::new(WithClientOptions { options }) } /// Sets a custom picker for X.509 SVIDs. +/// +/// The Workload API may return multiple X.509-SVIDs. The picker decides which +/// one is returned by higher-level helpers that need a single SVID (e.g. +/// SPIFFE-TLS server configs). pub fn with_default_x509_svid_picker( - picker: Arc crate::svid::x509svid::SVID + Send + Sync>, + picker: Arc< + dyn Fn(&[crate::svid::x509svid::SVID]) -> crate::svid::x509svid::SVID + Send + Sync, + >, ) -> Arc { Arc::new(WithDefaultX509SVIDPicker { picker }) } /// Sets a custom picker for JWT SVIDs. +/// +/// The Workload API may return multiple JWT-SVIDs. The picker decides which one +/// is returned by helpers that expect a single SVID. pub fn with_default_jwt_svid_picker( picker: Arc crate::svid::jwtsvid::SVID + Send + Sync>, ) -> Arc { @@ -131,8 +144,9 @@ impl Default for X509SourceConfig { pub struct JWTSourceConfig { pub watcher: WatcherConfig, - pub picker: - Option crate::svid::jwtsvid::SVID + Send + Sync>>, + pub picker: Option< + Arc crate::svid::jwtsvid::SVID + Send + Sync>, + >, } impl Default for JWTSourceConfig { @@ -233,7 +247,8 @@ impl SourceOption for WithClientOptions { } struct WithDefaultX509SVIDPicker { - picker: Arc crate::svid::x509svid::SVID + Send + Sync>, + picker: + Arc crate::svid::x509svid::SVID + Send + Sync>, } impl X509SourceOption for WithDefaultX509SVIDPicker { diff --git a/src/workloadapi/watcher.rs b/src/workloadapi/watcher.rs index fb42e26..2d1b824 100644 --- a/src/workloadapi/watcher.rs +++ b/src/workloadapi/watcher.rs @@ -1,10 +1,19 @@ -use crate::workloadapi::{JWTBundleWatcher, Result, X509Context, X509ContextWatcher}; use crate::workloadapi::{option::WatcherConfig, Client, Context}; +use crate::workloadapi::{JWTBundleWatcher, Result, X509Context, X509ContextWatcher}; use std::sync::{Arc, Mutex}; use tokio::sync::{oneshot, watch}; use tokio::task::JoinHandle; use tokio_util::sync::CancellationToken; +/// High-level helper that spawns Workload API watch tasks and exposes an +/// "updated" signal. +/// +/// A `Watcher` can: +/// - create or reuse a [`Client`] +/// - spawn background tasks that watch X.509 contexts and/or JWT bundles +/// - provide a [`watch::Receiver`] that ticks whenever an update is observed +/// +/// Call [`Watcher::close`] to stop the background tasks. pub struct Watcher { updated_tx: watch::Sender, updated_rx: watch::Receiver, @@ -15,6 +24,13 @@ pub struct Watcher { } impl Watcher { + /// Creates a watcher and starts watch tasks immediately. + /// + /// - If `config.client` is `None`, a new Workload API [`Client`] is created + /// using `config.client_options`. + /// - If a handler is `Some`, the corresponding watch task is spawned. + /// - This method waits for each spawned watch to deliver its first update + /// before returning successfully. pub async fn new( ctx: &Context, config: WatcherConfig, @@ -43,6 +59,7 @@ impl Watcher { Ok(watcher) } + /// Cancels all watch tasks and (if owned) closes the underlying client. pub async fn close(&self) -> Result<()> { self.cancel.cancel(); if let Ok(mut tasks) = self.tasks.lock() { @@ -56,6 +73,10 @@ impl Watcher { Ok(()) } + /// Waits until at least one update is observed. + /// + /// This returns after the internal `updated` counter changes. If you need a + /// stream of updates, use [`Watcher::updated`]. pub async fn wait_until_updated(&self, ctx: &Context) -> Result<()> { let mut rx = self.updated_rx.clone(); tokio::select! { @@ -64,11 +85,14 @@ impl Watcher { } } + /// Returns a receiver that increments on each observed update. + /// + /// Consumers can call `changed().await` or read the counter value to detect + /// updates. pub fn updated(&self) -> watch::Receiver { self.updated_rx.clone() } - async fn spawn_watchers( &self, ctx: &Context, diff --git a/src/workloadapi/x509source.rs b/src/workloadapi/x509source.rs index 71d806c..b405a9e 100644 --- a/src/workloadapi/x509source.rs +++ b/src/workloadapi/x509source.rs @@ -81,7 +81,11 @@ impl X509Source { self.bundles .read() .ok() - .and_then(|guard| guard.as_ref().and_then(|b| b.get_x509_bundle_for_trust_domain(trust_domain).ok())) + .and_then(|guard| { + guard + .as_ref() + .and_then(|b| b.get_x509_bundle_for_trust_domain(trust_domain).ok()) + }) .ok_or_else(|| crate::workloadapi::Error::new("x509source: no X.509 bundle found")) } @@ -97,7 +101,9 @@ impl X509Source { fn check_closed(&self) -> Result<()> { if self.closed.load(std::sync::atomic::Ordering::SeqCst) { - return Err(crate::workloadapi::Error::new("x509source: source is closed")); + return Err(crate::workloadapi::Error::new( + "x509source: source is closed", + )); } Ok(()) } diff --git a/tests/compat_spiffebundle_go.rs b/tests/compat_spiffebundle_go.rs index fbd9df6..cf965f1 100644 --- a/tests/compat_spiffebundle_go.rs +++ b/tests/compat_spiffebundle_go.rs @@ -17,8 +17,11 @@ fn spiffebundle_marshal_matches_go() { return; } - let temp_dir = std::env::temp_dir() - .join(format!("spiffe_rs_compat_{}_{}", std::process::id(), chrono_stamp())); + let temp_dir = std::env::temp_dir().join(format!( + "spiffe_rs_compat_{}_{}", + std::process::id(), + chrono_stamp() + )); fs::create_dir_all(&temp_dir).expect("create temp dir"); let go_mod = format!( @@ -28,9 +31,7 @@ fn spiffebundle_marshal_matches_go() { fs::write(temp_dir.join("go.mod"), go_mod).expect("write go.mod"); let input = PathBuf::from("tests/testdata/spiffebundle/spiffebundle_valid_1.json"); - let input_abs = input - .canonicalize() - .expect("canonicalize test bundle path"); + let input_abs = input.canonicalize().expect("canonicalize test bundle path"); let trust_domain = "domain.test"; let main = format!( r#" diff --git a/tests/compat_spiffetls_go.rs b/tests/compat_spiffetls_go.rs index 54f6e1b..51a4045 100644 --- a/tests/compat_spiffetls_go.rs +++ b/tests/compat_spiffetls_go.rs @@ -17,13 +17,15 @@ async fn spiffetls_accepts_go_svid_tls_server() { return; } - let temp_dir = std::env::temp_dir() - .join(format!("spiffe_rs_tls_{}_{}", std::process::id(), chrono_stamp())); + let temp_dir = std::env::temp_dir().join(format!( + "spiffe_rs_tls_{}_{}", + std::process::id(), + chrono_stamp() + )); fs::create_dir_all(&temp_dir).expect("create temp dir"); let ca_path = temp_dir.join("ca.pem"); - fs::write(temp_dir.join("go.mod"), "module compat\n\ngo 1.20\n") - .expect("write go.mod"); + fs::write(temp_dir.join("go.mod"), "module compat\n\ngo 1.20\n").expect("write go.mod"); let main = format!( r#" @@ -140,9 +142,9 @@ func main() {{ } let bundle = load_ca_bundle(&ca_path); - let authorizer = spiffetls::tlsconfig::authorize_id( - spiffeid::require_from_string("spiffe://example.org/workload-1"), - ); + let authorizer = spiffetls::tlsconfig::authorize_id(spiffeid::require_from_string( + "spiffe://example.org/workload-1", + )); let ctx = workloadapi::background(); let server_name = rustls::ServerName::try_from("example.org").expect("server name"); let mode = spiffetls::tls_client_with_raw_config(authorizer, Arc::new(bundle)); @@ -168,8 +170,8 @@ fn load_ca_bundle(path: &Path) -> x509bundle::Bundle { if pem.tag() != "CERTIFICATE" { continue; } - let (_rest, cert) = x509_parser::parse_x509_certificate(pem.contents()) - .expect("parse cert"); + let (_rest, cert) = + x509_parser::parse_x509_certificate(pem.contents()).expect("parse cert"); if cert.is_ca() { bundle.add_x509_authority(pem.contents()); } diff --git a/tests/compat_workloadapi_go.rs b/tests/compat_workloadapi_go.rs index 485d218..13fcfa3 100644 --- a/tests/compat_workloadapi_go.rs +++ b/tests/compat_workloadapi_go.rs @@ -18,8 +18,11 @@ async fn workloadapi_fetches_from_go_server() { return; } - let temp_dir = std::env::temp_dir() - .join(format!("spiffe_rs_wl_{}_{}", std::process::id(), chrono_stamp())); + let temp_dir = std::env::temp_dir().join(format!( + "spiffe_rs_wl_{}_{}", + std::process::id(), + chrono_stamp() + )); fs::create_dir_all(&temp_dir).expect("create temp dir"); let go_mod = format!( @@ -151,9 +154,14 @@ func main() {{ let svid = client.fetch_x509_svid(&ctx).await.expect("fetch svid"); assert_eq!(svid.id.to_string(), "spiffe://example.org/workload-1"); - let bundles = client.fetch_x509_bundles(&ctx).await.expect("fetch bundles"); + let bundles = client + .fetch_x509_bundles(&ctx) + .await + .expect("fetch bundles"); let td = spiffeid::require_trust_domain_from_string("example.org"); - let bundle = bundles.get_x509_bundle_for_trust_domain(td).expect("bundle"); + let bundle = bundles + .get_x509_bundle_for_trust_domain(td) + .expect("bundle"); assert!(!bundle.x509_authorities().is_empty()); let _ = child.kill(); diff --git a/tests/federation_fetch_tests.rs b/tests/federation_fetch_tests.rs index 3ab6106..ed84bfa 100644 --- a/tests/federation_fetch_tests.rs +++ b/tests/federation_fetch_tests.rs @@ -30,8 +30,8 @@ fn start_test_server(body: Vec) -> (String, thread::JoinHandle<()>) { #[test] fn fetch_bundle_over_http() { let trust_domain = spiffeid::require_trust_domain_from_string("domain.test"); - let body = - fs::read("tests/testdata/spiffebundle/spiffebundle_valid_1.json").expect("read test bundle"); + let body = fs::read("tests/testdata/spiffebundle/spiffebundle_valid_1.json") + .expect("read test bundle"); let expected = spiffebundle::Bundle::parse(trust_domain.clone(), &body).expect("parse bundle"); let (url, handle) = start_test_server(body); @@ -48,7 +48,9 @@ fn fetch_bundle_option_conflict() { let bundle_source = Arc::new(x509bundle::Bundle::from_x509_authorities(trust_domain, &[])); let options: Vec> = vec![ Box::new(federation::with_spiffe_auth(bundle_source, id)), - Box::new(federation::with_web_pki_roots(rustls::RootCertStore::empty())), + Box::new(federation::with_web_pki_roots( + rustls::RootCertStore::empty(), + )), ]; let err = federation::fetch_bundle( diff --git a/tests/federation_watch_handler_tests.rs b/tests/federation_watch_handler_tests.rs index f6368ae..041d529 100644 --- a/tests/federation_watch_handler_tests.rs +++ b/tests/federation_watch_handler_tests.rs @@ -52,7 +52,13 @@ fn start_sequence_server(bodies: Vec>) -> (String, std::thread::JoinHand let body = queue .lock() .ok() - .and_then(|mut q| if !q.is_empty() { Some(q.remove(0)) } else { None }) + .and_then(|mut q| { + if !q.is_empty() { + Some(q.remove(0)) + } else { + None + } + }) .unwrap_or_default(); let response = format!( "HTTP/1.1 200 OK\r\nContent-Length: {}\r\nConnection: close\r\n\r\n", @@ -95,8 +101,8 @@ async fn handler_serves_bundle() { let body = fs::read("tests/testdata/spiffebundle/spiffebundle_valid_1.json").expect("bundle"); let bundle = spiffebundle::Bundle::parse(trust_domain.clone(), &body).expect("parse bundle"); let source = spiffebundle::Set::new(&[bundle.clone_bundle()]); - let mut handler = federation::new_handler(trust_domain, Arc::new(source), Vec::new()) - .expect("handler"); + let mut handler = + federation::new_handler(trust_domain, Arc::new(source), Vec::new()).expect("handler"); let response = Service::call(&mut handler, hyper::Request::new(hyper::Body::empty())) .await diff --git a/tests/jwtbundle_tests.rs b/tests/jwtbundle_tests.rs index 6f7f50e..12b2244 100644 --- a/tests/jwtbundle_tests.rs +++ b/tests/jwtbundle_tests.rs @@ -1,5 +1,5 @@ -use spiffe_rs::bundle::jwtbundle::{Bundle, Set}; use spiffe_rs::bundle::jwtbundle::JwtKey; +use spiffe_rs::bundle::jwtbundle::{Bundle, Set}; use spiffe_rs::spiffeid::require_trust_domain_from_string; use std::collections::HashMap; use std::fs; @@ -11,7 +11,8 @@ fn load_file(path: &str) -> Vec { #[test] fn bundle_load_read_parse() { let td = require_trust_domain_from_string("example.org"); - let bundle = Bundle::load(td.clone(), "tests/testdata/jwtbundle/jwks_valid_1.json").expect("load"); + let bundle = + Bundle::load(td.clone(), "tests/testdata/jwtbundle/jwks_valid_1.json").expect("load"); assert_eq!(bundle.jwt_authorities().len(), 1); let bytes = load_file("tests/testdata/jwtbundle/jwks_valid_2.json"); @@ -54,14 +55,16 @@ fn bundle_crud_and_equal() { bundle.remove_jwt_authority("key-1"); assert!(!bundle.has_jwt_authority("key-1")); - bundle.add_jwt_authority( - "key-1", - JwtKey::Ec { - crv: "P-256".to_string(), - x: vec![1], - y: vec![2], - }, - ).expect("add"); + bundle + .add_jwt_authority( + "key-1", + JwtKey::Ec { + crv: "P-256".to_string(), + x: vec![1], + y: vec![2], + }, + ) + .expect("add"); let cloned = bundle.clone_bundle(); assert!(bundle.equal(&cloned)); @@ -70,7 +73,8 @@ fn bundle_crud_and_equal() { #[test] fn bundle_marshal_roundtrip() { let td = require_trust_domain_from_string("example.org"); - let bundle = Bundle::load(td.clone(), "tests/testdata/jwtbundle/jwks_valid_2.json").expect("load"); + let bundle = + Bundle::load(td.clone(), "tests/testdata/jwtbundle/jwks_valid_2.json").expect("load"); let bytes = bundle.marshal().expect("marshal"); let parsed = Bundle::parse(td, &bytes).expect("parse"); assert!(bundle.equal(&parsed)); @@ -81,14 +85,19 @@ fn bundle_get_for_trust_domain() { let td = require_trust_domain_from_string("example.org"); let td2 = require_trust_domain_from_string("example-2.org"); let bundle = Bundle::new(td.clone()); - let ok = bundle.get_jwt_bundle_for_trust_domain(td.clone()).expect("bundle"); + let ok = bundle + .get_jwt_bundle_for_trust_domain(td.clone()) + .expect("bundle"); assert!(bundle.equal(&ok)); let err = bundle .get_jwt_bundle_for_trust_domain(td2) .unwrap_err() .to_string(); - assert_eq!(err, "jwtbundle: no JWT bundle for trust domain \"example-2.org\""); + assert_eq!( + err, + "jwtbundle: no JWT bundle for trust domain \"example-2.org\"" + ); } #[test] @@ -109,5 +118,8 @@ fn set_ops() { .get_jwt_bundle_for_trust_domain(require_trust_domain_from_string("missing.test")) .unwrap_err() .to_string(); - assert_eq!(err, "jwtbundle: no JWT bundle for trust domain \"missing.test\""); + assert_eq!( + err, + "jwtbundle: no JWT bundle for trust domain \"missing.test\"" + ); } diff --git a/tests/jwtsvid_tests.rs b/tests/jwtsvid_tests.rs index 710baa9..e9defd8 100644 --- a/tests/jwtsvid_tests.rs +++ b/tests/jwtsvid_tests.rs @@ -1,13 +1,13 @@ +use base64::Engine; use rand::rngs::OsRng; use rsa::pkcs1v15::SigningKey as RsaSigningKey; +use rsa::signature::{SignatureEncoding, Signer}; +use rsa::traits::PublicKeyParts; use rsa::RsaPrivateKey; use sha2::Sha256; -use rsa::signature::{Signer, SignatureEncoding}; -use rsa::traits::PublicKeyParts; -use base64::Engine; use spiffe_rs::bundle::jwtbundle::{Bundle, JwtKey}; -use spiffe_rs::svid::jwtsvid; use spiffe_rs::spiffeid::require_trust_domain_from_string; +use spiffe_rs::svid::jwtsvid; use std::time::{Duration, SystemTime}; fn generate_rsa_key() -> RsaPrivateKey { @@ -32,12 +32,10 @@ fn build_jwt( if let Some(typ) = typ { header["typ"] = serde_json::Value::String(typ.to_string()); } - let header_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode( - serde_json::to_vec(&header).expect("header json"), - ); - let payload_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode( - serde_json::to_vec(&claims).expect("claims json"), - ); + let header_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD + .encode(serde_json::to_vec(&header).expect("header json")); + let payload_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD + .encode(serde_json::to_vec(&claims).expect("claims json")); let signing_input = format!("{}.{}", header_b64, payload_b64); let signature = signer.sign(signing_input.as_bytes(), alg); let sig_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(signature); @@ -112,9 +110,15 @@ fn parse_and_validate_success() { "exp": exp_secs, "iat": exp_secs - 10, }); - let token = build_jwt("ES384", Some("authority1"), None, claims, &SignerKey::P384(p384_key)); - let svid = jwtsvid::parse_and_validate(&token, &bundle, &["audience".to_string()]) - .expect("parse"); + let token = build_jwt( + "ES384", + Some("authority1"), + None, + claims, + &SignerKey::P384(p384_key), + ); + let svid = + jwtsvid::parse_and_validate(&token, &bundle, &["audience".to_string()]).expect("parse"); assert_eq!(svid.id.to_string(), "spiffe://trustdomain/host"); } @@ -175,7 +179,13 @@ fn parse_insecure_success() { "aud": ["audience"], "exp": exp_secs, }); - let token = build_jwt("ES384", Some("key1"), None, claims, &SignerKey::P384(p384_key)); + let token = build_jwt( + "ES384", + Some("key1"), + None, + claims, + &SignerKey::P384(p384_key), + ); let svid = jwtsvid::parse_insecure(&token, &["audience".to_string()]).expect("parse"); assert_eq!(svid.id.to_string(), "spiffe://trustdomain/host"); } diff --git a/tests/spiffebundle_tests.rs b/tests/spiffebundle_tests.rs index 5a34097..9ef0ba3 100644 --- a/tests/spiffebundle_tests.rs +++ b/tests/spiffebundle_tests.rs @@ -1,7 +1,7 @@ use spiffe_rs::bundle::jwtbundle; +use spiffe_rs::bundle::spiffebundle::JwtKey; use spiffe_rs::bundle::spiffebundle::{Bundle, Set}; use spiffe_rs::bundle::x509bundle; -use spiffe_rs::bundle::spiffebundle::JwtKey; use spiffe_rs::spiffeid::require_trust_domain_from_string; use std::fs; use std::time::Duration; @@ -56,8 +56,11 @@ fn bundle_refresh_hint_and_sequence() { #[test] fn bundle_marshal_roundtrip() { let td = require_trust_domain_from_string("domain.test"); - let bundle = Bundle::load(td.clone(), "tests/testdata/spiffebundle/spiffebundle_valid_2.json") - .expect("load"); + let bundle = Bundle::load( + td.clone(), + "tests/testdata/spiffebundle/spiffebundle_valid_2.json", + ) + .expect("load"); let bytes = bundle.marshal().expect("marshal"); let parsed = Bundle::parse(td, &bytes).expect("parse"); assert!(bundle.equal(&parsed)); @@ -66,11 +69,14 @@ fn bundle_marshal_roundtrip() { #[test] fn bundle_x509_and_jwt_views() { let td = require_trust_domain_from_string("domain.test"); - let x509_bundle = x509bundle::Bundle::load(td.clone(), "tests/testdata/x509bundle/cert.pem").expect("load"); + let x509_bundle = + x509bundle::Bundle::load(td.clone(), "tests/testdata/x509bundle/cert.pem").expect("load"); let bundle = Bundle::from_x509_bundle(&x509_bundle); assert_eq!(bundle.x509_authorities().len(), 1); - let jwt_bundle = jwtbundle::Bundle::load(td.clone(), "tests/testdata/jwtbundle/jwks_valid_1.json").expect("load"); + let jwt_bundle = + jwtbundle::Bundle::load(td.clone(), "tests/testdata/jwtbundle/jwks_valid_1.json") + .expect("load"); let bundle = Bundle::from_jwt_bundle(&jwt_bundle); assert_eq!(bundle.jwt_authorities().len(), 1); } @@ -104,14 +110,19 @@ fn bundle_get_for_trust_domain() { let td = require_trust_domain_from_string("domain.test"); let td2 = require_trust_domain_from_string("domain2.test"); let bundle = Bundle::new(td.clone()); - let ok = bundle.get_bundle_for_trust_domain(td.clone()).expect("bundle"); + let ok = bundle + .get_bundle_for_trust_domain(td.clone()) + .expect("bundle"); assert!(bundle.equal(&ok)); let err = bundle .get_bundle_for_trust_domain(td2) .unwrap_err() .to_string(); - assert_eq!(err, "spiffebundle: no SPIFFE bundle for trust domain \"domain2.test\""); + assert_eq!( + err, + "spiffebundle: no SPIFFE bundle for trust domain \"domain2.test\"" + ); } #[test] @@ -132,5 +143,8 @@ fn set_ops() { .get_bundle_for_trust_domain(require_trust_domain_from_string("missing.test")) .unwrap_err() .to_string(); - assert_eq!(err, "spiffebundle: no SPIFFE bundle for trust domain \"missing.test\""); + assert_eq!( + err, + "spiffebundle: no SPIFFE bundle for trust domain \"missing.test\"" + ); } diff --git a/tests/spiffeid_parity.rs b/tests/spiffeid_parity.rs index 483b243..544a91e 100644 --- a/tests/spiffeid_parity.rs +++ b/tests/spiffeid_parity.rs @@ -4,7 +4,7 @@ use spiffe_rs::spiffeid::{ require_format_path, require_from_path, require_from_pathf, require_from_segments, require_from_string, require_from_stringf, require_from_uri, require_join_path_segments, require_trust_domain_from_string, require_trust_domain_from_uri, trust_domain_from_string, - trust_domain_from_uri, validate_path, validate_path_segment, Error, ID, SpiffeUrl, TrustDomain, + trust_domain_from_uri, validate_path, validate_path_segment, Error, SpiffeUrl, TrustDomain, ID, }; use std::cmp::Ordering; use std::collections::HashSet; @@ -18,7 +18,10 @@ fn assert_error_contains(err: Result<(), Error>, contains: &str) { fn assert_id_equal(id: &ID, expect_td: &TrustDomain, expect_path: &str) { assert_eq!(&id.trust_domain(), expect_td, "unexpected trust domain"); assert_eq!(id.path(), expect_path, "unexpected path"); - assert_eq!(id.to_string(), format!("{}{}", expect_td.id_string(), expect_path)); + assert_eq!( + id.to_string(), + format!("{}{}", expect_td.id_string(), expect_path) + ); assert_eq!(id.url().to_string(), id.to_string()); } @@ -69,7 +72,9 @@ fn from_string_validation_matches_go() { let err = ID::from_stringf(format_args!("{}", input)).unwrap_err(); assert!(err.to_string().contains(expect_err)); assert!(std::panic::catch_unwind(|| require_from_string(input)).is_err()); - assert!(std::panic::catch_unwind(|| require_from_stringf(format_args!("{}", input))).is_err()); + assert!( + std::panic::catch_unwind(|| require_from_stringf(format_args!("{}", input))).is_err() + ); }; assert_fail("", "cannot be empty"); @@ -83,7 +88,11 @@ fn from_string_validation_matches_go() { let s = c.to_string(); if td_chars.contains(&s) { let td_with_char = require_trust_domain_from_string(&format!("trustdomain{s}")); - assert_ok(&format!("spiffe://trustdomain{s}/path"), &td_with_char, "/path"); + assert_ok( + &format!("spiffe://trustdomain{s}/path"), + &td_with_char, + "/path", + ); } else { assert_fail( &format!("spiffe://trustdomain{s}/path"), @@ -92,7 +101,11 @@ fn from_string_validation_matches_go() { } if path_chars.contains(&s) { - assert_ok(&format!("spiffe://trustdomain/path{s}"), &td, &format!("/path{s}")); + assert_ok( + &format!("spiffe://trustdomain/path{s}"), + &td, + &format!("/path{s}"), + ); } else { assert_fail( &format!("spiffe://trustdomain/path{s}"), @@ -107,17 +120,44 @@ fn from_string_validation_matches_go() { assert_fail("spiffe://", "trust domain is missing"); assert_fail("spiffe:///", "trust domain is missing"); assert_fail("spiffe://trustdomain/", "path cannot have a trailing slash"); - assert_fail("spiffe://trustdomain//", "path cannot contain empty segments"); - assert_fail("spiffe://trustdomain//path", "path cannot contain empty segments"); - assert_fail("spiffe://trustdomain/path/", "path cannot have a trailing slash"); + assert_fail( + "spiffe://trustdomain//", + "path cannot contain empty segments", + ); + assert_fail( + "spiffe://trustdomain//path", + "path cannot contain empty segments", + ); + assert_fail( + "spiffe://trustdomain/path/", + "path cannot have a trailing slash", + ); assert_fail("spiffe://trustdomain/.", "path cannot contain dot segments"); - assert_fail("spiffe://trustdomain/./path", "path cannot contain dot segments"); - assert_fail("spiffe://trustdomain/path/./other", "path cannot contain dot segments"); - assert_fail("spiffe://trustdomain/path/..", "path cannot contain dot segments"); - assert_fail("spiffe://trustdomain/..", "path cannot contain dot segments"); - assert_fail("spiffe://trustdomain/../path", "path cannot contain dot segments"); - assert_fail("spiffe://trustdomain/path/../other", "path cannot contain dot segments"); + assert_fail( + "spiffe://trustdomain/./path", + "path cannot contain dot segments", + ); + assert_fail( + "spiffe://trustdomain/path/./other", + "path cannot contain dot segments", + ); + assert_fail( + "spiffe://trustdomain/path/..", + "path cannot contain dot segments", + ); + assert_fail( + "spiffe://trustdomain/..", + "path cannot contain dot segments", + ); + assert_fail( + "spiffe://trustdomain/../path", + "path cannot contain dot segments", + ); + assert_fail( + "spiffe://trustdomain/path/../other", + "path cannot contain dot segments", + ); assert_ok("spiffe://trustdomain/.path", &td, "/.path"); assert_ok("spiffe://trustdomain/..path", &td, "/..path"); @@ -163,7 +203,10 @@ fn trust_domain_from_string_validation_matches_go() { assert_fail("spiffe://", "trust domain is missing"); assert_fail("spiffe:///path", "trust domain is missing"); assert_fail("spiffe://trustdomain/", "path cannot have a trailing slash"); - assert_fail("spiffe://trustdomain/path/", "path cannot have a trailing slash"); + assert_fail( + "spiffe://trustdomain/path/", + "path cannot have a trailing slash", + ); assert_fail( "spiffe://%F0%9F%A4%AF/path", "trust domain characters are limited to lowercase letters, numbers, dots, dashes, and underscores", @@ -203,7 +246,10 @@ fn trust_domain_from_uri_matches_go() { let assert_ok = |s: &str| { let url = parse(s); let td = trust_domain_from_uri(&url).expect("valid trust domain"); - assert_eq!(td, require_trust_domain_from_string(url.host_str().unwrap_or(""))); + assert_eq!( + td, + require_trust_domain_from_string(url.host_str().unwrap_or("")) + ); }; let assert_fail = |url: Url, expect_err: &str| { let err = trust_domain_from_uri(&url).unwrap_err(); @@ -212,7 +258,10 @@ fn trust_domain_from_uri_matches_go() { assert_ok("spiffe://trustdomain"); assert_ok("spiffe://trustdomain/path"); - assert_fail(Url::parse("spiffe://").expect("url"), "trust domain is missing"); + assert_fail( + Url::parse("spiffe://").expect("url"), + "trust domain is missing", + ); assert_fail( Url::parse("http://trustdomain").expect("url"), "scheme is missing or invalid", @@ -285,7 +334,10 @@ fn from_uri_matches_go() { assert_ok("spiffe://trustdomain"); assert_ok("spiffe://trustdomain/path"); - assert_fail(Url::parse("spiffe://").expect("url"), "trust domain is missing"); + assert_fail( + Url::parse("spiffe://").expect("url"), + "trust domain is missing", + ); assert_fail( Url::parse("http://trustdomain").expect("url"), "scheme is missing or invalid", @@ -402,7 +454,10 @@ fn id_replace_append_matches_go() { .unwrap_err(); assert_eq!(err.to_string(), "path cannot contain empty segments"); let err = ID::zero().replace_segments(&["/"]).unwrap_err(); - assert_eq!(err.to_string(), "cannot replace path segments on a zero ID value"); + assert_eq!( + err.to_string(), + "cannot replace path segments on a zero ID value" + ); let id = require_from_path(td.clone(), "/path") .append_path("/foo") @@ -435,7 +490,10 @@ fn id_replace_append_matches_go() { .unwrap_err(); assert_eq!(err.to_string(), "path cannot contain empty segments"); let err = ID::zero().append_segments(&["/"]).unwrap_err(); - assert_eq!(err.to_string(), "cannot append path segments on a zero ID value"); + assert_eq!( + err.to_string(), + "cannot append path segments on a zero ID value" + ); } #[test] @@ -447,28 +505,29 @@ fn matcher_behavior_matches_go() { let foo_c = require_from_string("spiffe://foo.test/sub/C"); let bar_a = require_from_string("spiffe://bar.test/A"); - let test_match = |matcher: Box Result<(), spiffe_rs::spiffeid::MatcherError>>, - zero_err: &str, - foo_err: &str, - foo_a_err: &str, - foo_b_err: &str, - foo_c_err: &str, - bar_a_err: &str| { - let check = |id: &ID, expect_err: &str| { - let result = matcher(id); - if expect_err.is_empty() { - assert!(result.is_ok()); - } else { - assert_eq!(result.unwrap_err().to_string(), expect_err); - } + let test_match = + |matcher: Box Result<(), spiffe_rs::spiffeid::MatcherError>>, + zero_err: &str, + foo_err: &str, + foo_a_err: &str, + foo_b_err: &str, + foo_c_err: &str, + bar_a_err: &str| { + let check = |id: &ID, expect_err: &str| { + let result = matcher(id); + if expect_err.is_empty() { + assert!(result.is_ok()); + } else { + assert_eq!(result.unwrap_err().to_string(), expect_err); + } + }; + check(&zero, zero_err); + check(&foo, foo_err); + check(&foo_a, foo_a_err); + check(&foo_b, foo_b_err); + check(&foo_c, foo_c_err); + check(&bar_a, bar_a_err); }; - check(&zero, zero_err); - check(&foo, foo_err); - check(&foo_a, foo_a_err); - check(&foo_b, foo_b_err); - check(&foo_c, foo_c_err); - check(&bar_a, bar_a_err); - }; test_match(match_any(), "", "", "", "", "", ""); test_match( @@ -536,11 +595,17 @@ fn require_helpers_match_go() { let id = require_from_pathf(td.clone(), format_args!("/{}", "path")); assert_eq!(id.to_string(), "spiffe://trustdomain/path"); - assert!(std::panic::catch_unwind(|| require_from_pathf(td.clone(), format_args!("{}", "relative"))).is_err()); + assert!(std::panic::catch_unwind(|| require_from_pathf( + td.clone(), + format_args!("{}", "relative") + )) + .is_err()); let id = require_from_segments(td.clone(), &["path"]); assert_eq!(id.to_string(), "spiffe://trustdomain/path"); - assert!(std::panic::catch_unwind(|| require_from_segments(td.clone(), &["/absolute"])).is_err()); + assert!( + std::panic::catch_unwind(|| require_from_segments(td.clone(), &["/absolute"])).is_err() + ); let id = require_from_string("spiffe://trustdomain/path"); assert_eq!(id.to_string(), "spiffe://trustdomain/path"); @@ -548,19 +613,33 @@ fn require_helpers_match_go() { let id = require_from_stringf(format_args!("spiffe://trustdomain/{}", "path")); assert_eq!(id.to_string(), "spiffe://trustdomain/path"); - assert!(std::panic::catch_unwind(|| require_from_stringf(format_args!("{}://trustdomain/path", "sparfe"))).is_err()); + assert!( + std::panic::catch_unwind(|| require_from_stringf(format_args!( + "{}://trustdomain/path", + "sparfe" + ))) + .is_err() + ); let id = require_from_uri(&Url::parse("spiffe://trustdomain/path").unwrap()); assert_eq!(id.to_string(), "spiffe://trustdomain/path"); - assert!(std::panic::catch_unwind(|| require_from_uri(&Url::parse("spiffe://").unwrap())).is_err()); + assert!( + std::panic::catch_unwind(|| require_from_uri(&Url::parse("spiffe://").unwrap())).is_err() + ); let td = require_trust_domain_from_string("spiffe://trustdomain/path"); assert_eq!(td.name(), "trustdomain"); - assert!(std::panic::catch_unwind(|| require_trust_domain_from_string("spiffe://TRUSTDOMAIN/path")).is_err()); + assert!( + std::panic::catch_unwind(|| require_trust_domain_from_string("spiffe://TRUSTDOMAIN/path")) + .is_err() + ); let td = require_trust_domain_from_uri(&Url::parse("spiffe://trustdomain/path").unwrap()); assert_eq!(td.name(), "trustdomain"); - assert!(std::panic::catch_unwind(|| require_trust_domain_from_uri(&Url::parse("spiffe://").unwrap())).is_err()); + assert!(std::panic::catch_unwind(|| require_trust_domain_from_uri( + &Url::parse("spiffe://").unwrap() + )) + .is_err()); let path = require_format_path(format_args!("/{}", "path")); assert_eq!(path, "/path"); @@ -584,8 +663,14 @@ fn path_helpers_match_go() { validate_path_segment(""), "path cannot contain empty segments", ); - assert_error_contains(validate_path_segment("."), "path cannot contain dot segments"); - assert_error_contains(validate_path_segment(".."), "path cannot contain dot segments"); + assert_error_contains( + validate_path_segment("."), + "path cannot contain dot segments", + ); + assert_error_contains( + validate_path_segment(".."), + "path cannot contain dot segments", + ); assert_error_contains( validate_path_segment("/"), "path segment characters are limited to letters, numbers, dots, dashes, and underscores", diff --git a/tests/x509bundle_tests.rs b/tests/x509bundle_tests.rs index c601718..4ebaa62 100644 --- a/tests/x509bundle_tests.rs +++ b/tests/x509bundle_tests.rs @@ -81,7 +81,9 @@ fn bundle_get_for_trust_domain() { let td = require_trust_domain_from_string("domain.test"); let td2 = require_trust_domain_from_string("domain2.test"); let bundle = Bundle::new(td.clone()); - let ok = bundle.get_x509_bundle_for_trust_domain(td.clone()).expect("bundle"); + let ok = bundle + .get_x509_bundle_for_trust_domain(td.clone()) + .expect("bundle"); assert!(bundle.equal(&ok)); let err = bundle diff --git a/tests/x509svid_tests.rs b/tests/x509svid_tests.rs index 900c994..817158b 100644 --- a/tests/x509svid_tests.rs +++ b/tests/x509svid_tests.rs @@ -1,5 +1,5 @@ -use spiffe_rs::svid::x509svid; use spiffe_rs::spiffeid::ID; +use spiffe_rs::svid::x509svid; use std::fs; fn load_file(path: &str) -> Vec { @@ -76,7 +76,10 @@ fn marshal_roundtrip() { let cert_single = "tests/testdata/x509svid/good-leaf-only.pem"; let svid = x509svid::SVID::load(cert_single, key_rsa).expect("load"); let (certs, key) = svid.marshal().expect("marshal"); - assert_eq!(normalize_pem(&certs), normalize_pem(&load_file(cert_single))); + assert_eq!( + normalize_pem(&certs), + normalize_pem(&load_file(cert_single)) + ); assert_eq!(normalize_pem(&key), normalize_pem(&load_file(key_rsa))); } @@ -113,5 +116,8 @@ fn parse_raw_roundtrip() { .expect("private key") }; let svid = x509svid::SVID::parse_raw(&cert_raw, &key_der).expect("parse raw"); - assert_eq!(svid.id, ID::from_string("spiffe://example.org/workload-1").unwrap()); + assert_eq!( + svid.id, + ID::from_string("spiffe://example.org/workload-1").unwrap() + ); }