Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 27 additions & 13 deletions connect/eaas/core/egress_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,15 @@ def __init__(
certificates: EgressProxyCertificates,
):
self.proxy = proxy
self.cert_file = self._create_temp_cert_file(certificates.client_cert)
self.key_file = self._create_temp_cert_file(certificates.client_key)
self.ca_file = self._create_temp_cert_file(certificates.ca_cert)
self.cert_file = self._create_cert_file(
'client_cert.pem', certificates.client_cert,
)
self.key_file = self._create_cert_file(
'client_key.pem', certificates.client_key,
)
self.ca_file = self._create_cert_file(
'ca_cert.pem', certificates.ca_cert,
)

super().__init__(
endpoint=self.proxy.url,
Expand All @@ -39,16 +45,24 @@ def __init__(
)

@staticmethod
def _create_temp_cert_file(cert_content):
"""Create a temporary file with certificate content."""
temp_file = tempfile.NamedTemporaryFile(
mode='w',
delete=False,
suffix='.pem',
)
temp_file.write(cert_content)
temp_file.close()
return temp_file.name
def _create_cert_file(filename, cert_content):
"""
Create or reuse a certificate file at a fixed location.

Uses /tmp/<filename> as a stable path. Only writes if the
file doesn't exist or its content has changed.
"""
filepath = os.path.join(tempfile.gettempdir(), filename)

if os.path.exists(filepath):
with open(filepath, 'r') as f:
if f.read() == cert_content:
return filepath

with open(filepath, 'w') as f:
f.write(cert_content)

return filepath

@classmethod
def require_proxy(cls, account_id: str):
Expand Down
147 changes: 137 additions & 10 deletions tests/connect/eaas/core/test_egress_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,18 +596,20 @@ def test_validate_headers_success_with_empty_headers_list(
client._validate_headers(headers)


# Tests for _create_temp_cert_file static method
# Tests for _create_cert_file static method


def test_create_temp_cert_file_creates_file():
"""Test _create_temp_cert_file creates a temporary file."""
def test_create_cert_file_creates_file():
"""Test _create_cert_file creates a file at a fixed location."""
cert_content = (
'-----BEGIN CERTIFICATE-----\n'
'TEST\n'
'-----END CERTIFICATE-----'
)

filepath = EgressProxyClient._create_temp_cert_file(cert_content)
filepath = EgressProxyClient._create_cert_file(
'test_cert.pem', cert_content,
)

try:
assert os.path.exists(filepath)
Expand All @@ -620,13 +622,107 @@ def test_create_temp_cert_file_creates_file():
os.unlink(filepath)


def test_create_temp_cert_file_multiple_calls_create_different_files():
"""Test multiple calls create different temp files."""
cert1 = '-----BEGIN CERTIFICATE-----\nCERT1\n-----END CERTIFICATE-----'
cert2 = '-----BEGIN CERTIFICATE-----\nCERT2\n-----END CERTIFICATE-----'
def test_create_cert_file_uses_deterministic_path():
"""Test _create_cert_file returns the same path for same filename."""
cert_content = (
'-----BEGIN CERTIFICATE-----\n'
'CERT\n'
'-----END CERTIFICATE-----'
)

filepath1 = EgressProxyClient._create_cert_file(
'determ_cert.pem', cert_content,
)
filepath2 = EgressProxyClient._create_cert_file(
'determ_cert.pem', cert_content,
)

try:
assert filepath1 == filepath2
finally:
if os.path.exists(filepath1):
os.unlink(filepath1)


def test_create_cert_file_writes_only_once():
"""
Test _create_cert_file writes the file only on first call.
Subsequent calls with same content reuse existing file.
"""
filename = 'write_once_cert.pem'
cert_content = (
'-----BEGIN CERTIFICATE-----\n'
'ONCE\n'
'-----END CERTIFICATE-----'
)
filepath = EgressProxyClient._create_cert_file(
filename, cert_content,
)

try:
mtime_before = os.path.getmtime(filepath)

EgressProxyClient._create_cert_file(
filename, cert_content,
)

mtime_after = os.path.getmtime(filepath)
assert mtime_before == mtime_after
finally:
if os.path.exists(filepath):
os.unlink(filepath)


def test_create_cert_file_rewrites_when_content_changes():
"""
Test _create_cert_file rewrites the file when content changes.
"""
filename = 'rewrite_cert.pem'
old_content = (
'-----BEGIN CERTIFICATE-----\n'
'OLD\n'
'-----END CERTIFICATE-----'
)
new_content = (
'-----BEGIN CERTIFICATE-----\n'
'NEW\n'
'-----END CERTIFICATE-----'
)
filepath = EgressProxyClient._create_cert_file(
filename, old_content,
)

try:
EgressProxyClient._create_cert_file(
filename, new_content,
)

with open(filepath, 'r') as f:
assert f.read() == new_content
finally:
if os.path.exists(filepath):
os.unlink(filepath)


def test_create_cert_file_different_names_create_different_files():
"""Test different filenames create different files."""
cert1 = (
'-----BEGIN CERTIFICATE-----\n'
'CERT1\n'
'-----END CERTIFICATE-----'
)
cert2 = (
'-----BEGIN CERTIFICATE-----\n'
'CERT2\n'
'-----END CERTIFICATE-----'
)

filepath1 = EgressProxyClient._create_temp_cert_file(cert1)
filepath2 = EgressProxyClient._create_temp_cert_file(cert2)
filepath1 = EgressProxyClient._create_cert_file(
'diff_cert1.pem', cert1,
)
filepath2 = EgressProxyClient._create_cert_file(
'diff_cert2.pem', cert2,
)

try:
assert filepath1 != filepath2
Expand All @@ -645,6 +741,37 @@ def test_create_temp_cert_file_multiple_calls_create_different_files():
os.unlink(filepath2)


def test_cert_files_created_once_across_multiple_clients(
egress_proxy, certificates, cleanup_client,
):
"""
Test that creating multiple clients reuses existing cert files
without rewriting them.
"""
client1 = cleanup_client(
EgressProxyClient(
proxy=egress_proxy, certificates=certificates,
),
)

mtime_cert = os.path.getmtime(client1.cert_file)
mtime_key = os.path.getmtime(client1.key_file)
mtime_ca = os.path.getmtime(client1.ca_file)

client2 = cleanup_client(
EgressProxyClient(
proxy=egress_proxy, certificates=certificates,
),
)

assert client1.cert_file == client2.cert_file
assert client1.key_file == client2.key_file
assert client1.ca_file == client2.ca_file
assert os.path.getmtime(client2.cert_file) == mtime_cert
assert os.path.getmtime(client2.key_file) == mtime_key
assert os.path.getmtime(client2.ca_file) == mtime_ca


@responses.activate
def test_full_workflow_with_required_headers(
env_vars, proxy_config, cleanup_client,
Expand Down