Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ml-dsa/benches/ml_dsa.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use criterion::{Criterion, criterion_group, criterion_main};
use getrandom::SysRng;
use hybrid_array::{Array, ArraySize};
use ml_dsa::{B32, KeyGen, MlDsa65, Signature, SigningKey, VerifyingKey};
use ml_dsa::{B32, KeyGen, MlDsa65, Signature, SigningKey, VerifyingKey, signature::Keypair};
use rand_core::{CryptoRng, UnwrapErr};

pub fn rand<L: ArraySize, R: CryptoRng + ?Sized>(rng: &mut R) -> Array<u8, L> {
Expand Down
128 changes: 57 additions & 71 deletions ml-dsa/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ use crate::hint::Hint;
use crate::ntt::{Ntt, NttInverse};
use crate::param::{ParameterSet, QMinus1, SamplingSize, SpecQ};
use crate::sampling::{expand_a, expand_mask, expand_s, sample_in_ball};
use core::convert::{AsRef, TryFrom, TryInto};
use core::convert::{TryFrom, TryInto};
use core::fmt;
use hybrid_array::{
Array,
Expand Down Expand Up @@ -184,31 +184,23 @@ impl AsMut<Shake256> for MuBuilder {
}
}

/// An ML-DSA key pair
pub struct KeyPair<P: MlDsaParams> {
/// An ML-DSA signing key initialized through a seed
pub struct SeededSigningKey<P: MlDsaParams> {
/// The signing key of the key pair
signing_key: SigningKey<P>,

/// The verifying key of the key pair
verifying_key: VerifyingKey<P>,

/// The seed this signing key was derived from
seed: B32,
}

impl<P: MlDsaParams> KeyPair<P> {
impl<P: MlDsaParams> SeededSigningKey<P> {
/// The signing key of the key pair
pub fn signing_key(&self) -> &SigningKey<P> {
&self.signing_key
}

/// The verifying key of the key pair
pub fn verifying_key(&self) -> &VerifyingKey<P> {
&self.verifying_key
}

/// Serialize the [`Seed`] value: 32-bytes which can be used to reconstruct the
/// [`KeyPair`].
/// [`SeededSigningKey`].
///
/// # ⚠️ Warning!
///
Expand All @@ -219,43 +211,38 @@ impl<P: MlDsaParams> KeyPair<P> {
}
}

impl<P: MlDsaParams> AsRef<VerifyingKey<P>> for KeyPair<P> {
fn as_ref(&self) -> &VerifyingKey<P> {
&self.verifying_key
}
}

impl<P: MlDsaParams> fmt::Debug for KeyPair<P> {
impl<P: MlDsaParams> fmt::Debug for SeededSigningKey<P> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("KeyPair")
.field("verifying_key", &self.verifying_key)
.finish_non_exhaustive()
f.debug_struct("SeededSigningKey").finish_non_exhaustive()
}
}

impl<P: MlDsaParams> signature::KeypairRef for KeyPair<P> {
impl<P: MlDsaParams> signature::Keypair for SeededSigningKey<P> {
type VerifyingKey = VerifyingKey<P>;
fn verifying_key(&self) -> VerifyingKey<P> {
self.signing_key.verifying_key()
}
}

/// The `Signer` implementation for `KeyPair` uses the optional deterministic variant of ML-DSA, and
/// The `Signer` implementation for `SeededSigningKey` uses the optional deterministic variant of ML-DSA, and
/// only supports signing with an empty context string.
impl<P: MlDsaParams> Signer<Signature<P>> for KeyPair<P> {
impl<P: MlDsaParams> Signer<Signature<P>> for SeededSigningKey<P> {
fn try_sign(&self, msg: &[u8]) -> Result<Signature<P>, Error> {
self.try_multipart_sign(&[msg])
}
}

/// The `Signer` implementation for `KeyPair` uses the optional deterministic variant of ML-DSA, and
/// The `Signer` implementation for `SeededSigningKey` uses the optional deterministic variant of ML-DSA, and
/// only supports signing with an empty context string.
impl<P: MlDsaParams> MultipartSigner<Signature<P>> for KeyPair<P> {
impl<P: MlDsaParams> MultipartSigner<Signature<P>> for SeededSigningKey<P> {
fn try_multipart_sign(&self, msg: &[&[u8]]) -> Result<Signature<P>, Error> {
self.signing_key.raw_sign_deterministic(msg, &[])
}
}

/// The `DigestSigner` implementation for `KeyPair` uses the optional deterministic variant of ML-DSA
/// The `DigestSigner` implementation for `SeededSigningKey` uses the optional deterministic variant of ML-DSA
/// with a pre-computed μ, and only supports signing with an empty context string.
impl<P: MlDsaParams> DigestSigner<Shake256, Signature<P>> for KeyPair<P> {
impl<P: MlDsaParams> DigestSigner<Shake256, Signature<P>> for SeededSigningKey<P> {
fn try_sign_digest<F: Fn(&mut Shake256) -> Result<(), Error>>(
&self,
f: F,
Expand Down Expand Up @@ -547,9 +534,9 @@ impl<P: MlDsaParams> SigningKey<P> {

/// DEPRECATED: encode the key in a fixed-size byte array.
///
/// Note that this form is deprecated in practice; prefer to use [`KeyPair::to_seed`].
/// Note that this form is deprecated in practice; prefer to use [`SeededSigningKey:to_seed`].
// Algorithm 24 skEncode
#[deprecated(since = "0.1.0", note = "use `KeyPair::to_seed` instead")]
#[deprecated(since = "0.1.0", note = "use `SeededSigningKey::to_seed` instead")]
pub fn to_expanded(&self) -> ExpandedSigningKey<P>
where
P: MlDsaParams,
Expand Down Expand Up @@ -602,7 +589,7 @@ impl<P: MlDsaParams> DigestSigner<Shake256, Signature<P>> for SigningKey<P> {
}
}

/// The `KeyPair` implementation for `SigningKey` allows to derive a `VerifyingKey` from
/// The [`signature::KeyPair`] implementation for `SigningKey` allows to derive a `VerifyingKey` from
/// a bare `SigningKey` (even in the absence of the original seed).
impl<P: MlDsaParams> signature::Keypair for SigningKey<P> {
type VerifyingKey = VerifyingKey<P>;
Expand Down Expand Up @@ -902,12 +889,12 @@ impl<P> KeyGen for P
where
P: MlDsaParams,
{
type KeyPair = KeyPair<P>;
type KeyPair = SeededSigningKey<P>;

/// Generate a signing key pair from the specified RNG
// Algorithm 1 ML-DSA.KeyGen()
#[cfg(feature = "rand_core")]
fn key_gen<R: CryptoRng + ?Sized>(rng: &mut R) -> KeyPair<P> {
fn key_gen<R: CryptoRng + ?Sized>(rng: &mut R) -> SeededSigningKey<P> {
let mut xi = B32::default();
rng.fill_bytes(&mut xi);
Self::from_seed(&xi)
Expand All @@ -917,7 +904,7 @@ where
///
/// This method reflects the ML-DSA.KeyGen_internal algorithm from FIPS 204.
// Algorithm 6 ML-DSA.KeyGen_internal
fn from_seed(xi: &Seed) -> KeyPair<P>
fn from_seed(xi: &Seed) -> SeededSigningKey<P>
where
P: MlDsaParams,
{
Expand All @@ -943,12 +930,12 @@ where
// Compress and encode
let (t1, t0) = t.power2round();

let verifying_key = VerifyingKey::new(rho, t1, A_hat.clone(), None);
let signing_key = SigningKey::new(rho, K, verifying_key.tr.clone(), s1, s2, t0, A_hat);
let enc = VerifyingKey::<P>::encode_internal(&rho, &t1);
let tr: B64 = H::default().absorb(&enc).squeeze_new();
let signing_key = SigningKey::new(rho, K, tr, s1, s2, t0, A_hat);

KeyPair {
SeededSigningKey {
signing_key,
verifying_key,
seed: xi.clone(),
}
}
Expand All @@ -958,6 +945,7 @@ where
mod test {
use super::*;
use crate::param::*;
use signature::Keypair;

#[test]
fn output_sizes() {
Expand All @@ -983,11 +971,11 @@ mod test {
P: MlDsaParams + PartialEq,
{
let seed = Array::default();
let kp = P::from_seed(&seed);
assert_eq!(kp.to_seed(), seed);
let ssk = P::from_seed(&seed);
assert_eq!(ssk.to_seed(), seed);

let sk = kp.signing_key;
let vk = kp.verifying_key;
let sk = &ssk.signing_key;
let vk = ssk.verifying_key();

let vk_bytes = vk.encode();
let vk2 = VerifyingKey::<P>::decode(&vk_bytes);
Expand All @@ -997,7 +985,7 @@ mod test {
{
let sk_bytes = sk.to_expanded();
let sk2 = SigningKey::<P>::from_expanded(&sk_bytes);
assert!(sk == sk2);
assert!(sk == &sk2);

let M = b"Hello world";
let rnd = Array([0u8; 32]);
Expand All @@ -1019,9 +1007,9 @@ mod test {
where
P: MlDsaParams + PartialEq,
{
let kp = P::from_seed(&Array::default());
let sk = kp.signing_key;
let vk = kp.verifying_key;
let ssk = P::from_seed(&Array::default());
let sk = &ssk.signing_key;
let vk = ssk.verifying_key();
let vk_derived = sk.verifying_key();

assert!(vk == vk_derived);
Expand All @@ -1038,9 +1026,9 @@ mod test {
where
P: MlDsaParams,
{
let kp = P::from_seed(&Array::default());
let sk = kp.signing_key;
let vk = kp.verifying_key;
let ssk = P::from_seed(&Array::default());
let sk = &ssk.signing_key;
let vk = ssk.verifying_key();

let M = b"Hello world";
let rnd = Array([0u8; 32]);
Expand All @@ -1062,9 +1050,9 @@ mod test {
where
P: MlDsaParams,
{
let kp = P::from_seed(&Array::default());
let sk = kp.signing_key;
let vk = kp.verifying_key;
let ssk = P::from_seed(&Array::default());
let sk = &ssk.signing_key;
let vk = ssk.verifying_key();

let M = b"Hello world";
let rnd = Array([0u8; 32]);
Expand All @@ -1084,9 +1072,9 @@ mod test {
where
P: MlDsaParams,
{
let kp = P::from_seed(&Array::default());
let sk = kp.signing_key;
let vk = kp.verifying_key;
let ssk = P::from_seed(&Array::default());
let sk = &ssk.signing_key;
let vk = ssk.verifying_key();

let M = b"Hello world";
let rnd = Array([0u8; 32]);
Expand All @@ -1106,9 +1094,9 @@ mod test {
where
P: MlDsaParams,
{
let kp = P::from_seed(&Array::default());
let sk = kp.signing_key;
let vk = kp.verifying_key;
let ssk = P::from_seed(&Array::default());
let sk = &ssk.signing_key;
let vk = ssk.verifying_key();

let M = b"Hello world";
let rnd = Array([0u8; 32]);
Expand All @@ -1129,11 +1117,9 @@ mod test {
P: MlDsaParams,
{
let seed = Seed::default();
let kp1 = P::from_seed(&seed);
let ssk = P::from_seed(&seed);
let sk1 = SigningKey::<P>::from_seed(&seed);
let vk1 = sk1.verifying_key();
assert_eq!(kp1.signing_key, sk1);
assert_eq!(kp1.verifying_key, vk1);
assert_eq!(ssk.signing_key, sk1);
}
assert_from_seed_equality::<MlDsa44>();
assert_from_seed_equality::<MlDsa65>();
Expand Down Expand Up @@ -1194,9 +1180,9 @@ mod test {
#[test]
fn context_length_validation() {
fn test_ctx_length<P: MlDsaParams>() {
let kp = P::from_seed(&Array::default());
let sk = kp.signing_key();
let vk = kp.verifying_key();
let ssk = P::from_seed(&Array::default());
let sk = ssk.signing_key();
let vk = ssk.verifying_key();

let msg = b"Hello world";
let long_ctx = [0u8; 256];
Expand All @@ -1217,16 +1203,16 @@ mod test {
fn derived_verifying_key_validates_signatures() {
fn test_derived_vk<P: MlDsaParams>() {
let seed = Array([42u8; 32]);
let kp = P::from_seed(&seed);
let sk = kp.signing_key();
let ssk = P::from_seed(&seed);
let sk = ssk.signing_key();
let derived_vk = sk.verifying_key();

let msg = b"Test message for derived key";
let rnd = Array([0u8; 32]);
let sig = sk.sign_internal(&[msg], &rnd);

assert!(derived_vk.verify_internal(msg, &sig));
assert_eq!(derived_vk.encode(), kp.verifying_key().encode());
assert_eq!(derived_vk.encode(), ssk.verifying_key().encode());
}
test_derived_vk::<MlDsa44>();
test_derived_vk::<MlDsa65>();
Expand All @@ -1244,7 +1230,7 @@ mod test {

let mut kp_debug = alloc::string::String::new();
write!(&mut kp_debug, "{:?}", kp).unwrap();
assert!(kp_debug.contains("KeyPair"));
assert!(kp_debug.contains("SeededSigningKey"));

let mut sk_debug = alloc::string::String::new();
write!(&mut sk_debug, "{:?}", kp.signing_key()).unwrap();
Expand Down
12 changes: 6 additions & 6 deletions ml-dsa/src/pkcs8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
#![cfg(feature = "pkcs8")]

use crate::{
EncodedVerifyingKey, KeyGen, KeyPair, MlDsa44, MlDsa65, MlDsa87, MlDsaParams, Signature,
SigningKey, VerifyingKey,
EncodedVerifyingKey, KeyGen, MlDsa44, MlDsa65, MlDsa87, MlDsaParams, SeededSigningKey,
Signature, SigningKey, VerifyingKey,
};
use ::pkcs8::{
AlgorithmIdentifierRef, PrivateKeyInfoRef,
Expand Down Expand Up @@ -79,7 +79,7 @@ impl<P: MlDsaParams> SignatureBitStringEncoding for Signature<P> {
}
}

impl<P> SignatureAlgorithmIdentifier for KeyPair<P>
impl<P> SignatureAlgorithmIdentifier for SeededSigningKey<P>
where
P: MlDsaParams,
P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
Expand All @@ -90,7 +90,7 @@ where
Signature::<P>::ALGORITHM_IDENTIFIER;
}

impl<P> TryFrom<PrivateKeyInfoRef<'_>> for KeyPair<P>
impl<P> TryFrom<PrivateKeyInfoRef<'_>> for SeededSigningKey<P>
where
P: MlDsaParams,
P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
Expand All @@ -117,7 +117,7 @@ where
}

#[cfg(feature = "alloc")]
impl<P> EncodePrivateKey for KeyPair<P>
impl<P> EncodePrivateKey for SeededSigningKey<P>
where
P: MlDsaParams,
P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
Expand Down Expand Up @@ -155,7 +155,7 @@ where
type Error = ::pkcs8::Error;

fn try_from(private_key_info: ::pkcs8::PrivateKeyInfoRef<'_>) -> ::pkcs8::Result<Self> {
let keypair = KeyPair::try_from(private_key_info)?;
let keypair = SeededSigningKey::try_from(private_key_info)?;

Ok(keypair.signing_key)
}
Expand Down
7 changes: 4 additions & 3 deletions ml-dsa/tests/key-gen.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use ml_dsa::*;

use hybrid_array::Array;
use signature::Keypair;
use std::{fs::read_to_string, path::PathBuf};

#[test]
Expand Down Expand Up @@ -31,9 +32,9 @@ fn verify<P: MlDsaParams>(tc: &acvp::TestCase) {
let vk_bytes = EncodedVerifyingKey::<P>::try_from(tc.pk.as_slice()).unwrap();
let sk_bytes = ExpandedSigningKey::<P>::try_from(tc.sk.as_slice()).unwrap();

let kp = P::from_seed(&seed);
let sk = kp.signing_key().clone();
let vk = kp.verifying_key().clone();
let ssk = P::from_seed(&seed);
let sk = ssk.signing_key().clone();
let vk = ssk.verifying_key().clone();

assert_eq!(vk.encode(), vk_bytes);
assert!(vk == VerifyingKey::<P>::decode(&vk_bytes));
Expand Down
Loading
Loading