Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions aenv/src/cli/cmds/push.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from cli.client.aenv_hub_client import AEnvHubClient, AEnvHubError, EnvStatus
from cli.extends.storage.storage_manager import StorageContext, load_storage
from cli.utils.parallel import parallel_execute


@click.command()
Expand Down Expand Up @@ -59,16 +60,33 @@ def push(work_dir, dry_run, force, version):
err=False,
)
hub_client = AEnvHubClient.load_client()
exist = hub_client.check_env(name=env_name, version=version)

# Execute check_env and state_environment in parallel
tasks = [
("check_env", lambda: hub_client.check_env(name=env_name, version=version)),
("state_env", lambda: hub_client.state_environment(env_name, version)),
]
results = parallel_execute(tasks)

# Process results
check_result = results.get("check_env")
state_result = results.get("state_env")

if check_result and not check_result.success:
if check_result.error:
raise check_result.error

exist = check_result.result if check_result and check_result.success else False

if exist:
click.echo(
f"aenv:{env_name}:{version} already exists in remote aenv_hub", err=False
)
state = hub_client.state_environment(env_name, version)
env_state = EnvStatus.parse_state(state)
if env_state.running() and not force:
click.echo("❌ Environment is being prepared, use --force to overwrite")
raise click.Abort()
if state_result and state_result.success:
env_state = EnvStatus.parse_state(state_result.result)
if env_state.running() and not force:
click.echo("❌ Environment is being prepared, use --force to overwrite")
raise click.Abort()
Comment on lines +85 to +89
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If the state_env task fails, the check for a running environment is skipped. This could lead to unintentionally overwriting a running environment without the --force flag because the failure is handled silently. For safety, the operation should be aborted if the state of an existing environment cannot be determined.

Suggested change
if state_result and state_result.success:
env_state = EnvStatus.parse_state(state_result.result)
if env_state.running() and not force:
click.echo("❌ Environment is being prepared, use --force to overwrite")
raise click.Abort()
if state_result and state_result.success:
env_state = EnvStatus.parse_state(state_result.result)
if env_state.running() and not force:
click.echo("❌ Environment is being prepared, use --force to overwrite")
raise click.Abort()
else:
click.echo("❌ Could not determine the state of the existing environment. Aborting to prevent accidental overwrite.", err=True)
if state_result and state_result.error:
raise state_result.error
raise click.Abort()


storage = load_storage()
infos = {"name": env_name, "version": version}
Expand Down
77 changes: 68 additions & 9 deletions aenv/src/cli/extends/storage/storage_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import tempfile
import zipfile
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from pathlib import Path
from typing import Literal, Optional
Expand All @@ -27,6 +28,8 @@
from cli.client.aenv_hub_client import AEnvHubClient
from cli.utils.archive_tool import TempArchive
from cli.utils.cli_config import get_config_manager
from cli.utils.compression import pack_directory_parallel
from cli.utils.parallel import is_parallel_disabled


@dataclass
Expand Down Expand Up @@ -403,25 +406,81 @@ def upload(self, context: StorageContext) -> StorageStatus:

work_dir = context.src_url
infos = context.infos
if infos:
name = infos.get("name")
version = infos.get("version")
name = infos.get("name")
version = infos.get("version")

hub_client = AEnvHubClient.load_client()

if is_parallel_disabled():
return self._upload_sequential(
work_dir, name, version, prefix, hub_client
)

with TempArchive(str(work_dir)) as archive:
return self._upload_concurrent(
work_dir, name, version, prefix, hub_client
)

def _upload_sequential(
self,
work_dir: str,
name: str,
version: str,
prefix: str,
hub_client: AEnvHubClient,
) -> StorageStatus:
"""Sequential upload: compress first, then get URL and upload."""
with TempArchive(str(work_dir), use_parallel=True) as archive:
print(f"🔄 Archive: {archive}")
infos = context.infos
name = infos.get("name")
version = infos.get("version")
oss_url = hub_client.apply_sign_url(name, version)
with open(archive, "rb") as tar:
hub_client = AEnvHubClient.load_client()
oss_url = hub_client.apply_sign_url(name, version)
headers = {"x-oss-object-acl": "public-read-write"}
response = requests.put(oss_url, data=tar, headers=headers)
response.raise_for_status()

dest = f"{prefix}/{name}-{version}.tar"
return StorageStatus(state=True, dest_url=dest)

def _upload_concurrent(
self,
work_dir: str,
name: str,
version: str,
prefix: str,
hub_client: AEnvHubClient,
) -> StorageStatus:
"""Concurrent upload: compress and fetch URL in parallel."""
archive_path = None
oss_url = None

try:
with ThreadPoolExecutor(max_workers=2) as executor:
compress_future = executor.submit(
pack_directory_parallel,
work_dir,
None,
["__pycache__"],
True,
)
url_future = executor.submit(
hub_client.apply_sign_url, name, version
)

archive_path = compress_future.result()
print(f"🔄 Archive: {archive_path}")
oss_url = url_future.result()

with open(archive_path, "rb") as tar:
headers = {"x-oss-object-acl": "public-read-write"}
response = requests.put(oss_url, data=tar, headers=headers)
response.raise_for_status()

dest = f"{prefix}/{name}-{version}.tar"
return StorageStatus(state=True, dest_url=dest)

finally:
if archive_path and os.path.exists(archive_path):
os.unlink(archive_path)
Comment on lines +481 to +482
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The finally block correctly cleans up the temporary archive file. However, it's worth noting that os.unlink can raise an exception (e.g., PermissionError). While unlikely in this context, wrapping it in a try...except block would make the cleanup even more robust and prevent a cleanup failure from masking the original exception if one occurred in the try block.

Suggested change
if archive_path and os.path.exists(archive_path):
os.unlink(archive_path)
if archive_path and os.path.exists(archive_path):
try:
os.unlink(archive_path)
except OSError:
pass



def load_storage():
store_config = get_config_manager().get_storage_config()
Expand Down
196 changes: 196 additions & 0 deletions aenv/src/cli/tests/test_compression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
# Copyright 2025.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for parallel compression utilities."""

import os
import tarfile
import tempfile
from pathlib import Path
from unittest.mock import patch

import pytest

from cli.utils.compression import (
get_pigz_path,
get_cpu_count,
pack_directory_parallel,
_pack_with_tarfile,
)


@pytest.fixture
def temp_source_dir():
"""Create a temporary directory with test files."""
with tempfile.TemporaryDirectory() as tmpdir:
source_dir = Path(tmpdir) / "test_source"
source_dir.mkdir()

(source_dir / "file1.txt").write_text("content1")
(source_dir / "file2.txt").write_text("content2")

subdir = source_dir / "subdir"
subdir.mkdir()
(subdir / "file3.txt").write_text("content3")

pycache = source_dir / "__pycache__"
pycache.mkdir()
(pycache / "cache.pyc").write_text("cached")

yield str(source_dir)


class TestGetPigzPath:
def test_returns_path_when_available(self):
with patch("shutil.which") as mock_which:
mock_which.return_value = "/usr/bin/pigz"
result = get_pigz_path()
assert result == "/usr/bin/pigz"

def test_returns_none_when_not_available(self):
with patch("shutil.which") as mock_which:
mock_which.return_value = None
result = get_pigz_path()
assert result is None


class TestGetCpuCount:
def test_returns_positive_number(self):
count = get_cpu_count()
assert count >= 1

def test_handles_exception(self):
with patch("os.cpu_count", side_effect=Exception("error")):
count = get_cpu_count()
assert count == 1


class TestPackDirectoryParallel:
def test_creates_archive(self, temp_source_dir):
output_path = pack_directory_parallel(temp_source_dir, use_parallel=False)

try:
assert os.path.exists(output_path)
assert output_path.endswith(".tar.gz")

with tarfile.open(output_path, "r:gz") as tar:
names = tar.getnames()
assert any("file1.txt" in n for n in names)
assert any("file2.txt" in n for n in names)
assert any("file3.txt" in n for n in names)
finally:
if os.path.exists(output_path):
os.unlink(output_path)

def test_excludes_patterns(self, temp_source_dir):
output_path = pack_directory_parallel(
temp_source_dir,
exclude_patterns=["__pycache__"],
use_parallel=False,
)

try:
with tarfile.open(output_path, "r:gz") as tar:
names = tar.getnames()
assert not any("__pycache__" in n for n in names)
assert not any("cache.pyc" in n for n in names)
finally:
if os.path.exists(output_path):
os.unlink(output_path)

def test_custom_output_path(self, temp_source_dir):
with tempfile.TemporaryDirectory() as tmpdir:
custom_path = os.path.join(tmpdir, "custom_archive.tar.gz")
output_path = pack_directory_parallel(
temp_source_dir,
output_path=custom_path,
use_parallel=False,
)

assert output_path == custom_path
assert os.path.exists(custom_path)

def test_raises_for_nonexistent_dir(self):
with pytest.raises(FileNotFoundError):
pack_directory_parallel("/nonexistent/path")

def test_falls_back_to_tarfile_when_parallel_disabled(self, temp_source_dir):
with patch.dict(os.environ, {"AENV_DISABLE_PARALLEL": "1"}):
output_path = pack_directory_parallel(temp_source_dir)

try:
assert os.path.exists(output_path)
finally:
if os.path.exists(output_path):
os.unlink(output_path)

def test_uses_pigz_when_available(self, temp_source_dir):
with (
patch("cli.utils.compression.get_pigz_path") as mock_pigz,
patch("cli.utils.compression._pack_with_pigz") as mock_pack_pigz,
patch.dict(os.environ, {}, clear=True),
):
os.environ.pop("AENV_DISABLE_PARALLEL", None)
mock_pigz.return_value = "/usr/bin/pigz"
mock_pack_pigz.return_value = "/tmp/test.tar.gz"

pack_directory_parallel(temp_source_dir)

mock_pack_pigz.assert_called_once()

def test_falls_back_when_pigz_fails(self, temp_source_dir):
with (
patch("cli.utils.compression.get_pigz_path") as mock_pigz,
patch("cli.utils.compression._pack_with_pigz") as mock_pack_pigz,
patch.dict(os.environ, {}, clear=True),
):
os.environ.pop("AENV_DISABLE_PARALLEL", None)
mock_pigz.return_value = "/usr/bin/pigz"
mock_pack_pigz.side_effect = Exception("pigz failed")

output_path = pack_directory_parallel(temp_source_dir)

try:
assert os.path.exists(output_path)
finally:
if os.path.exists(output_path):
os.unlink(output_path)


class TestPackWithTarfile:
def test_creates_valid_archive(self, temp_source_dir):
with tempfile.TemporaryDirectory() as tmpdir:
output_path = os.path.join(tmpdir, "test.tar.gz")
result = _pack_with_tarfile(temp_source_dir, output_path, None)

assert result == output_path
assert os.path.exists(output_path)

with tarfile.open(output_path, "r:gz") as tar:
assert len(tar.getnames()) > 0

def test_respects_exclude_patterns(self, temp_source_dir):
with tempfile.TemporaryDirectory() as tmpdir:
output_path = os.path.join(tmpdir, "test.tar.gz")
_pack_with_tarfile(
temp_source_dir,
output_path,
["__pycache__", "file1"],
)

with tarfile.open(output_path, "r:gz") as tar:
names = tar.getnames()
assert not any("__pycache__" in n for n in names)
assert not any("file1" in n for n in names)
assert any("file2" in n for n in names)
Loading
Loading