diff --git a/aenv/src/cli/cmds/push.py b/aenv/src/cli/cmds/push.py index 6e192f66..819eac27 100644 --- a/aenv/src/cli/cmds/push.py +++ b/aenv/src/cli/cmds/push.py @@ -22,6 +22,7 @@ from cli.client.aenv_hub_client import AEnvHubClient, AEnvHubError, EnvStatus from cli.extends.storage.storage_manager import StorageContext, load_storage +from cli.utils.parallel import parallel_execute @click.command() @@ -59,16 +60,33 @@ def push(work_dir, dry_run, force, version): err=False, ) hub_client = AEnvHubClient.load_client() - exist = hub_client.check_env(name=env_name, version=version) + + # Execute check_env and state_environment in parallel + tasks = [ + ("check_env", lambda: hub_client.check_env(name=env_name, version=version)), + ("state_env", lambda: hub_client.state_environment(env_name, version)), + ] + results = parallel_execute(tasks) + + # Process results + check_result = results.get("check_env") + state_result = results.get("state_env") + + if check_result and not check_result.success: + if check_result.error: + raise check_result.error + + exist = check_result.result if check_result and check_result.success else False + if exist: click.echo( f"aenv:{env_name}:{version} already exists in remote aenv_hub", err=False ) - state = hub_client.state_environment(env_name, version) - env_state = EnvStatus.parse_state(state) - if env_state.running() and not force: - click.echo("❌ Environment is being prepared, use --force to overwrite") - raise click.Abort() + if state_result and state_result.success: + env_state = EnvStatus.parse_state(state_result.result) + if env_state.running() and not force: + click.echo("❌ Environment is being prepared, use --force to overwrite") + raise click.Abort() storage = load_storage() infos = {"name": env_name, "version": version} diff --git a/aenv/src/cli/extends/storage/storage_manager.py b/aenv/src/cli/extends/storage/storage_manager.py index 274b8ac3..b17ac0ef 100644 --- a/aenv/src/cli/extends/storage/storage_manager.py +++ b/aenv/src/cli/extends/storage/storage_manager.py @@ -18,6 +18,7 @@ import tempfile import zipfile from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from pathlib import Path from typing import Literal, Optional @@ -27,6 +28,8 @@ from cli.client.aenv_hub_client import AEnvHubClient from cli.utils.archive_tool import TempArchive from cli.utils.cli_config import get_config_manager +from cli.utils.compression import pack_directory_parallel +from cli.utils.parallel import is_parallel_disabled @dataclass @@ -403,18 +406,33 @@ def upload(self, context: StorageContext) -> StorageStatus: work_dir = context.src_url infos = context.infos - if infos: - name = infos.get("name") - version = infos.get("version") + name = infos.get("name") + version = infos.get("version") + + hub_client = AEnvHubClient.load_client() + + if is_parallel_disabled(): + return self._upload_sequential( + work_dir, name, version, prefix, hub_client + ) - with TempArchive(str(work_dir)) as archive: + return self._upload_concurrent( + work_dir, name, version, prefix, hub_client + ) + + def _upload_sequential( + self, + work_dir: str, + name: str, + version: str, + prefix: str, + hub_client: AEnvHubClient, + ) -> StorageStatus: + """Sequential upload: compress first, then get URL and upload.""" + with TempArchive(str(work_dir), use_parallel=True) as archive: print(f"🔄 Archive: {archive}") - infos = context.infos - name = infos.get("name") - version = infos.get("version") + oss_url = hub_client.apply_sign_url(name, version) with open(archive, "rb") as tar: - hub_client = AEnvHubClient.load_client() - oss_url = hub_client.apply_sign_url(name, version) headers = {"x-oss-object-acl": "public-read-write"} response = requests.put(oss_url, data=tar, headers=headers) response.raise_for_status() @@ -422,6 +440,47 @@ def upload(self, context: StorageContext) -> StorageStatus: dest = f"{prefix}/{name}-{version}.tar" return StorageStatus(state=True, dest_url=dest) + def _upload_concurrent( + self, + work_dir: str, + name: str, + version: str, + prefix: str, + hub_client: AEnvHubClient, + ) -> StorageStatus: + """Concurrent upload: compress and fetch URL in parallel.""" + archive_path = None + oss_url = None + + try: + with ThreadPoolExecutor(max_workers=2) as executor: + compress_future = executor.submit( + pack_directory_parallel, + work_dir, + None, + ["__pycache__"], + True, + ) + url_future = executor.submit( + hub_client.apply_sign_url, name, version + ) + + archive_path = compress_future.result() + print(f"🔄 Archive: {archive_path}") + oss_url = url_future.result() + + with open(archive_path, "rb") as tar: + headers = {"x-oss-object-acl": "public-read-write"} + response = requests.put(oss_url, data=tar, headers=headers) + response.raise_for_status() + + dest = f"{prefix}/{name}-{version}.tar" + return StorageStatus(state=True, dest_url=dest) + + finally: + if archive_path and os.path.exists(archive_path): + os.unlink(archive_path) + def load_storage(): store_config = get_config_manager().get_storage_config() diff --git a/aenv/src/cli/tests/test_compression.py b/aenv/src/cli/tests/test_compression.py new file mode 100644 index 00000000..431ac14f --- /dev/null +++ b/aenv/src/cli/tests/test_compression.py @@ -0,0 +1,196 @@ +# Copyright 2025. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for parallel compression utilities.""" + +import os +import tarfile +import tempfile +from pathlib import Path +from unittest.mock import patch + +import pytest + +from cli.utils.compression import ( + get_pigz_path, + get_cpu_count, + pack_directory_parallel, + _pack_with_tarfile, +) + + +@pytest.fixture +def temp_source_dir(): + """Create a temporary directory with test files.""" + with tempfile.TemporaryDirectory() as tmpdir: + source_dir = Path(tmpdir) / "test_source" + source_dir.mkdir() + + (source_dir / "file1.txt").write_text("content1") + (source_dir / "file2.txt").write_text("content2") + + subdir = source_dir / "subdir" + subdir.mkdir() + (subdir / "file3.txt").write_text("content3") + + pycache = source_dir / "__pycache__" + pycache.mkdir() + (pycache / "cache.pyc").write_text("cached") + + yield str(source_dir) + + +class TestGetPigzPath: + def test_returns_path_when_available(self): + with patch("shutil.which") as mock_which: + mock_which.return_value = "/usr/bin/pigz" + result = get_pigz_path() + assert result == "/usr/bin/pigz" + + def test_returns_none_when_not_available(self): + with patch("shutil.which") as mock_which: + mock_which.return_value = None + result = get_pigz_path() + assert result is None + + +class TestGetCpuCount: + def test_returns_positive_number(self): + count = get_cpu_count() + assert count >= 1 + + def test_handles_exception(self): + with patch("os.cpu_count", side_effect=Exception("error")): + count = get_cpu_count() + assert count == 1 + + +class TestPackDirectoryParallel: + def test_creates_archive(self, temp_source_dir): + output_path = pack_directory_parallel(temp_source_dir, use_parallel=False) + + try: + assert os.path.exists(output_path) + assert output_path.endswith(".tar.gz") + + with tarfile.open(output_path, "r:gz") as tar: + names = tar.getnames() + assert any("file1.txt" in n for n in names) + assert any("file2.txt" in n for n in names) + assert any("file3.txt" in n for n in names) + finally: + if os.path.exists(output_path): + os.unlink(output_path) + + def test_excludes_patterns(self, temp_source_dir): + output_path = pack_directory_parallel( + temp_source_dir, + exclude_patterns=["__pycache__"], + use_parallel=False, + ) + + try: + with tarfile.open(output_path, "r:gz") as tar: + names = tar.getnames() + assert not any("__pycache__" in n for n in names) + assert not any("cache.pyc" in n for n in names) + finally: + if os.path.exists(output_path): + os.unlink(output_path) + + def test_custom_output_path(self, temp_source_dir): + with tempfile.TemporaryDirectory() as tmpdir: + custom_path = os.path.join(tmpdir, "custom_archive.tar.gz") + output_path = pack_directory_parallel( + temp_source_dir, + output_path=custom_path, + use_parallel=False, + ) + + assert output_path == custom_path + assert os.path.exists(custom_path) + + def test_raises_for_nonexistent_dir(self): + with pytest.raises(FileNotFoundError): + pack_directory_parallel("/nonexistent/path") + + def test_falls_back_to_tarfile_when_parallel_disabled(self, temp_source_dir): + with patch.dict(os.environ, {"AENV_DISABLE_PARALLEL": "1"}): + output_path = pack_directory_parallel(temp_source_dir) + + try: + assert os.path.exists(output_path) + finally: + if os.path.exists(output_path): + os.unlink(output_path) + + def test_uses_pigz_when_available(self, temp_source_dir): + with ( + patch("cli.utils.compression.get_pigz_path") as mock_pigz, + patch("cli.utils.compression._pack_with_pigz") as mock_pack_pigz, + patch.dict(os.environ, {}, clear=True), + ): + os.environ.pop("AENV_DISABLE_PARALLEL", None) + mock_pigz.return_value = "/usr/bin/pigz" + mock_pack_pigz.return_value = "/tmp/test.tar.gz" + + pack_directory_parallel(temp_source_dir) + + mock_pack_pigz.assert_called_once() + + def test_falls_back_when_pigz_fails(self, temp_source_dir): + with ( + patch("cli.utils.compression.get_pigz_path") as mock_pigz, + patch("cli.utils.compression._pack_with_pigz") as mock_pack_pigz, + patch.dict(os.environ, {}, clear=True), + ): + os.environ.pop("AENV_DISABLE_PARALLEL", None) + mock_pigz.return_value = "/usr/bin/pigz" + mock_pack_pigz.side_effect = Exception("pigz failed") + + output_path = pack_directory_parallel(temp_source_dir) + + try: + assert os.path.exists(output_path) + finally: + if os.path.exists(output_path): + os.unlink(output_path) + + +class TestPackWithTarfile: + def test_creates_valid_archive(self, temp_source_dir): + with tempfile.TemporaryDirectory() as tmpdir: + output_path = os.path.join(tmpdir, "test.tar.gz") + result = _pack_with_tarfile(temp_source_dir, output_path, None) + + assert result == output_path + assert os.path.exists(output_path) + + with tarfile.open(output_path, "r:gz") as tar: + assert len(tar.getnames()) > 0 + + def test_respects_exclude_patterns(self, temp_source_dir): + with tempfile.TemporaryDirectory() as tmpdir: + output_path = os.path.join(tmpdir, "test.tar.gz") + _pack_with_tarfile( + temp_source_dir, + output_path, + ["__pycache__", "file1"], + ) + + with tarfile.open(output_path, "r:gz") as tar: + names = tar.getnames() + assert not any("__pycache__" in n for n in names) + assert not any("file1" in n for n in names) + assert any("file2" in n for n in names) diff --git a/aenv/src/cli/tests/test_parallel.py b/aenv/src/cli/tests/test_parallel.py new file mode 100644 index 00000000..f49af93a --- /dev/null +++ b/aenv/src/cli/tests/test_parallel.py @@ -0,0 +1,185 @@ +# Copyright 2025. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for parallel execution utilities.""" + +import os +import time +from unittest.mock import patch + +import pytest + +from cli.utils.parallel import ( + TaskResult, + is_parallel_disabled, + parallel_execute, + _execute_sequential, +) + + +class TestIsParallelDisabled: + def test_disabled_when_env_is_1(self): + with patch.dict(os.environ, {"AENV_DISABLE_PARALLEL": "1"}): + assert is_parallel_disabled() is True + + def test_disabled_when_env_is_true(self): + with patch.dict(os.environ, {"AENV_DISABLE_PARALLEL": "true"}): + assert is_parallel_disabled() is True + + def test_disabled_when_env_is_yes(self): + with patch.dict(os.environ, {"AENV_DISABLE_PARALLEL": "yes"}): + assert is_parallel_disabled() is True + + def test_not_disabled_when_env_is_0(self): + with patch.dict(os.environ, {"AENV_DISABLE_PARALLEL": "0"}): + assert is_parallel_disabled() is False + + def test_not_disabled_when_env_not_set(self): + with patch.dict(os.environ, {}, clear=True): + os.environ.pop("AENV_DISABLE_PARALLEL", None) + assert is_parallel_disabled() is False + + +class TestParallelExecute: + def test_empty_tasks(self): + results = parallel_execute([]) + assert results == {} + + def test_single_task_success(self): + tasks = [("task1", lambda: "result1")] + results = parallel_execute(tasks) + + assert "task1" in results + assert results["task1"].success is True + assert results["task1"].result == "result1" + assert results["task1"].error is None + + def test_multiple_tasks_success(self): + tasks = [ + ("task1", lambda: "result1"), + ("task2", lambda: "result2"), + ("task3", lambda: "result3"), + ] + results = parallel_execute(tasks) + + assert len(results) == 3 + for i in range(1, 4): + name = f"task{i}" + assert results[name].success is True + assert results[name].result == f"result{i}" + + def test_task_with_exception(self): + def failing_task(): + raise ValueError("test error") + + tasks = [ + ("success_task", lambda: "ok"), + ("failing_task", failing_task), + ] + results = parallel_execute(tasks) + + assert results["success_task"].success is True + assert results["success_task"].result == "ok" + + assert results["failing_task"].success is False + assert isinstance(results["failing_task"].error, ValueError) + assert str(results["failing_task"].error) == "test error" + + def test_tasks_run_concurrently(self): + def slow_task(delay, result): + time.sleep(delay) + return result + + tasks = [ + ("task1", lambda: slow_task(0.1, "result1")), + ("task2", lambda: slow_task(0.1, "result2")), + ] + + start_time = time.time() + results = parallel_execute(tasks) + elapsed = time.time() - start_time + + assert results["task1"].success is True + assert results["task2"].success is True + assert elapsed < 0.2 + + def test_respects_disable_flag(self): + call_order = [] + + def task1(): + call_order.append("task1") + return "result1" + + def task2(): + call_order.append("task2") + return "result2" + + tasks = [("task1", task1), ("task2", task2)] + + with patch.dict(os.environ, {"AENV_DISABLE_PARALLEL": "1"}): + results = parallel_execute(tasks) + + assert results["task1"].success is True + assert results["task2"].success is True + assert call_order == ["task1", "task2"] + + +class TestExecuteSequential: + def test_executes_in_order(self): + call_order = [] + + def task1(): + call_order.append("task1") + return "result1" + + def task2(): + call_order.append("task2") + return "result2" + + tasks = [("task1", task1), ("task2", task2)] + results = _execute_sequential(tasks) + + assert call_order == ["task1", "task2"] + assert results["task1"].success is True + assert results["task2"].success is True + + def test_continues_after_exception(self): + def failing_task(): + raise RuntimeError("error") + + tasks = [ + ("fail", failing_task), + ("success", lambda: "ok"), + ] + results = _execute_sequential(tasks) + + assert results["fail"].success is False + assert results["success"].success is True + assert results["success"].result == "ok" + + +class TestTaskResult: + def test_task_result_creation(self): + result = TaskResult(name="test", success=True, result="data") + assert result.name == "test" + assert result.success is True + assert result.result == "data" + assert result.error is None + + def test_task_result_with_error(self): + error = ValueError("test error") + result = TaskResult(name="test", success=False, error=error) + assert result.success is False + assert result.error is error + assert result.result is None diff --git a/aenv/src/cli/tests/test_push_optimized.py b/aenv/src/cli/tests/test_push_optimized.py new file mode 100644 index 00000000..a8e997f1 --- /dev/null +++ b/aenv/src/cli/tests/test_push_optimized.py @@ -0,0 +1,253 @@ +# Copyright 2025. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for optimized push command.""" + +import json +import os +import tempfile +import time +from unittest.mock import MagicMock, patch + +import pytest +from click.testing import CliRunner + +from cli.cmds.push import push +from cli.extends.storage.storage_manager import ( + AEnvHubStorage, + StorageContext, +) +from cli.utils.parallel import parallel_execute + + +@pytest.fixture +def mock_project_dir(): + """Create a mock aenv project directory.""" + with tempfile.TemporaryDirectory() as tmpdir: + config = { + "name": "test-env", + "version": "1.0.0", + "tags": ["test"], + } + config_path = os.path.join(tmpdir, "config.json") + with open(config_path, "w") as f: + json.dump(config, f) + + (Path(tmpdir) / "src").mkdir() + (Path(tmpdir) / "src" / "main.py").write_text("print('hello')") + + yield tmpdir + + +from pathlib import Path + + +class TestParallelHttpRequests: + def test_check_env_and_state_run_concurrently(self): + call_times = {} + + def mock_check_env(): + call_times["check_env_start"] = time.time() + time.sleep(0.1) + call_times["check_env_end"] = time.time() + return True + + def mock_state_env(): + call_times["state_env_start"] = time.time() + time.sleep(0.1) + call_times["state_env_end"] = time.time() + return "pending" + + tasks = [ + ("check_env", mock_check_env), + ("state_env", mock_state_env), + ] + + start = time.time() + results = parallel_execute(tasks) + elapsed = time.time() - start + + assert results["check_env"].success is True + assert results["state_env"].success is True + assert elapsed < 0.2 + + def test_handles_check_env_failure(self): + def mock_check_env(): + raise ConnectionError("Network error") + + def mock_state_env(): + return "pending" + + tasks = [ + ("check_env", mock_check_env), + ("state_env", mock_state_env), + ] + + results = parallel_execute(tasks) + + assert results["check_env"].success is False + assert isinstance(results["check_env"].error, ConnectionError) + assert results["state_env"].success is True + + +class TestAEnvHubStorageOptimized: + def test_concurrent_upload(self, mock_project_dir): + storage = AEnvHubStorage() + + mock_response = MagicMock() + mock_response.raise_for_status = MagicMock() + + with ( + patch("cli.extends.storage.storage_manager.get_config_manager") as mock_config, + patch("cli.extends.storage.storage_manager.AEnvHubClient") as mock_client_cls, + patch("requests.put", return_value=mock_response), + patch.dict(os.environ, {}, clear=True), + ): + os.environ.pop("AENV_DISABLE_PARALLEL", None) + + mock_config.return_value.get_storage_config.return_value = { + "custom": {"prefix": "/test"} + } + + mock_client = MagicMock() + mock_client.apply_sign_url.return_value = "https://oss.example.com/signed" + mock_client_cls.load_client.return_value = mock_client + + ctx = StorageContext( + src_url=mock_project_dir, + infos={"name": "test-env", "version": "1.0.0"}, + ) + + result = storage.upload(ctx) + + assert result.state is True + assert "test-env" in result.dest_url + mock_client.apply_sign_url.assert_called_once_with("test-env", "1.0.0") + + def test_sequential_upload_when_disabled(self, mock_project_dir): + storage = AEnvHubStorage() + + mock_response = MagicMock() + mock_response.raise_for_status = MagicMock() + + with ( + patch("cli.extends.storage.storage_manager.get_config_manager") as mock_config, + patch("cli.extends.storage.storage_manager.AEnvHubClient") as mock_client_cls, + patch("requests.put", return_value=mock_response), + patch.dict(os.environ, {"AENV_DISABLE_PARALLEL": "1"}), + ): + mock_config.return_value.get_storage_config.return_value = { + "custom": {"prefix": "/test"} + } + + mock_client = MagicMock() + mock_client.apply_sign_url.return_value = "https://oss.example.com/signed" + mock_client_cls.load_client.return_value = mock_client + + ctx = StorageContext( + src_url=mock_project_dir, + infos={"name": "test-env", "version": "1.0.0"}, + ) + + result = storage.upload(ctx) + + assert result.state is True + + +class TestPushCommandIntegration: + def test_push_with_parallel_checks(self, mock_project_dir): + runner = CliRunner() + + with ( + patch("cli.cmds.push.AEnvHubClient") as mock_client_cls, + patch("cli.cmds.push.load_storage") as mock_load_storage, + ): + mock_client = MagicMock() + mock_client.check_env.return_value = False + mock_client.state_environment.return_value = "completed" + mock_client.create_environment.return_value = {} + mock_client_cls.load_client.return_value = mock_client + + mock_storage = MagicMock() + mock_storage.upload.return_value = MagicMock( + state=True, dest_url="/test/path" + ) + mock_load_storage.return_value = mock_storage + + result = runner.invoke(push, ["--work-dir", mock_project_dir]) + + assert result.exit_code == 0 + assert "Push successfully" in result.output + + def test_push_existing_env_not_running(self, mock_project_dir): + runner = CliRunner() + + with ( + patch("cli.cmds.push.AEnvHubClient") as mock_client_cls, + patch("cli.cmds.push.load_storage") as mock_load_storage, + ): + mock_client = MagicMock() + mock_client.check_env.return_value = True + mock_client.state_environment.return_value = "completed" + mock_client.update_environment.return_value = {} + mock_client_cls.load_client.return_value = mock_client + + mock_storage = MagicMock() + mock_storage.upload.return_value = MagicMock( + state=True, dest_url="/test/path" + ) + mock_load_storage.return_value = mock_storage + + result = runner.invoke(push, ["--work-dir", mock_project_dir]) + + assert result.exit_code == 0 + mock_client.update_environment.assert_called_once() + + def test_push_existing_env_running_without_force(self, mock_project_dir): + runner = CliRunner() + + with patch("cli.cmds.push.AEnvHubClient") as mock_client_cls: + mock_client = MagicMock() + mock_client.check_env.return_value = True + mock_client.state_environment.return_value = "pending" + mock_client_cls.load_client.return_value = mock_client + + result = runner.invoke(push, ["--work-dir", mock_project_dir]) + + assert result.exit_code == 1 + assert "being prepared" in result.output + + def test_push_existing_env_running_with_force(self, mock_project_dir): + runner = CliRunner() + + with ( + patch("cli.cmds.push.AEnvHubClient") as mock_client_cls, + patch("cli.cmds.push.load_storage") as mock_load_storage, + ): + mock_client = MagicMock() + mock_client.check_env.return_value = True + mock_client.state_environment.return_value = "pending" + mock_client.update_environment.return_value = {} + mock_client_cls.load_client.return_value = mock_client + + mock_storage = MagicMock() + mock_storage.upload.return_value = MagicMock( + state=True, dest_url="/test/path" + ) + mock_load_storage.return_value = mock_storage + + result = runner.invoke(push, ["--work-dir", mock_project_dir, "--force"]) + + assert result.exit_code == 0 + mock_client.update_environment.assert_called_once() diff --git a/aenv/src/cli/utils/archive_tool.py b/aenv/src/cli/utils/archive_tool.py index 1b47f64f..cf535dc7 100644 --- a/aenv/src/cli/utils/archive_tool.py +++ b/aenv/src/cli/utils/archive_tool.py @@ -239,21 +239,38 @@ def quick_cleanup(*paths: str) -> bool: # Convenient class supporting with operations class TempArchive: - """Temporary archive context manager""" + """Temporary archive context manager with optional parallel compression""" - def __init__(self, source_dir: str, **pack_kwargs): + def __init__(self, source_dir: str, use_parallel: bool = True, **pack_kwargs): + """ + Initialize TempArchive. + + Args: + source_dir: Source directory to archive + use_parallel: Whether to use parallel compression (pigz) when available + **pack_kwargs: Additional arguments passed to pack function + """ self.source_dir = source_dir + self.use_parallel = use_parallel self.pack_kwargs = pack_kwargs self.archive_path = None - if pack_kwargs is None: - pack_kwargs = {"exclude_patterns": ["__pycache__"]} - self.pack_kwargs = pack_kwargs + if not pack_kwargs: + self.pack_kwargs = {"exclude_patterns": ["__pycache__"]} def __enter__(self): - self.archive_path = ArchiveTool.pack_directory( - self.source_dir, **self.pack_kwargs - ) + if self.use_parallel: + from cli.utils.compression import pack_directory_parallel + + self.archive_path = pack_directory_parallel( + self.source_dir, + exclude_patterns=self.pack_kwargs.get("exclude_patterns"), + use_parallel=True, + ) + else: + self.archive_path = ArchiveTool.pack_directory( + self.source_dir, **self.pack_kwargs + ) return self.archive_path def __exit__(self, exc_type, exc_val, exc_tb): diff --git a/aenv/src/cli/utils/compression.py b/aenv/src/cli/utils/compression.py new file mode 100644 index 00000000..e596581c --- /dev/null +++ b/aenv/src/cli/utils/compression.py @@ -0,0 +1,197 @@ +# Copyright 2025. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Parallel compression utilities for aenv CLI + +Provides multi-threaded compression using pigz when available, +with automatic fallback to standard gzip. +""" + +import logging +import os +import shutil +import subprocess +import tarfile +import tempfile +import time +from pathlib import Path +from typing import List, Optional + +from cli.utils.parallel import is_parallel_disabled + +logger = logging.getLogger(__name__) + + +def get_pigz_path() -> Optional[str]: + """ + Check if pigz is available in the system. + + Returns: + Path to pigz executable or None if not available + """ + return shutil.which("pigz") + + +def get_cpu_count() -> int: + """Get the number of CPUs available for compression.""" + try: + return os.cpu_count() or 1 + except Exception: + return 1 + + +def pack_directory_parallel( + source_dir: str, + output_path: Optional[str] = None, + exclude_patterns: Optional[List[str]] = None, + use_parallel: bool = True, + compression_level: int = 6, +) -> str: + """ + Package directory as tar.gz file using parallel compression when available. + + Uses pigz for multi-threaded compression if available, otherwise falls back + to standard tarfile compression. + + Args: + source_dir: Source directory path + output_path: Output file path, generates temporary file if None + exclude_patterns: List of file patterns to exclude + use_parallel: Whether to use parallel compression (default: True) + compression_level: Compression level 1-9 (default: 6) + + Returns: + Path to the compressed archive file + """ + source_path = Path(source_dir) + if not source_path.exists(): + raise FileNotFoundError(f"Directory does not exist: {source_dir}") + + if output_path is None: + timestamp = int(time.time()) + filename = f"{source_path.name}_{timestamp}.tar.gz" + output_path = str(Path(tempfile.gettempdir()) / filename) + + pigz_path = get_pigz_path() + use_pigz = ( + use_parallel + and not is_parallel_disabled() + and pigz_path is not None + ) + + if use_pigz: + try: + return _pack_with_pigz( + source_dir=source_dir, + output_path=output_path, + exclude_patterns=exclude_patterns, + pigz_path=pigz_path, + compression_level=compression_level, + ) + except Exception as e: + logger.warning(f"Pigz compression failed, falling back to tarfile: {e}") + + return _pack_with_tarfile( + source_dir=source_dir, + output_path=output_path, + exclude_patterns=exclude_patterns, + ) + + +def _pack_with_pigz( + source_dir: str, + output_path: str, + exclude_patterns: Optional[List[str]], + pigz_path: str, + compression_level: int = 6, +) -> str: + """ + Create tar.gz archive using pigz for parallel compression. + + This creates an uncompressed tar first, then pipes it through pigz. + """ + source_path = Path(source_dir) + exclude_set = set(exclude_patterns or []) + cpu_count = get_cpu_count() + + with tempfile.NamedTemporaryFile(suffix=".tar", delete=False) as tmp_tar: + tmp_tar_path = tmp_tar.name + + try: + with tarfile.open(tmp_tar_path, "w") as tar: + for root, dirs, files in os.walk(source_dir): + dirs[:] = [ + d for d in dirs + if not any(exclude in d for exclude in exclude_set) + ] + files = [ + f for f in files + if not any(exclude in f for exclude in exclude_set) + ] + for file in files: + file_path = Path(root) / file + arc_path = file_path.relative_to(source_path.parent) + tar.add(file_path, arcname=arc_path) + + with open(tmp_tar_path, "rb") as tar_input: + with open(output_path, "wb") as gz_output: + process = subprocess.run( + [pigz_path, f"-{compression_level}", "-p", str(cpu_count)], + stdin=tar_input, + stdout=gz_output, + stderr=subprocess.PIPE, + check=True, + ) + + logger.info( + f"Parallel compression completed with pigz ({cpu_count} threads): " + f"{output_path} ({os.path.getsize(output_path)} bytes)" + ) + return output_path + + finally: + if os.path.exists(tmp_tar_path): + os.unlink(tmp_tar_path) + + +def _pack_with_tarfile( + source_dir: str, + output_path: str, + exclude_patterns: Optional[List[str]], +) -> str: + """Create tar.gz archive using standard tarfile module.""" + source_path = Path(source_dir) + exclude_set = set(exclude_patterns or []) + + with tarfile.open(output_path, "w:gz") as tar: + for root, dirs, files in os.walk(source_dir): + dirs[:] = [ + d for d in dirs + if not any(exclude in d for exclude in exclude_set) + ] + files = [ + f for f in files + if not any(exclude in f for exclude in exclude_set) + ] + for file in files: + file_path = Path(root) / file + arc_path = file_path.relative_to(source_path.parent) + tar.add(file_path, arcname=arc_path) + + logger.info( + f"Standard compression completed: {output_path} " + f"({os.path.getsize(output_path)} bytes)" + ) + return output_path diff --git a/aenv/src/cli/utils/parallel.py b/aenv/src/cli/utils/parallel.py new file mode 100644 index 00000000..26dda122 --- /dev/null +++ b/aenv/src/cli/utils/parallel.py @@ -0,0 +1,113 @@ +# Copyright 2025. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Parallel execution utilities for aenv CLI + +Provides concurrent task execution with graceful error handling and fallback. +""" + +import os +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple + + +@dataclass +class TaskResult: + """Result of a parallel task execution.""" + + name: str + success: bool + result: Any = None + error: Optional[Exception] = None + + +def is_parallel_disabled() -> bool: + """Check if parallel execution is disabled via environment variable.""" + return os.environ.get("AENV_DISABLE_PARALLEL", "").lower() in ("1", "true", "yes") + + +def parallel_execute( + tasks: List[Tuple[str, Callable[[], Any]]], + timeout: Optional[float] = None, + max_workers: Optional[int] = None, +) -> Dict[str, TaskResult]: + """ + Execute multiple tasks in parallel using ThreadPoolExecutor. + + Args: + tasks: List of (name, callable) tuples to execute + timeout: Optional timeout for each task in seconds + max_workers: Maximum number of worker threads (default: min(len(tasks), 4)) + + Returns: + Dictionary mapping task names to TaskResult objects + + Example: + tasks = [ + ("check_env", lambda: client.check_env(name, version)), + ("state_env", lambda: client.state_environment(name, version)), + ] + results = parallel_execute(tasks) + check_result = results["check_env"].result + """ + if not tasks: + return {} + + if is_parallel_disabled(): + return _execute_sequential(tasks) + + results: Dict[str, TaskResult] = {} + max_workers = max_workers or min(len(tasks), 4) + + try: + with ThreadPoolExecutor(max_workers=max_workers) as executor: + future_to_name = { + executor.submit(task_fn): name for name, task_fn in tasks + } + + for future in as_completed(future_to_name, timeout=timeout): + name = future_to_name[future] + try: + result = future.result() + results[name] = TaskResult( + name=name, success=True, result=result + ) + except Exception as e: + results[name] = TaskResult( + name=name, success=False, error=e + ) + + except Exception: + # Fallback to sequential execution if parallel fails + return _execute_sequential(tasks) + + return results + + +def _execute_sequential( + tasks: List[Tuple[str, Callable[[], Any]]] +) -> Dict[str, TaskResult]: + """Execute tasks sequentially as fallback.""" + results: Dict[str, TaskResult] = {} + + for name, task_fn in tasks: + try: + result = task_fn() + results[name] = TaskResult(name=name, success=True, result=result) + except Exception as e: + results[name] = TaskResult(name=name, success=False, error=e) + + return results