diff --git a/src/truststore/_api.py b/src/truststore/_api.py index 47b7a63..de087e0 100644 --- a/src/truststore/_api.py +++ b/src/truststore/_api.py @@ -17,7 +17,7 @@ if platform.system() == "Windows": from ._windows import _configure_context, _verify_peercerts_impl elif platform.system() == "Darwin": - from ._macos import _configure_context, _verify_peercerts_impl + from ._macos import _configure_context, _verify_peercerts_impl, _load_default_certs_impl else: from ._openssl import _configure_context, _verify_peercerts_impl @@ -166,7 +166,10 @@ def load_cert_chain( def load_default_certs( self, purpose: ssl.Purpose = ssl.Purpose.SERVER_AUTH ) -> None: - return self._ctx.load_default_certs(purpose) + if sys.platform == "darwin": + return _load_default_certs_impl(self._ctx) + else: + return self._ctx.load_default_certs(purpose) def set_alpn_protocols(self, alpn_protocols: typing.Iterable[str]) -> None: return self._ctx.set_alpn_protocols(alpn_protocols) @@ -189,19 +192,17 @@ def cert_store_stats(self) -> dict[str, int]: def set_default_verify_paths(self) -> None: self._ctx.set_default_verify_paths() - @typing.overload - def get_ca_certs( - self, binary_form: typing.Literal[False] = ... - ) -> list[typing.Any]: ... - @typing.overload def get_ca_certs(self, binary_form: typing.Literal[True] = ...) -> list[bytes]: ... @typing.overload def get_ca_certs(self, binary_form: bool = ...) -> typing.Any: ... - def get_ca_certs(self, binary_form: bool = False) -> list[typing.Any] | list[bytes]: - raise NotImplementedError() + @typing.overload + def get_ca_certs(self, binary_form: bool = ...) -> typing.Any: ... + + def get_ca_certs(self, binary_form: bool = False): + return self._ctx.get_ca_certs(binary_form=binary_form) @property def check_hostname(self) -> bool: diff --git a/src/truststore/_macos.py b/src/truststore/_macos.py index 3450307..623cc35 100644 --- a/src/truststore/_macos.py +++ b/src/truststore/_macos.py @@ -76,7 +76,8 @@ def _load_cdll(name: str, macos10_16_path: str) -> CDLL: CFArrayCallBacks = c_void_p CFOptionFlags = c_uint32 -SecCertificateRef = POINTER(c_void_p) +SecCertificate = c_void_p +SecCertificateRef = POINTER(SecCertificate) SecPolicyRef = POINTER(c_void_p) SecTrustRef = POINTER(c_void_p) SecTrustResultType = c_uint32 @@ -86,7 +87,7 @@ def _load_cdll(name: str, macos10_16_path: str) -> CDLL: Security.SecCertificateCreateWithData.argtypes = [CFAllocatorRef, CFDataRef] Security.SecCertificateCreateWithData.restype = SecCertificateRef - Security.SecCertificateCopyData.argtypes = [SecCertificateRef] + Security.SecCertificateCopyData.argtypes = [SecCertificate] Security.SecCertificateCopyData.restype = CFDataRef Security.SecCopyErrorMessageString.argtypes = [OSStatus, c_void_p] @@ -123,6 +124,9 @@ def _load_cdll(name: str, macos10_16_path: str) -> CDLL: ] Security.SecTrustEvaluate.restype = OSStatus + Security.SecTrustCopyAnchorCertificates.argtypes = [CFArrayRef] + Security.SecTrustCopyAnchorCertificates.restype = OSStatus + Security.SecTrustRef = SecTrustRef # type: ignore[attr-defined] Security.SecTrustResultType = SecTrustResultType # type: ignore[attr-defined] Security.OSStatus = OSStatus # type: ignore[attr-defined] @@ -181,10 +185,10 @@ def _load_cdll(name: str, macos10_16_path: str) -> CDLL: CoreFoundation.CFArrayAppendValue.argtypes = [CFMutableArrayRef, c_void_p] CoreFoundation.CFArrayAppendValue.restype = None - CoreFoundation.CFArrayGetCount.argtypes = [CFArrayRef] + CoreFoundation.CFArrayGetCount.argtypes = [CFArray] CoreFoundation.CFArrayGetCount.restype = CFIndex - CoreFoundation.CFArrayGetValueAtIndex.argtypes = [CFArrayRef, CFIndex] + CoreFoundation.CFArrayGetValueAtIndex.argtypes = [CFArray, CFIndex] CoreFoundation.CFArrayGetValueAtIndex.restype = c_void_p CoreFoundation.CFErrorGetCode.argtypes = [CFErrorRef] @@ -200,6 +204,7 @@ def _load_cdll(name: str, macos10_16_path: str) -> CDLL: CoreFoundation, "kCFTypeArrayCallBacks" ) + CoreFoundation.CFArray = CFArray # type: ignore[attr-defined] CoreFoundation.CFTypeRef = CFTypeRef # type: ignore[attr-defined] CoreFoundation.CFArrayRef = CFArrayRef # type: ignore[attr-defined] CoreFoundation.CFStringRef = CFStringRef # type: ignore[attr-defined] @@ -278,7 +283,7 @@ def _handle_osstatus(result: OSStatus, _: typing.Any, args: typing.Any) -> typin Security.SecTrustSetAnchorCertificatesOnly.errcheck = _handle_osstatus # type: ignore[assignment] Security.SecTrustGetTrustResult.errcheck = _handle_osstatus # type: ignore[assignment] Security.SecTrustEvaluate.errcheck = _handle_osstatus # type: ignore[assignment] - +Security.SecTrustCopyAnchorCertificates.errcheck = _handle_osstatus # type: ignore[assignment] class CFConst: """CoreFoundation constants""" @@ -569,3 +574,39 @@ def _verify_peercerts_impl_macos_10_14( finally: if cf_error_string_ref: CoreFoundation.CFRelease(cf_error_string_ref) + +try: + CoreFoundation.CFRelease.argtypes = [CFType] +except AttributeError as e: + raise ImportError(f"Error initializing ctypes: {e}") from None + +def _load_default_certs_impl(ssl_context: ssl.SSLContext) -> None: + """ + Loads the default system certificates into the SSLContext. + """ + + certs_array = CFArray(None) + Security.SecTrustCopyAnchorCertificates( + ctypes.byref(certs_array) + ) + + count = CoreFoundation.CFArrayGetCount(certs_array) + + for i in range(count): + # Get the certificate from the array + cert_ref = CoreFoundation.CFArrayGetValueAtIndex(certs_array, i) + data_ref = Security.SecCertificateCopyData(cert_ref) + + length = CoreFoundation.CFDataGetLength(data_ref) + data_ptr = CoreFoundation.CFDataGetBytePtr(data_ref) + + cert_bytes = ctypes.string_at(data_ptr, length) + + # Load the certificate into the SSLContext + ssl_context.load_verify_locations( + cadata=cert_bytes + ) + + CoreFoundation.CFRelease(data_ref) + + CoreFoundation.CFRelease(certs_array) \ No newline at end of file