From d4f0e9e7dc15ccfdd7d7c27877e1ea5a583faffe Mon Sep 17 00:00:00 2001 From: juergw Date: Thu, 27 Nov 2025 08:40:59 +0000 Subject: [PATCH] Implement translateKey for ML-DSA. Because PKCS8 and X509 encodings are now supported, we can now implement translateKey by serializing and deserializing the key. --- .../org/conscrypt/OpenSslMlDsaKeyFactory.java | 33 ++++++-- .../test/java/org/conscrypt/MlDsaTest.java | 82 +++++++++++++++++++ 2 files changed, 109 insertions(+), 6 deletions(-) diff --git a/common/src/main/java/org/conscrypt/OpenSslMlDsaKeyFactory.java b/common/src/main/java/org/conscrypt/OpenSslMlDsaKeyFactory.java index 47485e01d..158ee8cce 100644 --- a/common/src/main/java/org/conscrypt/OpenSslMlDsaKeyFactory.java +++ b/common/src/main/java/org/conscrypt/OpenSslMlDsaKeyFactory.java @@ -251,13 +251,34 @@ protected T engineGetKeySpec(Key key, Class keySpec) @Override protected Key engineTranslateKey(Key key) throws InvalidKeyException { - if (key == null) { - throw new InvalidKeyException("key == null"); - } - if ((key instanceof OpenSslMlDsaPublicKey) || (key instanceof OpenSslMlDsaPrivateKey)) { + if ((key instanceof OpenSslMlDsaPublicKey)) { + OpenSslMlDsaPublicKey conscryptKey = (OpenSslMlDsaPublicKey) key; + if (!supportsAlgorithm(conscryptKey.getMlDsaAlgorithm())) { + throw new InvalidKeyException("Key algorithm mismatch"); + } + return conscryptKey; + } else if (key instanceof OpenSslMlDsaPrivateKey) { + OpenSslMlDsaPrivateKey conscryptKey = (OpenSslMlDsaPrivateKey) key; + if (!supportsAlgorithm(conscryptKey.getMlDsaAlgorithm())) { + throw new InvalidKeyException("Key algorithm mismatch"); + } return key; + } else if ((key instanceof PrivateKey) && key.getFormat().equals("PKCS#8")) { + byte[] encoded = key.getEncoded(); + try { + return engineGeneratePrivate(new PKCS8EncodedKeySpec(encoded)); + } catch (InvalidKeySpecException e) { + throw new InvalidKeyException(e); + } + } else if ((key instanceof PublicKey) && key.getFormat().equals("X.509")) { + byte[] encoded = key.getEncoded(); + try { + return engineGeneratePublic(new X509EncodedKeySpec(encoded)); + } catch (InvalidKeySpecException e) { + throw new InvalidKeyException(e); + } + } else { + throw new InvalidKeyException("Unable to translate key into ML-DSA key"); } - throw new InvalidKeyException( - "Key must be OpenSslMlDsaPublicKey or OpenSslMlDsaPrivateKey"); } } diff --git a/common/src/test/java/org/conscrypt/MlDsaTest.java b/common/src/test/java/org/conscrypt/MlDsaTest.java index 917632006..ba24e8232 100644 --- a/common/src/test/java/org/conscrypt/MlDsaTest.java +++ b/common/src/test/java/org/conscrypt/MlDsaTest.java @@ -296,6 +296,54 @@ public void generateFromInvalidRawKey_throws() throws Exception { } } + /** Helper class to test KeyFactory.translateKey. */ + static class TestPublicKey implements PublicKey { + public TestPublicKey(byte[] x509encoded) { + this.x509encoded = x509encoded; + } + + private final byte[] x509encoded; + + @Override + public String getAlgorithm() { + return "ML-DSA"; + } + + @Override + public String getFormat() { + return "X.509"; + } + + @Override + public byte[] getEncoded() { + return x509encoded; + } + } + + /** Helper class to test KeyFactory.translateKey. */ + static class TestPrivateKey implements PrivateKey { + public TestPrivateKey(byte[] pkcs8encoded) { + this.pkcs8encoded = pkcs8encoded; + } + + private final byte[] pkcs8encoded; + + @Override + public String getAlgorithm() { + return "ML-DSA"; + } + + @Override + public String getFormat() { + return "PKCS#8"; + } + + @Override + public byte[] getEncoded() { + return pkcs8encoded; + } + } + @Test public void mldsa65KeyPair_x509AndPkcs8() throws Exception { KeyPairGenerator keyGen = KeyPairGenerator.getInstance("ML-DSA-65", conscryptProvider); @@ -324,6 +372,13 @@ public void mldsa65KeyPair_x509AndPkcs8() throws Exception { assertEquals(privateKey, keyPair.getPrivate()); assertEquals(publicKey, keyPair.getPublic()); + + assertEquals(keyPair.getPrivate(), keyFactory.translateKey(keyPair.getPrivate())); + assertEquals(keyPair.getPrivate(), + keyFactory.translateKey(new TestPrivateKey(keyPair.getPrivate().getEncoded()))); + assertEquals(keyPair.getPublic(), keyFactory.translateKey(keyPair.getPublic())); + assertEquals(keyPair.getPublic(), + keyFactory.translateKey(new TestPublicKey(keyPair.getPublic().getEncoded()))); } KeyFactory keyFactory = KeyFactory.getInstance("ML-DSA-87", conscryptProvider); @@ -337,6 +392,16 @@ public void mldsa65KeyPair_x509AndPkcs8() throws Exception { new RawKeySpec(keyPair.getPrivate().getEncoded()))); assertThrows(InvalidKeySpecException.class, () -> keyFactory.generatePublic(new RawKeySpec(keyPair.getPublic().getEncoded()))); + + assertThrows( + InvalidKeyException.class, () -> keyFactory.translateKey(keyPair.getPrivate())); + assertThrows(InvalidKeyException.class, + () + -> keyFactory.translateKey( + new TestPrivateKey(keyPair.getPrivate().getEncoded()))); + assertThrows(InvalidKeyException.class, () -> keyFactory.translateKey(keyPair.getPublic())); + assertThrows(InvalidKeyException.class, + () -> keyFactory.translateKey(new TestPublicKey(keyPair.getPublic().getEncoded()))); } @Test @@ -367,6 +432,13 @@ public void mldsa87KeyPair_x509AndPkcs8() throws Exception { assertEquals(privateKey, keyPair.getPrivate()); assertEquals(publicKey, keyPair.getPublic()); + + assertEquals(keyPair.getPrivate(), keyFactory.translateKey(keyPair.getPrivate())); + assertEquals(keyPair.getPrivate(), + keyFactory.translateKey(new TestPrivateKey(keyPair.getPrivate().getEncoded()))); + assertEquals(keyPair.getPublic(), keyFactory.translateKey(keyPair.getPublic())); + assertEquals(keyPair.getPublic(), + keyFactory.translateKey(new TestPublicKey(keyPair.getPublic().getEncoded()))); } KeyFactory keyFactory = KeyFactory.getInstance("ML-DSA-65", conscryptProvider); @@ -380,6 +452,16 @@ public void mldsa87KeyPair_x509AndPkcs8() throws Exception { new RawKeySpec(keyPair.getPrivate().getEncoded()))); assertThrows(InvalidKeySpecException.class, () -> keyFactory.generatePublic(new RawKeySpec(keyPair.getPublic().getEncoded()))); + + assertThrows( + InvalidKeyException.class, () -> keyFactory.translateKey(keyPair.getPrivate())); + assertThrows(InvalidKeyException.class, + () + -> keyFactory.translateKey( + new TestPrivateKey(keyPair.getPrivate().getEncoded()))); + assertThrows(InvalidKeyException.class, () -> keyFactory.translateKey(keyPair.getPublic())); + assertThrows(InvalidKeyException.class, + () -> keyFactory.translateKey(new TestPublicKey(keyPair.getPublic().getEncoded()))); } @Test