diff --git a/java/shared/src/main/java/com/squareup/subzero/shared/SubzeroUtils.java b/java/shared/src/main/java/com/squareup/subzero/shared/SubzeroUtils.java index 1fca1941..e8acba78 100644 --- a/java/shared/src/main/java/com/squareup/subzero/shared/SubzeroUtils.java +++ b/java/shared/src/main/java/com/squareup/subzero/shared/SubzeroUtils.java @@ -2,6 +2,7 @@ import com.google.common.collect.Lists; import com.google.common.io.BaseEncoding; +import com.google.protobuf.ByteString; import com.google.protobuf.InvalidProtocolBufferException; import com.google.zxing.BarcodeFormat; import com.google.zxing.common.BitMatrix; @@ -18,7 +19,9 @@ import java.security.MessageDigest; import java.util.ArrayList; import java.util.Arrays; +import java.util.HashSet; import java.util.List; +import java.util.Set; import org.bitcoinj.core.Address; import org.bitcoinj.core.Coin; import org.bitcoinj.core.ECKey; @@ -130,16 +133,19 @@ protected static List validateAndSort(List pubkeys, byte[] hash, // This needs to be the same sort that ScriptBuilder.createRedeemScript does. pubkeys.sort(ECKey.PUBKEY_COMPARATOR); List sortedSigs = new ArrayList<>(); + Set seenSigs = new HashSet<>(); // If this check fails, we've probably got invalid signatures that would fail below, or when // broadcast. However, this lets us distinguish between invalid signatures and signatures over // the wrong data. Useful primarily for debugging when making changes to either piece of code. + ByteString expectedHash = ByteString.copyFrom(hash); for (Signature sig : signatures) { if (sig.hasHash()) { - if (!Arrays.equals(sig.getHash().toByteArray(), hash)) { + ByteString sigHash = sig.getHash(); + if (!sigHash.equals(expectedHash)) { throw new RuntimeException(format( "Our calculated hash does not match the HSM provided sig: %s != %s", - Hex.toHexString(sig.getHash().toByteArray()), Hex.toHexString(hash))); + sigHash.toStringUtf8(), expectedHash.toStringUtf8())); } } } @@ -149,8 +155,14 @@ protected static List validateAndSort(List pubkeys, byte[] hash, for(ECKey pubkey: pubkeys) { for (Signature sig : signatures) { try { - if (pubkey.verify(hash, sig.getDer().toByteArray())) { - sortedSigs.add(sig.getDer().toByteArray()); + byte[] sigDerBytes = sig.getDer().toByteArray(); + if (pubkey.verify(hash, sigDerBytes)) { + ByteString sigDerByteString = ByteString.copyFrom(sigDerBytes); + if (!seenSigs.contains(sigDerByteString)) { + // Only add the signature if it is unique + sortedSigs.add(sigDerBytes); + seenSigs.add(sigDerByteString); + } } } catch (SignatureDecodeException e) { // keep going, we'll throw a RuntimeException later if we don't find the right number of diff --git a/java/shared/src/test/java/com/squareup/subzero/shared/SubzeroUtilsTest.java b/java/shared/src/test/java/com/squareup/subzero/shared/SubzeroUtilsTest.java index 77de1ff6..95345049 100644 --- a/java/shared/src/test/java/com/squareup/subzero/shared/SubzeroUtilsTest.java +++ b/java/shared/src/test/java/com/squareup/subzero/shared/SubzeroUtilsTest.java @@ -6,15 +6,22 @@ import com.squareup.subzero.proto.service.Common.EncryptedMasterSeed; import com.squareup.subzero.proto.service.Common.EncryptedPubKey; import com.squareup.subzero.proto.service.Common.Path; +import com.squareup.subzero.proto.service.Common.Signature; import com.squareup.subzero.proto.service.Common.TxInput; import com.squareup.subzero.proto.service.Common.TxOutput; import com.squareup.subzero.proto.service.Internal.InternalCommandRequest; import com.squareup.subzero.proto.service.Service.CommandRequest; +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.stream.Collectors; import org.bitcoinj.core.Address; +import org.bitcoinj.core.ECKey; +import org.bitcoinj.core.Sha256Hash; import org.bitcoinj.core.VerificationException; import org.bitcoinj.crypto.DeterministicKey; import org.bitcoinj.params.TestNet3Params; @@ -494,6 +501,135 @@ public void testValidateSignTxCommandRequest() { assertTrue(e3.getMessage().contains(ERROR_OUTPUTS_COUNT)); } + @Test + public void testValidateAndSortValidSignatures() { + List keys = Arrays.asList(new ECKey(), new ECKey(), new ECKey()); + keys.sort(ECKey.PUBKEY_COMPARATOR); + + byte[] hash = sha256Hash("mock input"); + List signatures = Arrays.asList( + createValidSignature(hash, keys.get(0)), + createValidSignature(hash, keys.get(1)) + ); + + List sortedSigs = SubzeroUtils.validateAndSort(keys, hash, signatures); + + assertEquals(Constants.MULTISIG_THRESHOLD, sortedSigs.size()); + assertArrayEquals(signatures.get(0).getDer().toByteArray(), sortedSigs.get(0)); + assertArrayEquals(signatures.get(1).getDer().toByteArray(), sortedSigs.get(1)); + } + + @Test + public void testValidateAndSortInvalidSignatures() { + List keys = Arrays.asList(new ECKey(), new ECKey()); + byte[] hash = sha256Hash("mock input"); + List signatures = Arrays.asList( + createInvalidEmptySignature(), + createInvalidEmptySignature() + ); + + RuntimeException exception = assertThrows(RuntimeException.class, () -> + SubzeroUtils.validateAndSort(keys, hash, signatures) + ); + + assertTrue(exception.getMessage().contains("Our calculated hash does not match the HSM provided sig")); + } + + @Test + public void testValidateAndSortValidSignaturesFlippedBit() { + List keys = Arrays.asList(new ECKey(), new ECKey(), new ECKey()); + keys.sort(ECKey.PUBKEY_COMPARATOR); + + byte[] hash = sha256Hash("mock input"); + List signatures = Arrays.asList( + createValidSignature(hash, keys.get(0)), + createValidSignature(hash, keys.get(1)) + ); + + List sortedSigs = SubzeroUtils.validateAndSort(keys, hash, signatures); + + assertEquals(Constants.MULTISIG_THRESHOLD, sortedSigs.size()); + assertArrayEquals(signatures.get(0).getDer().toByteArray(), sortedSigs.get(0)); + assertArrayEquals(signatures.get(1).getDer().toByteArray(), sortedSigs.get(1)); + + // Flip a bit in one signature and check for failure + byte[] corruptedSigBytes = signatures.get(0).getDer().toByteArray(); + corruptedSigBytes[0] ^= 0x01; + + // Use the same hash but use the corrupted signature + Signature corruptedSignature = Signature.newBuilder() + .setHash(signatures.get(0).getHash()) + .setDer(ByteString.copyFrom(corruptedSigBytes)) + .build(); + + List corruptedSignatures = Arrays.asList(corruptedSignature, signatures.get(1)); + + // Validate and expect a failure + RuntimeException exception = assertThrows(RuntimeException.class, () -> + SubzeroUtils.validateAndSort(keys, hash, corruptedSignatures) + ); + + assertTrue(exception.getMessage().contains("Failed validating signatures")); + } + + @Test + public void testValidateAndSortSignatureWithWrongHash() { + List keys = Arrays.asList(new ECKey(), new ECKey()); + byte[] hash = sha256Hash("input1"); + byte[] wrongHash = sha256Hash("input2"); + List signatures = Arrays.asList(createValidSignature(wrongHash, keys.get(0))); + + RuntimeException exception = assertThrows(RuntimeException.class, () -> + SubzeroUtils.validateAndSort(keys, hash, signatures) + ); + assertTrue(exception.getMessage().contains("Our calculated hash does not match the HSM provided sig")); + } + + @Test + public void testValidateAndSortDuplicateSignatures() { + List keys = Arrays.asList(new ECKey(), new ECKey()); + byte[] hash = sha256Hash("input"); + Signature validSig = createValidSignature(hash, keys.get(0)); + List signatures = Arrays.asList(validSig, validSig); + + RuntimeException exception = assertThrows(RuntimeException.class, () -> + SubzeroUtils.validateAndSort(keys, hash, signatures) + ); + assertTrue(exception.getMessage().contains("Failed validating signatures")); + } + + @Test + public void testValidateAndSortDuplicateKey() { + List keys = Arrays.asList(new ECKey(), new ECKey()); + byte[] hash = sha256Hash("input"); + + // Sign twice using the same key rather than using a copied reference to the same signature + // as in testValidateAndSortDuplicateSignatures + Signature validSig1 = createValidSignature(hash, keys.get(0)); + Signature validSig2 = createValidSignature(hash, keys.get(0)); + List signatures = Arrays.asList(validSig1, validSig2); + + RuntimeException exception = assertThrows(RuntimeException.class, () -> + SubzeroUtils.validateAndSort(keys, hash, signatures) + ); + assertTrue(exception.getMessage().contains("Failed validating signatures")); + } + + @Test + public void testValidateAndSortSignaturesBelowThreshold() { + List keys = Arrays.asList(new ECKey(), new ECKey()); + byte[] hash = sha256Hash("input"); + // Only supply 1 sig + List signatures = Arrays.asList( + createValidSignature(hash, keys.get(0)) + ); + + RuntimeException exception = assertThrows(RuntimeException.class, () -> + SubzeroUtils.validateAndSort(keys, hash, signatures) + ); + assertTrue(exception.getMessage().contains("Failed validating signatures")); + } + private TxInput testInput() { return testInput(1000L, ByteString.copyFromUtf8("test prev hash")); } @@ -526,4 +662,30 @@ private TxOutput testOutput(long amount, Destination destination) { .setIndex(2)) .build(); } + + private Signature createValidSignature(byte[] hash, ECKey ecKey) { + Sha256Hash sha256Hash = Sha256Hash.wrap(hash); + ECKey.ECDSASignature ecdsaSignature = ecKey.sign(sha256Hash); + byte[] signatureBytes = ecdsaSignature.encodeToDER(); + return Signature.newBuilder() + .setHash(ByteString.copyFrom(hash)) + .setDer(ByteString.copyFrom(signatureBytes)) + .build(); + } + + private byte[] sha256Hash(String inputData) { + try { + MessageDigest digest = MessageDigest.getInstance("SHA-256"); + return digest.digest(inputData.getBytes(StandardCharsets.UTF_8)); + } catch (NoSuchAlgorithmException e) { + throw new RuntimeException("SHA-256 not found", e); + } + } + + private Signature createInvalidEmptySignature() { + return Signature.newBuilder() + .setHash(ByteString.copyFrom(new byte[]{0})) + .setDer(ByteString.copyFrom(new byte[]{0})) + .build(); + } }