Skip to content
14 changes: 14 additions & 0 deletions common/src/main/java/org/conscrypt/OpenSslSlhDsaPrivateKey.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

package org.conscrypt;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.security.PrivateKey;
import java.security.spec.EncodedKeySpec;
import java.security.spec.InvalidKeySpecException;
Expand Down Expand Up @@ -97,4 +100,15 @@ public boolean equals(Object o) {
public int hashCode() {
return Arrays.hashCode(raw);
}

private void readObject(ObjectInputStream stream) throws IOException, ClassNotFoundException {
stream.defaultReadObject(); // reads "raw"
if (raw.length != PRIVATE_KEY_SIZE_BYTES) {
throw new IOException("Invalid key size");
}
}

private void writeObject(ObjectOutputStream stream) throws IOException {
stream.defaultWriteObject(); // writes "raw"
}
}
14 changes: 14 additions & 0 deletions common/src/main/java/org/conscrypt/OpenSslSlhDsaPublicKey.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

package org.conscrypt;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.security.PublicKey;
import java.security.spec.EncodedKeySpec;
import java.security.spec.InvalidKeySpecException;
Expand Down Expand Up @@ -91,4 +94,15 @@ public int hashCode() {
}
return Arrays.hashCode(raw);
}

private void readObject(ObjectInputStream stream) throws IOException, ClassNotFoundException {
stream.defaultReadObject(); // reads "raw"
if (raw.length != PUBLIC_KEY_SIZE_BYTES) {
throw new IOException("Invalid key size");
}
}

private void writeObject(ObjectOutputStream stream) throws IOException {
stream.defaultWriteObject(); // writes "raw"
}
}
131 changes: 131 additions & 0 deletions common/src/test/java/org/conscrypt/SlhDsaTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.nio.charset.StandardCharsets;
import java.security.KeyFactory;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
Expand Down Expand Up @@ -214,4 +220,129 @@ public void testVectors() throws Exception {
assertTrue(verifier.verify(signature));
}
}

@Test
public void serializeAndDeserialize_works() throws Exception {
KeyPairGenerator keyGen =
KeyPairGenerator.getInstance("SLH-DSA-SHA2-128S", conscryptProvider);
KeyPair keyPair = keyGen.generateKeyPair();

ByteArrayOutputStream baos = new ByteArrayOutputStream(16384);
try (ObjectOutputStream oos = new ObjectOutputStream(baos)) {
oos.writeObject(keyPair.getPrivate());
oos.writeObject(keyPair.getPublic());
}

ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray());
ObjectInputStream ois = new ObjectInputStream(bais);
PrivateKey inflatedPrivateKey = (PrivateKey) ois.readObject();
PublicKey inflatedPublicKey = (PublicKey) ois.readObject();

assertEquals(inflatedPrivateKey, keyPair.getPrivate());
assertEquals(inflatedPublicKey, keyPair.getPublic());
}

@Test
public void serializePrivateKey_isEqualToTestVector() throws Exception {
byte[] rawPrivateKey = new byte[64];
KeyFactory keyFactory = KeyFactory.getInstance("SLH-DSA-SHA2-128S", conscryptProvider);
PrivateKey privateKey = keyFactory.generatePrivate(new RawKeySpec(rawPrivateKey));

ByteArrayOutputStream baos = new ByteArrayOutputStream(16384);
try (ObjectOutputStream oos = new ObjectOutputStream(baos)) {
oos.writeObject(privateKey);
}

String classNameHex = TestUtils.encodeHex(
privateKey.getClass().getName().getBytes(StandardCharsets.UTF_8));
String expectedHexEncoding = "aced0005737200"
+ Integer.toHexString(privateKey.getClass().getName().length()) + classNameHex
+ "87e8776a4491fecb" // serialVersionUID
+ "0300015b00"
+ "03" // size of raw
+ "726177" // hex("raw")
+ "7400025b427870757200025b42acf317f8060854e00200007870000000"
+ "40" // size of raw key = 64
+ "0000000000000000000000000000000000000000000000000000000000000000"
+ "0000000000000000000000000000000000000000000000000000000000000000"
+ "78";
assertEquals(expectedHexEncoding, TestUtils.encodeHex(baos.toByteArray()));
}

@Test
public void serializePublicKey_isEqualToTestVector() throws Exception {
byte[] rawPublicKey = new byte[32];
KeyFactory keyFactory = KeyFactory.getInstance("SLH-DSA-SHA2-128S", conscryptProvider);
PublicKey publicKey = keyFactory.generatePublic(new RawKeySpec(rawPublicKey));

ByteArrayOutputStream baos = new ByteArrayOutputStream(16384);
try (ObjectOutputStream oos = new ObjectOutputStream(baos)) {
oos.writeObject(publicKey);
}

String classNameHex = TestUtils.encodeHex(
publicKey.getClass().getName().getBytes(StandardCharsets.UTF_8));
String expectedHexEncoding = "aced0005737200"
+ Integer.toHexString(publicKey.getClass().getName().length()) + classNameHex
+ "4589aa00e279d127" // serialVersionUID
+ "0300015b00"
+ "03" // size of raw
+ "726177" // hex("raw")
+ "7400025b427870757200025b42acf317f8060854e00200007870000000"
+ "20" // size of raw key = 32
+ "0000000000000000000000000000000000000000000000000000000000000000"
+ "78";
assertEquals(expectedHexEncoding, TestUtils.encodeHex(baos.toByteArray()));
}

@Test
public void deserializeInvalidPrivateKey_fails() throws Exception {
KeyFactory keyFactory = KeyFactory.getInstance("SLH-DSA-SHA2-128S", conscryptProvider);
PrivateKey privateKey = keyFactory.generatePrivate(new RawKeySpec(new byte[64]));

String classNameHex = TestUtils.encodeHex(
privateKey.getClass().getName().getBytes(StandardCharsets.UTF_8));
String invalidPrivateKeySerialized = "aced0005737200"
+ Integer.toHexString(privateKey.getClass().getName().length()) + classNameHex
+ "87e8776a4491fecb" // serialVersionUID
+ "0300015b00"
+ "03" // length of string "raw"
+ "726177" // hex("raw")
+ "7400025b427870757200025b42acf317f8060854e00200007870000000"
+ "3f" // length of invalid raw key = 63
+ "0000000000000000000000000000000000000000000000000000000000000000"
+ "00000000000000000000000000000000000000000000000000000000000000"
+ "78";

ByteArrayInputStream bais =
new ByteArrayInputStream(TestUtils.decodeHex(invalidPrivateKeySerialized));
ObjectInputStream ois = new ObjectInputStream(bais);

assertThrows(IOException.class, () -> ois.readObject());
}

@Test
public void deserializeInvalidPublicKey_fails() throws Exception {
KeyFactory keyFactory = KeyFactory.getInstance("SLH-DSA-SHA2-128S", conscryptProvider);
PublicKey publicKey = keyFactory.generatePublic(new RawKeySpec(new byte[32]));

String classNameHex = TestUtils.encodeHex(
publicKey.getClass().getName().getBytes(StandardCharsets.UTF_8));
String invalidPublicKeySerialized = "aced0005737200"
+ Integer.toHexString(publicKey.getClass().getName().length()) + classNameHex
+ "4589aa00e279d127" // serialVersionUID
+ "0300015b00"
+ "03" // length of string "raw"
+ "726177" // hex("raw")
+ "7400025b427870757200025b42acf317f8060854e00200007870000000"
+ "1f" // length of invalid raw key = 31
+ "00000000000000000000000000000000000000000000000000000000000000"
+ "78";

ByteArrayInputStream bais =
new ByteArrayInputStream(TestUtils.decodeHex(invalidPublicKeySerialized));
ObjectInputStream ois = new ObjectInputStream(bais);

assertThrows(IOException.class, () -> ois.readObject());
}
}
Loading