Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.compile_well_known_types(true)
.compile(&["proto/spiffe/workload/workload.proto"], &["proto"])?;

tonic_build::configure()
.compile(&["proto/examples/helloworld.proto"], &["proto/examples"])?;
tonic_build::configure().compile(&["proto/examples/helloworld.proto"], &["proto/examples"])?;
Ok(())
}
21 changes: 11 additions & 10 deletions examples/spiffe-grpc/client.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use hyper::Uri;
use spiffe_rs::spiffeid;
use spiffe_rs::spiffetls;
use spiffe_rs::workloadapi;
use std::sync::Arc;
use hyper::Uri;
use std::io;
use std::sync::Arc;
use tokio::net::TcpStream;
use tokio_rustls::TlsConnector;
use tonic::transport::{Channel, Endpoint};
Expand All @@ -24,7 +24,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let source = Arc::new(workloadapi::X509Source::new(&ctx, Vec::new()).await?);
let server_id = spiffeid::require_from_string("spiffe://example.org/server");
let authorizer = spiffetls::tlsconfig::authorize_id(server_id);
let mut tls_config = spiffetls::tlsconfig::mtls_client_config(source.as_ref(), source.clone(), authorizer)?;
let mut tls_config =
spiffetls::tlsconfig::mtls_client_config(source.as_ref(), source.clone(), authorizer)?;
tls_config.alpn_protocols = vec![b"h2".to_vec()];
let connector = TlsConnector::from(Arc::new(tls_config));

Expand All @@ -33,16 +34,16 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.connect_with_connector(service_fn(move |uri: Uri| {
let connector = connector.clone();
async move {
let authority = uri
.authority()
.map(|auth| auth.as_str())
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "missing authority"))?;
let authority = uri.authority().map(|auth| auth.as_str()).ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "missing authority")
})?;
let stream = TcpStream::connect(authority).await?;
let server_name = rustls::ServerName::try_from("example.org")
.map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
connector.connect(server_name, stream).await.map_err(|err| {
io::Error::new(io::ErrorKind::Other, err)
})
connector
.connect(server_name, stream)
.await
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))
}
}))
.await?;
Expand Down
26 changes: 12 additions & 14 deletions examples/spiffe-grpc/server.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
use spiffe_rs::spiffeid;
use spiffe_rs::spiffetls;
use spiffe_rs::workloadapi;
use std::pin::Pin;
use std::sync::Arc;
use tokio::net::TcpListener;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::TcpListener;
use tokio_rustls::TlsAcceptor;
use tokio_stream::wrappers::ReceiverStream;
use tonic::transport::server::Connected;
use tonic::transport::Server;
use tonic::{Request, Response, Status};
use tonic::transport::server::Connected;
use std::pin::Pin;
use std::task::{Context, Poll};

pub mod helloworld {
tonic::include_proto!("helloworld");
Expand All @@ -24,7 +24,10 @@ struct GreeterService;

#[tonic::async_trait]
impl Greeter for GreeterService {
async fn say_hello(&self, request: Request<HelloRequest>) -> Result<Response<HelloReply>, Status> {
async fn say_hello(
&self,
request: Request<HelloRequest>,
) -> Result<Response<HelloReply>, Status> {
let name = request.into_inner().name;
let reply = HelloReply {
message: format!("Hello {}", name),
Expand All @@ -41,7 +44,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let source = Arc::new(workloadapi::X509Source::new(&ctx, Vec::new()).await?);
let client_id = spiffeid::require_from_string("spiffe://example.org/client");
let authorizer = spiffetls::tlsconfig::authorize_id(client_id);
let mut tls_config = spiffetls::tlsconfig::mtls_server_config(source.as_ref(), source.clone(), authorizer)?;
let mut tls_config =
spiffetls::tlsconfig::mtls_server_config(source.as_ref(), source.clone(), authorizer)?;
tls_config.alpn_protocols = vec![b"h2".to_vec()];
let acceptor = TlsAcceptor::from(Arc::new(tls_config));

Expand Down Expand Up @@ -110,17 +114,11 @@ impl AsyncWrite for TlsIo {
Pin::new(&mut self.0).poll_write(cx, data)
}

fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::io::Result<()>> {
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.0).poll_flush(cx)
}

fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::io::Result<()>> {
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.0).poll_shutdown(cx)
}
}
Expand Down
3 changes: 2 additions & 1 deletion examples/spiffe-http/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let source = Arc::new(workloadapi::X509Source::new(&ctx, Vec::new()).await?);
let server_id = spiffeid::require_from_string("spiffe://example.org/server");
let authorizer = spiffetls::tlsconfig::authorize_id(server_id);
let tls_config = spiffetls::tlsconfig::mtls_client_config(source.as_ref(), source.clone(), authorizer)?;
let tls_config =
spiffetls::tlsconfig::mtls_client_config(source.as_ref(), source.clone(), authorizer)?;
let connector = TlsConnector::from(Arc::new(tls_config));

let stream = TcpStream::connect("127.0.0.1:8443").await?;
Expand Down
3 changes: 2 additions & 1 deletion examples/spiffe-http/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let source = Arc::new(workloadapi::X509Source::new(&ctx, Vec::new()).await?);
let client_id = spiffeid::require_from_string("spiffe://example.org/client");
let authorizer = spiffetls::tlsconfig::authorize_id(client_id);
let tls_config = spiffetls::tlsconfig::mtls_server_config(source.as_ref(), source.clone(), authorizer)?;
let tls_config =
spiffetls::tlsconfig::mtls_server_config(source.as_ref(), source.clone(), authorizer)?;
let acceptor = TlsAcceptor::from(Arc::new(tls_config));

let listener = TcpListener::bind("127.0.0.1:8443").await?;
Expand Down
3 changes: 2 additions & 1 deletion examples/spiffe-jwt-using-proxy/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
return;
}
};
let service = service_fn(move |req| handle_request(req, jwt_source.clone(), audience.clone()));
let service =
service_fn(move |req| handle_request(req, jwt_source.clone(), audience.clone()));
if let Err(err) = hyper::server::conn::Http::new()
.serve_connection(tls, service)
.await
Expand Down
3 changes: 2 additions & 1 deletion examples/spiffe-jwt/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
return;
}
};
let service = service_fn(move |req| handle_request(req, jwt_source.clone(), audience.clone()));
let service =
service_fn(move |req| handle_request(req, jwt_source.clone(), audience.clone()));
if let Err(err) = hyper::server::conn::Http::new()
.serve_connection(tls, service)
.await
Expand Down
3 changes: 2 additions & 1 deletion examples/spiffe-tls/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let server_id = spiffeid::require_from_string("spiffe://example.org/server");
let authorizer = spiffetls::tlsconfig::authorize_id(server_id);
let server_name = rustls::ServerName::try_from("example.org")?;
let mut stream = spiffetls::dial(&ctx, "127.0.0.1:55555", server_name, authorizer, Vec::new()).await?;
let mut stream =
spiffetls::dial(&ctx, "127.0.0.1:55555", server_name, authorizer, Vec::new()).await?;

stream.write_all(b"Hello server")?;
stream.flush()?;
Expand Down
14 changes: 7 additions & 7 deletions src/bundle/jwtbundle/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ impl Bundle {

/// Loads a JWT bundle from a JSON file (JWKS).
pub fn load(trust_domain: TrustDomain, path: &str) -> Result<Bundle> {
let bytes =
fs::read(path).map_err(|err| wrap_error(format!("unable to read JWT bundle: {}", err)))?;
let bytes = fs::read(path)
.map_err(|err| wrap_error(format!("unable to read JWT bundle: {}", err)))?;
Bundle::parse(trust_domain, &bytes)
}

Expand All @@ -82,15 +82,15 @@ impl Bundle {

/// Parses a JWT bundle from JSON bytes (JWKS).
pub fn parse(trust_domain: TrustDomain, bytes: &[u8]) -> Result<Bundle> {
let jwks: JwkDocument =
serde_json::from_slice(bytes).map_err(|err| wrap_error(format!("unable to parse JWKS: {}", err)))?;
let jwks: JwkDocument = serde_json::from_slice(bytes)
.map_err(|err| wrap_error(format!("unable to parse JWKS: {}", err)))?;
let bundle = Bundle::new(trust_domain);
let keys = jwks.keys.unwrap_or_default();
for (idx, key) in keys.iter().enumerate() {
let key_id = key.key_id().unwrap_or_default();
let jwt_key = key
.to_jwt_key()
.map_err(|err| wrap_error(format!("error adding authority {} of JWKS: {}", idx, err)))?;
let jwt_key = key.to_jwt_key().map_err(|err| {
wrap_error(format!("error adding authority {} of JWKS: {}", idx, err))
})?;
if let Err(err) = bundle.add_jwt_authority(key_id, jwt_key) {
return Err(wrap_error(format!(
"error adding authority {} of JWKS: {}",
Expand Down
58 changes: 36 additions & 22 deletions src/bundle/spiffebundle/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ use crate::internal::jwtutil;
use crate::internal::x509util;
use crate::spiffeid::TrustDomain;
use base64::Engine;
use oid_registry::{OID_EC_P256, OID_KEY_TYPE_EC_PUBLIC_KEY, OID_NIST_EC_P384, OID_NIST_EC_P521};
use serde::Serialize;
use std::collections::HashMap;
use std::fs;
use std::io::Read;
use std::sync::RwLock;
use std::time::Duration;
use oid_registry::{OID_EC_P256, OID_KEY_TYPE_EC_PUBLIC_KEY, OID_NIST_EC_P384, OID_NIST_EC_P521};
use x509_parser::prelude::X509Certificate;

const X509_SVID_USE: &str = "x509-svid";
Expand Down Expand Up @@ -70,8 +70,8 @@ impl Bundle {

/// Loads a SPIFFE bundle from a JSON file (JWKS).
pub fn load(trust_domain: TrustDomain, path: &str) -> Result<Bundle> {
let bytes =
fs::read(path).map_err(|err| wrap_error(format!("unable to read SPIFFE bundle: {}", err)))?;
let bytes = fs::read(path)
.map_err(|err| wrap_error(format!("unable to read SPIFFE bundle: {}", err)))?;
Bundle::parse(trust_domain, &bytes)
}

Expand All @@ -86,8 +86,8 @@ impl Bundle {

/// Parses a SPIFFE bundle from JSON bytes (JWKS).
pub fn parse(trust_domain: TrustDomain, bytes: &[u8]) -> Result<Bundle> {
let jwks: JwkDocument =
serde_json::from_slice(bytes).map_err(|err| wrap_error(format!("unable to parse JWKS: {}", err)))?;
let jwks: JwkDocument = serde_json::from_slice(bytes)
.map_err(|err| wrap_error(format!("unable to parse JWKS: {}", err)))?;
let bundle = Bundle::new(trust_domain);
if let Some(hint) = jwks.spiffe_refresh_hint {
bundle.set_refresh_hint(Duration::from_secs(hint as u64));
Expand All @@ -96,18 +96,18 @@ impl Bundle {
bundle.set_sequence_number(seq);
}

let keys = jwks.keys.ok_or_else(|| wrap_error("no authorities found"))?;
let keys = jwks
.keys
.ok_or_else(|| wrap_error("no authorities found"))?;
for (idx, key) in keys.iter().enumerate() {
match key.use_field.as_deref() {
Some(X509_SVID_USE) => {
let cert = key
.x509_certificate_der()
.ok_or_else(|| {
wrap_error(format!(
"expected a single certificate in {} entry {}; got 0",
X509_SVID_USE, idx
))
})?;
let cert = key.x509_certificate_der().ok_or_else(|| {
wrap_error(format!(
"expected a single certificate in {} entry {}; got 0",
X509_SVID_USE, idx
))
})?;
if let Some(count) = key.x5c.as_ref().map(|x| x.len()) {
if count != 1 {
return Err(wrap_error(format!(
Expand All @@ -120,9 +120,9 @@ impl Bundle {
}
Some(JWT_SVID_USE) => {
let key_id = key.key_id().unwrap_or_default();
let jwt_key = key
.to_jwt_key()
.map_err(|err| wrap_error(format!("error adding authority {} of JWKS: {}", idx, err)))?;
let jwt_key = key.to_jwt_key().map_err(|err| {
wrap_error(format!("error adding authority {} of JWKS: {}", idx, err))
})?;
if let Err(err) = bundle.add_jwt_authority(key_id, jwt_key) {
return Err(wrap_error(format!(
"error adding authority {} of JWKS: {}",
Expand Down Expand Up @@ -379,7 +379,10 @@ impl Bundle {
}

/// Returns the X.509 bundle for the given trust domain if it matches.
pub fn get_x509_bundle_for_trust_domain(&self, trust_domain: TrustDomain) -> Result<x509bundle::Bundle> {
pub fn get_x509_bundle_for_trust_domain(
&self,
trust_domain: TrustDomain,
) -> Result<x509bundle::Bundle> {
if self.trust_domain != trust_domain {
return Err(wrap_error(format!(
"no X.509 bundle for trust domain \"{}\"",
Expand All @@ -390,7 +393,10 @@ impl Bundle {
}

/// Returns the JWT bundle for the given trust domain if it matches.
pub fn get_jwt_bundle_for_trust_domain(&self, trust_domain: TrustDomain) -> Result<jwtbundle::Bundle> {
pub fn get_jwt_bundle_for_trust_domain(
&self,
trust_domain: TrustDomain,
) -> Result<jwtbundle::Bundle> {
if self.trust_domain != trust_domain {
return Err(wrap_error(format!(
"no JWT bundle for trust domain \"{}\"",
Expand Down Expand Up @@ -496,7 +502,10 @@ impl Set {
}

/// Returns the X.509 bundle for the given trust domain.
pub fn get_x509_bundle_for_trust_domain(&self, trust_domain: TrustDomain) -> Result<x509bundle::Bundle> {
pub fn get_x509_bundle_for_trust_domain(
&self,
trust_domain: TrustDomain,
) -> Result<x509bundle::Bundle> {
let guard = self
.bundles
.read()
Expand All @@ -511,7 +520,10 @@ impl Set {
}

/// Returns the JWT bundle for the given trust domain.
pub fn get_jwt_bundle_for_trust_domain(&self, trust_domain: TrustDomain) -> Result<jwtbundle::Bundle> {
pub fn get_jwt_bundle_for_trust_domain(
&self,
trust_domain: TrustDomain,
) -> Result<jwtbundle::Bundle> {
let guard = self
.bundles
.read()
Expand Down Expand Up @@ -620,7 +632,9 @@ fn ec_public_key_parameters(cert: &X509Certificate<'_>) -> Result<(String, Vec<u
.parameters
.as_ref()
.ok_or_else(|| wrap_error("missing EC parameters"))?;
let oid = params.as_oid().map_err(|_| wrap_error("invalid EC parameters"))?;
let oid = params
.as_oid()
.map_err(|_| wrap_error("invalid EC parameters"))?;
if oid == OID_EC_P256 {
"P-256".to_string()
} else if oid == OID_NIST_EC_P384 {
Expand Down
4 changes: 2 additions & 2 deletions src/bundle/x509bundle/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,8 @@ fn parse_raw_certificates(bytes: &[u8]) -> std::result::Result<Vec<Vec<u8>>, Str
let mut remaining = bytes;
let mut certs = Vec::new();
while !remaining.is_empty() {
let (rest, _cert) = x509_parser::parse_x509_certificate(remaining)
.map_err(|err| err.to_string())?;
let (rest, _cert) =
x509_parser::parse_x509_certificate(remaining).map_err(|err| err.to_string())?;
let consumed = remaining
.len()
.checked_sub(rest.len())
Expand Down
Loading