Skip to content

Race condition on context setting #209

@stefanv

Description

@stefanv

This is to bring your attention to the race condition identified at nominal-io/nominal-client#620

The issue is with the following (on macos and windows only):

        with contextlib.ExitStack() as stack:
            with self._ctx_lock:
                stack.enter_context(_configure_context(self._ctx))

On macos, e.g., _configure_context looks like this:

@contextlib.contextmanager
def _configure_context(ctx: ssl.SSLContext) -> typing.Iterator[None]:
    check_hostname = ctx.check_hostname
    verify_mode = ctx.verify_mode
    ctx.check_hostname = False
    _set_ssl_context_verify_mode(ctx, ssl.CERT_NONE)
    try:
        yield
    finally:
        ctx.check_hostname = check_hostname
        _set_ssl_context_verify_mode(ctx, verify_mode)

So, you have _set_ssl_context_verify_mode make the verify_mode ssl.CERT_NONE before yielding. Then, the _ctx_lock is released. wrap_socket then presumably sets it to a sensible value, but meanwhile other threads can access it.

Relevant from that thread:

When recycling connections, urllib3 calls load_verify_locations on the shared context, snapshots verify_mode before the call, and restores it after. If this snapshot is taken while truststore has temporarily set verify_mode=CERT_NONE, urllib3 "restores" the corrupted value, permanently leaving the context in an unverified state for that connection

I asked Claude to come up with a reproducer, which I attach here in case that is helpful to you. It reproduces two types of errors (Demo 1 and Demo 2); I was trying to illustrate nr 2. Because I'm testing on Linux, I had to swap out the context manager with the mac version. Copying it works fine, but I used it directly here (with some mac calls mocked out), to show that it really is an issue with the code in truststore itself.

Run server.py in one terminal and client.py in another.

server.py:

Details
#!/usr/bin/env python3
"""
HTTPS server for the truststore race condition reproducer.

Generates a self-signed certificate for localhost and serves HTTPS on port 4443.
Run this first, then run client.py in another terminal.
"""

import datetime
import ipaddress
import pathlib
import ssl
from http.server import BaseHTTPRequestHandler, HTTPServer

from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509.oid import NameOID

CERT_FILE = pathlib.Path(__file__).parent / "server.crt"
KEY_FILE = pathlib.Path(__file__).parent / "server.key"
HOST = "127.0.0.1"
PORT = 4443


def generate_self_signed_cert() -> None:
    key = rsa.generate_private_key(public_exponent=65537, key_size=2048)

    subject = issuer = x509.Name([
        x509.NameAttribute(NameOID.COMMON_NAME, "localhost"),
    ])
    cert = (
        x509.CertificateBuilder()
        .subject_name(subject)
        .issuer_name(issuer)
        .public_key(key.public_key())
        .serial_number(x509.random_serial_number())
        .not_valid_before(datetime.datetime.now(datetime.UTC))
        .not_valid_after(datetime.datetime.now(datetime.UTC) + datetime.timedelta(days=365))
        .add_extension(
            x509.SubjectAlternativeName([
                x509.DNSName("localhost"),
                x509.IPAddress(ipaddress.IPv4Address("127.0.0.1")),
            ]),
            critical=False,
        )
        .sign(key, hashes.SHA256())
    )

    CERT_FILE.write_bytes(cert.public_bytes(serialization.Encoding.PEM))
    KEY_FILE.write_bytes(
        key.private_bytes(
            serialization.Encoding.PEM,
            serialization.PrivateFormat.TraditionalOpenSSL,
            serialization.NoEncryption(),
        )
    )
    print(f"Generated self-signed cert: {CERT_FILE}")


class QuietHandler(BaseHTTPRequestHandler):
    def do_GET(self) -> None:
        self.send_response(200)
        self.end_headers()
        self.wfile.write(b"OK")

    def log_message(self, format: str, *args: object) -> None:
        pass  # suppress per-request logs


def main() -> None:
    if not CERT_FILE.exists() or not KEY_FILE.exists():
        generate_self_signed_cert()
    else:
        print(f"Reusing existing cert: {CERT_FILE}")

    ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
    ctx.load_cert_chain(certfile=CERT_FILE, keyfile=KEY_FILE)

    httpd = HTTPServer((HOST, PORT), QuietHandler)
    httpd.socket = ctx.wrap_socket(httpd.socket, server_side=True)

    print(f"Listening on https://{HOST}:{PORT}  (Ctrl-C to stop)")
    httpd.serve_forever()


if __name__ == "__main__":
    main()

client.py:

Details
#!/usr/bin/env python3
"""
Reproducer for the truststore race condition (GitHub PR #620).

The bug
-------
truststore.SSLContext.wrap_socket() uses an ExitStack to hold a
_configure_context() context manager.  On macOS/Windows that context manager
temporarily sets verify_mode=CERT_NONE (so that the OS trust store, not
OpenSSL, does the real cert check).  The critical mistake is that _ctx_lock
is released *before* the TLS handshake completes, while the context manager
– which will later restore verify_mode – is still live on the ExitStack.

    with contextlib.ExitStack() as stack:
        with self._ctx_lock:                         # (1) acquire lock
            stack.enter_context(_configure_context(self._ctx))  # sets CERT_NONE
        # (2) LOCK RELEASED – CERT_NONE is still set on self._ctx
        ssl_sock = self._ctx.wrap_socket(...)        # (3) handshake (slow)
    # (4) ExitStack exits → _configure_context restores the value it saved at (1)

Race window = between (2) and (4).

If Thread B enters wrap_socket while Thread A is between (2) and (4):
  - Thread B acquires the lock (fine, A released it)
  - _configure_context reads ctx.verify_mode → CERT_NONE  (corrupted!)
  - Thread B releases the lock, does its handshake
  - Thread A finishes: ExitStack restores CERT_REQUIRED (saved correctly in (1))
  - Thread B finishes: ExitStack "restores" CERT_NONE    (saved corrupted value)
  → ctx is now permanently CERT_NONE

How to run
----------
    # Terminal 1
    python server.py

    # Terminal 2
    python client.py

What to expect on macOS/Windows (native)
-----------------------------------------
Flood of InsecureRequestWarning lines + final verify_mode=CERT_NONE.

What this script does for Linux
--------------------------------
Linux's truststore._openssl._configure_context never sets CERT_NONE, so the
native path is immune.  This script monkey-patches _configure_context to the
macOS/Windows behaviour so the race can be demonstrated on any platform.

Two demos are run:
  1. Direct – calls wrap_socket() from many threads simultaneously and checks
     whether verify_mode is permanently corrupted afterward.
  2. urllib3 – makes concurrent HTTP requests through a shared SSLContext;
     threads that hit the race window get connections stamped is_verified=False
     and emit InsecureRequestWarning on every subsequent request.
"""

import socket
import ssl
import sys
import threading
import warnings
from pathlib import Path

import urllib3
from urllib3.exceptions import InsecureRequestWarning

# ── locate the server cert generated by server.py ───────────────────────────
CERT_FILE = Path(__file__).parent / "server.crt"
HOST = "127.0.0.1"
PORT = 4443

# ── patch truststore to simulate the macOS/Windows _configure_context ───────
#
# On Linux, _openssl._configure_context only loads CA certs (no CERT_NONE).
# We replace it inside _api with the macOS/Windows version so the race window
# exists on this platform.

import platform
from unittest.mock import MagicMock, patch

import truststore._api as _truststore_api
import truststore

if platform.system() not in ("Darwin", "Windows"):
    # On Linux, patch in the macOS _configure_context to get the same behaviour.
    # _macos.py can't be imported directly because its module-level code calls
    # platform.mac_ver() and loads macOS-only DLLs via ctypes.CDLL; mock those.
    sys.modules.pop("truststore._macos", None)
    with patch("platform.mac_ver", return_value=("10.15.0", ("", "", ""), "")), \
         patch("ctypes.CDLL", return_value=MagicMock()), \
         patch("ctypes.util.find_library", return_value="/fake/path"), \
         patch("ctypes.c_void_p.in_dll", return_value=MagicMock()):
        from truststore._macos import _configure_context as _buggy_configure_context

    _truststore_api._configure_context = _buggy_configure_context  # type: ignore[attr-defined]


# ── Demo 1: direct verify_mode corruption ───────────────────────────────────

def demo_direct_corruption(n_threads: int = 32) -> None:
    """
    Calls wrap_socket() from many threads sharing one SSLContext.
    With the macOS-like _configure_context, a corrupted verify_mode=CERT_NONE
    can be saved by a thread that enters the lock while another thread's
    CERT_NONE is still visible, permanently corrupting the context.
    """
    print("=" * 60)
    print("Demo 1: direct verify_mode corruption via wrap_socket()")
    print("=" * 60)

    ctx = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
    ctx.load_verify_locations(cafile=str(CERT_FILE))
    print(f"  Before: verify_mode = {ctx._ctx.verify_mode!r}")

    errors: list[str] = []
    barrier = threading.Barrier(n_threads)

    def connect_once() -> None:
        barrier.wait()  # all threads start simultaneously
        try:
            raw = socket.create_connection((HOST, PORT), timeout=5)
            ssl_sock = ctx.wrap_socket(raw, server_hostname="127.0.0.1")
            ssl_sock.close()
        except Exception as exc:
            errors.append(str(exc))

    threads = [threading.Thread(target=connect_once) for _ in range(n_threads)]
    for t in threads:
        t.start()
    for t in threads:
        t.join()

    final = ctx._ctx.verify_mode
    print(f"  After:  verify_mode = {final!r}")
    if errors:
        print(f"  (connection errors, expected with self-signed cert + CERT_REQUIRED: "
              f"{len(errors)} of {n_threads})")
    if final == ssl.CERT_NONE:
        print("  ✗ BUG: verify_mode permanently corrupted to CERT_NONE\n")
    else:
        print("  ✓ verify_mode intact this run (race is non-deterministic; try again)\n")


# ── Demo 2: urllib3 InsecureRequestWarning flood ────────────────────────────

def demo_urllib3_warnings(n_threads: int = 16, requests_per_thread: int = 20) -> None:
    """
    Makes concurrent HTTP requests through a shared truststore SSLContext.
    Threads that open connections while another thread holds verify_mode=CERT_NONE
    get urllib3 connection objects stamped with is_verified=False, causing
    InsecureRequestWarning on every subsequent request through that connection.
    """
    print("=" * 60)
    print("Demo 2: InsecureRequestWarning flood via urllib3")
    print("=" * 60)

    ctx = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
    ctx.load_verify_locations(cafile=str(CERT_FILE))

    # One shared pool → one shared SSLContext, mirroring requests' behaviour.
    pool = urllib3.HTTPSConnectionPool(HOST, port=PORT, ssl_context=ctx, cert_reqs=None)

    # Warm-up: verify server is reachable.
    try:
        resp = pool.request("GET", "/")
        print(f"  Warm-up: HTTP {resp.status}")
    except Exception as exc:
        print(f"  Cannot reach server: {exc}")
        print("  Is `python server.py` running?")
        return

    warning_count = 0
    lock = threading.Lock()
    barrier = threading.Barrier(n_threads)

    def worker() -> None:
        nonlocal warning_count
        barrier.wait()
        for _ in range(requests_per_thread):
            with warnings.catch_warnings(record=True) as caught:
                warnings.simplefilter("always")
                try:
                    pool.request("GET", "/")
                except Exception:
                    pass
            insecure = [
                w for w in caught
                if issubclass(w.category, InsecureRequestWarning)
            ]
            if insecure:
                with lock:
                    warning_count += len(insecure)
                    tname = threading.current_thread().name
                    print(f"  [{tname}] InsecureRequestWarning  "
                          f"(running total: {warning_count})")

    print(f"  Launching {n_threads} threads × {requests_per_thread} requests …")
    threads = [
        threading.Thread(target=worker, name=f"T{i:02d}")
        for i in range(n_threads)
    ]
    for t in threads:
        t.start()
    for t in threads:
        t.join()

    print(f"\n  Total InsecureRequestWarnings: {warning_count}")
    if warning_count > 0:
        print(
            "  ✗ BUG: warnings fired because urllib3 saw verify_mode=CERT_NONE\n"
            "    during the race window and marked those connections is_verified=False.\n"
            "    All future requests on those connections continue to warn.\n"
        )
    else:
        print("  ✓ No warnings this run (race is non-deterministic; try again)\n")


# ── entry point ─────────────────────────────────────────────────────────────

def main() -> None:
    if not CERT_FILE.exists():
        sys.exit(f"Certificate not found: {CERT_FILE}\nRun `python server.py` first.")

    demo_direct_corruption()
    demo_urllib3_warnings()


if __name__ == "__main__":
    main()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions