diff --git a/auto_dev/contracts/contract_scafolder.py b/auto_dev/contracts/contract_scafolder.py index 0603961e..925dd1b1 100644 --- a/auto_dev/contracts/contract_scafolder.py +++ b/auto_dev/contracts/contract_scafolder.py @@ -3,11 +3,12 @@ import os import json import shutil +from pathlib import Path from dataclasses import dataclass from aea.configurations.base import PublicId, PackageId, PackageType -from auto_dev.utils import isolated_filesystem +from auto_dev.utils import rollback from auto_dev.constants import DEFAULT_ENCODING, DEFAULT_IPFS_HASH from auto_dev.cli_executor import CommandExecutor from auto_dev.contracts.contract import Contract @@ -61,7 +62,7 @@ def generate_openaea_contract(self, contract: Contract): msg = "Failed to initialise agent lib." raise ValueError(msg) - with isolated_filesystem(): + with rollback.on_exception(Path.cwd()): if not (output := CommandExecutor("aea create myagent".split(" "))).execute(verbose=verbose): msg = f"Failed to create agent.\n{output}" raise ValueError(msg) diff --git a/auto_dev/utils.py b/auto_dev/utils/__init__.py similarity index 94% rename from auto_dev/utils.py rename to auto_dev/utils/__init__.py index a4181989..d98c4ea9 100644 --- a/auto_dev/utils.py +++ b/auto_dev/utils/__init__.py @@ -188,29 +188,6 @@ def filter_protobuf_files(file_path: str) -> bool: return [f for f in python_files if not filter_protobuf_files(f)] -@contextmanager -def isolated_filesystem(copy_cwd: bool = False): - """Context manager to create an isolated file system. - And to navigate to it and then to clean it up. - """ - original_path = Path.cwd() - with tempfile.TemporaryDirectory(dir=tempfile.gettempdir()) as temp_dir: - temp_dir_path = Path(temp_dir).resolve() - os.chdir(temp_dir_path) - if copy_cwd: - # we copy the content of the original directory into the temporary one - for file_name in os.listdir(original_path): - if file_name == "__pycache__": - continue - file_path = Path(original_path, file_name) - if file_path.is_file(): - shutil.copy(file_path, temp_dir_path) - elif file_path.is_dir(): - shutil.copytree(file_path, Path(temp_dir, file_name)) - yield str(Path(temp_dir_path)) - os.chdir(original_path) - - @contextmanager def change_dir(target_path): """Temporarily change the working directory.""" diff --git a/auto_dev/utils/rollback.py b/auto_dev/utils/rollback.py new file mode 100644 index 00000000..3db25f90 --- /dev/null +++ b/auto_dev/utils/rollback.py @@ -0,0 +1,65 @@ +"""Filesystem utilities for temporary backups and rollback mechanisms.""" + +import shutil +import signal +import tempfile +from pathlib import Path +from contextlib import chdir, contextmanager + +from auto_dev.utils import signals + + +# https://www.youtube.com/watch?v=0GRLhpMao3I +# async-signal safe is the strongest concept of reentrancy. +# async-signal safe implies thread safe. + +# signal.SIGKILL cannot be intercepted +SIGNALS_TO_BLOCK = (signal.SIGINT, signal.SIGTERM) + + +def _restore_from_backup(directory: Path, backup: Path): + for item in directory.rglob("*"): + backup_item = backup / item.relative_to(directory) + if item.is_file() or item.is_symlink(): + item.unlink() + elif item.is_dir() and not backup_item.exists(): + shutil.rmtree(item) + + for item in backup.rglob("*"): + directory_item = directory / item.relative_to(backup) + if item.is_symlink(): + directory_item.symlink_to(item.readlink()) + elif item.is_dir(): + directory_item.mkdir(parents=True, exist_ok=True) + elif item.is_file(): + shutil.copy2(item, directory_item) + + +@contextmanager +def on_exit(directory: Path): + """Creates a temporary backup of the directory and restores it upon exit.""" + backup = Path(tempfile.mkdtemp(prefix="backup_")) / directory.name + shutil.copytree(directory, backup, symlinks=True) + with chdir(Path.cwd()): + try: + yield + finally: + with signals.mask(*SIGNALS_TO_BLOCK): + _restore_from_backup(directory, backup) + shutil.rmtree(backup) + + +@contextmanager +def on_exception(directory: Path): + """Creates a temporary backup of the directory and restores it only if an exception occurs.""" + backup = Path(tempfile.mkdtemp(prefix="backup_")) / directory.name + shutil.copytree(directory, backup, symlinks=True) + with chdir(Path.cwd()): + try: + yield + except BaseException: + with signals.mask(*SIGNALS_TO_BLOCK): + _restore_from_backup(directory, backup) + raise + finally: + shutil.rmtree(backup) diff --git a/auto_dev/utils/signals.py b/auto_dev/utils/signals.py new file mode 100644 index 00000000..ef08e243 --- /dev/null +++ b/auto_dev/utils/signals.py @@ -0,0 +1,46 @@ +"""Signal management utilities.""" + +import signal +from types import FrameType +from contextlib import contextmanager +from collections.abc import Callable + + +CallableSignalHandler = Callable[[int, FrameType], None] +SignalHandler = signal.Handlers | CallableSignalHandler + + +@contextmanager +def block(*signals: int): + """Context manager to globally block specified signals by replacing their handlers with signal.SIG_IGN.""" + original_handlers = {sig: signal.getsignal(sig) for sig in signals} + try: + for sig in signals: + signal.signal(sig, signal.SIG_IGN) + yield + finally: + for sig, handler in original_handlers.items(): + signal.signal(sig, handler) + + +@contextmanager +def mask(*signals: int): + """Context manager to temporarily block specified signals for the current thread by modifying its signal mask.""" + original_mask = signal.pthread_sigmask(signal.SIG_BLOCK, signals) + try: + yield + finally: + signal.pthread_sigmask(signal.SIG_SETMASK, original_mask) + + +@contextmanager +def replace_handler(signal_handler: SignalHandler, *signals: int): + """Context manager to replace the signal handlers for specified signals with a custom handler.""" + original_handlers = {sig: signal.getsignal(sig) for sig in signals} + try: + for sig in signals: + signal.signal(sig, signal_handler) + yield + finally: + for sig, handler in original_handlers.items(): + signal.signal(sig, handler) diff --git a/tests/conftest.py b/tests/conftest.py index fbde625a..d8d1b6ad 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,11 +2,13 @@ # pylint: disable=W0135 import os +import tempfile from pathlib import Path +from contextlib import chdir import pytest -from auto_dev.utils import isolated_filesystem +from auto_dev.utils import rollback from auto_dev.constants import ( DEFAULT_PUBLIC_ID, ) @@ -68,15 +70,17 @@ def openapi_test_case(request): @pytest.fixture def test_filesystem(): """Fixture for invoking command-line interfaces.""" - with isolated_filesystem(copy_cwd=True) as directory: - yield directory + with rollback.on_exit(Path.cwd()): + yield Path.cwd() @pytest.fixture def test_clean_filesystem(): """Fixture for invoking command-line interfaces.""" - with isolated_filesystem() as directory: - yield directory + with tempfile.TemporaryDirectory(dir=tempfile.gettempdir()) as temp_dir: + temp_dir_path = Path(temp_dir).resolve() + with chdir(temp_dir_path): + yield temp_dir_path @pytest.fixture diff --git a/tests/test_cli.py b/tests/test_cli.py index 3dad69e1..c48ae003 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -9,7 +9,7 @@ def test_lint_fails(cli_runner, test_filesystem): """Test the lint command fails with no packages.""" - assert str(Path.cwd()) == test_filesystem + assert Path.cwd() == test_filesystem cmd = ["adev", "-n", "0", "lint", "-p", "packages/fake"] runner = cli_runner(cmd) runner.execute() @@ -18,7 +18,7 @@ def test_lint_fails(cli_runner, test_filesystem): def test_lints_self(cli_runner, test_filesystem): """Test the lint command works with the current package.""" - assert str(Path.cwd()) == test_filesystem + assert Path.cwd() == test_filesystem cmd = ["adev", "-v", "-n", "0", "lint", "-p", "."] runner = cli_runner(cmd) result = runner.execute() @@ -28,7 +28,7 @@ def test_lints_self(cli_runner, test_filesystem): def test_formats_self(cli_runner, test_filesystem): """Test the format command works with the current package.""" - assert str(Path.cwd()) == test_filesystem + assert Path.cwd() == test_filesystem cmd = ["adev", "-n", "0", "-v", "fmt", "-p", "."] runner = cli_runner(cmd) result = runner.execute() @@ -38,7 +38,7 @@ def test_formats_self(cli_runner, test_filesystem): def test_create_invalid_name(test_filesystem): """Test the create command fails with invalid agent name.""" - assert str(Path.cwd()) == test_filesystem + assert Path.cwd() == test_filesystem task = Task(command="adev create NEW_AGENT -t eightballer/base --no-clean-up") task.work() assert all([task.is_done, task.is_failed]), task.client.output @@ -51,7 +51,7 @@ def test_create_invalid_name(test_filesystem): def test_create_valid_names(test_packages_filesystem): """Test the create command succeeds with valid agent names.""" - assert str(Path.cwd()) == test_packages_filesystem + assert Path.cwd() == test_packages_filesystem valid_names = ["my_agent", "_test_agent", "agent123", "valid_agent_name_123"] for name in valid_names: @@ -64,7 +64,7 @@ def test_create_valid_names(test_packages_filesystem): def test_create_with_publish_no_packages(test_filesystem): """Test the create command succeeds when there is no local packages directory.""" - assert str(Path.cwd()) == test_filesystem + assert Path.cwd() == test_filesystem task = Task( command=f"adev create {DEFAULT_PUBLIC_ID!s} -t eightballer/base --no-clean-up", ) diff --git a/tests/test_contracts.py b/tests/test_contracts.py index a1023031..a7155f84 100644 --- a/tests/test_contracts.py +++ b/tests/test_contracts.py @@ -108,7 +108,7 @@ def test_scaffolder_generate(scaffolder): @responses.activate def test_scaffolder_generate_openaea_contract(scaffolder, test_packages_filesystem): """Test the scaffolder.""" - del test_packages_filesystem + assert test_packages_filesystem responses.add( responses.GET, f"{BLOCK_EXPLORER_URL}/{KNOWN_ADDRESS}?network={NETWORK.value}", diff --git a/tests/test_eject.py b/tests/test_eject.py index b3847146..f41d0e80 100644 --- a/tests/test_eject.py +++ b/tests/test_eject.py @@ -8,7 +8,7 @@ def test_eject_metrics_skill_workflow(test_filesystem): """Test the complete workflow of creating an agent and ejecting the metrics skill.""" - assert str(Path.cwd()) == test_filesystem + assert Path.cwd() == test_filesystem # 1. Create agent with eightballer/base template wf_manager: WorkflowManager = WorkflowManager().from_yaml( file_path=Path(AUTO_DEV_FOLDER) / "data" / "workflows" / "eject_component.yaml" @@ -20,7 +20,7 @@ def test_eject_metrics_skill_workflow(test_filesystem): assert all([task.is_done, not task.is_failed]), f"Task failed: {task.client.output}" task_2 = wf_manager.workflows[0].tasks[1].work() assert all([task_2.is_done, not task_2.is_failed]), f"Task failed: {task_2.client.output}" - assert str(Path.cwd()) == test_filesystem + assert Path.cwd() == test_filesystem ejected_skill_path = Path(DEFAULT_AGENT_NAME) / "skills" / "simple_fsm" assert ejected_skill_path.exists(), "Ejected skill directory not found" # Verify the original vendor skill was removed @@ -30,7 +30,7 @@ def test_eject_metrics_skill_workflow(test_filesystem): def test_eject_metrics_skill_skip_deps(test_filesystem): """Test ejecting the metrics skill with skip-dependencies flag.""" - assert str(Path.cwd()) == test_filesystem + assert Path.cwd() == test_filesystem # 1. Create agent with eightballer/base template wf_manager: WorkflowManager = WorkflowManager().from_yaml( file_path=Path(AUTO_DEV_FOLDER) / "data" / "workflows" / "eject_component.yaml" @@ -62,7 +62,7 @@ def test_eject_metrics_skill_skip_deps(test_filesystem): def test_eject_http_protocol(test_filesystem): """Test ejecting the metrics skill with skip-dependencies flag.""" - assert str(Path.cwd()) == test_filesystem + assert Path.cwd() == test_filesystem # 1. Create agent with eightballer/base template wf_manager: WorkflowManager = WorkflowManager().from_yaml( file_path=Path(AUTO_DEV_FOLDER) / "data" / "workflows" / "eject_component.yaml" diff --git a/tests/test_utils.py b/tests/test_utils.py index 194d5e15..6fbc869f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -125,13 +125,13 @@ def autonomy_fs(test_packages_filesystem): def test_get_paths_changed_only(test_packages_filesystem): """Test get_paths.""" - assert test_packages_filesystem == str(Path.cwd()) + assert Path.cwd() == test_packages_filesystem assert len(get_paths(changed_only=True)) == 0 def test_get_paths(test_packages_filesystem): """Test get_paths.""" - assert test_packages_filesystem == str(Path.cwd()) + assert Path.cwd() == test_packages_filesystem assert len(get_paths()) == 0