diff --git a/common/src/main/java/org/conscrypt/OpenSSLAeadCipher.java b/common/src/main/java/org/conscrypt/OpenSSLAeadCipher.java index 1b201cc2d..843ab5362 100644 --- a/common/src/main/java/org/conscrypt/OpenSSLAeadCipher.java +++ b/common/src/main/java/org/conscrypt/OpenSSLAeadCipher.java @@ -241,7 +241,7 @@ protected int engineDoFinal(ByteBuffer input, ByteBuffer output) throws ShortBuf throw new IllegalArgumentException("Cannot write to Read Only ByteBuffer"); } if (bufCount != 0) { - return super.engineDoFinal(input, output);// traditional case + return super.engineDoFinal(input, output); // traditional case } int bytesWritten; if (!input.isDirect()) { @@ -268,20 +268,71 @@ protected int engineDoFinal(ByteBuffer input, ByteBuffer output) throws ShortBuf return bytesWritten; } + @Override + protected byte[] engineDoFinal(byte[] input, int inputOffset, int inputLen) + throws IllegalBlockSizeException, BadPaddingException { + final int maximumLen = getOutputSizeForFinal(inputLen); + /* Assume that we'll output exactly on a byte boundary. */ + final byte[] output = new byte[maximumLen]; + + int bytesWritten; + if (inputLen > 0) { + try { + bytesWritten = updateInternal(input, inputOffset, inputLen, output, 0, maximumLen); + } catch (ShortBufferException e) { + /* This should not happen since we sized our own buffer. */ + throw new RuntimeException("our calculated buffer was too small", e); + } + } else { + bytesWritten = 0; + } + + try { + bytesWritten += doFinalInternal(output, bytesWritten, maximumLen - bytesWritten); + } catch (ShortBufferException e) { + /* This should not happen since we sized our own buffer. */ + throw new RuntimeException("our calculated buffer was too small", e); + } + + if (bytesWritten == output.length) { + return output; + } else if (bytesWritten == 0) { + return EmptyArray.BYTE; + } else { + return Arrays.copyOfRange(output, 0, bytesWritten); + } + } + @Override protected int engineDoFinal(byte[] input, int inputOffset, int inputLen, byte[] output, int outputOffset) throws ShortBufferException, IllegalBlockSizeException, BadPaddingException { // Because the EVP_AEAD updateInternal processes input but doesn't create any output - // (and thus can't check the output buffer), we need to add this check before the - // superclass' processing to ensure that updateInternal is never called if the + // (and thus can't check the output buffer), we need to add this check + // to ensure that updateInternal is never called if the // output buffer isn't large enough. if (output != null) { if (getOutputSizeForFinal(inputLen) > output.length - outputOffset) { throw new ShortBufferWithoutStackTraceException("Insufficient output space"); } } - return super.engineDoFinal(input, inputOffset, inputLen, output, outputOffset); + if (output == null) { + throw new NullPointerException("output == null"); + } + + int maximumLen = getOutputSizeForFinal(inputLen); + + final int bytesWritten; + if (inputLen > 0) { + bytesWritten = updateInternal(input, inputOffset, inputLen, output, outputOffset, + maximumLen); + outputOffset += bytesWritten; + maximumLen -= bytesWritten; + } else { + bytesWritten = 0; + } + + return bytesWritten + doFinalInternal(output, outputOffset, maximumLen); } @Override @@ -351,7 +402,6 @@ int doFinalInternal(ByteBuffer input, ByteBuffer output) return bytesWritten; } - @Override int doFinalInternal(byte[] output, int outputOffset, int maximumLen) throws ShortBufferException, IllegalBlockSizeException, BadPaddingException { checkInitialization(); diff --git a/common/src/main/java/org/conscrypt/OpenSSLCipher.java b/common/src/main/java/org/conscrypt/OpenSSLCipher.java index 3e3984806..005a64aea 100644 --- a/common/src/main/java/org/conscrypt/OpenSSLCipher.java +++ b/common/src/main/java/org/conscrypt/OpenSSLCipher.java @@ -151,16 +151,6 @@ abstract void engineInitInternal(byte[] encodedKey, AlgorithmParameterSpec param abstract int updateInternal(byte[] input, int inputOffset, int inputLen, byte[] output, int outputOffset, int maximumLen) throws ShortBufferException; - /** - * API-specific implementation of the final block. The {@code maximumLen} - * will be the maximum length of the possible output as returned by - * {@link #getOutputSizeForFinal(int)}. The return value must be the number - * of bytes processed and placed into {@code output}. On error, an exception - * must be thrown. - */ - abstract int doFinalInternal(byte[] output, int outputOffset, int maximumLen) - throws IllegalBlockSizeException, BadPaddingException, ShortBufferException; - /** * Returns the standard name for the particular algorithm. */ @@ -349,64 +339,6 @@ protected int engineUpdate(byte[] input, int inputOffset, int inputLen, byte[] o return updateInternal(input, inputOffset, inputLen, output, outputOffset, maximumLen); } - @Override - protected byte[] engineDoFinal(byte[] input, int inputOffset, int inputLen) - throws IllegalBlockSizeException, BadPaddingException { - final int maximumLen = getOutputSizeForFinal(inputLen); - /* Assume that we'll output exactly on a byte boundary. */ - final byte[] output = new byte[maximumLen]; - - int bytesWritten; - if (inputLen > 0) { - try { - bytesWritten = updateInternal(input, inputOffset, inputLen, output, 0, maximumLen); - } catch (ShortBufferException e) { - /* This should not happen since we sized our own buffer. */ - throw new RuntimeException("our calculated buffer was too small", e); - } - } else { - bytesWritten = 0; - } - - try { - bytesWritten += doFinalInternal(output, bytesWritten, maximumLen - bytesWritten); - } catch (ShortBufferException e) { - /* This should not happen since we sized our own buffer. */ - throw new RuntimeException("our calculated buffer was too small", e); - } - - if (bytesWritten == output.length) { - return output; - } else if (bytesWritten == 0) { - return EmptyArray.BYTE; - } else { - return Arrays.copyOfRange(output, 0, bytesWritten); - } - } - - @Override - protected int engineDoFinal(byte[] input, int inputOffset, int inputLen, byte[] output, - int outputOffset) throws ShortBufferException, IllegalBlockSizeException, - BadPaddingException { - if (output == null) { - throw new NullPointerException("output == null"); - } - - int maximumLen = getOutputSizeForFinal(inputLen); - - final int bytesWritten; - if (inputLen > 0) { - bytesWritten = updateInternal(input, inputOffset, inputLen, output, outputOffset, - maximumLen); - outputOffset += bytesWritten; - maximumLen -= bytesWritten; - } else { - bytesWritten = 0; - } - - return bytesWritten + doFinalInternal(output, outputOffset, maximumLen); - } - @Override protected byte[] engineWrap(Key key) throws IllegalBlockSizeException, InvalidKeyException { try { diff --git a/common/src/main/java/org/conscrypt/OpenSSLCipherChaCha20.java b/common/src/main/java/org/conscrypt/OpenSSLCipherChaCha20.java index ddbd17004..22bbe60c8 100644 --- a/common/src/main/java/org/conscrypt/OpenSSLCipherChaCha20.java +++ b/common/src/main/java/org/conscrypt/OpenSSLCipherChaCha20.java @@ -21,6 +21,9 @@ import java.security.NoSuchAlgorithmException; import java.security.SecureRandom; import java.security.spec.AlgorithmParameterSpec; +import java.util.Arrays; +import javax.crypto.BadPaddingException; +import javax.crypto.IllegalBlockSizeException; import javax.crypto.NoSuchPaddingException; import javax.crypto.ShortBufferException; import javax.crypto.spec.IvParameterSpec; @@ -101,9 +104,58 @@ int updateInternal(byte[] input, int inputOffset, int inputLen, byte[] output, i } @Override - int doFinalInternal(byte[] output, int outputOffset, int maximumLen) { + protected byte[] engineDoFinal(byte[] input, int inputOffset, int inputLen) + throws IllegalBlockSizeException, BadPaddingException { + final int maximumLen = getOutputSizeForFinal(inputLen); + /* Assume that we'll output exactly on a byte boundary. */ + final byte[] output = new byte[maximumLen]; + + int bytesWritten; + if (inputLen > 0) { + try { + bytesWritten = updateInternal(input, inputOffset, inputLen, output, 0, maximumLen); + } catch (ShortBufferException e) { + /* This should not happen since we sized our own buffer. */ + throw new RuntimeException("our calculated buffer was too small", e); + } + } else { + bytesWritten = 0; + } + reset(); - return 0; + + if (bytesWritten == output.length) { + return output; + } else if (bytesWritten == 0) { + return EmptyArray.BYTE; + } else { + return Arrays.copyOfRange(output, 0, bytesWritten); + } + } + + @Override + protected int engineDoFinal(byte[] input, int inputOffset, int inputLen, byte[] output, + int outputOffset) throws ShortBufferException, IllegalBlockSizeException, + BadPaddingException { + if (output == null) { + throw new NullPointerException("output == null"); + } + + int maximumLen = getOutputSizeForFinal(inputLen); + + final int bytesWritten; + if (inputLen > 0) { + bytesWritten = updateInternal(input, inputOffset, inputLen, output, outputOffset, + maximumLen); + outputOffset += bytesWritten; + maximumLen -= bytesWritten; + } else { + bytesWritten = 0; + } + + reset(); + + return bytesWritten; } private void reset() { diff --git a/common/src/main/java/org/conscrypt/OpenSSLEvpCipher.java b/common/src/main/java/org/conscrypt/OpenSSLEvpCipher.java index f52712711..b53efe83b 100644 --- a/common/src/main/java/org/conscrypt/OpenSSLEvpCipher.java +++ b/common/src/main/java/org/conscrypt/OpenSSLEvpCipher.java @@ -20,6 +20,7 @@ import java.security.InvalidKeyException; import java.security.SecureRandom; import java.security.spec.AlgorithmParameterSpec; +import java.util.Arrays; import javax.crypto.BadPaddingException; import javax.crypto.IllegalBlockSizeException; import javax.crypto.ShortBufferException; @@ -127,7 +128,6 @@ int updateInternal(byte[] input, int inputOffset, int inputLen, byte[] output, return outputOffset - intialOutputOffset; } - @Override int doFinalInternal(byte[] output, int outputOffset, int maximumLen) throws IllegalBlockSizeException, BadPaddingException, ShortBufferException { /* Remember this so we can tell how many characters were written. */ @@ -163,6 +163,64 @@ int doFinalInternal(byte[] output, int outputOffset, int maximumLen) return outputOffset - initialOutputOffset; } + @Override + protected byte[] engineDoFinal(byte[] input, int inputOffset, int inputLen) + throws IllegalBlockSizeException, BadPaddingException { + final int maximumLen = getOutputSizeForFinal(inputLen); + /* Assume that we'll output exactly on a byte boundary. */ + final byte[] output = new byte[maximumLen]; + + int bytesWritten; + if (inputLen > 0) { + try { + bytesWritten = updateInternal(input, inputOffset, inputLen, output, 0, maximumLen); + } catch (ShortBufferException e) { + /* This should not happen since we sized our own buffer. */ + throw new RuntimeException("our calculated buffer was too small", e); + } + } else { + bytesWritten = 0; + } + + try { + bytesWritten += doFinalInternal(output, bytesWritten, maximumLen - bytesWritten); + } catch (ShortBufferException e) { + /* This should not happen since we sized our own buffer. */ + throw new RuntimeException("our calculated buffer was too small", e); + } + + if (bytesWritten == output.length) { + return output; + } else if (bytesWritten == 0) { + return EmptyArray.BYTE; + } else { + return Arrays.copyOfRange(output, 0, bytesWritten); + } + } + + @Override + protected int engineDoFinal(byte[] input, int inputOffset, int inputLen, byte[] output, + int outputOffset) throws ShortBufferException, IllegalBlockSizeException, + BadPaddingException { + if (output == null) { + throw new NullPointerException("output == null"); + } + + int maximumLen = getOutputSizeForFinal(inputLen); + + final int bytesWritten; + if (inputLen > 0) { + bytesWritten = updateInternal(input, inputOffset, inputLen, output, outputOffset, + maximumLen); + outputOffset += bytesWritten; + maximumLen -= bytesWritten; + } else { + bytesWritten = 0; + } + + return bytesWritten + doFinalInternal(output, outputOffset, maximumLen); + } + @Override int getOutputSizeForFinal(int inputLen) { if (modeBlockSize == 1) { diff --git a/common/src/test/java/org/conscrypt/javax/crypto/CipherBasicsTest.java b/common/src/test/java/org/conscrypt/javax/crypto/CipherBasicsTest.java index 4aef5bb82..d3cdb8c71 100644 --- a/common/src/test/java/org/conscrypt/javax/crypto/CipherBasicsTest.java +++ b/common/src/test/java/org/conscrypt/javax/crypto/CipherBasicsTest.java @@ -22,6 +22,7 @@ import java.nio.ByteBuffer; import java.security.AlgorithmParameters; +import java.security.GeneralSecurityException; import java.security.InvalidAlgorithmParameterException; import java.security.InvalidKeyException; import java.security.Key; @@ -85,6 +86,116 @@ public static void setUp() { TestUtils.assumeAllowsUnsignedCrypto(); } + private enum CallPattern { + DO_FINAL, + DO_FINAL_WITH_OFFSET, + UPDATE_DO_FINAL, + MULTIPLE_UPDATE_DO_FINAL, + UPDATE_DO_FINAL_WITH_OUTPUT_ARRAY, + UPDATE_DO_FINAL_WITH_OUTPUT_ARRAY_AND_OFFSET, + DO_FINAL_WITH_INPUT_OUTPUT_ARRAY, + DO_FINAL_WITH_INPUT_OUTPUT_ARRAY_AND_OFFSET, + UPDATE_DO_FINAL_WITH_INPUT_OUTPUT_ARRAY + } + + /** Concatenates the given arrays into a single array.*/ + byte[] concatArrays(byte[]... arrays) { + int length = 0; + for (byte[] array : arrays) { + if (array == null) { + continue; + } + length += array.length; + } + byte[] result = new byte[length]; + int pos = 0; + for (byte[] array : arrays) { + if (array == null) { + continue; + } + System.arraycopy(array, 0, result, pos, array.length); + pos += array.length; + } + return result; + } + + /** Calls an initialized cipher with different equivalent call patterns. */ + private byte[] callCipher( + Cipher cipher, byte[] input, int expectedOutputLength, CallPattern callPattern) + throws GeneralSecurityException { + switch (callPattern) { + case DO_FINAL: { + return cipher.doFinal(input); + } + case DO_FINAL_WITH_OFFSET: { + byte[] inputCopy = new byte[input.length + 100]; + int inputOffset = 42; + System.arraycopy(input, 0, inputCopy, inputOffset, input.length); + return cipher.doFinal(inputCopy, inputOffset, input.length); + } + case UPDATE_DO_FINAL: { + byte[] output1 = cipher.update(input); + byte[] output2 = cipher.doFinal(); + return concatArrays(output1, output2); + } + case MULTIPLE_UPDATE_DO_FINAL: { + int input1Length = input.length / 2; + int input2Length = input.length - input1Length; + byte[] output1 = cipher.update(input, /*inputOffset= */ 0, input1Length); + int input2Offset = input1Length; + byte[] output2 = cipher.update(input, input2Offset, input2Length); + byte[] output3 = cipher.update(new byte[0]); + byte[] output4 = cipher.doFinal(); + return concatArrays(output1, output2, output3, output4); + } + case UPDATE_DO_FINAL_WITH_OUTPUT_ARRAY: { + byte[] output1 = cipher.update(input); + byte[] output2 = new byte[expectedOutputLength - output1.length]; + int written = cipher.doFinal(output2, /*outputOffset= */ 0); + assertEquals(expectedOutputLength - output1.length, written); + return concatArrays(output1, output2); + } + case UPDATE_DO_FINAL_WITH_OUTPUT_ARRAY_AND_OFFSET: { + byte[] output1 = cipher.update(input); + byte[] output2WithOffset = new byte[expectedOutputLength + 100]; + int outputOffset = 42; + int written = cipher.doFinal(output2WithOffset, outputOffset); + assertEquals(expectedOutputLength - output1.length, written); + byte[] output2 = Arrays.copyOfRange(output2WithOffset, outputOffset, outputOffset + written); + return concatArrays(output1, output2); + } + case DO_FINAL_WITH_INPUT_OUTPUT_ARRAY: { + byte[] output = new byte[expectedOutputLength]; + int written = cipher.doFinal(input, /*inputOffset= */ 0, input.length, output); + assertEquals(expectedOutputLength, written); + return output; + } + case DO_FINAL_WITH_INPUT_OUTPUT_ARRAY_AND_OFFSET: { + byte[] inputWithOffset = new byte[input.length + 100]; + int inputOffset = 37; + System.arraycopy(input, 0, inputWithOffset, inputOffset, input.length); + byte[] outputWithOffset = new byte[expectedOutputLength + 100]; + int outputOffset = 21; + int written = cipher.doFinal( + inputWithOffset, inputOffset, input.length, outputWithOffset, outputOffset); + return Arrays.copyOfRange(outputWithOffset, outputOffset, outputOffset + written); + } + case UPDATE_DO_FINAL_WITH_INPUT_OUTPUT_ARRAY: { + int input1Length = input.length / 2; + byte[] output = new byte[expectedOutputLength]; + int written1 = cipher.update(input, /*inputOffset= */ 0, input1Length, output); + int input2Offset = input1Length; + int input2Length = input.length - input1Length; + int outputOffset = written1; + int written2 = cipher.doFinal( + input, input2Offset, input2Length, output, outputOffset); + assertEquals(expectedOutputLength, written1 + written2); + return output; + } + } + throw new IllegalArgumentException("Unsupported CallPattern: " + callPattern); + } + @Test public void testBasicEncryption() throws Exception { for (Provider p : Security.getProviders()) { @@ -132,25 +243,36 @@ public void testBasicEncryption() throws Exception { } try { - cipher.init(Cipher.ENCRYPT_MODE, key, params); - assertEquals("Provider " + p.getName() + for (CallPattern callPattern: CallPattern.values()) { + cipher.init(Cipher.ENCRYPT_MODE, key, params); + assertEquals("Provider " + p.getName() + ", algorithm " + transformation + " reported the wrong output size", ciphertext.length, cipher.getOutputSize(plaintext.length)); - assertArrayEquals("Provider " + p.getName() - + ", algorithm " + transformation - + " failed on encryption, data is " + Arrays.toString(line), - ciphertext, cipher.doFinal(plaintext)); - - cipher.init(Cipher.DECRYPT_MODE, key, params); - assertEquals("Provider " + p.getName() - + ", algorithm " + transformation - + " reported the wrong output size", - plaintext.length, cipher.getOutputSize(ciphertext.length)); - assertArrayEquals("Provider " + p.getName() - + ", algorithm " + transformation - + " failed on decryption, data is " + Arrays.toString(line), - plaintext, cipher.doFinal(ciphertext)); + byte[] encrypted = callCipher( + cipher, plaintext, ciphertext.length, callPattern); + assertArrayEquals( + "Provider " + p.getName() + ", algorithm " + transformation + + ", CallPattern " + callPattern + + " failed on encryption, data is " + Arrays.toString(line), + ciphertext, encrypted); + + cipher.init(Cipher.DECRYPT_MODE, key, params); + byte[] decrypted; + try { + decrypted = callCipher( + cipher, ciphertext, plaintext.length, callPattern); + } catch (GeneralSecurityException e) { + throw new GeneralSecurityException("Provider " + p.getName() + + ", algorithm " + transformation + ", CallPattern " + callPattern + + " failed on decryption, data is " + Arrays.toString(line), e); + } + assertArrayEquals( + "Provider " + p.getName() + ", algorithm " + transformation + + ", CallPattern " + callPattern + + " failed on decryption, data is " + Arrays.toString(line), + plaintext, decrypted); + } } catch (InvalidKeyException e) { // Some providers may not support raw SecretKeySpec keys, that's allowed } @@ -159,37 +281,53 @@ public void testBasicEncryption() throws Exception { } } + static final byte[] EMPTY_AAD = new byte[0]; + public void arrayBasedAssessment(Cipher cipher, byte[] aad, byte[] tag, byte[] plaintext, byte[] ciphertext, Key key, AlgorithmParameterSpec params, String transformation, Provider p, String[] line) throws Exception { - cipher.init(Cipher.ENCRYPT_MODE, key, params); - if (aad.length > 0) { - cipher.updateAAD(aad); - } - byte[] combinedOutput = new byte[ciphertext.length + tag.length]; - assertEquals("Provider " + p.getName() - + ", algorithm " + transformation - + " reported the wrong output size", - combinedOutput.length, cipher.getOutputSize(plaintext.length)); - System.arraycopy(ciphertext, 0, combinedOutput, 0, ciphertext.length); - System.arraycopy(tag, 0, combinedOutput, ciphertext.length, tag.length); - assertArrayEquals("Provider " + p.getName() - + ", algorithm " + transformation + byte[] combinedCiphertext = new byte[ciphertext.length + tag.length]; + System.arraycopy(ciphertext, 0, combinedCiphertext, 0, ciphertext.length); + System.arraycopy(tag, 0, combinedCiphertext, ciphertext.length, tag.length); + + for (CallPattern callPattern: CallPattern.values()) { + cipher.init(Cipher.ENCRYPT_MODE, key, params); + if (aad.length > 0) { + cipher.updateAAD(aad); + } + assertEquals("Provider " + p.getName() + + ", algorithm " + transformation + + " reported the wrong output size", + combinedCiphertext.length, cipher.getOutputSize(plaintext.length)); + byte[] encrypted = callCipher(cipher, plaintext, combinedCiphertext.length, callPattern); + assertArrayEquals("Provider " + p.getName() + + ", algorithm " + transformation + ", CallPattern " + callPattern + " failed on encryption, data is " + Arrays.toString(line), - combinedOutput, cipher.doFinal(plaintext)); - - cipher.init(Cipher.DECRYPT_MODE, key, params); - if (aad.length > 0) { - cipher.updateAAD(aad); + combinedCiphertext, encrypted); } - assertEquals("Provider " + p.getName() - + ", algorithm " + transformation - + " reported the wrong output size", - plaintext.length, cipher.getOutputSize(combinedOutput.length)); - assertArrayEquals("Provider " + p.getName() - + ", algorithm " + transformation + + for (CallPattern callPattern: CallPattern.values()) { + cipher.init(Cipher.DECRYPT_MODE, key, params); + if (aad.length > 0) { + cipher.updateAAD(aad); + } + assertEquals("Provider " + p.getName() + + ", algorithm " + transformation + + " reported the wrong output size", + plaintext.length, cipher.getOutputSize(combinedCiphertext.length)); + byte[] decrypted; + try { + decrypted = callCipher(cipher, combinedCiphertext, plaintext.length, callPattern); + } catch (GeneralSecurityException e) { + throw new GeneralSecurityException("Provider " + p.getName() + + ", algorithm " + transformation + ", CallPattern " + callPattern + + " failed on decryption, data is " + Arrays.toString(line), e); + } + assertArrayEquals("Provider " + p.getName() + + ", algorithm " + transformation + ", CallPattern " + callPattern + " failed on decryption, data is " + Arrays.toString(line), - plaintext, cipher.doFinal(combinedOutput)); + plaintext, decrypted); + } } @Test @@ -486,3 +624,4 @@ public void testByteBufferShiftedAlias() throws Exception { } } } +