diff --git a/connect/eaas/core/egress_proxy.py b/connect/eaas/core/egress_proxy.py index 7cf6050..2260768 100644 --- a/connect/eaas/core/egress_proxy.py +++ b/connect/eaas/core/egress_proxy.py @@ -27,9 +27,15 @@ def __init__( certificates: EgressProxyCertificates, ): self.proxy = proxy - self.cert_file = self._create_temp_cert_file(certificates.client_cert) - self.key_file = self._create_temp_cert_file(certificates.client_key) - self.ca_file = self._create_temp_cert_file(certificates.ca_cert) + self.cert_file = self._create_cert_file( + 'client_cert.pem', certificates.client_cert, + ) + self.key_file = self._create_cert_file( + 'client_key.pem', certificates.client_key, + ) + self.ca_file = self._create_cert_file( + 'ca_cert.pem', certificates.ca_cert, + ) super().__init__( endpoint=self.proxy.url, @@ -39,16 +45,24 @@ def __init__( ) @staticmethod - def _create_temp_cert_file(cert_content): - """Create a temporary file with certificate content.""" - temp_file = tempfile.NamedTemporaryFile( - mode='w', - delete=False, - suffix='.pem', - ) - temp_file.write(cert_content) - temp_file.close() - return temp_file.name + def _create_cert_file(filename, cert_content): + """ + Create or reuse a certificate file at a fixed location. + + Uses /tmp/ as a stable path. Only writes if the + file doesn't exist or its content has changed. + """ + filepath = os.path.join(tempfile.gettempdir(), filename) + + if os.path.exists(filepath): + with open(filepath, 'r') as f: + if f.read() == cert_content: + return filepath + + with open(filepath, 'w') as f: + f.write(cert_content) + + return filepath @classmethod def require_proxy(cls, account_id: str): diff --git a/tests/connect/eaas/core/test_egress_proxy.py b/tests/connect/eaas/core/test_egress_proxy.py index 73eacf2..9ac17ec 100644 --- a/tests/connect/eaas/core/test_egress_proxy.py +++ b/tests/connect/eaas/core/test_egress_proxy.py @@ -596,18 +596,20 @@ def test_validate_headers_success_with_empty_headers_list( client._validate_headers(headers) -# Tests for _create_temp_cert_file static method +# Tests for _create_cert_file static method -def test_create_temp_cert_file_creates_file(): - """Test _create_temp_cert_file creates a temporary file.""" +def test_create_cert_file_creates_file(): + """Test _create_cert_file creates a file at a fixed location.""" cert_content = ( '-----BEGIN CERTIFICATE-----\n' 'TEST\n' '-----END CERTIFICATE-----' ) - filepath = EgressProxyClient._create_temp_cert_file(cert_content) + filepath = EgressProxyClient._create_cert_file( + 'test_cert.pem', cert_content, + ) try: assert os.path.exists(filepath) @@ -620,13 +622,107 @@ def test_create_temp_cert_file_creates_file(): os.unlink(filepath) -def test_create_temp_cert_file_multiple_calls_create_different_files(): - """Test multiple calls create different temp files.""" - cert1 = '-----BEGIN CERTIFICATE-----\nCERT1\n-----END CERTIFICATE-----' - cert2 = '-----BEGIN CERTIFICATE-----\nCERT2\n-----END CERTIFICATE-----' +def test_create_cert_file_uses_deterministic_path(): + """Test _create_cert_file returns the same path for same filename.""" + cert_content = ( + '-----BEGIN CERTIFICATE-----\n' + 'CERT\n' + '-----END CERTIFICATE-----' + ) + + filepath1 = EgressProxyClient._create_cert_file( + 'determ_cert.pem', cert_content, + ) + filepath2 = EgressProxyClient._create_cert_file( + 'determ_cert.pem', cert_content, + ) + + try: + assert filepath1 == filepath2 + finally: + if os.path.exists(filepath1): + os.unlink(filepath1) + + +def test_create_cert_file_writes_only_once(): + """ + Test _create_cert_file writes the file only on first call. + Subsequent calls with same content reuse existing file. + """ + filename = 'write_once_cert.pem' + cert_content = ( + '-----BEGIN CERTIFICATE-----\n' + 'ONCE\n' + '-----END CERTIFICATE-----' + ) + filepath = EgressProxyClient._create_cert_file( + filename, cert_content, + ) + + try: + mtime_before = os.path.getmtime(filepath) + + EgressProxyClient._create_cert_file( + filename, cert_content, + ) + + mtime_after = os.path.getmtime(filepath) + assert mtime_before == mtime_after + finally: + if os.path.exists(filepath): + os.unlink(filepath) + + +def test_create_cert_file_rewrites_when_content_changes(): + """ + Test _create_cert_file rewrites the file when content changes. + """ + filename = 'rewrite_cert.pem' + old_content = ( + '-----BEGIN CERTIFICATE-----\n' + 'OLD\n' + '-----END CERTIFICATE-----' + ) + new_content = ( + '-----BEGIN CERTIFICATE-----\n' + 'NEW\n' + '-----END CERTIFICATE-----' + ) + filepath = EgressProxyClient._create_cert_file( + filename, old_content, + ) + + try: + EgressProxyClient._create_cert_file( + filename, new_content, + ) + + with open(filepath, 'r') as f: + assert f.read() == new_content + finally: + if os.path.exists(filepath): + os.unlink(filepath) + + +def test_create_cert_file_different_names_create_different_files(): + """Test different filenames create different files.""" + cert1 = ( + '-----BEGIN CERTIFICATE-----\n' + 'CERT1\n' + '-----END CERTIFICATE-----' + ) + cert2 = ( + '-----BEGIN CERTIFICATE-----\n' + 'CERT2\n' + '-----END CERTIFICATE-----' + ) - filepath1 = EgressProxyClient._create_temp_cert_file(cert1) - filepath2 = EgressProxyClient._create_temp_cert_file(cert2) + filepath1 = EgressProxyClient._create_cert_file( + 'diff_cert1.pem', cert1, + ) + filepath2 = EgressProxyClient._create_cert_file( + 'diff_cert2.pem', cert2, + ) try: assert filepath1 != filepath2 @@ -645,6 +741,37 @@ def test_create_temp_cert_file_multiple_calls_create_different_files(): os.unlink(filepath2) +def test_cert_files_created_once_across_multiple_clients( + egress_proxy, certificates, cleanup_client, +): + """ + Test that creating multiple clients reuses existing cert files + without rewriting them. + """ + client1 = cleanup_client( + EgressProxyClient( + proxy=egress_proxy, certificates=certificates, + ), + ) + + mtime_cert = os.path.getmtime(client1.cert_file) + mtime_key = os.path.getmtime(client1.key_file) + mtime_ca = os.path.getmtime(client1.ca_file) + + client2 = cleanup_client( + EgressProxyClient( + proxy=egress_proxy, certificates=certificates, + ), + ) + + assert client1.cert_file == client2.cert_file + assert client1.key_file == client2.key_file + assert client1.ca_file == client2.ca_file + assert os.path.getmtime(client2.cert_file) == mtime_cert + assert os.path.getmtime(client2.key_file) == mtime_key + assert os.path.getmtime(client2.ca_file) == mtime_ca + + @responses.activate def test_full_workflow_with_required_headers( env_vars, proxy_config, cleanup_client,