Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions src/truststore/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
if platform.system() == "Windows":
from ._windows import _configure_context, _verify_peercerts_impl
elif platform.system() == "Darwin":
from ._macos import _configure_context, _verify_peercerts_impl
from ._macos import _configure_context, _verify_peercerts_impl, _load_default_certs_impl
else:
from ._openssl import _configure_context, _verify_peercerts_impl

Expand Down Expand Up @@ -166,7 +166,10 @@ def load_cert_chain(
def load_default_certs(
self, purpose: ssl.Purpose = ssl.Purpose.SERVER_AUTH
) -> None:
return self._ctx.load_default_certs(purpose)
if sys.platform == "darwin":
return _load_default_certs_impl(self._ctx)
else:
return self._ctx.load_default_certs(purpose)

def set_alpn_protocols(self, alpn_protocols: typing.Iterable[str]) -> None:
return self._ctx.set_alpn_protocols(alpn_protocols)
Expand All @@ -189,19 +192,17 @@ def cert_store_stats(self) -> dict[str, int]:
def set_default_verify_paths(self) -> None:
self._ctx.set_default_verify_paths()

@typing.overload
def get_ca_certs(
self, binary_form: typing.Literal[False] = ...
) -> list[typing.Any]: ...

@typing.overload
def get_ca_certs(self, binary_form: typing.Literal[True] = ...) -> list[bytes]: ...

@typing.overload
def get_ca_certs(self, binary_form: bool = ...) -> typing.Any: ...

def get_ca_certs(self, binary_form: bool = False) -> list[typing.Any] | list[bytes]:
raise NotImplementedError()
@typing.overload
def get_ca_certs(self, binary_form: bool = ...) -> typing.Any: ...

def get_ca_certs(self, binary_form: bool = False):
return self._ctx.get_ca_certs(binary_form=binary_form)

@property
def check_hostname(self) -> bool:
Expand Down
51 changes: 46 additions & 5 deletions src/truststore/_macos.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def _load_cdll(name: str, macos10_16_path: str) -> CDLL:
CFArrayCallBacks = c_void_p
CFOptionFlags = c_uint32

SecCertificateRef = POINTER(c_void_p)
SecCertificate = c_void_p
SecCertificateRef = POINTER(SecCertificate)
SecPolicyRef = POINTER(c_void_p)
SecTrustRef = POINTER(c_void_p)
SecTrustResultType = c_uint32
Expand All @@ -86,7 +87,7 @@ def _load_cdll(name: str, macos10_16_path: str) -> CDLL:
Security.SecCertificateCreateWithData.argtypes = [CFAllocatorRef, CFDataRef]
Security.SecCertificateCreateWithData.restype = SecCertificateRef

Security.SecCertificateCopyData.argtypes = [SecCertificateRef]
Security.SecCertificateCopyData.argtypes = [SecCertificate]
Security.SecCertificateCopyData.restype = CFDataRef

Security.SecCopyErrorMessageString.argtypes = [OSStatus, c_void_p]
Expand Down Expand Up @@ -123,6 +124,9 @@ def _load_cdll(name: str, macos10_16_path: str) -> CDLL:
]
Security.SecTrustEvaluate.restype = OSStatus

Security.SecTrustCopyAnchorCertificates.argtypes = [CFArrayRef]
Security.SecTrustCopyAnchorCertificates.restype = OSStatus

Security.SecTrustRef = SecTrustRef # type: ignore[attr-defined]
Security.SecTrustResultType = SecTrustResultType # type: ignore[attr-defined]
Security.OSStatus = OSStatus # type: ignore[attr-defined]
Expand Down Expand Up @@ -181,10 +185,10 @@ def _load_cdll(name: str, macos10_16_path: str) -> CDLL:
CoreFoundation.CFArrayAppendValue.argtypes = [CFMutableArrayRef, c_void_p]
CoreFoundation.CFArrayAppendValue.restype = None

CoreFoundation.CFArrayGetCount.argtypes = [CFArrayRef]
CoreFoundation.CFArrayGetCount.argtypes = [CFArray]
CoreFoundation.CFArrayGetCount.restype = CFIndex

CoreFoundation.CFArrayGetValueAtIndex.argtypes = [CFArrayRef, CFIndex]
CoreFoundation.CFArrayGetValueAtIndex.argtypes = [CFArray, CFIndex]
CoreFoundation.CFArrayGetValueAtIndex.restype = c_void_p

CoreFoundation.CFErrorGetCode.argtypes = [CFErrorRef]
Expand All @@ -200,6 +204,7 @@ def _load_cdll(name: str, macos10_16_path: str) -> CDLL:
CoreFoundation, "kCFTypeArrayCallBacks"
)

CoreFoundation.CFArray = CFArray # type: ignore[attr-defined]
CoreFoundation.CFTypeRef = CFTypeRef # type: ignore[attr-defined]
CoreFoundation.CFArrayRef = CFArrayRef # type: ignore[attr-defined]
CoreFoundation.CFStringRef = CFStringRef # type: ignore[attr-defined]
Expand Down Expand Up @@ -278,7 +283,7 @@ def _handle_osstatus(result: OSStatus, _: typing.Any, args: typing.Any) -> typin
Security.SecTrustSetAnchorCertificatesOnly.errcheck = _handle_osstatus # type: ignore[assignment]
Security.SecTrustGetTrustResult.errcheck = _handle_osstatus # type: ignore[assignment]
Security.SecTrustEvaluate.errcheck = _handle_osstatus # type: ignore[assignment]

Security.SecTrustCopyAnchorCertificates.errcheck = _handle_osstatus # type: ignore[assignment]

class CFConst:
"""CoreFoundation constants"""
Expand Down Expand Up @@ -569,3 +574,39 @@ def _verify_peercerts_impl_macos_10_14(
finally:
if cf_error_string_ref:
CoreFoundation.CFRelease(cf_error_string_ref)

try:
CoreFoundation.CFRelease.argtypes = [CFType]
except AttributeError as e:
raise ImportError(f"Error initializing ctypes: {e}") from None

def _load_default_certs_impl(ssl_context: ssl.SSLContext) -> None:
"""
Loads the default system certificates into the SSLContext.
"""

certs_array = CFArray(None)
Security.SecTrustCopyAnchorCertificates(
ctypes.byref(certs_array)
)

count = CoreFoundation.CFArrayGetCount(certs_array)

for i in range(count):
# Get the certificate from the array
cert_ref = CoreFoundation.CFArrayGetValueAtIndex(certs_array, i)
data_ref = Security.SecCertificateCopyData(cert_ref)

length = CoreFoundation.CFDataGetLength(data_ref)
data_ptr = CoreFoundation.CFDataGetBytePtr(data_ref)

cert_bytes = ctypes.string_at(data_ptr, length)

# Load the certificate into the SSLContext
ssl_context.load_verify_locations(
cadata=cert_bytes
)

CoreFoundation.CFRelease(data_ref)

CoreFoundation.CFRelease(certs_array)
Loading