From a1260f209f40025b4b1ecfadcf273293ec57d34a Mon Sep 17 00:00:00 2001 From: Guanhao Yin Date: Wed, 11 Mar 2026 15:51:59 +0800 Subject: [PATCH] feat: make openssl optional, add RustCrypto RSA backend Make OpenSSL optional, add a RustCrypto RSA backend, split RSA backend implementation files, simplify JWK conversion paths (including backend interop tests), preserve SomeKey PEM helpers/comments, and keep CI feature-matrix coverage aligned with supported combinations. --- .github/workflows/ci.yml | 52 ++-- Cargo.toml | 30 +- README.md | 31 +- benches/sig.rs | 23 +- examples/jwks.rs | 12 + examples/signing_and_verification.rs | 20 +- src/jwk.rs | 300 +++++++++--------- src/lib.rs | 34 ++- src/rsa.rs | 438 +++++---------------------- src/rsa/openssl_imp.rs | 413 +++++++++++++++++++++++++ src/rsa/rustcrypto_imp.rs | 344 +++++++++++++++++++++ src/some.rs | 150 +++++---- 12 files changed, 1233 insertions(+), 614 deletions(-) create mode 100644 src/rsa/openssl_imp.rs create mode 100644 src/rsa/rustcrypto_imp.rs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c4eed06..734374f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -7,30 +7,46 @@ on: jobs: test: runs-on: ubuntu-latest + strategy: + matrix: + features: + - "" + - "rsa" + - "openssl" + - "rsa,openssl" + - "rsa,remote-jwks" + - "openssl,remote-jwks" + - "rsa,openssl,remote-jwks" steps: - - uses: actions/checkout@v2 - - uses: actions-rs/toolchain@v1 + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable with: toolchain: 1.85.0 - default: true - profile: minimal components: rustfmt, clippy - - uses: actions-rs/toolchain@v1 - with: - toolchain: nightly - profile: minimal + - name: cargo test (features=${{ matrix.features || 'none' }}) + env: + RUSTFLAGS: -D warnings + run: cargo test --no-default-features --features "${{ matrix.features }}" + + - name: cargo clippy (features=${{ matrix.features || 'none' }}) + run: cargo clippy --no-default-features --features "${{ matrix.features }}" -- -D clippy::all + fmt: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + with: + toolchain: 1.85.0 + components: rustfmt - name: cargo fmt run: cargo fmt -- --check - - name: cargo test --benches - run: cargo +nightly test --benches && rm -r benches - - - name: cargo test - env: - RUSTFLAGS: -D warnings - run: cargo test --all-targets - - - name: cargo clippy - run: cargo clippy --all-targets -- -D clippy::all && cargo clippy --no-default-features --all-targets -- -D clippy::all + bench: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@nightly + - name: cargo test --benches (openssl) + run: cargo +nightly test --no-default-features --features openssl --benches diff --git a/Cargo.toml b/Cargo.toml index 65c4445..53f4951 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,28 +1,38 @@ [package] name = "jwtk" -version = "0.4.0" -edition = "2018" +version = "0.5.0" +edition = "2021" repository = "https://github.com/blckngm/jwtk" license = "MIT" description = "JWT signing (JWS) and verification, with first class JWK and JWK Set (JWKS) support." -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - [features] -default = ["remote-jwks"] -remote-jwks = ["reqwest", "tokio"] +default = ["rsa", "remote-jwks"] +rsa = ["dep:rsa_crt", "dep:signature", "dep:rand_core"] +openssl = ["dep:openssl", "dep:openssl-sys", "dep:foreign-types"] +remote-jwks = ["dep:reqwest", "dep:tokio"] [dependencies] base64 = "0.22.1" -openssl = "0.10.64" serde = { version = "1.0.200", features = ["derive"] } serde_json = "1.0.116" smallvec = "1.13.2" +serde_with = "3.1.0" +sha2 = { version = "0.10", features = ["oid"] } + +# RSA (RustCrypto fallback when openssl is not enabled) +rsa_crt = { package = "rsa", version = "0.9", optional = true, default-features = false, features = ["std", "getrandom"] } +signature = { version = "2", optional = true } +rand_core = { version = "0.6", features = ["std"], optional = true } + +# OpenSSL +openssl = { version = "0.10.64", optional = true } +openssl-sys = { version = "0.9.102", optional = true } +foreign-types = { version = "0.3.2", optional = true } + +# Remote JWKS reqwest = { version = "0.12.4", features = ["json"], optional = true } tokio = { version = "1.37.0", features = ["sync"], optional = true } -openssl-sys = "0.9.102" -foreign-types = "0.3.2" -serde_with = "3.1.0" [dev-dependencies] axum = "0.7" diff --git a/README.md b/README.md index 70e3e8d..5baa535 100644 --- a/README.md +++ b/README.md @@ -1,22 +1,33 @@ JWT signing (JWS) and verification, with first class JWK and JWK Set (JWKS) support. -Supports almost all JWS algorithms: +## Algorithms -* HS256, HS384, HS512 -* Ed25519 -* ES256, ES384, ES512, ES256K -* RS256, RS384, RS512 -* PS256, PS384, PS512 +* RS256, RS384, RS512, PS256, PS384, PS512 (feature: `rsa` or `openssl`) +* HS256, HS384, HS512 (feature: `openssl`) +* ES256, ES384, ES512, ES256K (feature: `openssl`) +* Ed25519 (feature: `openssl`) Supports `exp` and `nbf` validations. (Other validations will not be supported, because they are mostly application specific and can be easily implemented by applications.) -Supports converting public/private keys to/from PEM/JWK. Supports working with -generic keys (where the algorithm is determined at runtime), i.e. +Supports converting public/private keys to/from JWK. PEM support is available +when the `openssl` feature is enabled. Supports working with generic keys +(where the algorithm is determined at runtime), i.e. `SomePrivateKey`/`SomePublicKey`. -Uses good old openssl for crypto. +## Features -See the `examples` folder for some examples. +| Feature | Default | Description | +|---------|---------|-------------| +| `rsa` | Yes | RSA signing/verification via [RustCrypto](https://github.com/RustCrypto). No C dependencies. | +| `openssl` | No | Full algorithm support (RSA, HMAC, ECDSA, EdDSA) via OpenSSL. When enabled, RSA uses OpenSSL instead of RustCrypto. | +| `remote-jwks` | Yes | `RemoteJwksVerifier` for fetching and caching remote JWK Sets. | + +With the default features (`rsa` + `remote-jwks`), RS256 JWT verification +works out of the box with no C dependencies. + +## Examples + +See the `examples` folder for usage examples. diff --git a/benches/sig.rs b/benches/sig.rs index 1121008..5c01dc0 100644 --- a/benches/sig.rs +++ b/benches/sig.rs @@ -2,18 +2,15 @@ use std::time::Duration; -use jwtk::{ - ecdsa::{EcdsaAlgorithm, EcdsaPrivateKey}, - eddsa::Ed25519PrivateKey, - hmac::{HmacAlgorithm, HmacKey}, - rsa::RsaPrivateKey, - HeaderAndClaims, -}; +use jwtk::HeaderAndClaims; extern crate test; +#[cfg(feature = "openssl")] #[bench] fn bench_sig_es256(b: &mut test::Bencher) { + use jwtk::ecdsa::{EcdsaAlgorithm, EcdsaPrivateKey}; + let k = EcdsaPrivateKey::generate(EcdsaAlgorithm::ES256).unwrap(); b.iter(|| { @@ -29,8 +26,11 @@ fn bench_sig_es256(b: &mut test::Bencher) { }); } +#[cfg(any(feature = "rsa", feature = "openssl"))] #[bench] fn bench_sig_rs256(b: &mut test::Bencher) { + use jwtk::rsa::RsaPrivateKey; + let k = RsaPrivateKey::generate(2048, jwtk::rsa::RsaAlgorithm::RS256).unwrap(); b.iter(|| { @@ -46,8 +46,11 @@ fn bench_sig_rs256(b: &mut test::Bencher) { }); } +#[cfg(any(feature = "rsa", feature = "openssl"))] #[bench] fn bench_sig_ps256(b: &mut test::Bencher) { + use jwtk::rsa::RsaPrivateKey; + let k = RsaPrivateKey::generate(2048, jwtk::rsa::RsaAlgorithm::PS256).unwrap(); b.iter(|| { @@ -63,8 +66,11 @@ fn bench_sig_ps256(b: &mut test::Bencher) { }); } +#[cfg(feature = "openssl")] #[bench] fn bench_sig_hs256(b: &mut test::Bencher) { + use jwtk::hmac::{HmacAlgorithm, HmacKey}; + let k = HmacKey::generate(HmacAlgorithm::HS256).unwrap(); b.iter(|| { @@ -80,8 +86,11 @@ fn bench_sig_hs256(b: &mut test::Bencher) { }); } +#[cfg(feature = "openssl")] #[bench] fn bench_sig_ed25519(b: &mut test::Bencher) { + use jwtk::eddsa::Ed25519PrivateKey; + let k = Ed25519PrivateKey::generate().unwrap(); b.iter(|| { diff --git a/examples/jwks.rs b/examples/jwks.rs index 0553c2a..df54b40 100644 --- a/examples/jwks.rs +++ b/examples/jwks.rs @@ -7,28 +7,34 @@ //! //! Tokens will be issued at http://127.0.0.1:3000/token +#[cfg(feature = "openssl")] use axum::{ extract::State, response::{IntoResponse, Json}, routing::get, Router, }; +#[cfg(feature = "openssl")] use jwtk::{ jwk::{JwkSet, WithKid}, rsa::RsaAlgorithm, sign, HeaderAndClaims, PublicKeyToJwk, SomePrivateKey, }; +#[cfg(feature = "openssl")] use std::{sync::Arc, time::Duration}; +#[cfg(feature = "openssl")] struct AppState { k: WithKid, jwks: JwkSet, } +#[cfg(feature = "openssl")] async fn jwks_handler(state: State>) -> impl IntoResponse { Json(&state.jwks).into_response() } +#[cfg(feature = "openssl")] async fn token_handler(state: State>) -> impl IntoResponse { let mut token = HeaderAndClaims::new_dynamic(); token @@ -43,6 +49,7 @@ async fn token_handler(state: State>) -> impl IntoResponse { })) } +#[cfg(feature = "openssl")] #[tokio::main] async fn main() -> jwtk::Result<()> { let k = std::fs::read("key.pem")?; @@ -74,3 +81,8 @@ async fn main() -> jwtk::Result<()> { Ok(()) } + +#[cfg(not(feature = "openssl"))] +fn main() { + eprintln!("This example requires the 'openssl' feature"); +} diff --git a/examples/signing_and_verification.rs b/examples/signing_and_verification.rs index c94afb8..306537a 100644 --- a/examples/signing_and_verification.rs +++ b/examples/signing_and_verification.rs @@ -1,11 +1,12 @@ -use jwtk::{ - ecdsa::{EcdsaAlgorithm, EcdsaPrivateKey, EcdsaPublicKey}, - sign, verify, HeaderAndClaims, -}; -use serde_json::{Map, Value}; -use std::time::Duration; - +#[cfg(feature = "openssl")] fn main() -> jwtk::Result<()> { + use jwtk::{ + ecdsa::{EcdsaAlgorithm, EcdsaPrivateKey, EcdsaPublicKey}, + sign, verify, HeaderAndClaims, + }; + use serde_json::{Map, Value}; + use std::time::Duration; + let k = EcdsaPrivateKey::generate(EcdsaAlgorithm::ES256)?; let pem = k.public_key_to_pem()?; @@ -26,3 +27,8 @@ fn main() -> jwtk::Result<()> { Ok(()) } + +#[cfg(not(feature = "openssl"))] +fn main() { + eprintln!("This example requires the 'openssl' feature"); +} diff --git a/src/jwk.rs b/src/jwk.rs index ea02183..9f0573e 100644 --- a/src/jwk.rs +++ b/src/jwk.rs @@ -1,28 +1,20 @@ //! JWK and JWK Set. -//! -//! Only public keys are really supported for now. +#[cfg(feature = "openssl")] use crate::{ ecdsa::{EcdsaAlgorithm, EcdsaPrivateKey, EcdsaPublicKey}, eddsa::{Ed25519PrivateKey, Ed25519PublicKey}, - rsa::{RsaAlgorithm, RsaPrivateKey, RsaPublicKey}, - some::SomePublicKey, - verify, verify_only, Error, Header, HeaderAndClaims, PublicKeyToJwk, Result, SigningKey, - SomePrivateKey, VerificationKey, URL_SAFE_TRAILING_BITS, }; -use base64::Engine as _; -use openssl::{ - bn::BigNum, - hash::{hash, MessageDigest}, - pkey::PKey, - rsa::{Rsa, RsaPrivateKeyBuilder}, +use crate::{ + rsa::RsaAlgorithm, some::SomePublicKey, verify, verify_only, Error, Header, HeaderAndClaims, + PublicKeyToJwk, Result, SigningKey, SomePrivateKey, VerificationKey, URL_SAFE_TRAILING_BITS, }; +use base64::Engine as _; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde_json::Value; +use sha2::{Digest, Sha256}; use std::collections::{BTreeMap, HashMap}; -// TODO: private key jwk. - /// JWK Representation. #[non_exhaustive] #[derive(Debug, Deserialize, Serialize, Default)] @@ -50,7 +42,6 @@ pub struct Jwk { #[serde(skip_serializing_if = "Option::is_none")] pub d: Option, - // RSA private key. #[serde(skip_serializing_if = "Option::is_none")] pub p: Option, #[serde(skip_serializing_if = "Option::is_none")] @@ -78,125 +69,150 @@ impl Jwk { // If let would be too long. #[allow(clippy::single_match)] match &*self.kty { - "RSA" => match (self.alg.as_deref(), &self.n, &self.e) { - (alg, Some(ref n), Some(ref e)) => { - let n = URL_SAFE_TRAILING_BITS.decode(n)?; - let e = URL_SAFE_TRAILING_BITS.decode(e)?; - // If `alg` is specified, the key will only verify - // signatures generated by ONLY this specific `alg`, - // otherwise it will verify signatures generated by ANY RSA - // algorithm. - let alg = if let Some(alg) = alg { - Some(RsaAlgorithm::from_name(alg)?) - } else { - None - }; - return Ok(SomePublicKey::Rsa(RsaPublicKey::from_components( - &n, &e, alg, - )?)); - } - _ => {} - }, - "EC" => match (self.crv.as_deref(), &self.x, &self.y) { - // For EC keys `crv` is required. - (Some(crv), Some(ref x), Some(ref y)) => { - let x = URL_SAFE_TRAILING_BITS.decode(x)?; - let y = URL_SAFE_TRAILING_BITS.decode(y)?; - let alg = EcdsaAlgorithm::from_curve_name(crv)?; - return Ok(SomePublicKey::Ecdsa(EcdsaPublicKey::from_coordinates( - &x, &y, alg, - )?)); - } - _ => {} - }, - "OKP" => match (self.crv.as_deref(), &self.x) { - (Some(crv), Some(ref x)) => { - let x = URL_SAFE_TRAILING_BITS.decode(x)?; - match crv { - "Ed25519" => { - return Ok(SomePublicKey::Ed25519(Ed25519PublicKey::from_bytes(&x)?)); - } - _ => {} - } - } - _ => {} - }, + #[cfg(any(feature = "rsa", feature = "openssl"))] + "RSA" => return self.rsa_verification_key(), + #[cfg(feature = "openssl")] + "EC" => return self.ec_verification_key(), + #[cfg(feature = "openssl")] + "OKP" => return self.okp_verification_key(), _ => {} } Err(Error::UnsupportedOrInvalidKey) } - #[allow(clippy::many_single_char_names)] pub fn to_signing_key(&self, rsa_fallback_algorithm: RsaAlgorithm) -> Result { match &*self.kty { - "RSA" => { - let alg = if let Some(ref alg) = self.alg { - RsaAlgorithm::from_name(alg)? - } else { - rsa_fallback_algorithm - }; - match (self.d.as_deref(), self.n.as_deref(), self.e.as_deref()) { - (Some(d), Some(n), Some(e)) => { - fn decode(x: &str) -> Result { - Ok(BigNum::from_slice(&URL_SAFE_TRAILING_BITS.decode(x)?)?) - } - let d = decode(d)?; - let n = decode(n)?; - let e = decode(e)?; - match ( - self.p.as_deref(), - self.q.as_deref(), - self.dp.as_deref(), - self.dq.as_deref(), - self.qi.as_deref(), - self.oth.is_empty(), - ) { - (None, None, None, None, None, true) => { - let rsa = RsaPrivateKeyBuilder::new(n, e, d)?.build(); - let pkey = PKey::from_rsa(rsa)?; - RsaPrivateKey::from_pkey_without_check(pkey, alg).map(Into::into) - } - (Some(p), Some(q), Some(dp), Some(dq), Some(qi), true) => { - let p = decode(p)?; - let q = decode(q)?; - let dp = decode(dp)?; - let dq = decode(dq)?; - let qi = decode(qi)?; - let rsa = Rsa::from_private_components(n, e, d, p, q, dp, dq, qi)?; - let pkey = PKey::from_rsa(rsa)?; - RsaPrivateKey::from_pkey(pkey, alg).map(Into::into) - } - _ => Err(Error::UnsupportedOrInvalidKey), - } - } - _ => Err(Error::UnsupportedOrInvalidKey), - } + #[cfg(any(feature = "rsa", feature = "openssl"))] + "RSA" => self.rsa_signing_key(rsa_fallback_algorithm), + #[cfg(feature = "openssl")] + "EC" => self.ec_signing_key(), + #[cfg(feature = "openssl")] + "OKP" => self.okp_signing_key(), + _ => { + let _ = rsa_fallback_algorithm; + Err(Error::UnsupportedOrInvalidKey) } - "EC" => { + } + } + + #[cfg(any(feature = "rsa", feature = "openssl"))] + #[allow(clippy::many_single_char_names)] + pub(super) fn rsa_signing_key( + &self, + rsa_fallback_algorithm: RsaAlgorithm, + ) -> Result { + let alg = if let Some(ref alg) = self.alg { + RsaAlgorithm::from_name(alg)? + } else { + rsa_fallback_algorithm + }; + match (self.d.as_deref(), self.n.as_deref(), self.e.as_deref()) { + (Some(d), Some(n), Some(e)) => { + let d = URL_SAFE_TRAILING_BITS.decode(d)?; + let n = URL_SAFE_TRAILING_BITS.decode(n)?; + let e = URL_SAFE_TRAILING_BITS.decode(e)?; match ( - self.crv.as_deref(), - self.d.as_deref(), - self.x.as_deref(), - self.y.as_deref(), + self.p.as_deref(), + self.q.as_deref(), + self.dp.as_deref(), + self.dq.as_deref(), + self.qi.as_deref(), + self.oth.is_empty(), ) { - (Some(crv), Some(d), Some(x), Some(y)) => { - let alg = EcdsaAlgorithm::from_curve_name(crv)?; - let d = URL_SAFE_TRAILING_BITS.decode(d)?; - let x = URL_SAFE_TRAILING_BITS.decode(x)?; - let y = URL_SAFE_TRAILING_BITS.decode(y)?; - EcdsaPrivateKey::from_private_components(alg, &d, &x, &y).map(Into::into) + (None, None, None, None, None, true) => { + crate::rsa::RsaPrivateKey::from_components(&n, &e, &d, vec![], alg) + .map(Into::into) + } + (Some(p), Some(q), Some(_dp), Some(_dq), Some(_qi), true) => { + let p = URL_SAFE_TRAILING_BITS.decode(p)?; + let q = URL_SAFE_TRAILING_BITS.decode(q)?; + crate::rsa::RsaPrivateKey::from_components(&n, &e, &d, vec![p, q], alg) + .map(Into::into) } _ => Err(Error::UnsupportedOrInvalidKey), } } - "OKP" => match (self.crv.as_deref(), self.d.as_deref()) { - (Some("Ed25519"), Some(d)) => { - let d = URL_SAFE_TRAILING_BITS.decode(d)?; - Ed25519PrivateKey::from_bytes(&d).map(Into::into) - } - _ => Err(Error::UnsupportedOrInvalidKey), - }, + _ => Err(Error::UnsupportedOrInvalidKey), + } + } + + #[cfg(any(feature = "rsa", feature = "openssl"))] + pub(super) fn rsa_verification_key(&self) -> Result { + match (self.alg.as_deref(), &self.n, &self.e) { + (alg, Some(ref n), Some(ref e)) => { + let n = URL_SAFE_TRAILING_BITS.decode(n)?; + let e = URL_SAFE_TRAILING_BITS.decode(e)?; + // If `alg` is specified, the key will only verify + // signatures generated by ONLY this specific `alg`, + // otherwise it will verify signatures generated by ANY RSA + // algorithm. + let alg = if let Some(alg) = alg { + Some(RsaAlgorithm::from_name(alg)?) + } else { + None + }; + Ok(SomePublicKey::Rsa( + crate::rsa::RsaPublicKey::from_components(&n, &e, alg)?, + )) + } + _ => Err(Error::UnsupportedOrInvalidKey), + } + } + + #[cfg(feature = "openssl")] + pub(super) fn ec_verification_key(&self) -> Result { + match (self.crv.as_deref(), &self.x, &self.y) { + // For EC keys `crv` is required. + (Some(crv), Some(ref x), Some(ref y)) => { + let x = URL_SAFE_TRAILING_BITS.decode(x)?; + let y = URL_SAFE_TRAILING_BITS.decode(y)?; + let alg = EcdsaAlgorithm::from_curve_name(crv)?; + Ok(SomePublicKey::Ecdsa(EcdsaPublicKey::from_coordinates( + &x, &y, alg, + )?)) + } + _ => Err(Error::UnsupportedOrInvalidKey), + } + } + + #[cfg(feature = "openssl")] + pub(super) fn okp_verification_key(&self) -> Result { + match (self.crv.as_deref(), &self.x) { + (Some("Ed25519"), Some(ref x)) => { + let x = URL_SAFE_TRAILING_BITS.decode(x)?; + Ok(SomePublicKey::Ed25519(Ed25519PublicKey::from_bytes(&x)?)) + } + _ => Err(Error::UnsupportedOrInvalidKey), + } + } + + #[cfg(feature = "openssl")] + pub(super) fn ec_signing_key(&self) -> Result { + match ( + self.crv.as_deref(), + self.d.as_deref(), + self.x.as_deref(), + self.y.as_deref(), + ) { + (Some(crv), Some(d), Some(x), Some(y)) => { + let alg = EcdsaAlgorithm::from_curve_name(crv)?; + let d = URL_SAFE_TRAILING_BITS.decode(d)?; + let x = URL_SAFE_TRAILING_BITS.decode(x)?; + let y = URL_SAFE_TRAILING_BITS.decode(y)?; + EcdsaPrivateKey::from_private_components(alg, &d, &x, &y).map(Into::into) + } + _ => Err(Error::UnsupportedOrInvalidKey), + } + } + + #[cfg(feature = "openssl")] + pub(super) fn okp_signing_key(&self) -> Result { + match (self.crv.as_deref(), self.d.as_deref()) { + (Some("Ed25519"), Some(d)) => { + let d = URL_SAFE_TRAILING_BITS.decode(d)?; + Ed25519PrivateKey::from_bytes(&d).map(Into::into) + } _ => Err(Error::UnsupportedOrInvalidKey), } } @@ -249,7 +265,7 @@ impl Jwk { } _ => return Err(Error::UnsupportedOrInvalidKey), }; - let hash = hash(MessageDigest::sha256(), as_json.as_bytes())?; + let hash = Sha256::digest(as_json.as_bytes()); let mut out = [0u8; 32]; out.copy_from_slice(&hash[..]); Ok(out) @@ -287,7 +303,7 @@ impl JwkSet { /// Jwk set parsed and converted, ready to verify tokens. pub struct JwkSetVerifier { keys: HashMap, - require_kid: bool, + pub(crate) require_kid: bool, } impl JwkSetVerifier { @@ -298,11 +314,7 @@ impl JwkSetVerifier { } pub fn find(&self, kid: &str) -> Option<&SomePublicKey> { - if let Some(vk) = self.keys.get(kid) { - Some(vk) - } else { - None - } + self.keys.get(kid) } /// Decode and verify token with keys from this JWK set. @@ -323,9 +335,6 @@ impl JwkSetVerifier { self.find_and_verify(token, verify_only) } - /// Find and verify token with keys from this JWK set. - /// - /// restrict_kid is true will only match keys with the same `kid`. fn find_and_verify( &self, token: &str, @@ -479,7 +488,6 @@ impl RemoteJwksVerifier { async fn get_cache(&self) -> Result> { let cache = self.cache.read().await; - // Cache still valid. if let Some(c) = &*cache { if c.fresher_than(self.cache_duration) { return Ok(tokio::sync::RwLockReadGuard::map(cache, |c| { @@ -632,13 +640,6 @@ impl RemoteJwksVerifierBuilder { #[cfg(test)] mod tests { - use crate::{ - ecdsa::{EcdsaAlgorithm, EcdsaPrivateKey}, - eddsa::Ed25519PrivateKey, - rsa::RsaPrivateKey, - sign, - }; - use super::*; #[test] @@ -661,11 +662,23 @@ mod tests { Ok(()) } + #[cfg(any(feature = "rsa", feature = "openssl"))] #[test] - fn test_thumbprint() -> Result<()> { + fn test_rsa_thumbprint() -> Result<()> { + use crate::rsa::{RsaAlgorithm, RsaPrivateKey}; RsaPrivateKey::generate(2048, RsaAlgorithm::RS256)? .public_key_to_jwk()? .get_thumbprint_sha256_base64()?; + Ok(()) + } + + #[cfg(feature = "openssl")] + #[test] + fn test_ec_ed_thumbprint() -> Result<()> { + use crate::{ + ecdsa::{EcdsaAlgorithm, EcdsaPrivateKey}, + eddsa::Ed25519PrivateKey, + }; EcdsaPrivateKey::generate(EcdsaAlgorithm::ES256)? .public_key_to_jwk()? .get_thumbprint_sha256_base64()?; @@ -675,14 +688,21 @@ mod tests { Ok(()) } - #[derive(Serialize, Deserialize)] + #[cfg(any(feature = "rsa", feature = "openssl"))] + #[derive(serde::Serialize, serde::Deserialize)] struct MyClaim { foo: String, } + #[cfg(any(feature = "rsa", feature = "openssl"))] #[test] fn test_jwks_verify() -> Result<()> { - let k = EcdsaPrivateKey::generate(EcdsaAlgorithm::ES512)?; + use crate::{ + rsa::{RsaAlgorithm, RsaPrivateKey}, + sign, + }; + + let k = RsaPrivateKey::generate(2048, RsaAlgorithm::RS256)?; let kk = WithKid::new("my key".into(), k.clone()); let k_jwk = kk.public_key_to_jwk()?; let jwks = JwkSet { keys: vec![k_jwk] }; diff --git a/src/lib.rs b/src/lib.rs index 26bfa91..53566b8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,7 +5,6 @@ use base64::{ alphabet, engine::{general_purpose::NO_PAD, GeneralPurpose}, }; -use openssl::error::ErrorStack; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde_json::{Map, Value}; use serde_with::{serde_as, skip_serializing_none}; @@ -23,10 +22,13 @@ pub use some::*; mod some; +#[cfg(feature = "openssl")] pub mod hmac; +#[cfg(feature = "openssl")] pub mod eddsa; +#[cfg(feature = "openssl")] pub mod ecdsa; pub mod rsa; @@ -75,7 +77,7 @@ impl OneOrMany { /// Iterate over the values regardless of whether it contains one or many. #[inline] - pub fn iter(&self) -> OneOrManyIter { + pub fn iter(&self) -> OneOrManyIter<'_, T> { OneOrManyIter::new(self) } } @@ -399,7 +401,7 @@ pub fn verify_only( // Verify the signature. k.verify( - token[..header_and_payload_len].as_bytes(), + &token.as_bytes()[..header_and_payload_len], &sig, &header.alg, )?; @@ -476,7 +478,9 @@ pub enum Error { UnsupportedOrInvalidKey, Utf8(FromUtf8Error), IoError(std::io::Error), - OpenSsl(ErrorStack), + #[cfg(feature = "openssl")] + OpenSsl(openssl::error::ErrorStack), + Crypto(Box), SerdeJson(serde_json::Error), Decode(base64::DecodeError), #[cfg(feature = "remote-jwks")] @@ -487,7 +491,9 @@ impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Error::IoError(e) => e.fmt(f), + #[cfg(feature = "openssl")] Error::OpenSsl(e) => e.fmt(f), + Error::Crypto(e) => e.fmt(f), Error::SerdeJson(e) => e.fmt(f), Error::Decode(e) => e.fmt(f), #[cfg(feature = "remote-jwks")] @@ -512,7 +518,9 @@ impl std::error::Error for Error { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { match self { Error::IoError(e) => Some(e), + #[cfg(feature = "openssl")] Error::OpenSsl(e) => Some(e), + Error::Crypto(e) => Some(e.as_ref()), Error::SerdeJson(e) => Some(e), Error::Decode(e) => Some(e), Error::Utf8(e) => Some(e), @@ -530,13 +538,22 @@ impl From for Error { } } -impl From for Error { +#[cfg(feature = "openssl")] +impl From for Error { #[inline] - fn from(e: ErrorStack) -> Error { + fn from(e: openssl::error::ErrorStack) -> Error { Error::OpenSsl(e) } } +#[cfg(all(feature = "rsa", not(feature = "openssl")))] +impl From for Error { + #[inline] + fn from(e: signature::Error) -> Error { + Error::Crypto(Box::new(e)) + } +} + impl From for Error { #[inline] fn from(e: serde_json::Error) -> Error { @@ -570,8 +587,6 @@ pub type Result = std::result::Result; #[cfg(test)] mod tests { - use crate::ecdsa::{EcdsaAlgorithm, EcdsaPrivateKey}; - use super::*; #[test] @@ -589,8 +604,11 @@ mod tests { assert_eq!(iter.next(), None); } + #[cfg(feature = "openssl")] #[test] fn signing_and_verification() -> Result<()> { + use crate::ecdsa::{EcdsaAlgorithm, EcdsaPrivateKey}; + let mut claims = HeaderAndClaims::new_dynamic(); let k = EcdsaPrivateKey::generate(EcdsaAlgorithm::ES256)?; let k1 = EcdsaPrivateKey::generate(EcdsaAlgorithm::ES256)?; diff --git a/src/rsa.rs b/src/rsa.rs index 5836896..1819cfe 100644 --- a/src/rsa.rs +++ b/src/rsa.rs @@ -1,17 +1,4 @@ -use crate::{ - jwk::Jwk, Error, PrivateKeyToJwk, PublicKeyToJwk, Result, SigningKey, VerificationKey, - URL_SAFE_TRAILING_BITS, -}; -use base64::Engine as _; -/// RSASSA-PKCS1-v1_5 using SHA-256. -use openssl::{ - bn::BigNum, - hash::MessageDigest, - pkey::{Id, PKey, Private, Public}, - rsa::{Padding, Rsa}, - sign::{RsaPssSaltlen, Signer, Verifier}, -}; -use smallvec::SmallVec; +use crate::{Error, Result}; /// RSA signature algorithms. #[non_exhaustive] @@ -33,15 +20,6 @@ impl RsaAlgorithm { ) } - fn digest(self) -> MessageDigest { - use RsaAlgorithm::*; - match self { - RS256 | PS256 => MessageDigest::sha256(), - RS384 | PS384 => MessageDigest::sha384(), - RS512 | PS512 => MessageDigest::sha512(), - } - } - pub fn name(self) -> &'static str { use RsaAlgorithm::*; match self { @@ -67,346 +45,48 @@ impl RsaAlgorithm { } } -/// RSA Private Key. -/// -/// By default, it only verifies signatures generated by the same algorithm used -/// for signing. If you want to verify signatures generated by any RSA -/// algorithm, set `verify_any` to `true`. -#[derive(Debug, Clone)] -pub struct RsaPrivateKey { - private_key: PKey, - pub algorithm: RsaAlgorithm, - pub verify_any: bool, -} - -impl RsaPrivateKey { - /// bits >= 2048. - pub fn generate(bits: u32, algorithm: RsaAlgorithm) -> Result { - if bits < 2048 { - return Err(Error::UnsupportedOrInvalidKey); - } - - Ok(Self { - private_key: PKey::from_rsa(Rsa::generate(bits)?)?, - algorithm, - verify_any: false, - }) - } - - pub(crate) fn from_pkey(pkey: PKey, algorithm: RsaAlgorithm) -> Result { - if pkey.bits() < 2048 || !pkey.rsa()?.check_key()? { - return Err(Error::UnsupportedOrInvalidKey); - } - Ok(Self { - private_key: pkey, - algorithm, - verify_any: false, - }) - } - - pub(crate) fn from_pkey_without_check( - pkey: PKey, - algorithm: RsaAlgorithm, - ) -> Result { - if pkey.bits() < 2048 { - return Err(Error::UnsupportedOrInvalidKey); - } - Ok(Self { - private_key: pkey, - algorithm, - verify_any: false, - }) - } - - pub fn from_pem(pem: &[u8], algorithm: RsaAlgorithm) -> Result { - let pk = PKey::private_key_from_pem(pem)?; - Self::from_pkey(pk, algorithm) - } - - pub fn private_key_to_pem_pkcs8(&self) -> Result { - Ok(String::from_utf8( - self.private_key.private_key_to_pem_pkcs8()?, - )?) - } - - pub fn public_key_to_pem(&self) -> Result { - Ok(String::from_utf8(self.private_key.public_key_to_pem()?)?) - } - - pub fn public_key_to_pem_pkcs1(&self) -> Result { - Ok(String::from_utf8( - self.private_key.rsa()?.public_key_to_pem_pkcs1()?, - )?) - } - - pub fn n(&self) -> Result> { - Ok(self.private_key.rsa()?.n().to_vec()) - } - - pub fn e(&self) -> Result> { - Ok(self.private_key.rsa()?.e().to_vec()) - } -} - -impl PrivateKeyToJwk for RsaPrivateKey { - #[allow(clippy::many_single_char_names)] - fn private_key_to_jwk(&self) -> Result { - let n = self.n()?; - let e = self.e()?; - let rsa = self.private_key.rsa()?; - let d = rsa.d().to_vec(); - let p = rsa.p().map(|p| p.to_vec()); - let q = rsa.q().map(|q| q.to_vec()); - let dp = rsa.dmp1().map(|dp| dp.to_vec()); - let dq = rsa.dmq1().map(|dq| dq.to_vec()); - let qi = rsa.iqmp().map(|qi| qi.to_vec()); - fn encode(x: &[u8]) -> String { - URL_SAFE_TRAILING_BITS.encode(x) - } - Ok(Jwk { - kty: "RSA".into(), - alg: if self.verify_any { - None - } else { - Some(self.algorithm.name().into()) - }, - use_: Some("sig".into()), - n: Some(encode(&n)), - e: Some(encode(&e)), - d: Some(encode(&d)), - p: p.map(|p| encode(&p)), - q: q.map(|q| encode(&q)), - dp: dp.map(|dp| encode(&dp)), - dq: dq.map(|dq| encode(&dq)), - qi: qi.map(|qi| encode(&qi)), - ..Default::default() - }) - } -} - -impl PublicKeyToJwk for RsaPrivateKey { - fn public_key_to_jwk(&self) -> Result { - Ok(Jwk { - kty: "RSA".into(), - alg: if self.verify_any { - None - } else { - Some(self.algorithm.name().into()) - }, - use_: Some("sig".into()), - n: Some(URL_SAFE_TRAILING_BITS.encode(self.n()?)), - e: Some(URL_SAFE_TRAILING_BITS.encode(self.e()?)), - ..Jwk::default() - }) - } -} - -/// RSA Public Key. -#[derive(Debug)] -pub struct RsaPublicKey { - public_key: PKey, - /// If this is `None`, this key verifies signatures generated by ANY RSA - /// algorithms. Otherwise it ONLY verifies signatures generated by this - /// algorithm. - pub algorithm: Option, -} - -impl RsaPublicKey { - pub(crate) fn from_pkey(pkey: PKey, algorithm: Option) -> Result { - if pkey.id() != Id::RSA || pkey.bits() < 2048 { - return Err(Error::UnsupportedOrInvalidKey); - } - Ok(Self { - public_key: pkey, - algorithm, - }) - } - - /// Both `BEGIN PUBLIC KEY` and `BEGIN RSA PUBLIC KEY` are OK. - pub fn from_pem(pem: &[u8], algorithm: Option) -> Result { - if std::str::from_utf8(pem).is_ok_and(|pem| pem.contains("BEGIN RSA")) { - let rsa = Rsa::public_key_from_pem_pkcs1(pem)?; - Self::from_pkey(PKey::from_rsa(rsa)?, algorithm) - } else { - let pkey = PKey::public_key_from_pem(pem)?; - Self::from_pkey(pkey, algorithm) - } - } - - pub fn from_components(n: &[u8], e: &[u8], algorithm: Option) -> Result { - let rsa = Rsa::from_public_components(BigNum::from_slice(n)?, BigNum::from_slice(e)?)?; - Self::from_pkey(PKey::from_rsa(rsa)?, algorithm) - } - - /// BEGIN PUBLIC KEY - pub fn to_pem(&self) -> Result { - Ok(String::from_utf8(self.public_key.public_key_to_pem()?)?) - } - - /// BEGIN RSA PUBLIC KEY - pub fn to_pem_pkcs1(&self) -> Result { - Ok(String::from_utf8( - self.public_key.rsa()?.public_key_to_pem_pkcs1()?, - )?) - } - - pub fn n(&self) -> Result> { - Ok(self.public_key.rsa()?.n().to_vec()) - } - - pub fn e(&self) -> Result> { - Ok(self.public_key.rsa()?.e().to_vec()) - } -} - -impl PublicKeyToJwk for RsaPublicKey { - fn public_key_to_jwk(&self) -> Result { - Ok(Jwk { - kty: "RSA".into(), - alg: self.algorithm.map(|alg| alg.name().to_string()), - use_: Some("sig".into()), - n: Some(URL_SAFE_TRAILING_BITS.encode(self.n()?)), - e: Some(URL_SAFE_TRAILING_BITS.encode(self.e()?)), - ..Jwk::default() - }) - } -} - -impl SigningKey for RsaPrivateKey { - fn sign(&self, v: &[u8]) -> Result> { - let mut signer = Signer::new(self.algorithm.digest(), self.private_key.as_ref())?; - if self.algorithm.is_pss() { - signer.set_rsa_padding(Padding::PKCS1_PSS)?; - signer.set_rsa_pss_saltlen(RsaPssSaltlen::DIGEST_LENGTH)?; - } - - signer.update(v)?; - Ok(signer.sign_to_vec()?.into()) - } - - fn alg(&self) -> &'static str { - self.algorithm.name() - } -} - -impl VerificationKey for RsaPrivateKey { - fn verify(&self, v: &[u8], sig: &[u8], alg: &str) -> Result<()> { - let alg = if self.verify_any { - RsaAlgorithm::from_name(alg)? - } else { - if alg != self.algorithm.name() { - return Err(Error::VerificationError); - } - self.algorithm - }; - - let mut verifier = Verifier::new(alg.digest(), self.private_key.as_ref())?; - if alg.is_pss() { - verifier.set_rsa_padding(Padding::PKCS1_PSS)?; - verifier.set_rsa_pss_saltlen(RsaPssSaltlen::DIGEST_LENGTH)?; - } - if verifier.verify_oneshot(sig, v)? { - Ok(()) - } else { - Err(Error::VerificationError) - } - } -} - -impl VerificationKey for RsaPublicKey { - fn verify(&self, v: &[u8], sig: &[u8], alg: &str) -> Result<()> { - let alg = if let Some(self_alg) = self.algorithm { - if self_alg.name() != alg { - return Err(Error::VerificationError); - } - self_alg - } else { - RsaAlgorithm::from_name(alg)? - }; - - let mut verifier = Verifier::new(alg.digest(), self.public_key.as_ref())?; - if alg.is_pss() { - verifier.set_rsa_padding(Padding::PKCS1_PSS)?; - verifier.set_rsa_pss_saltlen(RsaPssSaltlen::DIGEST_LENGTH)?; - } - if verifier.verify_oneshot(sig, v)? { - Ok(()) - } else { - Err(Error::VerificationError) - } - } -} - -#[cfg(test)] -mod tests { +#[cfg(feature = "openssl")] +mod openssl_imp; +#[cfg(any( + all(feature = "rsa", not(feature = "openssl")), + all(test, feature = "rsa", feature = "openssl") +))] +mod rustcrypto_imp; + +#[cfg(feature = "openssl")] +pub use openssl_imp::*; +#[cfg(all(feature = "rsa", not(feature = "openssl")))] +pub use rustcrypto_imp::*; + +#[cfg(all(test, feature = "rsa", feature = "openssl"))] +mod interop_tests { + use super::{openssl_imp, rustcrypto_imp, RsaAlgorithm}; use crate::{ - ecdsa::{EcdsaAlgorithm, EcdsaPrivateKey}, - SomePrivateKey, + PrivateKeyToJwk, PublicKeyToJwk, SigningKey, VerificationKey, URL_SAFE_TRAILING_BITS, }; + use base64::Engine as _; + use openssl::{bn::BigNum, pkey::PKey, rsa::RsaPrivateKeyBuilder}; - use super::*; - - #[test] - fn conversion() -> Result<()> { - let k = RsaPrivateKey::generate(2048, RsaAlgorithm::PS384)?; - let pem = k.private_key_to_pem_pkcs8()?; - RsaPrivateKey::from_pem(pem.as_bytes(), RsaAlgorithm::PS384)?; - - let es256key_pem = - EcdsaPrivateKey::generate(EcdsaAlgorithm::ES256)?.private_key_to_pem_pkcs8()?; - assert!(RsaPrivateKey::from_pem(es256key_pem.as_bytes(), RsaAlgorithm::PS384).is_err()); - - let pk_pem = k.public_key_to_pem()?; - let pk_pem_pkcs1 = k.public_key_to_pem_pkcs1()?; - - let pk = RsaPublicKey::from_pem(pk_pem.as_bytes(), None)?; - let pk1 = RsaPublicKey::from_pem(pk_pem_pkcs1.as_bytes(), None)?; - - println!("pk: {:?}", pk); - - let pk_pem1 = pk1.to_pem()?; - let pk_pem_pkcs1_1 = pk.to_pem_pkcs1()?; - - assert_eq!(pk_pem, pk_pem1); - assert_eq!(pk_pem_pkcs1, pk_pem_pkcs1_1); - - assert_eq!(k.alg(), "PS384"); - - if let SomePrivateKey::Rsa(k1) = k - .private_key_to_jwk()? - .to_signing_key(RsaAlgorithm::RS512)? - { - assert!(k.private_key.public_eq(k1.private_key.as_ref())); - } else { - panic!("expected rsa private key"); - } - - k.public_key_to_jwk()?.to_verification_key()?; - pk.public_key_to_jwk()?; - - Ok(()) + fn decode_pub_jwk(jwk: &crate::jwk::Jwk) -> (Vec, Vec) { + let n = URL_SAFE_TRAILING_BITS + .decode(jwk.n.as_deref().expect("RSA n is present")) + .expect("decode n"); + let e = URL_SAFE_TRAILING_BITS + .decode(jwk.e.as_deref().expect("RSA e is present")) + .expect("decode e"); + (n, e) } - #[test] - fn test_private_key_from_jwk_n_e_d_only() -> Result<()> { - let k = RsaPrivateKey::generate(2048, RsaAlgorithm::PS256)?; - let mut jwk = k.private_key_to_jwk()?; - jwk.p = None; - jwk.q = None; - jwk.dp = None; - jwk.dq = None; - jwk.qi = None; - let k1 = jwk.to_signing_key(RsaAlgorithm::RS256)?; - let sig = k1.sign(b"msg")?; - k.verify(b"msg", &sig, "PS256")?; - k1.verify(b"msg", &sig, "PS256")?; - let sig = k.sign(b"msg")?; - k1.verify(b"msg", &sig, "PS256")?; - Ok(()) + fn decode_private_jwk(jwk: &crate::jwk::Jwk) -> (Vec, Vec, Vec) { + let (n, e) = decode_pub_jwk(jwk); + let d = URL_SAFE_TRAILING_BITS + .decode(jwk.d.as_deref().expect("RSA d is present")) + .expect("decode d"); + (n, e, d) } #[test] - fn sign_verify() -> Result<()> { + fn public_jwk_interop_between_backends() -> crate::Result<()> { for alg in [ RsaAlgorithm::RS256, RsaAlgorithm::RS384, @@ -415,14 +95,46 @@ mod tests { RsaAlgorithm::PS384, RsaAlgorithm::PS512, ] { - let k = RsaPrivateKey::generate(2048, alg)?; - let pk = RsaPublicKey::from_pem(k.public_key_to_pem()?.as_bytes(), None)?; - let sig = k.sign(b"...")?; - assert!(k.verify(b"...", &sig, alg.name()).is_ok()); - assert!(k.verify(b"...", &sig, "WRONG ALG").is_err()); - assert!(k.verify(b"....", &sig, alg.name()).is_err()); - assert!(pk.verify(b"...", &sig, alg.name()).is_ok()); - assert!(pk.verify(b"....", &sig, alg.name()).is_err()); + let openssl_k = openssl_imp::RsaPrivateKey::generate(2048, alg)?; + let openssl_pub_jwk = openssl_k.public_key_to_jwk()?; + let (n, e) = decode_pub_jwk(&openssl_pub_jwk); + let rust_pk = rustcrypto_imp::RsaPublicKey::from_components(&n, &e, Some(alg))?; + let sig = openssl_k.sign(b"openssl->rustcrypto")?; + rust_pk.verify(b"openssl->rustcrypto", &sig, alg.name())?; + + let rust_k = rustcrypto_imp::RsaPrivateKey::generate(2048, alg)?; + let rust_pub_jwk = rust_k.public_key_to_jwk()?; + let (n, e) = decode_pub_jwk(&rust_pub_jwk); + let openssl_pk = openssl_imp::RsaPublicKey::from_components(&n, &e, Some(alg))?; + let sig = rust_k.sign(b"rustcrypto->openssl")?; + openssl_pk.verify(b"rustcrypto->openssl", &sig, alg.name())?; + } + Ok(()) + } + + #[test] + fn private_jwk_interop_between_backends() -> crate::Result<()> { + for alg in [RsaAlgorithm::RS256, RsaAlgorithm::PS256] { + let openssl_k = openssl_imp::RsaPrivateKey::generate(2048, alg)?; + let openssl_jwk = openssl_k.private_key_to_jwk()?; + let (n, e, d) = decode_private_jwk(&openssl_jwk); + let rust_k = rustcrypto_imp::RsaPrivateKey::from_components(&n, &e, &d, vec![], alg)?; + let sig = rust_k.sign(b"openssl-jwk->rustcrypto-key")?; + openssl_k.verify(b"openssl-jwk->rustcrypto-key", &sig, alg.name())?; + + let rust_jwk = rust_k.private_key_to_jwk()?; + let (n, e, d) = decode_private_jwk(&rust_jwk); + let rsa = RsaPrivateKeyBuilder::new( + BigNum::from_slice(&n)?, + BigNum::from_slice(&e)?, + BigNum::from_slice(&d)?, + )? + .build(); + let pkey = PKey::from_rsa(rsa)?; + let openssl_k_from_jwk = + openssl_imp::RsaPrivateKey::from_pkey_without_check(pkey, alg)?; + let sig = openssl_k_from_jwk.sign(b"rustcrypto-jwk->openssl-key")?; + rust_k.verify(b"rustcrypto-jwk->openssl-key", &sig, alg.name())?; } Ok(()) } diff --git a/src/rsa/openssl_imp.rs b/src/rsa/openssl_imp.rs new file mode 100644 index 0000000..e8a5899 --- /dev/null +++ b/src/rsa/openssl_imp.rs @@ -0,0 +1,413 @@ +use super::*; +use crate::{ + jwk::Jwk, PrivateKeyToJwk, PublicKeyToJwk, SigningKey, VerificationKey, URL_SAFE_TRAILING_BITS, +}; +use base64::Engine as _; +use openssl::{ + bn::BigNum, + hash::MessageDigest, + pkey::{Id, PKey, Private, Public}, + rsa::{Padding, Rsa, RsaPrivateKeyBuilder}, + sign::{RsaPssSaltlen, Signer, Verifier}, +}; +use smallvec::SmallVec; + +impl RsaAlgorithm { + pub(crate) fn digest(self) -> MessageDigest { + use RsaAlgorithm::*; + match self { + RS256 | PS256 => MessageDigest::sha256(), + RS384 | PS384 => MessageDigest::sha384(), + RS512 | PS512 => MessageDigest::sha512(), + } + } +} + +/// RSA Private Key. +/// +/// By default, it only verifies signatures generated by the same algorithm used +/// for signing. If you want to verify signatures generated by any RSA +/// algorithm, set `verify_any` to `true`. +#[derive(Debug, Clone)] +pub struct RsaPrivateKey { + pub(crate) private_key: PKey, + pub algorithm: RsaAlgorithm, + pub verify_any: bool, +} + +impl RsaPrivateKey { + /// bits >= 2048. + pub fn generate(bits: usize, algorithm: RsaAlgorithm) -> Result { + if bits < 2048 { + return Err(Error::UnsupportedOrInvalidKey); + } + + Ok(Self { + private_key: PKey::from_rsa(Rsa::generate(bits as u32)?)?, + algorithm, + verify_any: false, + }) + } + + pub(crate) fn from_pkey(pkey: PKey, algorithm: RsaAlgorithm) -> Result { + if pkey.bits() < 2048 || !pkey.rsa()?.check_key()? { + return Err(Error::UnsupportedOrInvalidKey); + } + Ok(Self { + private_key: pkey, + algorithm, + verify_any: false, + }) + } + + pub(crate) fn from_pkey_without_check( + pkey: PKey, + algorithm: RsaAlgorithm, + ) -> Result { + if pkey.bits() < 2048 { + return Err(Error::UnsupportedOrInvalidKey); + } + Ok(Self { + private_key: pkey, + algorithm, + verify_any: false, + }) + } + + pub fn from_pem(pem: &[u8], algorithm: RsaAlgorithm) -> Result { + let pk = PKey::private_key_from_pem(pem)?; + Self::from_pkey(pk, algorithm) + } + + pub fn from_components( + n: &[u8], + e: &[u8], + d: &[u8], + primes: Vec>, + algorithm: RsaAlgorithm, + ) -> Result { + if !matches!(primes.len(), 0 | 2) { + return Err(Error::UnsupportedOrInvalidKey); + } + let rsa = RsaPrivateKeyBuilder::new( + BigNum::from_slice(n)?, + BigNum::from_slice(e)?, + BigNum::from_slice(d)?, + )? + .build(); + let pkey = PKey::from_rsa(rsa)?; + Self::from_pkey_without_check(pkey, algorithm) + } + + pub fn private_key_to_pem_pkcs8(&self) -> Result { + Ok(String::from_utf8( + self.private_key.private_key_to_pem_pkcs8()?, + )?) + } + + pub fn public_key_to_pem(&self) -> Result { + Ok(String::from_utf8(self.private_key.public_key_to_pem()?)?) + } + + pub fn public_key_to_pem_pkcs1(&self) -> Result { + Ok(String::from_utf8( + self.private_key.rsa()?.public_key_to_pem_pkcs1()?, + )?) + } + + fn rsa(&self) -> openssl::rsa::Rsa { + self.private_key.rsa().expect("key is RSA") + } + + pub fn n(&self) -> Vec { + self.rsa().n().to_vec() + } + + pub fn e(&self) -> Vec { + self.rsa().e().to_vec() + } +} + +impl PrivateKeyToJwk for RsaPrivateKey { + #[allow(clippy::many_single_char_names)] + fn private_key_to_jwk(&self) -> Result { + let n = self.n(); + let e = self.e(); + let rsa = self.rsa(); + let d = rsa.d().to_vec(); + let p = rsa.p().map(|p| p.to_vec()); + let q = rsa.q().map(|q| q.to_vec()); + let dp = rsa.dmp1().map(|dp| dp.to_vec()); + let dq = rsa.dmq1().map(|dq| dq.to_vec()); + let qi = rsa.iqmp().map(|qi| qi.to_vec()); + fn encode(x: &[u8]) -> String { + URL_SAFE_TRAILING_BITS.encode(x) + } + Ok(Jwk { + kty: "RSA".into(), + alg: if self.verify_any { + None + } else { + Some(self.algorithm.name().into()) + }, + use_: Some("sig".into()), + n: Some(encode(&n)), + e: Some(encode(&e)), + d: Some(encode(&d)), + p: p.map(|p| encode(&p)), + q: q.map(|q| encode(&q)), + dp: dp.map(|dp| encode(&dp)), + dq: dq.map(|dq| encode(&dq)), + qi: qi.map(|qi| encode(&qi)), + ..Default::default() + }) + } +} + +impl PublicKeyToJwk for RsaPrivateKey { + fn public_key_to_jwk(&self) -> Result { + Ok(Jwk { + kty: "RSA".into(), + alg: if self.verify_any { + None + } else { + Some(self.algorithm.name().into()) + }, + use_: Some("sig".into()), + n: Some(URL_SAFE_TRAILING_BITS.encode(self.n())), + e: Some(URL_SAFE_TRAILING_BITS.encode(self.e())), + ..Jwk::default() + }) + } +} + +/// RSA Public Key. +#[derive(Debug)] +pub struct RsaPublicKey { + pub(crate) public_key: PKey, + /// If this is `None`, this key verifies signatures generated by ANY RSA + /// algorithms. Otherwise it ONLY verifies signatures generated by this + /// algorithm. + pub algorithm: Option, +} + +impl RsaPublicKey { + pub(crate) fn from_pkey(pkey: PKey, algorithm: Option) -> Result { + if pkey.id() != Id::RSA || pkey.bits() < 2048 { + return Err(Error::UnsupportedOrInvalidKey); + } + Ok(Self { + public_key: pkey, + algorithm, + }) + } + + /// Both `BEGIN PUBLIC KEY` and `BEGIN RSA PUBLIC KEY` are OK. + pub fn from_pem(pem: &[u8], algorithm: Option) -> Result { + if std::str::from_utf8(pem).is_ok_and(|pem| pem.contains("BEGIN RSA")) { + let rsa = Rsa::public_key_from_pem_pkcs1(pem)?; + Self::from_pkey(PKey::from_rsa(rsa)?, algorithm) + } else { + let pkey = PKey::public_key_from_pem(pem)?; + Self::from_pkey(pkey, algorithm) + } + } + + pub fn from_components(n: &[u8], e: &[u8], algorithm: Option) -> Result { + let rsa = Rsa::from_public_components(BigNum::from_slice(n)?, BigNum::from_slice(e)?)?; + Self::from_pkey(PKey::from_rsa(rsa)?, algorithm) + } + + /// BEGIN PUBLIC KEY + pub fn to_pem(&self) -> Result { + Ok(String::from_utf8(self.public_key.public_key_to_pem()?)?) + } + + /// BEGIN RSA PUBLIC KEY + pub fn to_pem_pkcs1(&self) -> Result { + Ok(String::from_utf8( + self.public_key.rsa()?.public_key_to_pem_pkcs1()?, + )?) + } + + fn rsa(&self) -> openssl::rsa::Rsa { + self.public_key.rsa().expect("key is RSA") + } + + pub fn n(&self) -> Vec { + self.rsa().n().to_vec() + } + + pub fn e(&self) -> Vec { + self.rsa().e().to_vec() + } +} + +impl PublicKeyToJwk for RsaPublicKey { + fn public_key_to_jwk(&self) -> Result { + Ok(Jwk { + kty: "RSA".into(), + alg: self.algorithm.map(|alg| alg.name().to_string()), + use_: Some("sig".into()), + n: Some(URL_SAFE_TRAILING_BITS.encode(self.n())), + e: Some(URL_SAFE_TRAILING_BITS.encode(self.e())), + ..Jwk::default() + }) + } +} + +impl SigningKey for RsaPrivateKey { + fn sign(&self, v: &[u8]) -> Result> { + let mut signer = Signer::new(self.algorithm.digest(), self.private_key.as_ref())?; + if self.algorithm.is_pss() { + signer.set_rsa_padding(Padding::PKCS1_PSS)?; + signer.set_rsa_pss_saltlen(RsaPssSaltlen::DIGEST_LENGTH)?; + } + + signer.update(v)?; + Ok(signer.sign_to_vec()?.into()) + } + + fn alg(&self) -> &'static str { + self.algorithm.name() + } +} + +impl VerificationKey for RsaPrivateKey { + fn verify(&self, v: &[u8], sig: &[u8], alg: &str) -> Result<()> { + let alg = if self.verify_any { + RsaAlgorithm::from_name(alg)? + } else { + if alg != self.algorithm.name() { + return Err(Error::VerificationError); + } + self.algorithm + }; + + let mut verifier = Verifier::new(alg.digest(), self.private_key.as_ref())?; + if alg.is_pss() { + verifier.set_rsa_padding(Padding::PKCS1_PSS)?; + verifier.set_rsa_pss_saltlen(RsaPssSaltlen::DIGEST_LENGTH)?; + } + if verifier.verify_oneshot(sig, v)? { + Ok(()) + } else { + Err(Error::VerificationError) + } + } +} + +impl VerificationKey for RsaPublicKey { + fn verify(&self, v: &[u8], sig: &[u8], alg: &str) -> Result<()> { + let alg = if let Some(self_alg) = self.algorithm { + if self_alg.name() != alg { + return Err(Error::VerificationError); + } + self_alg + } else { + RsaAlgorithm::from_name(alg)? + }; + + let mut verifier = Verifier::new(alg.digest(), self.public_key.as_ref())?; + if alg.is_pss() { + verifier.set_rsa_padding(Padding::PKCS1_PSS)?; + verifier.set_rsa_pss_saltlen(RsaPssSaltlen::DIGEST_LENGTH)?; + } + if verifier.verify_oneshot(sig, v)? { + Ok(()) + } else { + Err(Error::VerificationError) + } + } +} + +#[cfg(test)] +mod tests { + use crate::{ + ecdsa::{EcdsaAlgorithm, EcdsaPrivateKey}, + SomePrivateKey, + }; + + use super::*; + + #[test] + fn conversion() -> Result<()> { + let k = RsaPrivateKey::generate(2048, RsaAlgorithm::PS384)?; + let pem = k.private_key_to_pem_pkcs8()?; + RsaPrivateKey::from_pem(pem.as_bytes(), RsaAlgorithm::PS384)?; + + let es256key_pem = + EcdsaPrivateKey::generate(EcdsaAlgorithm::ES256)?.private_key_to_pem_pkcs8()?; + assert!(RsaPrivateKey::from_pem(es256key_pem.as_bytes(), RsaAlgorithm::PS384).is_err()); + + let pk_pem = k.public_key_to_pem()?; + let pk_pem_pkcs1 = k.public_key_to_pem_pkcs1()?; + + let pk = RsaPublicKey::from_pem(pk_pem.as_bytes(), None)?; + let pk1 = RsaPublicKey::from_pem(pk_pem_pkcs1.as_bytes(), None)?; + + println!("pk: {:?}", pk); + + let pk_pem1 = pk1.to_pem()?; + let pk_pem_pkcs1_1 = pk.to_pem_pkcs1()?; + + assert_eq!(pk_pem, pk_pem1); + assert_eq!(pk_pem_pkcs1, pk_pem_pkcs1_1); + + assert_eq!(k.alg(), "PS384"); + + if let SomePrivateKey::Rsa(k1) = k + .private_key_to_jwk()? + .to_signing_key(RsaAlgorithm::RS512)? + { + assert!(k.private_key.public_eq(k1.private_key.as_ref())); + } else { + panic!("expected rsa private key"); + } + + k.public_key_to_jwk()?.to_verification_key()?; + pk.public_key_to_jwk()?; + + Ok(()) + } + + #[test] + fn test_private_key_from_jwk_n_e_d_only() -> Result<()> { + let k = RsaPrivateKey::generate(2048, RsaAlgorithm::PS256)?; + let mut jwk = k.private_key_to_jwk()?; + jwk.p = None; + jwk.q = None; + jwk.dp = None; + jwk.dq = None; + jwk.qi = None; + let k1 = jwk.to_signing_key(RsaAlgorithm::RS256)?; + let sig = k1.sign(b"msg")?; + k.verify(b"msg", &sig, "PS256")?; + k1.verify(b"msg", &sig, "PS256")?; + let sig = k.sign(b"msg")?; + k1.verify(b"msg", &sig, "PS256")?; + Ok(()) + } + + #[test] + fn sign_verify() -> Result<()> { + for alg in [ + RsaAlgorithm::RS256, + RsaAlgorithm::RS384, + RsaAlgorithm::RS512, + RsaAlgorithm::PS256, + RsaAlgorithm::PS384, + RsaAlgorithm::PS512, + ] { + let k = RsaPrivateKey::generate(2048, alg)?; + let pk = RsaPublicKey::from_pem(k.public_key_to_pem()?.as_bytes(), None)?; + let sig = k.sign(b"...")?; + assert!(k.verify(b"...", &sig, alg.name()).is_ok()); + assert!(k.verify(b"...", &sig, "WRONG ALG").is_err()); + assert!(k.verify(b"....", &sig, alg.name()).is_err()); + assert!(pk.verify(b"...", &sig, alg.name()).is_ok()); + assert!(pk.verify(b"....", &sig, alg.name()).is_err()); + } + Ok(()) + } +} diff --git a/src/rsa/rustcrypto_imp.rs b/src/rsa/rustcrypto_imp.rs new file mode 100644 index 0000000..43f4ff0 --- /dev/null +++ b/src/rsa/rustcrypto_imp.rs @@ -0,0 +1,344 @@ +use super::*; +use crate::{ + jwk::Jwk, PrivateKeyToJwk, PublicKeyToJwk, SigningKey, VerificationKey, URL_SAFE_TRAILING_BITS, +}; +use base64::Engine as _; +use rsa_crt::{ + pkcs1v15, pss, + traits::{PrivateKeyParts, PublicKeyParts}, + BigUint, RsaPrivateKey as CratePrivateKey, RsaPublicKey as CratePublicKey, +}; +use sha2::{Sha256, Sha384, Sha512}; +use signature::{SignatureEncoding, Signer, Verifier}; +use smallvec::SmallVec; + +/// RSA Private Key. +/// +/// By default, it only verifies signatures generated by the same algorithm +/// used for signing. If you want to verify signatures generated by any RSA +/// algorithm, set `verify_any` to `true`. +#[derive(Debug, Clone)] +pub struct RsaPrivateKey { + pub(crate) private_key: CratePrivateKey, + pub algorithm: RsaAlgorithm, + pub verify_any: bool, +} + +impl RsaPrivateKey { + pub fn generate(bits: usize, algorithm: RsaAlgorithm) -> Result { + if bits < 2048 { + return Err(Error::UnsupportedOrInvalidKey); + } + + let private_key = CratePrivateKey::new(&mut rand_core::OsRng, bits) + .map_err(|e| Error::Crypto(Box::new(e)))?; + Ok(Self { + private_key, + algorithm, + verify_any: false, + }) + } + + pub fn from_components( + n: &[u8], + e: &[u8], + d: &[u8], + primes: Vec>, + algorithm: RsaAlgorithm, + ) -> Result { + let n = BigUint::from_bytes_be(n); + let e = BigUint::from_bytes_be(e); + let d = BigUint::from_bytes_be(d); + let primes: Vec = primes + .into_iter() + .map(|p| BigUint::from_bytes_be(&p)) + .collect(); + let private_key = CratePrivateKey::from_components(n, e, d, primes) + .map_err(|e| Error::Crypto(Box::new(e)))?; + if private_key.n().bits() < 2048 { + return Err(Error::UnsupportedOrInvalidKey); + } + Ok(Self { + private_key, + algorithm, + verify_any: false, + }) + } + + pub fn n(&self) -> Vec { + self.private_key.n().to_bytes_be() + } + + pub fn e(&self) -> Vec { + self.private_key.e().to_bytes_be() + } +} + +impl PrivateKeyToJwk for RsaPrivateKey { + fn private_key_to_jwk(&self) -> Result { + let n = self.n(); + let e = self.e(); + let d = self.private_key.d().to_bytes_be(); + let primes = self.private_key.primes(); + + fn encode(x: &[u8]) -> String { + URL_SAFE_TRAILING_BITS.encode(x) + } + + let (p, q, dp, dq, qi) = if primes.len() >= 2 { + let p = &primes[0]; + let q = &primes[1]; + let one = BigUint::from(1u32); + let dp = self.private_key.d() % (p - &one); + let dq = self.private_key.d() % (q - &one); + // qi = q^(p-2) mod p (Fermat's little theorem) + let qi = q.modpow(&(p - BigUint::from(2u32)), p); + ( + Some(encode(&p.to_bytes_be())), + Some(encode(&q.to_bytes_be())), + Some(encode(&dp.to_bytes_be())), + Some(encode(&dq.to_bytes_be())), + Some(encode(&qi.to_bytes_be())), + ) + } else { + (None, None, None, None, None) + }; + + Ok(Jwk { + kty: "RSA".into(), + alg: if self.verify_any { + None + } else { + Some(self.algorithm.name().into()) + }, + use_: Some("sig".into()), + n: Some(encode(&n)), + e: Some(encode(&e)), + d: Some(encode(&d)), + p, + q, + dp, + dq, + qi, + ..Default::default() + }) + } +} + +impl PublicKeyToJwk for RsaPrivateKey { + fn public_key_to_jwk(&self) -> Result { + Ok(Jwk { + kty: "RSA".into(), + alg: if self.verify_any { + None + } else { + Some(self.algorithm.name().into()) + }, + use_: Some("sig".into()), + n: Some(URL_SAFE_TRAILING_BITS.encode(self.n())), + e: Some(URL_SAFE_TRAILING_BITS.encode(self.e())), + ..Jwk::default() + }) + } +} + +/// RSA Public Key. +#[derive(Debug, Clone)] +pub struct RsaPublicKey { + pub(crate) public_key: CratePublicKey, + /// If this is `None`, this key verifies signatures generated by ANY RSA + /// algorithms. Otherwise it ONLY verifies signatures generated by this + /// algorithm. + pub algorithm: Option, +} + +impl RsaPublicKey { + pub fn from_components(n: &[u8], e: &[u8], algorithm: Option) -> Result { + let n = BigUint::from_bytes_be(n); + let e = BigUint::from_bytes_be(e); + let public_key = CratePublicKey::new(n, e).map_err(|e| Error::Crypto(Box::new(e)))?; + if public_key.n().bits() < 2048 { + return Err(Error::UnsupportedOrInvalidKey); + } + Ok(Self { + public_key, + algorithm, + }) + } + + pub fn n(&self) -> Vec { + self.public_key.n().to_bytes_be() + } + + pub fn e(&self) -> Vec { + self.public_key.e().to_bytes_be() + } +} + +impl PublicKeyToJwk for RsaPublicKey { + fn public_key_to_jwk(&self) -> Result { + Ok(Jwk { + kty: "RSA".into(), + alg: self.algorithm.map(|alg| alg.name().to_string()), + use_: Some("sig".into()), + n: Some(URL_SAFE_TRAILING_BITS.encode(self.n())), + e: Some(URL_SAFE_TRAILING_BITS.encode(self.e())), + ..Jwk::default() + }) + } +} + +macro_rules! rsa_sign { + (pkcs1v15, $hash:ty, $key:expr, $msg:expr) => {{ + let sk = pkcs1v15::SigningKey::<$hash>::new($key.clone()); + let sig: pkcs1v15::Signature = sk.try_sign($msg).map_err(|e| Error::Crypto(Box::new(e)))?; + Ok(sig.to_bytes().to_vec().into()) + }}; + (pss, $hash:ty, $key:expr, $msg:expr) => {{ + let sk = pss::SigningKey::<$hash>::new($key.clone()); + let sig: pss::Signature = sk.try_sign($msg).map_err(|e| Error::Crypto(Box::new(e)))?; + Ok(sig.to_bytes().to_vec().into()) + }}; +} + +impl SigningKey for RsaPrivateKey { + fn sign(&self, v: &[u8]) -> Result> { + use RsaAlgorithm::*; + match self.algorithm { + RS256 => rsa_sign!(pkcs1v15, Sha256, self.private_key, v), + RS384 => rsa_sign!(pkcs1v15, Sha384, self.private_key, v), + RS512 => rsa_sign!(pkcs1v15, Sha512, self.private_key, v), + PS256 => rsa_sign!(pss, Sha256, self.private_key, v), + PS384 => rsa_sign!(pss, Sha384, self.private_key, v), + PS512 => rsa_sign!(pss, Sha512, self.private_key, v), + } + } + + fn alg(&self) -> &'static str { + self.algorithm.name() + } +} + +macro_rules! rsa_verify_pkcs1v15 { + ($hash:ty, $key:expr, $msg:expr, $sig:expr) => {{ + let vk = pkcs1v15::VerifyingKey::<$hash>::new($key.clone()); + let sig = pkcs1v15::Signature::try_from($sig).map_err(|_| Error::VerificationError)?; + vk.verify($msg, &sig).map_err(|_| Error::VerificationError) + }}; +} + +macro_rules! rsa_verify_pss { + ($hash:ty, $key:expr, $msg:expr, $sig:expr) => {{ + let vk = pss::VerifyingKey::<$hash>::new($key.clone()); + let sig = pss::Signature::try_from($sig).map_err(|_| Error::VerificationError)?; + vk.verify($msg, &sig).map_err(|_| Error::VerificationError) + }}; +} + +fn rsa_verify(alg: RsaAlgorithm, public_key: &CratePublicKey, v: &[u8], sig: &[u8]) -> Result<()> { + use RsaAlgorithm::*; + match alg { + RS256 => rsa_verify_pkcs1v15!(Sha256, public_key, v, sig), + RS384 => rsa_verify_pkcs1v15!(Sha384, public_key, v, sig), + RS512 => rsa_verify_pkcs1v15!(Sha512, public_key, v, sig), + PS256 => rsa_verify_pss!(Sha256, public_key, v, sig), + PS384 => rsa_verify_pss!(Sha384, public_key, v, sig), + PS512 => rsa_verify_pss!(Sha512, public_key, v, sig), + } +} + +impl VerificationKey for RsaPrivateKey { + fn verify(&self, v: &[u8], sig: &[u8], alg: &str) -> Result<()> { + let alg = if self.verify_any { + RsaAlgorithm::from_name(alg)? + } else { + if alg != self.algorithm.name() { + return Err(Error::VerificationError); + } + self.algorithm + }; + rsa_verify(alg, self.private_key.as_ref(), v, sig) + } +} + +impl VerificationKey for RsaPublicKey { + fn verify(&self, v: &[u8], sig: &[u8], alg: &str) -> Result<()> { + let alg = if let Some(self_alg) = self.algorithm { + if self_alg.name() != alg { + return Err(Error::VerificationError); + } + self_alg + } else { + RsaAlgorithm::from_name(alg)? + }; + rsa_verify(alg, &self.public_key, v, sig) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn sign_verify() -> Result<()> { + for alg in [ + RsaAlgorithm::RS256, + RsaAlgorithm::RS384, + RsaAlgorithm::RS512, + RsaAlgorithm::PS256, + RsaAlgorithm::PS384, + RsaAlgorithm::PS512, + ] { + let k = RsaPrivateKey::generate(2048, alg)?; + let pk = RsaPublicKey::from_components(&k.n(), &k.e(), None)?; + let sig = k.sign(b"...")?; + assert!(k.verify(b"...", &sig, alg.name()).is_ok()); + assert!(k.verify(b"...", &sig, "WRONG ALG").is_err()); + assert!(k.verify(b"....", &sig, alg.name()).is_err()); + assert!(pk.verify(b"...", &sig, alg.name()).is_ok()); + assert!(pk.verify(b"....", &sig, alg.name()).is_err()); + } + Ok(()) + } + + #[test] + fn jwk_roundtrip() -> Result<()> { + let k = RsaPrivateKey::generate(2048, RsaAlgorithm::PS256)?; + let jwk = k.private_key_to_jwk()?; + let k1 = jwk.to_signing_key(RsaAlgorithm::RS256)?; + let sig = k1.sign(b"msg")?; + k.verify(b"msg", &sig, "PS256")?; + k1.verify(b"msg", &sig, "PS256")?; + let sig = k.sign(b"msg")?; + k1.verify(b"msg", &sig, "PS256")?; + Ok(()) + } + + #[test] + fn jwk_n_e_d_only() -> Result<()> { + let k = RsaPrivateKey::generate(2048, RsaAlgorithm::PS256)?; + let mut jwk = k.private_key_to_jwk()?; + jwk.p = None; + jwk.q = None; + jwk.dp = None; + jwk.dq = None; + jwk.qi = None; + let k1 = jwk.to_signing_key(RsaAlgorithm::RS256)?; + let sig = k1.sign(b"msg")?; + k.verify(b"msg", &sig, "PS256")?; + k1.verify(b"msg", &sig, "PS256")?; + let sig = k.sign(b"msg")?; + k1.verify(b"msg", &sig, "PS256")?; + Ok(()) + } + + #[test] + fn public_key_jwk() -> Result<()> { + let k = RsaPrivateKey::generate(2048, RsaAlgorithm::RS256)?; + let jwk = k.public_key_to_jwk()?; + let pk = jwk.to_verification_key()?; + let sig = k.sign(b"hello")?; + pk.verify(b"hello", &sig, "RS256")?; + Ok(()) + } +} diff --git a/src/some.rs b/src/some.rs index a5ea5b1..f8604d2 100644 --- a/src/some.rs +++ b/src/some.rs @@ -1,14 +1,15 @@ -//! Enum of HMAC / EC / RSA / Ed Keys. +//! Enum of EC / RSA / Ed keys. -use openssl::pkey::{Id, PKey}; +#[cfg(feature = "openssl")] +use crate::ecdsa::{EcdsaPrivateKey, EcdsaPublicKey}; +#[cfg(feature = "openssl")] +use crate::eddsa::{Ed25519PrivateKey, Ed25519PublicKey}; +#[cfg(any(feature = "rsa", feature = "openssl"))] +use crate::rsa::{RsaPrivateKey, RsaPublicKey}; -use crate::{ - ecdsa::{EcdsaPrivateKey, EcdsaPublicKey}, - eddsa::{Ed25519PrivateKey, Ed25519PublicKey}, - jwk::Jwk, - rsa::{RsaAlgorithm, RsaPrivateKey, RsaPublicKey}, - Error, PrivateKeyToJwk, PublicKeyToJwk, Result, SigningKey, VerificationKey, -}; +#[cfg(feature = "openssl")] +use crate::Error; +use crate::{jwk::Jwk, PrivateKeyToJwk, PublicKeyToJwk, Result, SigningKey, VerificationKey}; /// An RSA, EC or Ed25519 private key. /// @@ -17,23 +18,30 @@ use crate::{ #[non_exhaustive] #[derive(Debug)] pub enum SomePrivateKey { + #[cfg(feature = "openssl")] Ed25519(Ed25519PrivateKey), + #[cfg(feature = "openssl")] Ecdsa(EcdsaPrivateKey), + #[cfg(any(feature = "rsa", feature = "openssl"))] Rsa(RsaPrivateKey), } -/// An RSA, EC or Ed25519 public. +/// An RSA, EC or Ed25519 public key. /// /// Use this if you just want to load SOME public key from an external pem file /// or JWK. #[non_exhaustive] #[derive(Debug)] pub enum SomePublicKey { + #[cfg(feature = "openssl")] Ed25519(Ed25519PublicKey), + #[cfg(feature = "openssl")] Ecdsa(EcdsaPublicKey), + #[cfg(any(feature = "rsa", feature = "openssl"))] Rsa(RsaPublicKey), } +#[cfg(feature = "openssl")] impl From for SomePrivateKey { #[inline] fn from(k: Ed25519PrivateKey) -> SomePrivateKey { @@ -41,6 +49,7 @@ impl From for SomePrivateKey { } } +#[cfg(feature = "openssl")] impl From for SomePrivateKey { #[inline] fn from(k: EcdsaPrivateKey) -> SomePrivateKey { @@ -48,6 +57,7 @@ impl From for SomePrivateKey { } } +#[cfg(any(feature = "rsa", feature = "openssl"))] impl From for SomePrivateKey { #[inline] fn from(k: RsaPrivateKey) -> SomePrivateKey { @@ -55,6 +65,7 @@ impl From for SomePrivateKey { } } +#[cfg(feature = "openssl")] impl From for SomePublicKey { #[inline] fn from(k: Ed25519PublicKey) -> SomePublicKey { @@ -62,6 +73,7 @@ impl From for SomePublicKey { } } +#[cfg(feature = "openssl")] impl From for SomePublicKey { #[inline] fn from(k: EcdsaPublicKey) -> SomePublicKey { @@ -69,6 +81,7 @@ impl From for SomePublicKey { } } +#[cfg(any(feature = "rsa", feature = "openssl"))] impl From for SomePublicKey { #[inline] fn from(k: RsaPublicKey) -> SomePublicKey { @@ -76,6 +89,7 @@ impl From for SomePublicKey { } } +#[cfg(feature = "openssl")] impl SomePrivateKey { /// Read an RSA/EC/Ed25519 private key from PEM. /// @@ -83,14 +97,13 @@ impl SomePrivateKey { /// P-256 -> ES256. /// /// For an RSA private key, `if_rsa_algorithm` is used. - pub fn from_pem(pem: &[u8], if_rsa_algorithm: RsaAlgorithm) -> Result { + pub fn from_pem(pem: &[u8], if_rsa_algorithm: crate::rsa::RsaAlgorithm) -> Result { + use openssl::pkey::{Id, PKey}; + let pk = PKey::private_key_from_pem(pem)?; match pk.id() { - Id::RSA => { - let k = RsaPrivateKey::from_pkey(pk, if_rsa_algorithm)?; - Ok(Self::Rsa(k)) - } + Id::RSA => Ok(Self::Rsa(RsaPrivateKey::from_pkey(pk, if_rsa_algorithm)?)), Id::EC => { let k = EcdsaPrivateKey::from_pkey(pk)?; Ok(Self::Ecdsa(k)) @@ -120,26 +133,7 @@ impl SomePrivateKey { } } -impl PublicKeyToJwk for SomePrivateKey { - fn public_key_to_jwk(&self) -> Result { - match self { - SomePrivateKey::Ed25519(ed) => ed.public_key_to_jwk(), - SomePrivateKey::Ecdsa(ec) => ec.public_key_to_jwk(), - SomePrivateKey::Rsa(rsa) => rsa.public_key_to_jwk(), - } - } -} - -impl PrivateKeyToJwk for SomePrivateKey { - fn private_key_to_jwk(&self) -> Result { - match self { - SomePrivateKey::Ed25519(ed) => ed.private_key_to_jwk(), - SomePrivateKey::Ecdsa(ec) => ec.private_key_to_jwk(), - SomePrivateKey::Rsa(rsa) => rsa.private_key_to_jwk(), - } - } -} - +#[cfg(feature = "openssl")] impl SomePublicKey { /// Read an RSA/EC/Ed25519 public key from PEM. /// @@ -149,12 +143,11 @@ impl SomePublicKey { /// For an RSA public key, signatures generated by any RSA algorithms can be /// verified. pub fn from_pem(pem: &[u8]) -> Result { + use openssl::pkey::{Id, PKey}; + let pk = PKey::public_key_from_pem(pem)?; match pk.id() { - Id::RSA => { - let k = RsaPublicKey::from_pkey(pk, None)?; - Ok(Self::Rsa(k)) - } + Id::RSA => Ok(Self::Rsa(RsaPublicKey::from_pkey(pk, None)?)), Id::EC => { let k = EcdsaPublicKey::from_pkey(pk)?; Ok(Self::Ecdsa(k)) @@ -176,40 +169,90 @@ impl SomePublicKey { } } +impl PublicKeyToJwk for SomePrivateKey { + fn public_key_to_jwk(&self) -> Result { + match self { + #[cfg(feature = "openssl")] + SomePrivateKey::Ed25519(ed) => ed.public_key_to_jwk(), + #[cfg(feature = "openssl")] + SomePrivateKey::Ecdsa(ec) => ec.public_key_to_jwk(), + #[cfg(any(feature = "rsa", feature = "openssl"))] + SomePrivateKey::Rsa(rsa) => rsa.public_key_to_jwk(), + #[cfg(not(any(feature = "rsa", feature = "openssl")))] + _ => unreachable!(), + } + } +} + +impl PrivateKeyToJwk for SomePrivateKey { + fn private_key_to_jwk(&self) -> Result { + match self { + #[cfg(feature = "openssl")] + SomePrivateKey::Ed25519(ed) => ed.private_key_to_jwk(), + #[cfg(feature = "openssl")] + SomePrivateKey::Ecdsa(ec) => ec.private_key_to_jwk(), + #[cfg(any(feature = "rsa", feature = "openssl"))] + SomePrivateKey::Rsa(rsa) => rsa.private_key_to_jwk(), + #[cfg(not(any(feature = "rsa", feature = "openssl")))] + _ => unreachable!(), + } + } +} + impl SigningKey for SomePrivateKey { fn alg(&self) -> &'static str { match self { + #[cfg(feature = "openssl")] SomePrivateKey::Ed25519(ed) => ed.alg(), + #[cfg(feature = "openssl")] SomePrivateKey::Ecdsa(ec) => ec.alg(), + #[cfg(any(feature = "rsa", feature = "openssl"))] SomePrivateKey::Rsa(rsa) => rsa.alg(), + #[cfg(not(any(feature = "rsa", feature = "openssl")))] + _ => unreachable!(), } } - fn sign(&self, v: &[u8]) -> crate::Result> { + fn sign(&self, _v: &[u8]) -> crate::Result> { match self { - SomePrivateKey::Ed25519(ed) => ed.sign(v), - SomePrivateKey::Ecdsa(ec) => ec.sign(v), - SomePrivateKey::Rsa(rsa) => rsa.sign(v), + #[cfg(feature = "openssl")] + SomePrivateKey::Ed25519(ed) => ed.sign(_v), + #[cfg(feature = "openssl")] + SomePrivateKey::Ecdsa(ec) => ec.sign(_v), + #[cfg(any(feature = "rsa", feature = "openssl"))] + SomePrivateKey::Rsa(rsa) => rsa.sign(_v), + #[cfg(not(any(feature = "rsa", feature = "openssl")))] + _ => unreachable!(), } } } impl VerificationKey for SomePrivateKey { - fn verify(&self, v: &[u8], sig: &[u8], alg: &str) -> crate::Result<()> { + fn verify(&self, _v: &[u8], _sig: &[u8], _alg: &str) -> crate::Result<()> { match self { - SomePrivateKey::Ed25519(ed) => ed.verify(v, sig, alg), - SomePrivateKey::Ecdsa(ec) => ec.verify(v, sig, alg), - SomePrivateKey::Rsa(rsa) => rsa.verify(v, sig, alg), + #[cfg(feature = "openssl")] + SomePrivateKey::Ed25519(ed) => ed.verify(_v, _sig, _alg), + #[cfg(feature = "openssl")] + SomePrivateKey::Ecdsa(ec) => ec.verify(_v, _sig, _alg), + #[cfg(any(feature = "rsa", feature = "openssl"))] + SomePrivateKey::Rsa(rsa) => rsa.verify(_v, _sig, _alg), + #[cfg(not(any(feature = "rsa", feature = "openssl")))] + _ => unreachable!(), } } } impl VerificationKey for SomePublicKey { - fn verify(&self, v: &[u8], sig: &[u8], alg: &str) -> crate::Result<()> { + fn verify(&self, _v: &[u8], _sig: &[u8], _alg: &str) -> crate::Result<()> { match self { - SomePublicKey::Ed25519(ed) => ed.verify(v, sig, alg), - SomePublicKey::Ecdsa(ec) => ec.verify(v, sig, alg), - SomePublicKey::Rsa(rsa) => rsa.verify(v, sig, alg), + #[cfg(feature = "openssl")] + SomePublicKey::Ed25519(ed) => ed.verify(_v, _sig, _alg), + #[cfg(feature = "openssl")] + SomePublicKey::Ecdsa(ec) => ec.verify(_v, _sig, _alg), + #[cfg(any(feature = "rsa", feature = "openssl"))] + SomePublicKey::Rsa(rsa) => rsa.verify(_v, _sig, _alg), + #[cfg(not(any(feature = "rsa", feature = "openssl")))] + _ => unreachable!(), } } } @@ -217,9 +260,14 @@ impl VerificationKey for SomePublicKey { impl PublicKeyToJwk for SomePublicKey { fn public_key_to_jwk(&self) -> Result { match self { + #[cfg(feature = "openssl")] SomePublicKey::Ed25519(ed) => ed.public_key_to_jwk(), + #[cfg(feature = "openssl")] SomePublicKey::Ecdsa(ec) => ec.public_key_to_jwk(), + #[cfg(any(feature = "rsa", feature = "openssl"))] SomePublicKey::Rsa(rsa) => rsa.public_key_to_jwk(), + #[cfg(not(any(feature = "rsa", feature = "openssl")))] + _ => unreachable!(), } } }