diff --git a/common/src/jni/main/cpp/conscrypt/native_crypto.cc b/common/src/jni/main/cpp/conscrypt/native_crypto.cc index 8a8a5f17b..52b6a9346 100644 --- a/common/src/jni/main/cpp/conscrypt/native_crypto.cc +++ b/common/src/jni/main/cpp/conscrypt/native_crypto.cc @@ -9861,6 +9861,39 @@ static jstring NativeCrypto_SSL_get_current_cipher(JNIEnv* env, jclass, jlong ss return env->NewStringUTF(name); } +static void NativeCrypto_SSL_set1_groups(JNIEnv* env, jclass, jlong sslAddress, + CONSCRYPT_UNUSED jobject sslHolder, jintArray groups) { + CHECK_ERROR_QUEUE_ON_RETURN; + SSL* ssl = to_SSL(env, sslAddress, /* throwIfNull= */ true); + JNI_TRACE("ssl=%p NativeCrypto_SSL_set1_groups groups=%p", ssl, groups); + if (ssl == nullptr) { + // to_SSL already called conscrypt::jniutil::throwNullPointerException + return; + } + if (groups == nullptr) { + conscrypt::jniutil::throwNullPointerException(env, "groups == null"); + return; + } + ScopedIntArrayRO groups_ro(env, groups); + if (groups_ro.get() == nullptr) { + JNI_TRACE("ssl=%p NativeCrypto_SSL_set1_groups => threw exception", ssl); + conscrypt::jniutil::throwOutOfMemory(env, "Unable to allocate buffer for groups"); + return; + } + std::vector groups_vector; + groups_vector.reserve(groups_ro.size()); + const jint* groups_ptr = groups_ro.get(); + for (int i = 0; i < groups_ro.size(); i++) { + groups_vector.push_back(groups_ptr[i]); + } + + if (!SSL_set1_groups(ssl, groups_vector.data(), groups_vector.size())) { + conscrypt::jniutil::throwSSLExceptionStr(env, "Error parsing groups"); + ERR_clear_error(); + return; + } +} + static jstring NativeCrypto_SSL_get_curve_name(JNIEnv* env, jclass, jlong sslAddress, CONSCRYPT_UNUSED jobject sslHolder) { CHECK_ERROR_QUEUE_ON_RETURN; @@ -12497,6 +12530,7 @@ static JNINativeMethod sNativeCryptoMethods[] = { CONSCRYPT_NATIVE_METHOD(SSL_get_servername, "(J" REF_SSL ")Ljava/lang/String;"), CONSCRYPT_NATIVE_METHOD(SSL_do_handshake, "(J" REF_SSL FILE_DESCRIPTOR SSL_CALLBACKS "I)V"), CONSCRYPT_NATIVE_METHOD(SSL_get_current_cipher, "(J" REF_SSL ")Ljava/lang/String;"), + CONSCRYPT_NATIVE_METHOD(SSL_set1_groups, "(J" REF_SSL "[I)V"), CONSCRYPT_NATIVE_METHOD(SSL_get_curve_name, "(J" REF_SSL ")Ljava/lang/String;"), CONSCRYPT_NATIVE_METHOD(SSL_get_version, "(J" REF_SSL ")Ljava/lang/String;"), CONSCRYPT_NATIVE_METHOD(SSL_get0_peer_certificates, "(J" REF_SSL ")[[B"), diff --git a/common/src/main/java/org/conscrypt/NativeCrypto.java b/common/src/main/java/org/conscrypt/NativeCrypto.java index 24afc900d..9269ba708 100644 --- a/common/src/main/java/org/conscrypt/NativeCrypto.java +++ b/common/src/main/java/org/conscrypt/NativeCrypto.java @@ -1343,6 +1343,8 @@ static native void SSL_do_handshake(long ssl, NativeSsl ssl_holder, FileDescript public static native String SSL_get_current_cipher(long ssl, NativeSsl ssl_holder); + public static native void SSL_set1_groups(long ssl, NativeSsl sslHolder, int[] groups); + public static native String SSL_get_curve_name(long ssl, NativeSsl sslHolder); public static native String SSL_get_version(long ssl, NativeSsl ssl_holder); diff --git a/constants/src/gen/cpp/generate_constants.cc b/constants/src/gen/cpp/generate_constants.cc index 874f32e45..27b5f0785 100644 --- a/constants/src/gen/cpp/generate_constants.cc +++ b/constants/src/gen/cpp/generate_constants.cc @@ -60,6 +60,14 @@ int main(int /* argc */, char ** /* argv */) { CONST(EVP_PKEY_ML_DSA_65); CONST(EVP_PKEY_ML_DSA_87); + CONST(NID_X25519); + CONST(NID_X9_62_prime256v1); + CONST(NID_secp384r1); + CONST(NID_secp521r1); + CONST(NID_X25519MLKEM768); + CONST(NID_X25519Kyber768Draft00); + CONST(NID_ML_KEM_1024); + CONST(RSA_PKCS1_PADDING); CONST(RSA_NO_PADDING); CONST(RSA_PKCS1_OAEP_PADDING); diff --git a/openjdk/src/test/java/org/conscrypt/NativeCryptoTest.java b/openjdk/src/test/java/org/conscrypt/NativeCryptoTest.java index ed3469d2c..d470bd38f 100644 --- a/openjdk/src/test/java/org/conscrypt/NativeCryptoTest.java +++ b/openjdk/src/test/java/org/conscrypt/NativeCryptoTest.java @@ -356,6 +356,47 @@ public void test_SSL_new() throws Exception { NativeCrypto.SSL_CTX_free(c, null); } + @Test + public void setGroupsList_validGroups_works() throws Exception { + long c = NativeCrypto.SSL_CTX_new(); + long s = NativeCrypto.SSL_new(c, null); + + NativeCrypto.SSL_set1_groups(s, null, new int[] {NativeConstants.NID_X25519}); + NativeCrypto.SSL_set1_groups(s, null, new int[] {NativeConstants.NID_X9_62_prime256v1}); + NativeCrypto.SSL_set1_groups(s, null, new int[] {NativeConstants.NID_secp384r1}); + NativeCrypto.SSL_set1_groups(s, null, new int[] {NativeConstants.NID_secp521r1}); + NativeCrypto.SSL_set1_groups(s, null, new int[] {NativeConstants.NID_X25519MLKEM768}); + NativeCrypto.SSL_set1_groups( + s, null, new int[] {NativeConstants.NID_X25519Kyber768Draft00}); + NativeCrypto.SSL_set1_groups(s, null, new int[] {NativeConstants.NID_ML_KEM_1024}); + + NativeCrypto.SSL_set1_groups(s, null, + new int[] {NativeConstants.NID_X25519, NativeConstants.NID_X9_62_prime256v1, + NativeConstants.NID_secp384r1, NativeConstants.NID_secp521r1, + NativeConstants.NID_X25519MLKEM768, + NativeConstants.NID_X25519Kyber768Draft00, + NativeConstants.NID_ML_KEM_1024}); + + NativeCrypto.SSL_free(s, null); + NativeCrypto.SSL_CTX_free(c, null); + } + + @Test + public void setGroupsList_invalidInput_throws() throws Exception { + long c = NativeCrypto.SSL_CTX_new(); + long s = NativeCrypto.SSL_new(c, null); + + assertThrows(NullPointerException.class, () -> NativeCrypto.SSL_set1_groups(s, null, null)); + + assertThrows(SSLException.class, + () + -> NativeCrypto.SSL_set1_groups( + s, null, new int[] {NativeConstants.EVP_PKEY_RSA})); + + NativeCrypto.SSL_free(s, null); + NativeCrypto.SSL_CTX_free(c, null); + } + @Test public void setLocalCertsAndPrivateKey_withNullSSLShouldThrow() throws Exception { assertThrows(NullPointerException.class,