Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions common/src/main/java/org/conscrypt/OpenSslMlDsaKeyFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -251,13 +251,34 @@ protected <T extends KeySpec> T engineGetKeySpec(Key key, Class<T> keySpec)

@Override
protected Key engineTranslateKey(Key key) throws InvalidKeyException {
if (key == null) {
throw new InvalidKeyException("key == null");
}
if ((key instanceof OpenSslMlDsaPublicKey) || (key instanceof OpenSslMlDsaPrivateKey)) {
if ((key instanceof OpenSslMlDsaPublicKey)) {
OpenSslMlDsaPublicKey conscryptKey = (OpenSslMlDsaPublicKey) key;
if (!supportsAlgorithm(conscryptKey.getMlDsaAlgorithm())) {
throw new InvalidKeyException("Key algorithm mismatch");
}
return conscryptKey;
} else if (key instanceof OpenSslMlDsaPrivateKey) {
OpenSslMlDsaPrivateKey conscryptKey = (OpenSslMlDsaPrivateKey) key;
if (!supportsAlgorithm(conscryptKey.getMlDsaAlgorithm())) {
throw new InvalidKeyException("Key algorithm mismatch");
}
return key;
} else if ((key instanceof PrivateKey) && key.getFormat().equals("PKCS#8")) {
byte[] encoded = key.getEncoded();
try {
return engineGeneratePrivate(new PKCS8EncodedKeySpec(encoded));
} catch (InvalidKeySpecException e) {
throw new InvalidKeyException(e);
}
} else if ((key instanceof PublicKey) && key.getFormat().equals("X.509")) {
byte[] encoded = key.getEncoded();
try {
return engineGeneratePublic(new X509EncodedKeySpec(encoded));
} catch (InvalidKeySpecException e) {
throw new InvalidKeyException(e);
}
} else {
throw new InvalidKeyException("Unable to translate key into ML-DSA key");
}
throw new InvalidKeyException(
"Key must be OpenSslMlDsaPublicKey or OpenSslMlDsaPrivateKey");
}
}
82 changes: 82 additions & 0 deletions common/src/test/java/org/conscrypt/MlDsaTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,54 @@ public void generateFromInvalidRawKey_throws() throws Exception {
}
}

/** Helper class to test KeyFactory.translateKey. */
static class TestPublicKey implements PublicKey {
public TestPublicKey(byte[] x509encoded) {
this.x509encoded = x509encoded;
}

private final byte[] x509encoded;

@Override
public String getAlgorithm() {
return "ML-DSA";
}

@Override
public String getFormat() {
return "X.509";
}

@Override
public byte[] getEncoded() {
return x509encoded;
}
}

/** Helper class to test KeyFactory.translateKey. */
static class TestPrivateKey implements PrivateKey {
public TestPrivateKey(byte[] pkcs8encoded) {
this.pkcs8encoded = pkcs8encoded;
}

private final byte[] pkcs8encoded;

@Override
public String getAlgorithm() {
return "ML-DSA";
}

@Override
public String getFormat() {
return "PKCS#8";
}

@Override
public byte[] getEncoded() {
return pkcs8encoded;
}
}

@Test
public void mldsa65KeyPair_x509AndPkcs8() throws Exception {
KeyPairGenerator keyGen = KeyPairGenerator.getInstance("ML-DSA-65", conscryptProvider);
Expand Down Expand Up @@ -324,6 +372,13 @@ public void mldsa65KeyPair_x509AndPkcs8() throws Exception {

assertEquals(privateKey, keyPair.getPrivate());
assertEquals(publicKey, keyPair.getPublic());

assertEquals(keyPair.getPrivate(), keyFactory.translateKey(keyPair.getPrivate()));
assertEquals(keyPair.getPrivate(),
keyFactory.translateKey(new TestPrivateKey(keyPair.getPrivate().getEncoded())));
assertEquals(keyPair.getPublic(), keyFactory.translateKey(keyPair.getPublic()));
assertEquals(keyPair.getPublic(),
keyFactory.translateKey(new TestPublicKey(keyPair.getPublic().getEncoded())));
}

KeyFactory keyFactory = KeyFactory.getInstance("ML-DSA-87", conscryptProvider);
Expand All @@ -337,6 +392,16 @@ public void mldsa65KeyPair_x509AndPkcs8() throws Exception {
new RawKeySpec(keyPair.getPrivate().getEncoded())));
assertThrows(InvalidKeySpecException.class,
() -> keyFactory.generatePublic(new RawKeySpec(keyPair.getPublic().getEncoded())));

assertThrows(
InvalidKeyException.class, () -> keyFactory.translateKey(keyPair.getPrivate()));
assertThrows(InvalidKeyException.class,
()
-> keyFactory.translateKey(
new TestPrivateKey(keyPair.getPrivate().getEncoded())));
assertThrows(InvalidKeyException.class, () -> keyFactory.translateKey(keyPair.getPublic()));
assertThrows(InvalidKeyException.class,
() -> keyFactory.translateKey(new TestPublicKey(keyPair.getPublic().getEncoded())));
}

@Test
Expand Down Expand Up @@ -367,6 +432,13 @@ public void mldsa87KeyPair_x509AndPkcs8() throws Exception {

assertEquals(privateKey, keyPair.getPrivate());
assertEquals(publicKey, keyPair.getPublic());

assertEquals(keyPair.getPrivate(), keyFactory.translateKey(keyPair.getPrivate()));
assertEquals(keyPair.getPrivate(),
keyFactory.translateKey(new TestPrivateKey(keyPair.getPrivate().getEncoded())));
assertEquals(keyPair.getPublic(), keyFactory.translateKey(keyPair.getPublic()));
assertEquals(keyPair.getPublic(),
keyFactory.translateKey(new TestPublicKey(keyPair.getPublic().getEncoded())));
}

KeyFactory keyFactory = KeyFactory.getInstance("ML-DSA-65", conscryptProvider);
Expand All @@ -380,6 +452,16 @@ public void mldsa87KeyPair_x509AndPkcs8() throws Exception {
new RawKeySpec(keyPair.getPrivate().getEncoded())));
assertThrows(InvalidKeySpecException.class,
() -> keyFactory.generatePublic(new RawKeySpec(keyPair.getPublic().getEncoded())));

assertThrows(
InvalidKeyException.class, () -> keyFactory.translateKey(keyPair.getPrivate()));
assertThrows(InvalidKeyException.class,
()
-> keyFactory.translateKey(
new TestPrivateKey(keyPair.getPrivate().getEncoded())));
assertThrows(InvalidKeyException.class, () -> keyFactory.translateKey(keyPair.getPublic()));
assertThrows(InvalidKeyException.class,
() -> keyFactory.translateKey(new TestPublicKey(keyPair.getPublic().getEncoded())));
}

@Test
Expand Down
Loading